llm_base/
loader.rs

1use std::{
2    collections::HashMap,
3    fmt::{Display, Formatter},
4    fs::File,
5    io::{BufRead, BufReader, Read, Seek, SeekFrom},
6    path::{Path, PathBuf},
7};
8
9use crate::{
10    util::{self, FindAllModelFilesError},
11    Hyperparameters, KnownModel, ModelParameters, TokenId, Vocabulary,
12};
13pub use ggml::ContainerType;
14use ggml::{
15    format::{LoadError as FormatLoadError, PartialHyperparameters, TensorLoadInfo},
16    Context,
17};
18use memmap2::Mmap;
19use thiserror::Error;
20
21/// How the tensors are stored in GGML LLM models.
22#[derive(Debug, PartialEq, Clone, Copy, Eq, Default)]
23pub enum FileType {
24    /// All tensors are stored as f32.
25    F32,
26    #[default]
27    /// All tensors are mostly stored as `f16`, except for the 1D tensors (32-bit).
28    MostlyF16,
29    /// All tensors are mostly stored as `Q4_0`, except for the 1D tensors (32-bit).
30    MostlyQ4_0,
31    /// All tensors are mostly stored as `Q4_1`, except for the 1D tensors (32-bit)
32    MostlyQ4_1,
33    /// All tensors are mostly stored as `Q4_1`, except for the 1D tensors (32-bit)
34    /// and the `tok_embeddings.weight` (f16) and `output.weight` tensors (f16).
35    MostlyQ4_1SomeF16,
36    /// All tensors are mostly stored as `Q4_2`, except for the 1D tensors (32-bit).
37    MostlyQ4_2,
38    /// All tensors are mostly stored as `Q8_0`, except for the 1D tensors (32-bit).
39    MostlyQ8_0,
40    /// All tensors are mostly stored as `Q5_0`, except for the 1D tensors (32-bit).
41    MostlyQ5_0,
42    /// All tensors are mostly stored as `Q5_1`, except for the 1D tensors (32-bit).
43    MostlyQ5_1,
44}
45impl From<FileType> for i32 {
46    fn from(value: FileType) -> Self {
47        match value {
48            FileType::F32 => 0,
49            FileType::MostlyF16 => 1,
50            FileType::MostlyQ4_0 => 2,
51            FileType::MostlyQ4_1 => 3,
52            FileType::MostlyQ4_1SomeF16 => 4,
53            FileType::MostlyQ4_2 => 5,
54            FileType::MostlyQ8_0 => 7,
55            FileType::MostlyQ5_0 => 8,
56            FileType::MostlyQ5_1 => 9,
57        }
58    }
59}
60impl TryFrom<i32> for FileType {
61    type Error = ();
62
63    fn try_from(value: i32) -> Result<Self, Self::Error> {
64        match value {
65            0 => Ok(FileType::F32),
66            1 => Ok(FileType::MostlyF16),
67            2 => Ok(FileType::MostlyQ4_0),
68            3 => Ok(FileType::MostlyQ4_1),
69            4 => Ok(FileType::MostlyQ4_1SomeF16),
70            5 => Ok(FileType::MostlyQ4_2),
71            7 => Ok(FileType::MostlyQ8_0),
72            8 => Ok(FileType::MostlyQ5_0),
73            9 => Ok(FileType::MostlyQ5_1),
74            _ => Err(()),
75        }
76    }
77}
78impl Display for FileType {
79    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80        match self {
81            FileType::F32 => write!(f, "f32"),
82            FileType::MostlyF16 => write!(f, "f16"),
83            FileType::MostlyQ4_0 => write!(f, "q4_0"),
84            FileType::MostlyQ4_1 => write!(f, "q4_1"),
85            FileType::MostlyQ4_1SomeF16 => write!(f, "q4_1_with_f16"),
86            FileType::MostlyQ4_2 => write!(f, "q4_2"),
87            FileType::MostlyQ8_0 => write!(f, "q8_0"),
88            FileType::MostlyQ5_0 => write!(f, "q5_0"),
89            FileType::MostlyQ5_1 => write!(f, "q5_1"),
90        }
91    }
92}
93
94/// Each variant represents a step within the process of loading the model.
95/// These can be used to report progress to the user.
96#[derive(Clone, PartialEq, Eq, Debug)]
97pub enum LoadProgress {
98    /// The hyperparameters have been loaded from the model.
99    HyperparametersLoaded,
100    /// The context has been created.
101    ContextSize {
102        /// The size of the context.
103        bytes: usize,
104    },
105    /// A tensor from the current part has been loaded.
106    TensorLoaded {
107        /// The current tensor (0-indexed).
108        current_tensor: usize,
109        /// The number of total tensors.
110        tensor_count: usize,
111    },
112    /// A model part has finished fully loading.
113    Loaded {
114        /// The number of bytes in the part.
115        file_size: u64,
116        /// The number of tensors in the part.
117        tensor_count: usize,
118    },
119}
120
121#[derive(Error, Debug)]
122/// Errors encountered during the loading process.
123pub enum LoadError {
124    #[error("could not open file {path:?}")]
125    /// A file failed to open.
126    OpenFileFailed {
127        /// The original error.
128        source: std::io::Error,
129        /// The path that failed.
130        path: PathBuf,
131    },
132    #[error("no parent path for {path:?}")]
133    /// There is no parent path for a given path.
134    NoParentPath {
135        /// The path without a parent.
136        path: PathBuf,
137    },
138    #[error("unable to read exactly {bytes} bytes")]
139    /// Reading exactly `bytes` from a file failed.
140    ReadExactFailed {
141        /// The original error.
142        source: std::io::Error,
143        /// The number of bytes that were attempted to be read.
144        bytes: usize,
145    },
146    #[error("non-specific I/O error")]
147    /// A non-specific IO error.
148    Io(#[from] std::io::Error),
149    #[error("could not convert bytes to a UTF-8 string")]
150    /// One of the strings encountered was not valid UTF-8.
151    InvalidUtf8(#[from] std::string::FromUtf8Error),
152    #[error("invalid integer conversion")]
153    /// One of the integers encountered could not be converted to a more appropriate type.
154    InvalidIntegerConversion(#[from] std::num::TryFromIntError),
155    #[error("unsupported f16_: {0}")]
156    /// The `f16_` hyperparameter had an invalid value.
157    UnsupportedFileType(i32),
158    #[error("invalid magic number {magic:#x} for {path:?}")]
159    /// An invalid magic number was encountered during the loading process.
160    InvalidMagic {
161        /// The path that failed.
162        path: PathBuf,
163        /// The magic number that was encountered.
164        magic: u32,
165    },
166    #[error("invalid file format version {version}")]
167    /// The version of the format is not supported by this version of `llm`.
168    InvalidFormatVersion {
169        /// The format that was encountered.
170        container_type: ContainerType,
171        /// The version that was encountered.
172        version: u32,
173    },
174    #[error("invalid value {ftype} for `f16` in hyperparameters")]
175    /// The `f16` hyperparameter had an invalid value.
176    HyperparametersF16Invalid {
177        /// The format type that was encountered.
178        ftype: i32,
179    },
180    #[error("unknown tensor `{tensor_name}` in {path:?}")]
181    /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during
182    /// the model prelude.
183    UnknownTensor {
184        /// The name of the tensor.
185        tensor_name: String,
186        /// The path that failed.
187        path: PathBuf,
188    },
189    #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")]
190    /// The tensor `tensor_name` did not match its expected size.
191    TensorWrongSize {
192        /// The name of the tensor.
193        tensor_name: String,
194        /// The path that failed.
195        path: PathBuf,
196    },
197    /// The tensor `tensor_name` did not have the expected format type.
198    #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")]
199    UnsupportedElementType {
200        /// The name of the tensor.
201        tensor_name: String,
202        /// The format type that was encountered.
203        ftype: u32,
204        /// The path that failed.
205        path: PathBuf,
206    },
207    /// An invariant was broken.
208    ///
209    /// This error is not relevant unless `loader2` is being used.
210    #[error("invariant broken: {invariant} in {path:?}")]
211    InvariantBroken {
212        /// The path that failed.
213        path: Option<PathBuf>,
214        /// The invariant that was broken.
215        invariant: String,
216    },
217    /// The model could not be created.
218    ///
219    /// This implies that there were no tensors in the model to be loaded.
220    ///
221    /// This error is not relevant unless `loader2` is being used.
222    #[error("could not create model from {path:?}")]
223    ModelNotCreated {
224        /// The path that failed.
225        path: PathBuf,
226    },
227    /// Multiple parts of the model were found.
228    ///
229    /// Multi-part models are not supported. Please convert the model to a single part.
230    #[error("multipart models are not supported")]
231    MultipartNotSupported {
232        /// The paths that were found.
233        paths: Vec<PathBuf>,
234    },
235}
236impl From<FindAllModelFilesError> for LoadError {
237    fn from(value: FindAllModelFilesError) -> Self {
238        match value {
239            FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path },
240            FindAllModelFilesError::IO(err) => LoadError::Io(err),
241        }
242    }
243}
244
245impl LoadError {
246    #[doc(hidden)]
247    pub fn from_format_error(value: FormatLoadError<LoadError>, path: PathBuf) -> Self {
248        match value {
249            FormatLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { path, magic },
250            FormatLoadError::InvalidFormatVersion(container_type, version) => {
251                LoadError::InvalidFormatVersion {
252                    container_type,
253                    version,
254                }
255            }
256            FormatLoadError::Io(err) => LoadError::Io(err),
257            FormatLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err),
258            FormatLoadError::InvalidIntegerConversion(err) => {
259                LoadError::InvalidIntegerConversion(err)
260            }
261            FormatLoadError::ImplementationError(err) => err,
262            FormatLoadError::UnsupportedElementType { tensor_name, ftype } => {
263                LoadError::UnsupportedElementType {
264                    path,
265                    tensor_name,
266                    ftype,
267                }
268            }
269            FormatLoadError::InvariantBroken(invariant) => LoadError::InvariantBroken {
270                path: Some(path),
271                invariant,
272            },
273        }
274    }
275}
276
277/// Used by models to fetch tensors from a loader.
278pub trait TensorLoader<E: std::error::Error> {
279    /// Gets a tensor from the loader.
280    fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
281    /// Loads a tensor from the loader.
282    fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, E>;
283    /// Finish loading the model, and extract all of the state from the loader.
284    fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>);
285}
286
287/// Load a GGML model from the `path` and configure it per the `params`. The status
288/// of the loading process will be reported through `load_progress_callback`.
289///
290/// Note that the model must be a single-part model, and the model in `path`
291/// *must* match the architecture of `M`.
292///
293/// # Panics
294///
295/// - If the model does not match the architecture of `M`. This is not checked
296///   before execution, so this function will panic if the model does not match
297///   the architecture.
298///
299///   This is a limitation of the GGML format, which does not
300///   store any information about the architecture.
301pub fn load<M: KnownModel>(
302    path: &Path,
303    params: ModelParameters,
304    load_progress_callback: impl FnMut(LoadProgress),
305) -> Result<M, LoadError> {
306    let paths = util::find_all_model_files(path)?;
307    if paths.len() != 1 {
308        return Err(LoadError::MultipartNotSupported { paths });
309    }
310
311    let file = File::open(path).map_err(|e| LoadError::OpenFileFailed {
312        source: e,
313        path: path.to_owned(),
314    })?;
315    let mut reader = BufReader::new(&file);
316
317    let mut loader = Loader::new(load_progress_callback);
318
319    ggml::format::load(&mut reader, &mut loader)
320        .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?;
321
322    let Loader {
323        hyperparameters,
324        vocabulary,
325        tensors,
326        mut load_progress_callback,
327        container_type,
328        ..
329    } = loader;
330
331    let use_mmap = params.prefer_mmap && container_type.support_mmap();
332
333    let ctx_size = tensors
334        .values()
335        .map(|ti| {
336            ggml::Tensor::C_TYPE_SIZE
337                + ggml::OBJECT_SIZE
338                + if use_mmap { 0 } else { ti.calc_size() }
339        })
340        .sum::<usize>();
341    (load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size });
342    let context = Context::init(ctx_size, !use_mmap);
343
344    let (mmap, file_size) = {
345        let file = File::open(path)?;
346        let mmap = if use_mmap {
347            Some(unsafe { Mmap::map(&file)? })
348        } else {
349            None
350        };
351        (mmap, file.metadata()?.len())
352    };
353
354    struct MmapCompatibleLoader<'a> {
355        path: PathBuf,
356        file: File,
357        tensors: HashMap<String, TensorLoadInfo>,
358        context: Context,
359        mmap: Option<Mmap>,
360        load_progress_callback: &'a mut dyn FnMut(LoadProgress),
361        loaded_tensors: HashMap<String, ggml::Tensor>,
362    }
363    impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
364        fn load(&mut self, name: &str) -> Result<ggml::Tensor, LoadError> {
365            let tensor_dims = self
366                .tensors
367                .get(name)
368                .map(|tensor| tensor.dims().to_vec())
369                .ok_or(LoadError::UnknownTensor {
370                    tensor_name: String::from(name),
371                    path: Default::default(),
372                })?;
373            self.load_manual(name, &tensor_dims)
374        }
375
376        fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, LoadError> {
377            let info = self
378                .tensors
379                .get(name)
380                .ok_or_else(|| LoadError::UnknownTensor {
381                    path: self.path.clone(),
382                    tensor_name: name.to_owned(),
383                })?;
384
385            let dims = ne.len();
386            if dims != info.n_dims {
387                return Err(LoadError::InvariantBroken {
388                    path: Some(self.path.clone()),
389                    invariant: format!(
390                        "the tensor {name} should have {} dimensions, not {dims}",
391                        info.n_dims
392                    ),
393                });
394            }
395
396            let ctx = &self.context;
397            let mut tensor = match dims {
398                1 => ctx.new_tensor_1d(info.element_type, ne[0]),
399                2 => ctx.new_tensor_2d(info.element_type, ne[0], ne[1]),
400                3 => ctx.new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]),
401                _ => {
402                    return Err(LoadError::InvariantBroken {
403                        path: Some(self.path.clone()),
404                        invariant: format!(
405                            "the tensor {name} had an unsupported dimension count: {ne:?}"
406                        ),
407                    })
408                }
409            };
410
411            match self.mmap.as_ref() {
412                Some(mmap) => unsafe {
413                    let ptr = mmap.as_ptr().offset(info.start_offset as isize);
414                    tensor.set_data(ptr as *mut std::ffi::c_void);
415                },
416                None => {
417                    let buf: &mut [u8] = unsafe {
418                        std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes())
419                    };
420                    self.file.seek(SeekFrom::Start(info.start_offset))?;
421                    self.file.read_exact(buf)?;
422                }
423            }
424
425            self.loaded_tensors.insert(name.to_owned(), tensor.share());
426            (self.load_progress_callback)(LoadProgress::TensorLoaded {
427                current_tensor: self.loaded_tensors.len(),
428                tensor_count: self.tensors.len(),
429            });
430
431            Ok(tensor)
432        }
433
434        fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>) {
435            (self.context, self.loaded_tensors, self.mmap)
436        }
437    }
438
439    let tensors_len = tensors.len();
440    let tl = MmapCompatibleLoader {
441        path: path.to_owned(),
442        file,
443        tensors,
444        context,
445        mmap,
446        load_progress_callback: &mut load_progress_callback,
447        loaded_tensors: Default::default(),
448    };
449
450    let model = KnownModel::new(hyperparameters, params, vocabulary, tl)?;
451
452    (load_progress_callback)(LoadProgress::Loaded {
453        file_size,
454        tensor_count: tensors_len,
455    });
456
457    Ok(model)
458}
459
460/// A GGML format loader for LLMs.
461pub struct Loader<Hp: Hyperparameters, F: FnMut(LoadProgress)> {
462    // Input
463    load_progress_callback: F,
464
465    // Output
466    /// The container type of the model.
467    pub container_type: ContainerType,
468    /// The hyperparameters of the model.
469    pub hyperparameters: Hp,
470    /// The vocabulary of the model.
471    pub vocabulary: Vocabulary,
472    /// The tensors of the model.
473    pub tensors: HashMap<String, TensorLoadInfo>,
474}
475impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> Loader<Hp, F> {
476    /// Creates a new loader.
477    pub fn new(load_progress_callback: F) -> Self {
478        Self {
479            load_progress_callback,
480
481            container_type: ContainerType::Ggjt,
482            hyperparameters: Hp::default(),
483            vocabulary: Vocabulary::default(),
484            tensors: HashMap::default(),
485        }
486    }
487}
488impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> ggml::format::LoadHandler<LoadError>
489    for Loader<Hp, F>
490{
491    fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> {
492        self.container_type = container_type;
493        Ok(())
494    }
495
496    fn vocabulary_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> Result<(), LoadError> {
497        let id = match TokenId::try_from(i) {
498            Ok(id) => id,
499            Err(err) => return Err(LoadError::InvalidIntegerConversion(err)),
500        };
501        self.vocabulary.push_token(id, token, score);
502
503        Ok(())
504    }
505
506    fn read_hyperparameters(
507        &mut self,
508        reader: &mut dyn BufRead,
509    ) -> Result<PartialHyperparameters, LoadError> {
510        // NOTE: Field order matters! Data is laid out in the file exactly in this order.
511        let hyperparameters = Hp::read_ggml(reader)?;
512        let partial = PartialHyperparameters {
513            n_vocab: hyperparameters.n_vocabulary(),
514        };
515        self.hyperparameters = hyperparameters;
516        (self.load_progress_callback)(LoadProgress::HyperparametersLoaded);
517
518        Ok(partial)
519    }
520
521    fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), LoadError> {
522        self.tensors.insert(info.name.clone(), info);
523        Ok(())
524    }
525}
526
527/// A implementation for `load_progress_callback` that outputs to `stdout`.
528pub fn load_progress_callback_stdout(progress: LoadProgress) {
529    match progress {
530        LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"),
531        LoadProgress::ContextSize { bytes } => println!(
532            "ggml ctx size = {:.2} MB\n",
533            bytes as f64 / (1024.0 * 1024.0)
534        ),
535        LoadProgress::TensorLoaded {
536            current_tensor,
537            tensor_count,
538            ..
539        } => {
540            let current_tensor = current_tensor + 1;
541            if current_tensor % 8 == 0 {
542                println!("Loaded tensor {current_tensor}/{tensor_count}");
543            }
544        }
545        LoadProgress::Loaded {
546            file_size: byte_size,
547            tensor_count,
548        } => {
549            println!("Loading of model complete");
550            println!(
551                "Model size = {:.2} MB / num tensors = {}",
552                byte_size as f64 / 1024.0 / 1024.0,
553                tensor_count
554            );
555        }
556    };
557}