candle_core/quantized/
ggml_file.rs

1//! Support for the GGML file format.
2
3use super::{k_quants, GgmlDType, QStorage};
4use crate::{Device, Result};
5use byteorder::{LittleEndian, ReadBytesExt};
6use std::collections::HashMap;
7
8// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.h#L37
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10enum Magic {
11    Ggjt,
12    Ggla,
13    Ggmf,
14    Ggml,
15    Ggsn,
16}
17
18impl TryFrom<u32> for Magic {
19    type Error = crate::Error;
20    fn try_from(value: u32) -> Result<Self> {
21        let magic = match value {
22            0x67676a74 => Self::Ggjt,
23            0x67676c61 => Self::Ggla,
24            0x67676d66 => Self::Ggmf,
25            0x67676d6c => Self::Ggml,
26            0x6767736e => Self::Ggsn,
27            _ => crate::bail!("unknown magic {value:08x}"),
28        };
29        Ok(magic)
30    }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum VersionedMagic {
35    GgmlUnversioned,
36    GgmfV1,
37    GgjtV1,
38    GgjtV2,
39    GgjtV3,
40}
41
42impl VersionedMagic {
43    fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
44        let magic = reader.read_u32::<LittleEndian>()?;
45        let magic = Magic::try_from(magic)?;
46        if magic == Magic::Ggml {
47            return Ok(Self::GgmlUnversioned);
48        }
49        let version = reader.read_u32::<LittleEndian>()?;
50        let versioned_magic = match (magic, version) {
51            (Magic::Ggmf, 1) => Self::GgmfV1,
52            (Magic::Ggjt, 1) => Self::GgjtV1,
53            (Magic::Ggjt, 2) => Self::GgjtV2,
54            (Magic::Ggjt, 3) => Self::GgjtV3,
55            _ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
56        };
57        Ok(versioned_magic)
58    }
59
60    fn align32(&self) -> bool {
61        match self {
62            Self::GgmlUnversioned | Self::GgmfV1 => false,
63            Self::GgjtV1 | Self::GgjtV2 | Self::GgjtV3 => true,
64        }
65    }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct HParams {
70    pub n_vocab: u32,
71    pub n_embd: u32,
72    pub n_mult: u32,
73    pub n_head: u32,
74    pub n_layer: u32,
75    pub n_rot: u32,
76    pub ftype: u32,
77}
78
79impl HParams {
80    fn read<R: std::io::Read>(reader: &mut R) -> Result<Self> {
81        let n_vocab = reader.read_u32::<LittleEndian>()?;
82        let n_embd = reader.read_u32::<LittleEndian>()?;
83        let n_mult = reader.read_u32::<LittleEndian>()?;
84        let n_head = reader.read_u32::<LittleEndian>()?;
85        let n_layer = reader.read_u32::<LittleEndian>()?;
86        let n_rot = reader.read_u32::<LittleEndian>()?;
87        let ftype = reader.read_u32::<LittleEndian>()?;
88        Ok(Self {
89            n_vocab,
90            n_embd,
91            n_mult,
92            n_head,
93            n_layer,
94            n_rot,
95            ftype,
96        })
97    }
98}
99
100#[derive(Debug, Clone, PartialEq)]
101pub struct Vocab {
102    pub token_score_pairs: Vec<(Vec<u8>, f32)>,
103}
104
105impl Vocab {
106    fn read<R: std::io::Read>(reader: &mut R, n_vocab: usize) -> Result<Self> {
107        // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L556
108        let mut token_score_pairs = Vec::with_capacity(n_vocab);
109        for _index in 0..n_vocab {
110            let len = reader.read_u32::<LittleEndian>()? as usize;
111            let mut word = vec![0u8; len];
112            reader.read_exact(&mut word)?;
113            let score = reader.read_f32::<LittleEndian>()?;
114            token_score_pairs.push((word, score))
115        }
116        Ok(Self { token_score_pairs })
117    }
118}
119
120fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
121    raw_data: &[u8],
122    size_in_bytes: usize,
123    dims: Vec<usize>,
124    device: &Device,
125) -> Result<super::QTensor> {
126    let raw_data_ptr = raw_data.as_ptr();
127    let n_blocks = size_in_bytes / std::mem::size_of::<T>();
128    let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
129    let data: QStorage = match device {
130        Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
131        Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
132        Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
133    };
134    super::QTensor::new(data, dims)
135}
136
137/// Creates a [Tensor] from a raw GGML tensor.
138pub fn qtensor_from_ggml(
139    ggml_dtype: GgmlDType,
140    raw_data: &[u8],
141    dims: Vec<usize>,
142    device: &Device,
143) -> Result<super::QTensor> {
144    let tensor_elems = dims.iter().product::<usize>();
145    let block_size = ggml_dtype.block_size();
146    if tensor_elems % block_size != 0 {
147        crate::bail!(
148            "the number of elements {tensor_elems} is not divisible by the block size {block_size}"
149        )
150    }
151    let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size();
152
153    match ggml_dtype {
154        GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
155        GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
156        GgmlDType::BF16 => from_raw_data::<half::bf16>(raw_data, size_in_bytes, dims, device),
157        GgmlDType::Q4_0 => {
158            from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
159        }
160        GgmlDType::Q4_1 => {
161            from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
162        }
163        GgmlDType::Q5_0 => {
164            from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
165        }
166        GgmlDType::Q5_1 => {
167            from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
168        }
169        GgmlDType::Q8_0 => {
170            from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
171        }
172        GgmlDType::Q2K => {
173            from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
174        }
175        GgmlDType::Q3K => {
176            from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
177        }
178        GgmlDType::Q4K => {
179            from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
180        }
181        GgmlDType::Q5K => {
182            from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
183        }
184        GgmlDType::Q6K => {
185            from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
186        }
187        _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
188    }
189}
190
191fn read_one_tensor<R: std::io::Seek + std::io::Read>(
192    reader: &mut R,
193    magic: VersionedMagic,
194    device: &Device,
195) -> Result<(String, super::QTensor)> {
196    let n_dims = reader.read_u32::<LittleEndian>()?;
197    let name_len = reader.read_u32::<LittleEndian>()?;
198    let ggml_dtype = reader.read_u32::<LittleEndian>()?;
199    let ggml_dtype = GgmlDType::from_u32(ggml_dtype)?;
200    let mut dims = vec![0u32; n_dims as usize];
201    reader.read_u32_into::<LittleEndian>(&mut dims)?;
202    // The dimensions are stored in reverse order, see for example:
203    // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/convert.py#L969
204    dims.reverse();
205    let mut name = vec![0u8; name_len as usize];
206    reader.read_exact(&mut name)?;
207    let name = String::from_utf8_lossy(&name).into_owned();
208
209    if magic.align32() {
210        let pos = reader.stream_position()?;
211        reader.seek(std::io::SeekFrom::Current(((32 - pos % 32) % 32) as i64))?;
212    }
213    let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
214    let tensor_elems = dims.iter().product::<usize>();
215    let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size();
216    // TODO: Mmap version to avoid copying the data around?
217    let mut raw_data = vec![0u8; size_in_bytes];
218    reader.read_exact(&mut raw_data)?;
219    match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
220        Ok(tensor) => Ok((name, tensor)),
221        Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
222    }
223}
224
225pub struct Content {
226    pub magic: VersionedMagic,
227    pub hparams: HParams,
228    pub vocab: Vocab,
229    pub tensors: HashMap<String, super::QTensor>,
230    pub device: Device,
231}
232
233impl Content {
234    pub fn read<R: std::io::Seek + std::io::Read>(
235        reader: &mut R,
236        device: &Device,
237    ) -> Result<Content> {
238        // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
239        let last_position = reader.seek(std::io::SeekFrom::End(0))?;
240        reader.seek(std::io::SeekFrom::Start(0))?;
241        let magic = VersionedMagic::read(reader)?;
242        let hparams = HParams::read(reader)?;
243        let vocab = Vocab::read(reader, hparams.n_vocab as usize)?;
244        let mut tensors = HashMap::new();
245
246        while reader.stream_position()? != last_position {
247            let (name, tensor) = read_one_tensor(reader, magic, device)?;
248            tensors.insert(name, tensor);
249        }
250        let device = device.clone();
251        Ok(Self {
252            magic,
253            hparams,
254            vocab,
255            tensors,
256            device,
257        })
258    }
259
260    pub fn remove(&mut self, name: &str) -> Result<super::QTensor> {
261        match self.tensors.remove(name) {
262            None => crate::bail!("cannot find tensor with name '{name}'"),
263            Some(tensor) => Ok(tensor),
264        }
265    }
266}