1use std::{
2 collections::HashMap,
3 fmt::{Display, Formatter},
4 fs::File,
5 io::{BufRead, BufReader, Read, Seek, SeekFrom},
6 path::{Path, PathBuf},
7};
8
9use crate::{
10 util::{self, FindAllModelFilesError},
11 Hyperparameters, KnownModel, ModelParameters, TokenId, Vocabulary,
12};
13pub use ggml::ContainerType;
14use ggml::{
15 format::{LoadError as FormatLoadError, PartialHyperparameters, TensorLoadInfo},
16 Context,
17};
18use memmap2::Mmap;
19use thiserror::Error;
20
21#[derive(Debug, PartialEq, Clone, Copy, Eq, Default)]
23pub enum FileType {
24 F32,
26 #[default]
27 MostlyF16,
29 MostlyQ4_0,
31 MostlyQ4_1,
33 MostlyQ4_1SomeF16,
36 MostlyQ4_2,
38 MostlyQ8_0,
40 MostlyQ5_0,
42 MostlyQ5_1,
44}
45impl From<FileType> for i32 {
46 fn from(value: FileType) -> Self {
47 match value {
48 FileType::F32 => 0,
49 FileType::MostlyF16 => 1,
50 FileType::MostlyQ4_0 => 2,
51 FileType::MostlyQ4_1 => 3,
52 FileType::MostlyQ4_1SomeF16 => 4,
53 FileType::MostlyQ4_2 => 5,
54 FileType::MostlyQ8_0 => 7,
55 FileType::MostlyQ5_0 => 8,
56 FileType::MostlyQ5_1 => 9,
57 }
58 }
59}
60impl TryFrom<i32> for FileType {
61 type Error = ();
62
63 fn try_from(value: i32) -> Result<Self, Self::Error> {
64 match value {
65 0 => Ok(FileType::F32),
66 1 => Ok(FileType::MostlyF16),
67 2 => Ok(FileType::MostlyQ4_0),
68 3 => Ok(FileType::MostlyQ4_1),
69 4 => Ok(FileType::MostlyQ4_1SomeF16),
70 5 => Ok(FileType::MostlyQ4_2),
71 7 => Ok(FileType::MostlyQ8_0),
72 8 => Ok(FileType::MostlyQ5_0),
73 9 => Ok(FileType::MostlyQ5_1),
74 _ => Err(()),
75 }
76 }
77}
78impl Display for FileType {
79 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
80 match self {
81 FileType::F32 => write!(f, "f32"),
82 FileType::MostlyF16 => write!(f, "f16"),
83 FileType::MostlyQ4_0 => write!(f, "q4_0"),
84 FileType::MostlyQ4_1 => write!(f, "q4_1"),
85 FileType::MostlyQ4_1SomeF16 => write!(f, "q4_1_with_f16"),
86 FileType::MostlyQ4_2 => write!(f, "q4_2"),
87 FileType::MostlyQ8_0 => write!(f, "q8_0"),
88 FileType::MostlyQ5_0 => write!(f, "q5_0"),
89 FileType::MostlyQ5_1 => write!(f, "q5_1"),
90 }
91 }
92}
93
94#[derive(Clone, PartialEq, Eq, Debug)]
97pub enum LoadProgress {
98 HyperparametersLoaded,
100 ContextSize {
102 bytes: usize,
104 },
105 TensorLoaded {
107 current_tensor: usize,
109 tensor_count: usize,
111 },
112 Loaded {
114 file_size: u64,
116 tensor_count: usize,
118 },
119}
120
121#[derive(Error, Debug)]
122pub enum LoadError {
124 #[error("could not open file {path:?}")]
125 OpenFileFailed {
127 source: std::io::Error,
129 path: PathBuf,
131 },
132 #[error("no parent path for {path:?}")]
133 NoParentPath {
135 path: PathBuf,
137 },
138 #[error("unable to read exactly {bytes} bytes")]
139 ReadExactFailed {
141 source: std::io::Error,
143 bytes: usize,
145 },
146 #[error("non-specific I/O error")]
147 Io(#[from] std::io::Error),
149 #[error("could not convert bytes to a UTF-8 string")]
150 InvalidUtf8(#[from] std::string::FromUtf8Error),
152 #[error("invalid integer conversion")]
153 InvalidIntegerConversion(#[from] std::num::TryFromIntError),
155 #[error("unsupported f16_: {0}")]
156 UnsupportedFileType(i32),
158 #[error("invalid magic number {magic:#x} for {path:?}")]
159 InvalidMagic {
161 path: PathBuf,
163 magic: u32,
165 },
166 #[error("invalid file format version {version}")]
167 InvalidFormatVersion {
169 container_type: ContainerType,
171 version: u32,
173 },
174 #[error("invalid value {ftype} for `f16` in hyperparameters")]
175 HyperparametersF16Invalid {
177 ftype: i32,
179 },
180 #[error("unknown tensor `{tensor_name}` in {path:?}")]
181 UnknownTensor {
184 tensor_name: String,
186 path: PathBuf,
188 },
189 #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")]
190 TensorWrongSize {
192 tensor_name: String,
194 path: PathBuf,
196 },
197 #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")]
199 UnsupportedElementType {
200 tensor_name: String,
202 ftype: u32,
204 path: PathBuf,
206 },
207 #[error("invariant broken: {invariant} in {path:?}")]
211 InvariantBroken {
212 path: Option<PathBuf>,
214 invariant: String,
216 },
217 #[error("could not create model from {path:?}")]
223 ModelNotCreated {
224 path: PathBuf,
226 },
227 #[error("multipart models are not supported")]
231 MultipartNotSupported {
232 paths: Vec<PathBuf>,
234 },
235}
236impl From<FindAllModelFilesError> for LoadError {
237 fn from(value: FindAllModelFilesError) -> Self {
238 match value {
239 FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path },
240 FindAllModelFilesError::IO(err) => LoadError::Io(err),
241 }
242 }
243}
244
245impl LoadError {
246 #[doc(hidden)]
247 pub fn from_format_error(value: FormatLoadError<LoadError>, path: PathBuf) -> Self {
248 match value {
249 FormatLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { path, magic },
250 FormatLoadError::InvalidFormatVersion(container_type, version) => {
251 LoadError::InvalidFormatVersion {
252 container_type,
253 version,
254 }
255 }
256 FormatLoadError::Io(err) => LoadError::Io(err),
257 FormatLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err),
258 FormatLoadError::InvalidIntegerConversion(err) => {
259 LoadError::InvalidIntegerConversion(err)
260 }
261 FormatLoadError::ImplementationError(err) => err,
262 FormatLoadError::UnsupportedElementType { tensor_name, ftype } => {
263 LoadError::UnsupportedElementType {
264 path,
265 tensor_name,
266 ftype,
267 }
268 }
269 FormatLoadError::InvariantBroken(invariant) => LoadError::InvariantBroken {
270 path: Some(path),
271 invariant,
272 },
273 }
274 }
275}
276
277pub trait TensorLoader<E: std::error::Error> {
279 fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
281 fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, E>;
283 fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>);
285}
286
287pub fn load<M: KnownModel>(
302 path: &Path,
303 params: ModelParameters,
304 load_progress_callback: impl FnMut(LoadProgress),
305) -> Result<M, LoadError> {
306 let paths = util::find_all_model_files(path)?;
307 if paths.len() != 1 {
308 return Err(LoadError::MultipartNotSupported { paths });
309 }
310
311 let file = File::open(path).map_err(|e| LoadError::OpenFileFailed {
312 source: e,
313 path: path.to_owned(),
314 })?;
315 let mut reader = BufReader::new(&file);
316
317 let mut loader = Loader::new(load_progress_callback);
318
319 ggml::format::load(&mut reader, &mut loader)
320 .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?;
321
322 let Loader {
323 hyperparameters,
324 vocabulary,
325 tensors,
326 mut load_progress_callback,
327 container_type,
328 ..
329 } = loader;
330
331 let use_mmap = params.prefer_mmap && container_type.support_mmap();
332
333 let ctx_size = tensors
334 .values()
335 .map(|ti| {
336 ggml::Tensor::C_TYPE_SIZE
337 + ggml::OBJECT_SIZE
338 + if use_mmap { 0 } else { ti.calc_size() }
339 })
340 .sum::<usize>();
341 (load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size });
342 let context = Context::init(ctx_size, !use_mmap);
343
344 let (mmap, file_size) = {
345 let file = File::open(path)?;
346 let mmap = if use_mmap {
347 Some(unsafe { Mmap::map(&file)? })
348 } else {
349 None
350 };
351 (mmap, file.metadata()?.len())
352 };
353
354 struct MmapCompatibleLoader<'a> {
355 path: PathBuf,
356 file: File,
357 tensors: HashMap<String, TensorLoadInfo>,
358 context: Context,
359 mmap: Option<Mmap>,
360 load_progress_callback: &'a mut dyn FnMut(LoadProgress),
361 loaded_tensors: HashMap<String, ggml::Tensor>,
362 }
363 impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
364 fn load(&mut self, name: &str) -> Result<ggml::Tensor, LoadError> {
365 let tensor_dims = self
366 .tensors
367 .get(name)
368 .map(|tensor| tensor.dims().to_vec())
369 .ok_or(LoadError::UnknownTensor {
370 tensor_name: String::from(name),
371 path: Default::default(),
372 })?;
373 self.load_manual(name, &tensor_dims)
374 }
375
376 fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, LoadError> {
377 let info = self
378 .tensors
379 .get(name)
380 .ok_or_else(|| LoadError::UnknownTensor {
381 path: self.path.clone(),
382 tensor_name: name.to_owned(),
383 })?;
384
385 let dims = ne.len();
386 if dims != info.n_dims {
387 return Err(LoadError::InvariantBroken {
388 path: Some(self.path.clone()),
389 invariant: format!(
390 "the tensor {name} should have {} dimensions, not {dims}",
391 info.n_dims
392 ),
393 });
394 }
395
396 let ctx = &self.context;
397 let mut tensor = match dims {
398 1 => ctx.new_tensor_1d(info.element_type, ne[0]),
399 2 => ctx.new_tensor_2d(info.element_type, ne[0], ne[1]),
400 3 => ctx.new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]),
401 _ => {
402 return Err(LoadError::InvariantBroken {
403 path: Some(self.path.clone()),
404 invariant: format!(
405 "the tensor {name} had an unsupported dimension count: {ne:?}"
406 ),
407 })
408 }
409 };
410
411 match self.mmap.as_ref() {
412 Some(mmap) => unsafe {
413 let ptr = mmap.as_ptr().offset(info.start_offset as isize);
414 tensor.set_data(ptr as *mut std::ffi::c_void);
415 },
416 None => {
417 let buf: &mut [u8] = unsafe {
418 std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes())
419 };
420 self.file.seek(SeekFrom::Start(info.start_offset))?;
421 self.file.read_exact(buf)?;
422 }
423 }
424
425 self.loaded_tensors.insert(name.to_owned(), tensor.share());
426 (self.load_progress_callback)(LoadProgress::TensorLoaded {
427 current_tensor: self.loaded_tensors.len(),
428 tensor_count: self.tensors.len(),
429 });
430
431 Ok(tensor)
432 }
433
434 fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>) {
435 (self.context, self.loaded_tensors, self.mmap)
436 }
437 }
438
439 let tensors_len = tensors.len();
440 let tl = MmapCompatibleLoader {
441 path: path.to_owned(),
442 file,
443 tensors,
444 context,
445 mmap,
446 load_progress_callback: &mut load_progress_callback,
447 loaded_tensors: Default::default(),
448 };
449
450 let model = KnownModel::new(hyperparameters, params, vocabulary, tl)?;
451
452 (load_progress_callback)(LoadProgress::Loaded {
453 file_size,
454 tensor_count: tensors_len,
455 });
456
457 Ok(model)
458}
459
460pub struct Loader<Hp: Hyperparameters, F: FnMut(LoadProgress)> {
462 load_progress_callback: F,
464
465 pub container_type: ContainerType,
468 pub hyperparameters: Hp,
470 pub vocabulary: Vocabulary,
472 pub tensors: HashMap<String, TensorLoadInfo>,
474}
475impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> Loader<Hp, F> {
476 pub fn new(load_progress_callback: F) -> Self {
478 Self {
479 load_progress_callback,
480
481 container_type: ContainerType::Ggjt,
482 hyperparameters: Hp::default(),
483 vocabulary: Vocabulary::default(),
484 tensors: HashMap::default(),
485 }
486 }
487}
488impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> ggml::format::LoadHandler<LoadError>
489 for Loader<Hp, F>
490{
491 fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> {
492 self.container_type = container_type;
493 Ok(())
494 }
495
496 fn vocabulary_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> Result<(), LoadError> {
497 let id = match TokenId::try_from(i) {
498 Ok(id) => id,
499 Err(err) => return Err(LoadError::InvalidIntegerConversion(err)),
500 };
501 self.vocabulary.push_token(id, token, score);
502
503 Ok(())
504 }
505
506 fn read_hyperparameters(
507 &mut self,
508 reader: &mut dyn BufRead,
509 ) -> Result<PartialHyperparameters, LoadError> {
510 let hyperparameters = Hp::read_ggml(reader)?;
512 let partial = PartialHyperparameters {
513 n_vocab: hyperparameters.n_vocabulary(),
514 };
515 self.hyperparameters = hyperparameters;
516 (self.load_progress_callback)(LoadProgress::HyperparametersLoaded);
517
518 Ok(partial)
519 }
520
521 fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), LoadError> {
522 self.tensors.insert(info.name.clone(), info);
523 Ok(())
524 }
525}
526
527pub fn load_progress_callback_stdout(progress: LoadProgress) {
529 match progress {
530 LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"),
531 LoadProgress::ContextSize { bytes } => println!(
532 "ggml ctx size = {:.2} MB\n",
533 bytes as f64 / (1024.0 * 1024.0)
534 ),
535 LoadProgress::TensorLoaded {
536 current_tensor,
537 tensor_count,
538 ..
539 } => {
540 let current_tensor = current_tensor + 1;
541 if current_tensor % 8 == 0 {
542 println!("Loaded tensor {current_tensor}/{tensor_count}");
543 }
544 }
545 LoadProgress::Loaded {
546 file_size: byte_size,
547 tensor_count,
548 } => {
549 println!("Loading of model complete");
550 println!(
551 "Model size = {:.2} MB / num tensors = {}",
552 byte_size as f64 / 1024.0 / 1024.0,
553 tensor_count
554 );
555 }
556 };
557}