use std::{
collections::HashMap,
fmt::{Display, Formatter},
fs::File,
io::{BufRead, BufReader, Read, Seek, SeekFrom},
path::{Path, PathBuf},
};
use crate::{
util::{self, FindAllModelFilesError},
Hyperparameters, KnownModel, ModelParameters, TokenId, Vocabulary,
};
pub use ggml::ContainerType;
use ggml::{
format::{LoadError as FormatLoadError, PartialHyperparameters, TensorLoadInfo},
Context,
};
use memmap2::Mmap;
use thiserror::Error;
#[derive(Debug, PartialEq, Clone, Copy, Eq, Default)]
pub enum FileType {
F32,
#[default]
MostlyF16,
MostlyQ4_0,
MostlyQ4_1,
MostlyQ4_1SomeF16,
MostlyQ4_2,
MostlyQ8_0,
MostlyQ5_0,
MostlyQ5_1,
}
impl From<FileType> for i32 {
fn from(value: FileType) -> Self {
match value {
FileType::F32 => 0,
FileType::MostlyF16 => 1,
FileType::MostlyQ4_0 => 2,
FileType::MostlyQ4_1 => 3,
FileType::MostlyQ4_1SomeF16 => 4,
FileType::MostlyQ4_2 => 5,
FileType::MostlyQ8_0 => 7,
FileType::MostlyQ5_0 => 8,
FileType::MostlyQ5_1 => 9,
}
}
}
impl TryFrom<i32> for FileType {
type Error = ();
fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
0 => Ok(FileType::F32),
1 => Ok(FileType::MostlyF16),
2 => Ok(FileType::MostlyQ4_0),
3 => Ok(FileType::MostlyQ4_1),
4 => Ok(FileType::MostlyQ4_1SomeF16),
5 => Ok(FileType::MostlyQ4_2),
7 => Ok(FileType::MostlyQ8_0),
8 => Ok(FileType::MostlyQ5_0),
9 => Ok(FileType::MostlyQ5_1),
_ => Err(()),
}
}
}
impl Display for FileType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
FileType::F32 => write!(f, "f32"),
FileType::MostlyF16 => write!(f, "f16"),
FileType::MostlyQ4_0 => write!(f, "q4_0"),
FileType::MostlyQ4_1 => write!(f, "q4_1"),
FileType::MostlyQ4_1SomeF16 => write!(f, "q4_1_with_f16"),
FileType::MostlyQ4_2 => write!(f, "q4_2"),
FileType::MostlyQ8_0 => write!(f, "q8_0"),
FileType::MostlyQ5_0 => write!(f, "q5_0"),
FileType::MostlyQ5_1 => write!(f, "q5_1"),
}
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum LoadProgress {
HyperparametersLoaded,
ContextSize {
bytes: usize,
},
TensorLoaded {
current_tensor: usize,
tensor_count: usize,
},
Loaded {
file_size: u64,
tensor_count: usize,
},
}
#[derive(Error, Debug)]
pub enum LoadError {
#[error("could not open file {path:?}")]
OpenFileFailed {
source: std::io::Error,
path: PathBuf,
},
#[error("no parent path for {path:?}")]
NoParentPath {
path: PathBuf,
},
#[error("unable to read exactly {bytes} bytes")]
ReadExactFailed {
source: std::io::Error,
bytes: usize,
},
#[error("non-specific I/O error")]
Io(#[from] std::io::Error),
#[error("could not convert bytes to a UTF-8 string")]
InvalidUtf8(#[from] std::string::FromUtf8Error),
#[error("invalid integer conversion")]
InvalidIntegerConversion(#[from] std::num::TryFromIntError),
#[error("unsupported f16_: {0}")]
UnsupportedFileType(i32),
#[error("invalid magic number {magic:#x} for {path:?}")]
InvalidMagic {
path: PathBuf,
magic: u32,
},
#[error("invalid file format version {version}")]
InvalidFormatVersion {
container_type: ContainerType,
version: u32,
},
#[error("invalid value {ftype} for `f16` in hyperparameters")]
HyperparametersF16Invalid {
ftype: i32,
},
#[error("unknown tensor `{tensor_name}` in {path:?}")]
UnknownTensor {
tensor_name: String,
path: PathBuf,
},
#[error("the tensor `{tensor_name}` has the wrong size in {path:?}")]
TensorWrongSize {
tensor_name: String,
path: PathBuf,
},
#[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")]
UnsupportedElementType {
tensor_name: String,
ftype: u32,
path: PathBuf,
},
#[error("invariant broken: {invariant} in {path:?}")]
InvariantBroken {
path: Option<PathBuf>,
invariant: String,
},
#[error("could not create model from {path:?}")]
ModelNotCreated {
path: PathBuf,
},
#[error("multipart models are not supported")]
MultipartNotSupported {
paths: Vec<PathBuf>,
},
}
impl From<FindAllModelFilesError> for LoadError {
fn from(value: FindAllModelFilesError) -> Self {
match value {
FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path },
FindAllModelFilesError::IO(err) => LoadError::Io(err),
}
}
}
impl LoadError {
#[doc(hidden)]
pub fn from_format_error(value: FormatLoadError<LoadError>, path: PathBuf) -> Self {
match value {
FormatLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { path, magic },
FormatLoadError::InvalidFormatVersion(container_type, version) => {
LoadError::InvalidFormatVersion {
container_type,
version,
}
}
FormatLoadError::Io(err) => LoadError::Io(err),
FormatLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err),
FormatLoadError::InvalidIntegerConversion(err) => {
LoadError::InvalidIntegerConversion(err)
}
FormatLoadError::ImplementationError(err) => err,
FormatLoadError::UnsupportedElementType { tensor_name, ftype } => {
LoadError::UnsupportedElementType {
path,
tensor_name,
ftype,
}
}
FormatLoadError::InvariantBroken(invariant) => LoadError::InvariantBroken {
path: Some(path),
invariant,
},
}
}
}
pub trait TensorLoader<E: std::error::Error> {
fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, E>;
fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>);
}
pub fn load<M: KnownModel>(
path: &Path,
params: ModelParameters,
load_progress_callback: impl FnMut(LoadProgress),
) -> Result<M, LoadError> {
let paths = util::find_all_model_files(path)?;
if paths.len() != 1 {
return Err(LoadError::MultipartNotSupported { paths });
}
let file = File::open(path).map_err(|e| LoadError::OpenFileFailed {
source: e,
path: path.to_owned(),
})?;
let mut reader = BufReader::new(&file);
let mut loader = Loader::new(load_progress_callback);
ggml::format::load(&mut reader, &mut loader)
.map_err(|err| LoadError::from_format_error(err, path.to_owned()))?;
let Loader {
hyperparameters,
vocabulary,
tensors,
mut load_progress_callback,
container_type,
..
} = loader;
let use_mmap = params.prefer_mmap && container_type.support_mmap();
let ctx_size = tensors
.values()
.map(|ti| {
ggml::Tensor::C_TYPE_SIZE
+ ggml::OBJECT_SIZE
+ if use_mmap { 0 } else { ti.calc_size() }
})
.sum::<usize>();
(load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size });
let context = Context::init(ctx_size, !use_mmap);
let (mmap, file_size) = {
let file = File::open(path)?;
let mmap = if use_mmap {
Some(unsafe { Mmap::map(&file)? })
} else {
None
};
(mmap, file.metadata()?.len())
};
struct MmapCompatibleLoader<'a> {
path: PathBuf,
file: File,
tensors: HashMap<String, TensorLoadInfo>,
context: Context,
mmap: Option<Mmap>,
load_progress_callback: &'a mut dyn FnMut(LoadProgress),
loaded_tensors: HashMap<String, ggml::Tensor>,
}
impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
fn load(&mut self, name: &str) -> Result<ggml::Tensor, LoadError> {
let tensor_dims = self
.tensors
.get(name)
.map(|tensor| tensor.dims().to_vec())
.ok_or(LoadError::UnknownTensor {
tensor_name: String::from(name),
path: Default::default(),
})?;
self.load_manual(name, &tensor_dims)
}
fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, LoadError> {
let info = self
.tensors
.get(name)
.ok_or_else(|| LoadError::UnknownTensor {
path: self.path.clone(),
tensor_name: name.to_owned(),
})?;
let dims = ne.len();
if dims != info.n_dims {
return Err(LoadError::InvariantBroken {
path: Some(self.path.clone()),
invariant: format!(
"the tensor {name} should have {} dimensions, not {dims}",
info.n_dims
),
});
}
let ctx = &self.context;
let mut tensor = match dims {
1 => ctx.new_tensor_1d(info.element_type, ne[0]),
2 => ctx.new_tensor_2d(info.element_type, ne[0], ne[1]),
3 => ctx.new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]),
_ => {
return Err(LoadError::InvariantBroken {
path: Some(self.path.clone()),
invariant: format!(
"the tensor {name} had an unsupported dimension count: {ne:?}"
),
})
}
};
match self.mmap.as_ref() {
Some(mmap) => unsafe {
let ptr = mmap.as_ptr().offset(info.start_offset as isize);
tensor.set_data(ptr as *mut std::ffi::c_void);
},
None => {
let buf: &mut [u8] = unsafe {
std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes())
};
self.file.seek(SeekFrom::Start(info.start_offset))?;
self.file.read_exact(buf)?;
}
}
self.loaded_tensors.insert(name.to_owned(), tensor.share());
(self.load_progress_callback)(LoadProgress::TensorLoaded {
current_tensor: self.loaded_tensors.len(),
tensor_count: self.tensors.len(),
});
Ok(tensor)
}
fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>) {
(self.context, self.loaded_tensors, self.mmap)
}
}
let tensors_len = tensors.len();
let tl = MmapCompatibleLoader {
path: path.to_owned(),
file,
tensors,
context,
mmap,
load_progress_callback: &mut load_progress_callback,
loaded_tensors: Default::default(),
};
let model = KnownModel::new(hyperparameters, params, vocabulary, tl)?;
(load_progress_callback)(LoadProgress::Loaded {
file_size,
tensor_count: tensors_len,
});
Ok(model)
}
pub struct Loader<Hp: Hyperparameters, F: FnMut(LoadProgress)> {
load_progress_callback: F,
pub container_type: ContainerType,
pub hyperparameters: Hp,
pub vocabulary: Vocabulary,
pub tensors: HashMap<String, TensorLoadInfo>,
}
impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> Loader<Hp, F> {
pub fn new(load_progress_callback: F) -> Self {
Self {
load_progress_callback,
container_type: ContainerType::Ggjt,
hyperparameters: Hp::default(),
vocabulary: Vocabulary::default(),
tensors: HashMap::default(),
}
}
}
impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> ggml::format::LoadHandler<LoadError>
for Loader<Hp, F>
{
fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> {
self.container_type = container_type;
Ok(())
}
fn vocabulary_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> Result<(), LoadError> {
let id = match TokenId::try_from(i) {
Ok(id) => id,
Err(err) => return Err(LoadError::InvalidIntegerConversion(err)),
};
self.vocabulary.push_token(id, token, score);
Ok(())
}
fn read_hyperparameters(
&mut self,
reader: &mut dyn BufRead,
) -> Result<PartialHyperparameters, LoadError> {
let hyperparameters = Hp::read_ggml(reader)?;
let partial = PartialHyperparameters {
n_vocab: hyperparameters.n_vocabulary(),
};
self.hyperparameters = hyperparameters;
(self.load_progress_callback)(LoadProgress::HyperparametersLoaded);
Ok(partial)
}
fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), LoadError> {
self.tensors.insert(info.name.clone(), info);
Ok(())
}
}
pub fn load_progress_callback_stdout(progress: LoadProgress) {
match progress {
LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"),
LoadProgress::ContextSize { bytes } => println!(
"ggml ctx size = {:.2} MB\n",
bytes as f64 / (1024.0 * 1024.0)
),
LoadProgress::TensorLoaded {
current_tensor,
tensor_count,
..
} => {
let current_tensor = current_tensor + 1;
if current_tensor % 8 == 0 {
println!("Loaded tensor {current_tensor}/{tensor_count}");
}
}
LoadProgress::Loaded {
file_size: byte_size,
tensor_count,
} => {
println!("Loading of model complete");
println!(
"Model size = {:.2} MB / num tensors = {}",
byte_size as f64 / 1024.0 / 1024.0,
tensor_count
);
}
};
}