use std::{
error::Error,
io::{BufRead, Seek, SeekFrom},
};
use crate::{
util::{has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32},
ContainerType, ElementType,
};
#[derive(Debug, thiserror::Error)]
pub enum LoadError<E: Error> {
#[error("invalid file magic number: {0}")]
InvalidMagic(u32),
#[error("invalid ggml format: format={0:?} version={1}")]
InvalidFormatVersion(ContainerType, u32),
#[error("non-specific I/O error")]
Io(#[from] std::io::Error),
#[error("could not convert bytes to a UTF-8 string")]
InvalidUtf8(#[from] std::string::FromUtf8Error),
#[error("invalid integer conversion")]
InvalidIntegerConversion(#[from] std::num::TryFromIntError),
#[error("implementation error")]
ImplementationError(#[source] E),
#[error("unsupported tensor type {ftype} for tensor {tensor_name}")]
UnsupportedElementType {
tensor_name: String,
ftype: u32,
},
#[error("invariant broken: {0}")]
InvariantBroken(String),
}
#[derive(Debug, Clone)]
pub struct TensorLoadInfo {
pub name: String,
pub n_dims: usize,
pub dims: [usize; 2],
pub n_elements: usize,
pub element_type: ElementType,
pub start_offset: u64,
}
impl TensorLoadInfo {
pub fn dims(&self) -> &[usize] {
&self.dims[0..self.n_dims]
}
pub fn calc_size(&self) -> usize {
data_size(self.element_type, self.dims().iter().product())
}
pub fn read_data<R: BufRead + Seek>(&self, reader: &mut R) -> std::io::Result<Vec<u8>> {
let n_bytes = self.n_elements * crate::type_size(self.element_type);
let mut data = vec![0; n_bytes];
reader.seek(SeekFrom::Start(self.start_offset))?;
reader.read_exact(&mut data)?;
Ok(data)
}
}
pub(crate) fn data_size(element_type: ElementType, n_elements: usize) -> usize {
(crate::type_size(element_type) * n_elements) / crate::blck_size(element_type)
}
#[derive(Debug, Clone)]
pub struct PartialHyperparameters {
pub n_vocab: usize,
}
pub trait LoadHandler<E: Error> {
fn container_type(&mut self, container_type: ContainerType) -> Result<(), E>;
fn vocabulary_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> Result<(), E>;
fn read_hyperparameters(
&mut self,
reader: &mut dyn BufRead,
) -> Result<PartialHyperparameters, E>;
fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), E>;
}
pub fn load<E: Error, R: BufRead + Seek>(
reader: &mut R,
handler: &mut impl LoadHandler<E>,
) -> Result<(), LoadError<E>> {
let container_type: ContainerType = match read_u32(reader)? {
crate::FILE_MAGIC_GGMF => ContainerType::Ggmf,
crate::FILE_MAGIC_GGJT => ContainerType::Ggjt,
crate::FILE_MAGIC_UNVERSIONED => ContainerType::Ggml,
magic => return Err(LoadError::InvalidMagic(magic)),
};
handler
.container_type(container_type)
.map_err(LoadError::ImplementationError)?;
match container_type {
ContainerType::Ggmf | ContainerType::Ggjt => {
let _version: u32 = match read_u32(reader)? {
crate::FORMAT_VERSION => crate::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion(container_type, version)),
};
}
ContainerType::Ggml => {}
}
let hparams = handler
.read_hyperparameters(reader)
.map_err(LoadError::ImplementationError)?;
let n_vocab = hparams.n_vocab;
for i in 0..n_vocab {
let len = read_u32(reader)?.try_into()?;
let token = read_bytes_with_len(reader, len)?;
let token_score = match container_type {
ContainerType::Ggmf | ContainerType::Ggjt => read_f32(reader)?,
ContainerType::Ggml => {
0.
}
};
handler
.vocabulary_token(i, token, token_score)
.map_err(LoadError::ImplementationError)?;
}
match container_type {
ContainerType::Ggmf | ContainerType::Ggml => load_weights(reader, handler, false),
ContainerType::Ggjt => load_weights(reader, handler, true),
}
}
fn load_weights<E: Error, R: BufRead + Seek>(
reader: &mut R,
handler: &mut impl LoadHandler<E>,
align: bool,
) -> Result<(), LoadError<E>> {
while has_data_left(reader)? {
let n_dims: usize = read_i32(reader)?.try_into()?;
let name_len = read_i32(reader)?;
let ftype = read_u32(reader)?;
let mut n_elements: usize = 1;
let mut dims = [1usize, 1];
let ne_len = dims.len();
if n_dims > ne_len {
return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}")));
}
#[allow(clippy::needless_range_loop)]
for i in 0..n_dims {
let dim: usize = read_i32(reader)?.try_into()?;
dims[i] = dim;
n_elements *= dim;
}
let name = String::from_utf8(read_bytes_with_len(reader, name_len.try_into()?)?)?;
let ftype =
crate::Type::try_from(ftype).map_err(|_| LoadError::UnsupportedElementType {
tensor_name: name.clone(),
ftype,
})?;
match ftype {
ElementType::Q4_0 | ElementType::Q4_1 => {
if dims[0] % 64 != 0 {
return Err(LoadError::InvariantBroken(format!("{dims:?}[0] % 64 == 0")));
}
}
_ => {}
}
let offset_curr = reader.stream_position()?;
let offset_aligned: u64 = if align {
(offset_curr + 31) & !31
} else {
offset_curr
};
let tensor_info = TensorLoadInfo {
name,
dims,
n_dims,
n_elements,
element_type: ftype,
start_offset: offset_aligned,
};
let n_bytes = tensor_info.calc_size();
handler
.tensor_buffer(tensor_info)
.map_err(LoadError::ImplementationError)?;
reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?;
}
Ok(())
}