gguf_rs/
lib.rs

1/// This module provides functionality for decoding and working with GGUF files.
2///
3/// GGUF files are binary files that contain key-value metadata and tensors.
4/// The `GGUFContainer` struct represents a GGUF file container and provides methods for decoding and accessing the data.
5/// The `GGUFModel` struct represents the decoded GGUF data, including the key-value metadata and tensors.
6/// The `Tensor` struct represents a tensor in the GGUF file, including its name, kind, offset, size, and shape.
7/// The `ByteOrder` enum represents the byte order of the GGUF file (little endian or big endian).
8/// The `Version` enum represents the version of the GGUF file (v1, v2, or v3).
9/// The `MetadataValueType` enum represents the value type of the metadata in the GGUF file.
10/// The `GGMLType` enum represents the GGML type of a tensor in the GGUF file.
11///
12/// Example usage:
13/// ```
14/// use gguf_rs::{get_gguf_container};
15/// use std::fs::File;
16///
17/// fn main() -> Result<(), Box<dyn std::error::Error>> {
18///     let mut container = get_gguf_container("tests/test-le-v3.gguf")?;
19///     let model = container.decode()?;
20///
21///     println!("GGUF version: {}", model.get_version());
22///
23///     for tensor in model.tensors() {
24///         println!("Tensor name: {}", tensor.name);
25///         println!("Tensor kind: {}", tensor.kind);
26///         println!("Tensor shape: {:?}", tensor.shape);
27///     }
28///
29///     Ok(())
30/// }
31/// ```
32use anyhow::{Result, anyhow};
33use byteorder::{BigEndian, LittleEndian, ReadBytesExt};
34#[cfg(feature = "debug")]
35use log::debug;
36use serde::{Deserialize, Serialize};
37use serde_json::Value;
38use std::{borrow::Borrow, collections::BTreeMap, fmt::Display};
39
40/// Magic constant for `ggml` files (unversioned).
41pub const FILE_MAGIC_GGML: i32 = 0x67676d6c;
42/// Magic constant for `ggml` files (versioned, ggmf).
43pub const FILE_MAGIC_GGMF: i32 = 0x67676d66;
44/// Magic constant for `ggml` files (versioned, ggjt).
45pub const FILE_MAGIC_GGJT: i32 = 0x67676a74;
46/// Magic constant for `ggla` files (LoRA adapter).
47pub const FILE_MAGIC_GGLA: i32 = 0x67676C61;
48/// Magic constant for `gguf` files (versioned, gguf)
49pub const FILE_MAGIC_GGUF_LE: i32 = 0x46554747;
50pub const FILE_MAGIC_GGUF_BE: i32 = 0x47475546;
51
52const GGUF_VERSION_V1: i32 = 0x00000001;
53const GGUF_VERSION_V2: i32 = 0x00000002;
54const GGUF_VERSION_V3: i32 = 0x00000003;
55
56const THOUSAND: u64 = 1000;
57const MILLION: u64 = 1_000_000;
58const BILLION: u64 = 1_000_000_000;
59
60/// Convert a number to a human-readable string.
61fn human_number(value: u64) -> String {
62    match value {
63        _ if value > BILLION => format!("{:.0}B", value as f64 / BILLION as f64),
64        _ if value > MILLION => format!("{:.0}M", value as f64 / MILLION as f64),
65        _ if value > THOUSAND => format!("{:.0}K", value as f64 / THOUSAND as f64),
66        _ => format!("{}", value),
67    }
68}
69
70/// Convert a file type to a human-readable string.
71/// GGUF spec: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
72fn file_type(ft: u64) -> String {
73    match ft {
74        0 => "All F32",
75        1 => "Mostly F16",
76        2 => "Mostly Q4_0",
77        3 => "Mostly Q4_1",
78        4 => "Mostly Q4_1 Some F16",
79        5 => "Mostly Q4_2 (UNSUPPORTED)",
80        6 => "Mostly Q4_3 (UNSUPPORTED)",
81        7 => "Mostly Q8_0",
82        8 => "Mostly Q5_0",
83        9 => "Mostly Q5_1",
84        10 => "Mostly Q2_K",
85        11 => "Mostly Q3_K",
86        12 => "Mostly Q4_K",
87        13 => "Mostly Q5_K",
88        14 => "Mostly Q6_K",
89        15 => "Mostly IQ2_XXS",
90        16 => "Mostly IQ2_XS",
91        17 => "Mostly IQ3_XXS",
92        18 => "Mostly IQ1_S",
93        19 => "Mostly IQ4_NL",
94        20 => "Mostly IQ3_S",
95        21 => "Mostly IQ2_S",
96        22 => "Mostly IQ4_XS",
97        23 => "Mostly IQ1_M",
98        24 => "Mostly BF16",
99        _ => "unknown",
100    }
101    .to_string()
102}
103
104/// Byte order of the GGUF file.
105#[derive(Default, Debug, Clone)]
106pub enum ByteOrder {
107    #[default]
108    LE,
109    BE,
110}
111
112/// Version of the GGUF file.
113#[derive(Debug, Clone)]
114pub enum Version {
115    V1(V1),
116    V2(V2),
117    V3(V3),
118}
119
120/// Version 1 of the GGUF file.
121#[derive(Debug, Deserialize, Default, Clone)]
122pub struct V1 {
123    num_tensor: u32,
124    num_kv: u32,
125}
126
127/// Version 2 of the GGUF file.
128#[derive(Debug, Deserialize, Default, Clone)]
129pub struct V2 {
130    num_tensor: u64,
131    num_kv: u64,
132}
133
134/// Version 3 of the GGUF file.
135#[derive(Debug, Deserialize, Default, Clone)]
136pub struct V3 {
137    num_tensor: u64,
138    num_kv: u64,
139}
140
141/// GGUF file container.
142pub struct GGUFContainer {
143    bo: ByteOrder,
144    version: Version,
145    reader: Box<dyn std::io::Read + 'static>,
146}
147
148impl GGUFContainer {
149    /// Create a new `GGUFContainer` from a byte order and a reader.
150    /// The reader must implement the `std::io::Read` trait.
151    /// ```
152    /// use gguf_rs::{get_gguf_container};
153    /// use std::fs::File;
154    ///
155    /// fn main() -> Result<(), Box<dyn std::error::Error>> {
156    ///     let mut container = get_gguf_container("tests/test-le-v3.gguf")?;
157    ///     let model = container.decode()?;
158    ///
159    ///     println!("GGUF version: {}", model.get_version());
160    ///
161    ///     for tensor in model.tensors() {
162    ///         println!("Tensor name: {}", tensor.name);
163    ///         println!("Tensor kind: {}", tensor.kind);
164    ///         println!("Tensor shape: {:?}", tensor.shape);
165    ///     }
166    ///
167    ///     Ok(())
168    /// }
169    pub fn new(bo: ByteOrder, reader: Box<dyn std::io::Read>) -> Self {
170        Self {
171            bo,
172            version: Version::V1(V1::default()),
173            reader,
174        }
175    }
176
177    /// Get the version of the GGUF file.
178    pub fn get_version(&self) -> String {
179        match &self.version {
180            Version::V1(_) => String::from("v1"),
181            Version::V2(_) => String::from("v2"),
182            Version::V3(_) => String::from("v3"),
183        }
184    }
185
186    /// Decode the GGUF file and return a `GGUFModel`.
187    pub fn decode(&mut self) -> Result<GGUFModel> {
188        let version = match self.bo {
189            ByteOrder::LE => self.reader.read_i32::<LittleEndian>()?,
190            ByteOrder::BE => self.reader.read_i32::<BigEndian>()?,
191        };
192
193        #[cfg(feature = "debug")]
194        {
195            debug!("version {}", version);
196        }
197
198        match version {
199            GGUF_VERSION_V1 => {
200                let mut buffer: [u32; 2] = [0; 2];
201                match self.bo {
202                    ByteOrder::LE => self.reader.read_u32_into::<LittleEndian>(&mut buffer)?,
203                    ByteOrder::BE => self.reader.read_u32_into::<BigEndian>(&mut buffer)?,
204                };
205
206                self.version = Version::V1(V1 {
207                    num_tensor: buffer[0],
208                    num_kv: buffer[1],
209                });
210            }
211            GGUF_VERSION_V2 | GGUF_VERSION_V3 => {
212                let mut buffer: [u64; 2] = [0; 2];
213                match self.bo {
214                    ByteOrder::LE => self.reader.read_u64_into::<LittleEndian>(&mut buffer)?,
215                    ByteOrder::BE => self.reader.read_u64_into::<BigEndian>(&mut buffer)?,
216                };
217
218                if version == GGUF_VERSION_V2 {
219                    self.version = Version::V2(V2 {
220                        num_tensor: buffer[0],
221                        num_kv: buffer[1],
222                    });
223                } else {
224                    self.version = Version::V3(V3 {
225                        num_tensor: buffer[0],
226                        num_kv: buffer[1],
227                    });
228                }
229            }
230            invalid_version => {
231                return Err(anyhow!(
232                    "invalid version {}, only support version: 1 | 2 | 3",
233                    invalid_version
234                ));
235            }
236        };
237
238        let mut model = GGUFModel {
239            kv: BTreeMap::new(),
240            tensors: Vec::new(),
241            parameters: 0,
242
243            bo: self.bo.clone(),
244            version: self.version.clone(),
245        };
246
247        model.decode(&mut self.reader)?;
248        Ok(model)
249    }
250}
251
252/// Tensor in the GGUF file.
253#[derive(Debug, Clone)]
254pub struct Tensor {
255    pub name: String,
256    pub kind: u32,
257    pub offset: u64,
258    pub size: u64,
259    // shape is the number of elements in each dimension
260    pub shape: Vec<u64>,
261}
262
263/// GGUF model.
264pub struct GGUFModel {
265    kv: BTreeMap<String, Value>,
266    tensors: Vec<Tensor>,
267    parameters: u64,
268
269    bo: ByteOrder,
270    version: Version,
271}
272
273#[derive(Debug)]
274pub enum MetadataValueType {
275    Uint8 = 0,
276    Int8 = 1,
277    Uint16 = 2,
278    Int16 = 3,
279    Uint32 = 4,
280    Int32 = 5,
281    Float32 = 6,
282    Bool = 7,
283    String = 8,
284    Array = 9,
285    Uint64 = 10,
286    Int64 = 11,
287    Float64 = 12,
288}
289
290impl TryFrom<u32> for MetadataValueType {
291    type Error = anyhow::Error;
292
293    fn try_from(value: u32) -> Result<Self, Self::Error> {
294        Ok(match value {
295            0 => MetadataValueType::Uint8,
296            1 => MetadataValueType::Int8,
297            2 => MetadataValueType::Uint16,
298            3 => MetadataValueType::Int16,
299            4 => MetadataValueType::Uint32,
300            5 => MetadataValueType::Int32,
301            6 => MetadataValueType::Float32,
302            7 => MetadataValueType::Bool,
303            8 => MetadataValueType::String,
304            9 => MetadataValueType::Array,
305            10 => MetadataValueType::Uint64,
306            11 => MetadataValueType::Int64,
307            12 => MetadataValueType::Float64,
308            _ => return Err(anyhow!("unsupport metadata value type")),
309        })
310    }
311}
312
313/// GGML type of a tensor in the GGUF file.
314#[derive(Debug, Serialize)]
315pub enum GGMLType {
316    F32 = 0,
317    F16 = 1,
318    Q4_0 = 2,
319    Q4_1 = 3,
320    Q4_2 = 4, // Unsupported
321    Q4_3 = 5, // Unsupported
322    Q5_0 = 6,
323    Q5_1 = 7,
324    Q8_0 = 8,
325    Q8_1 = 9,
326    Q2_K = 10,
327    Q3_K = 11,
328    Q4_K = 12,
329    Q5_K = 13,
330    Q6_K = 14,
331    Q8_K = 15,
332    IQ2_XXS = 16,
333    IQ2_XS = 17,
334    IQ3_XXS = 18,
335    IQ1_S = 19,
336    IQ4_NL = 20,
337    IQ3_S = 21,
338    IQ2_S = 22,
339    IQ4_XS = 23,
340    I8 = 24,
341    I16 = 25,
342    I32 = 26,
343    I64 = 27,
344    F64 = 28,
345    IQ1_M = 29,
346    BF16 = 30,
347    Q4_0_4_4 = 31, // Unsupported
348    Q4_0_4_8 = 32, // Unsupported
349    Q4_0_8_8 = 33, // Unsupported
350    TQ1_0 = 34,
351    TQ2_0 = 35,
352    IQ4_NL_4_4 = 36, // Unsupported
353    IQ4_NL_4_8 = 37, // Unsupported
354    IQ4_NL_8_8 = 38, // Unsupported
355    Count = 39,
356}
357
358impl Display for GGMLType {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        match self {
361            GGMLType::F32 => write!(f, "F32"),
362            GGMLType::F16 => write!(f, "F16"),
363            GGMLType::Q4_0 => write!(f, "Q4_0"),
364            GGMLType::Q4_1 => write!(f, "Q4_1"),
365            GGMLType::Q4_2 => write!(f, "Q4_2 (UNSUPPORTED)"),
366            GGMLType::Q4_3 => write!(f, "Q4_3 (UNSUPPORTED)"),
367            GGMLType::Q5_0 => write!(f, "Q5_0"),
368            GGMLType::Q5_1 => write!(f, "Q5_1"),
369            GGMLType::Q8_0 => write!(f, "Q8_0"),
370            GGMLType::Q8_1 => write!(f, "Q8_1"),
371            GGMLType::Q2_K => write!(f, "Q2_K"),
372            GGMLType::Q3_K => write!(f, "Q3_K"),
373            GGMLType::Q4_K => write!(f, "Q4_K"),
374            GGMLType::Q5_K => write!(f, "Q5_K"),
375            GGMLType::Q6_K => write!(f, "Q6_K"),
376            GGMLType::Q8_K => write!(f, "Q8_K"),
377            GGMLType::IQ2_XXS => write!(f, "IQ2_XXS"),
378            GGMLType::IQ2_XS => write!(f, "IQ2_XS"),
379            GGMLType::IQ3_XXS => write!(f, "IQ3_XXS"),
380            GGMLType::IQ1_S => write!(f, "IQ1_S"),
381            GGMLType::IQ4_NL => write!(f, "IQ4_NL"),
382            GGMLType::IQ3_S => write!(f, "IQ3_S"),
383            GGMLType::IQ2_S => write!(f, "IQ2_S"),
384            GGMLType::IQ4_XS => write!(f, "IQ4_XS"),
385            GGMLType::I8 => write!(f, "I8"),
386            GGMLType::I16 => write!(f, "I16"),
387            GGMLType::I32 => write!(f, "I32"),
388            GGMLType::I64 => write!(f, "I64"),
389            GGMLType::F64 => write!(f, "F64"),
390            GGMLType::IQ1_M => write!(f, "IQ1_M"),
391            GGMLType::BF16 => write!(f, "BF16"),
392            GGMLType::Q4_0_4_4 => write!(f, "Q4_0_4_4 (UNSUPPORTED)"),
393            GGMLType::Q4_0_4_8 => write!(f, "Q4_0_4_8 (UNSUPPORTED)"),
394            GGMLType::Q4_0_8_8 => write!(f, "Q4_0_8_8 (UNSUPPORTED)"),
395            GGMLType::TQ1_0 => write!(f, "TQ1_0"),
396            GGMLType::TQ2_0 => write!(f, "TQ2_0"),
397            GGMLType::IQ4_NL_4_4 => write!(f, "IQ4_NL_4_4 (UNSUPPORTED)"),
398            GGMLType::IQ4_NL_4_8 => write!(f, "IQ4_NL_4_8 (UNSUPPORTED)"),
399            GGMLType::IQ4_NL_8_8 => write!(f, "IQ4_NL_8_8 (UNSUPPORTED)"),
400            GGMLType::Count => write!(f, "Count"),
401        }
402    }
403}
404
405impl TryFrom<u32> for GGMLType {
406    type Error = anyhow::Error;
407
408    fn try_from(value: u32) -> std::prelude::v1::Result<Self, Self::Error> {
409        Ok(match value {
410            0 => GGMLType::F32,
411            1 => GGMLType::F16,
412            2 => GGMLType::Q4_0,
413            3 => GGMLType::Q4_1,
414            6 => GGMLType::Q5_0,
415            7 => GGMLType::Q5_1,
416            8 => GGMLType::Q8_0,
417            9 => GGMLType::Q8_1,
418            10 => GGMLType::Q2_K,
419            11 => GGMLType::Q3_K,
420            12 => GGMLType::Q4_K,
421            13 => GGMLType::Q5_K,
422            14 => GGMLType::Q6_K,
423            15 => GGMLType::Q8_K,
424            16 => GGMLType::IQ2_XXS,
425            17 => GGMLType::IQ2_XS,
426            18 => GGMLType::IQ3_XXS,
427            19 => GGMLType::IQ1_S,
428            20 => GGMLType::IQ4_NL,
429            21 => GGMLType::IQ3_S,
430            22 => GGMLType::IQ2_S,
431            23 => GGMLType::IQ4_XS,
432            24 => GGMLType::I8,
433            25 => GGMLType::I16,
434            26 => GGMLType::I32,
435            27 => GGMLType::I64,
436            28 => GGMLType::F64,
437            29 => GGMLType::IQ1_M,
438            30 => GGMLType::BF16,
439            31 => GGMLType::Q4_0_4_4,
440            32 => GGMLType::Q4_0_4_8,
441            33 => GGMLType::Q4_0_8_8,
442            34 => GGMLType::TQ1_0,
443            35 => GGMLType::TQ2_0,
444            36 => GGMLType::IQ4_NL_4_4,
445            37 => GGMLType::IQ4_NL_4_8,
446            38 => GGMLType::IQ4_NL_8_8,
447            39 => GGMLType::Count,
448            _ => return Err(anyhow!("invalid GGML type")),
449        })
450    }
451}
452
453impl GGUFModel {
454    /// Decode the GGUF file.
455    pub(crate) fn decode(&mut self, mut reader: impl std::io::Read) -> Result<()> {
456        // decode kv
457        for _i in 0..self.num_kv() {
458            let key = self.read_string(&mut reader)?;
459            let value_type: MetadataValueType = self.read_u32(&mut reader)?.try_into()?;
460            let value = match value_type {
461                MetadataValueType::Uint8 => Value::from(self.read_u8(&mut reader)?),
462                MetadataValueType::Int8 => Value::from(self.read_i8(&mut reader)?),
463                MetadataValueType::Uint16 => Value::from(self.read_u16(&mut reader)?),
464                MetadataValueType::Int16 => Value::from(self.read_i16(&mut reader)?),
465                MetadataValueType::Uint32 => Value::from(self.read_u32(&mut reader)?),
466                MetadataValueType::Int32 => Value::from(self.read_i32(&mut reader)?),
467                MetadataValueType::Float32 => Value::from(self.read_f32(&mut reader)?),
468                MetadataValueType::Bool => Value::from(self.read_bool(&mut reader)?),
469                MetadataValueType::String => Value::from(self.read_string(&mut reader)?),
470                MetadataValueType::Array => Value::from(self.read_array(&mut reader, 3)?),
471                MetadataValueType::Uint64 => Value::from(self.read_u64(&mut reader)?),
472                MetadataValueType::Int64 => Value::from(self.read_i64(&mut reader)?),
473                MetadataValueType::Float64 => Value::from(self.read_f64(&mut reader)?),
474            };
475            #[cfg(feature = "debug")]
476            {
477                debug!(
478                    "kv [{}] vtype {:?} key={}, value={}",
479                    _i, value_type, key, value
480                );
481            }
482            self.kv.insert(key, value);
483        }
484
485        // decode tensors
486        for _ in 0..self.num_tensor() {
487            let name = self.read_string(&mut reader)?;
488            let dims = self.read_u32(&mut reader)?;
489            let mut shape = [1; 4];
490            for i in 0..dims {
491                shape[i as usize] = self.read_u64(&mut reader)?;
492            }
493
494            let kind = self.read_u32(&mut reader)?;
495            let offset = self.read_u64(&mut reader)?;
496            let block_size = match kind {
497                _ if kind < 2 => 1,
498                _ if kind < 10 => 32,
499                _ => 256,
500            };
501            let ggml_type_kind: GGMLType = kind.try_into()?;
502            let type_size = match ggml_type_kind {
503                GGMLType::F32 => 4,
504                GGMLType::F16 => 2,
505                GGMLType::Q4_0 => 2 + block_size / 2,
506                GGMLType::Q4_1 => 2 + 2 + block_size / 2,
507                GGMLType::Q4_2 => 0,
508                GGMLType::Q4_3 => 0,
509                GGMLType::Q5_0 => 2 + 4 + block_size / 2,
510                GGMLType::Q5_1 => 2 + 2 + 4 + block_size / 2,
511                GGMLType::Q8_0 => 2 + block_size,
512                GGMLType::Q8_1 => 4 + 4 + block_size,
513                GGMLType::Q2_K => block_size / 16 + block_size / 4 + 2 + 2,
514                GGMLType::Q3_K => block_size / 8 + block_size / 4 + 12 + 2,
515                GGMLType::Q4_K => 2 + 2 + 12 + block_size / 2,
516                GGMLType::Q5_K => 2 + 2 + 12 + block_size / 8 + block_size / 2,
517                GGMLType::Q6_K => block_size / 2 + block_size / 4 + block_size / 16 + 2,
518                GGMLType::Q8_K => 4 + block_size + block_size / 16 * 2,
519                GGMLType::IQ2_XXS => 2 + block_size / 8 * 2,
520                GGMLType::IQ2_XS => 2 + block_size / 8 * 2 + block_size / 32,
521                GGMLType::IQ3_XXS => 2 + 3 * (block_size / 8),
522                GGMLType::IQ1_S => 2 + block_size / 8 + block_size / 16,
523                GGMLType::IQ4_NL => 2 + 16,
524                GGMLType::IQ3_S => 2 + 13 * (block_size / 32) + block_size / 64,
525                GGMLType::IQ2_S => 2 + block_size / 4 + block_size / 16,
526                GGMLType::IQ4_XS => 2 + 2 + block_size / 64 + block_size / 2,
527                GGMLType::I8 => 1,
528                GGMLType::I16 => 2,
529                GGMLType::I32 => 4,
530                GGMLType::I64 => 8,
531                GGMLType::F64 => 8,
532                GGMLType::IQ1_M => block_size / 8 + block_size / 16 + block_size / 32,
533                GGMLType::BF16 => 2,
534                GGMLType::IQ4_NL_4_4 => 0,
535                GGMLType::IQ4_NL_4_8 => 0,
536                GGMLType::IQ4_NL_8_8 => 0,
537                GGMLType::TQ1_0 => 2 + block_size / 64 + (block_size - 4 * block_size / 64) / 5,
538                GGMLType::TQ2_0 => 2 + block_size / 4,
539                GGMLType::Q4_0_4_4 => 0,
540                GGMLType::Q4_0_4_8 => 0,
541                GGMLType::Q4_0_8_8 => 0,
542                GGMLType::Count => unreachable!("GGMLType::Count is not a real data format"),
543            };
544
545            let parameters = shape[0] * shape[1] * shape[2] * shape[3];
546            let size = parameters * type_size / block_size;
547
548            self.tensors.push(Tensor {
549                name,
550                kind,
551                offset,
552                size,
553                shape: shape.to_vec(),
554            });
555
556            self.parameters += parameters;
557        }
558
559        Ok(())
560    }
561
562    fn read_u8(&self, mut reader: impl std::io::Read) -> Result<u8> {
563        Ok(reader.read_u8()?)
564    }
565
566    fn read_u32(&self, mut reader: impl std::io::Read) -> Result<u32> {
567        Ok(match self.bo {
568            ByteOrder::LE => reader.read_u32::<LittleEndian>()?,
569            ByteOrder::BE => reader.read_u32::<BigEndian>()?,
570        })
571    }
572
573    fn read_f32(&self, mut reader: impl std::io::Read) -> Result<f32> {
574        Ok(match self.bo {
575            ByteOrder::LE => reader.read_f32::<LittleEndian>()?,
576            ByteOrder::BE => reader.read_f32::<BigEndian>()?,
577        })
578    }
579
580    fn read_f64(&self, mut reader: impl std::io::Read) -> Result<f64> {
581        Ok(match self.bo {
582            ByteOrder::LE => reader.read_f64::<LittleEndian>()?,
583            ByteOrder::BE => reader.read_f64::<BigEndian>()?,
584        })
585    }
586
587    fn read_u64(&self, mut reader: impl std::io::Read) -> Result<u64> {
588        Ok(match self.bo {
589            ByteOrder::LE => reader.read_u64::<LittleEndian>()?,
590            ByteOrder::BE => reader.read_u64::<BigEndian>()?,
591        })
592    }
593
594    fn read_i8(&self, mut reader: impl std::io::Read) -> Result<i8> {
595        Ok(reader.read_i8()?)
596    }
597
598    fn read_u16(&self, mut reader: impl std::io::Read) -> Result<u16> {
599        Ok(match self.bo {
600            ByteOrder::LE => reader.read_u16::<LittleEndian>()?,
601            ByteOrder::BE => reader.read_u16::<BigEndian>()?,
602        })
603    }
604
605    fn read_i16(&self, mut reader: impl std::io::Read) -> Result<i16> {
606        Ok(match self.bo {
607            ByteOrder::LE => reader.read_i16::<LittleEndian>()?,
608            ByteOrder::BE => reader.read_i16::<BigEndian>()?,
609        })
610    }
611
612    fn read_i32(&self, mut reader: impl std::io::Read) -> Result<i32> {
613        Ok(match self.bo {
614            ByteOrder::LE => reader.read_i32::<LittleEndian>()?,
615            ByteOrder::BE => reader.read_i32::<BigEndian>()?,
616        })
617    }
618
619    fn read_i64(&self, mut reader: impl std::io::Read) -> Result<i64> {
620        Ok(match self.bo {
621            ByteOrder::LE => reader.read_i64::<LittleEndian>()?,
622            ByteOrder::BE => reader.read_i64::<BigEndian>()?,
623        })
624    }
625
626    fn read_bool(&self, mut reader: impl std::io::Read) -> Result<bool> {
627        Ok(reader.read_u8()? != 0)
628    }
629
630    fn read_string(&self, mut reader: impl std::io::Read) -> Result<String> {
631        let name_len = self.read_version_size(&mut reader)?;
632        let mut buffer = vec![0; name_len as usize];
633        reader.read_exact(&mut buffer)?;
634        Ok(String::from_utf8_lossy(&buffer).to_string())
635    }
636
637    fn read_array(&self, mut reader: impl std::io::Read, read_count: usize) -> Result<Vec<Value>> {
638        let mut data = Vec::new();
639        let item_type: MetadataValueType = self.read_u32(&mut reader)?.try_into()?;
640        let array_len = self.read_version_size(&mut reader)?;
641        for _ in 0..array_len {
642            let value = match item_type {
643                MetadataValueType::Uint8 => Value::from(self.read_u8(&mut reader)?),
644                MetadataValueType::Int8 => Value::from(self.read_i8(&mut reader)?),
645                MetadataValueType::Uint16 => Value::from(self.read_u16(&mut reader)?),
646                MetadataValueType::Int16 => Value::from(self.read_i16(&mut reader)?),
647                MetadataValueType::Uint32 => Value::from(self.read_u32(&mut reader)?),
648                MetadataValueType::Int32 => Value::from(self.read_i32(&mut reader)?),
649                MetadataValueType::Float32 => Value::from(self.read_f32(&mut reader)?),
650                MetadataValueType::Bool => Value::from(self.read_bool(&mut reader)?),
651                MetadataValueType::String => Value::from(self.read_string(&mut reader)?),
652                MetadataValueType::Uint64 => Value::from(self.read_u64(&mut reader)?),
653                MetadataValueType::Int64 => Value::from(self.read_i64(&mut reader)?),
654                MetadataValueType::Float64 => Value::from(self.read_f64(&mut reader)?),
655                _ => return Err(anyhow!("Unsupport item value type: Array")),
656            };
657
658            if read_count > 0 && data.len() < read_count {
659                data.push(value);
660            }
661        }
662
663        Ok(data)
664    }
665
666    fn read_version_size(&self, mut reader: impl std::io::Read) -> Result<u64> {
667        Ok(match self.version.borrow() {
668            Version::V1(_) => self.read_u32(&mut reader)? as u64,
669            Version::V2(_) => self.read_u64(&mut reader)?,
670            Version::V3(_) => self.read_u64(&mut reader)?,
671        })
672    }
673
674    /// Get the version of the GGUF file.
675    pub fn get_version(&self) -> String {
676        match &self.version {
677            Version::V1(_) => String::from("v1"),
678            Version::V2(_) => String::from("v2"),
679            Version::V3(_) => String::from("v3"),
680        }
681    }
682
683    /// Get the number of key-value pairs in the GGUF file.
684    pub fn num_kv(&self) -> u64 {
685        match &self.version {
686            Version::V1(v1) => v1.num_kv as u64,
687            Version::V2(v2) => v2.num_kv,
688            Version::V3(v3) => v3.num_kv,
689        }
690    }
691
692    /// Get the number of tensors in the GGUF file.
693    pub fn num_tensor(&self) -> u64 {
694        match &self.version {
695            Version::V1(v1) => v1.num_tensor as u64,
696            Version::V2(v2) => v2.num_tensor,
697            Version::V3(v3) => v3.num_tensor,
698        }
699    }
700
701    /// Get the model family of the GGUF file.
702    pub fn model_family(&self) -> String {
703        let arch = self
704            .kv
705            .get("general.architecture")
706            .cloned()
707            .unwrap_or(Value::from("unknown"));
708
709        match arch {
710            Value::String(arch) => arch,
711            _ => String::from("unknown"),
712        }
713    }
714
715    /// Get the number of parameters in the GGUF file.
716    pub fn model_parameters(&self) -> String {
717        if self.parameters > 0 {
718            human_number(self.parameters)
719        } else {
720            String::from("unknown")
721        }
722    }
723
724    /// Get the file type of the GGUF file.
725    pub fn file_type(&self) -> String {
726        if let Some(ft) = self.kv.get("general.file_type") {
727            file_type(ft.as_u64().unwrap())
728        } else {
729            String::from("unknown")
730        }
731    }
732
733    /// Get the key-value metadata of the GGUF file.
734    pub fn metadata(&self) -> &BTreeMap<String, Value> {
735        &self.kv
736    }
737
738    /// Get the tensors of the GGUF file.
739    pub fn tensors(&self) -> &Vec<Tensor> {
740        &self.tensors
741    }
742}
743
744/// Get a `GGUFContainer` from a file.
745pub fn get_gguf_container(file: &str) -> Result<GGUFContainer> {
746    if !std::path::Path::new(file).exists() {
747        return Err(anyhow!("file not found"));
748    }
749
750    let mut reader = std::fs::File::open(file)?;
751    let byte_le = reader.read_i32::<LittleEndian>()?;
752    match byte_le {
753        FILE_MAGIC_GGML => Err(anyhow!("unsupport ggml format")),
754        FILE_MAGIC_GGMF => Err(anyhow!("unsupport ggmf format")),
755        FILE_MAGIC_GGJT => Err(anyhow!("unsupport ggjt format")),
756        FILE_MAGIC_GGLA => Err(anyhow!("unsupport ggla format")),
757        FILE_MAGIC_GGUF_LE => Ok(GGUFContainer::new(ByteOrder::LE, Box::new(reader))),
758        FILE_MAGIC_GGUF_BE => Ok(GGUFContainer::new(ByteOrder::BE, Box::new(reader))),
759        _ => Err(anyhow!("invalid file magic")),
760    }
761}
762
763#[cfg(test)]
764mod tests {
765    use serde_json::json;
766
767    #[test]
768    fn test_read_le_v3_gguf() {
769        let mut container = super::get_gguf_container("tests/test-le-v3.gguf").unwrap();
770        let model = container.decode().unwrap();
771        assert_eq!(model.get_version(), "v3");
772        assert_eq!(model.model_family(), "llama");
773        assert_eq!(model.file_type(), "unknown");
774        assert_eq!(model.model_parameters(), "192");
775        assert_eq!(
776            serde_json::to_value(model.kv).unwrap(),
777            json!({
778                "general.architecture": "llama", "llama.block_count": 12, "general.alignment": 64, "answer": 42, "answer_in_float": 42.0
779            })
780        );
781    }
782}