use std::{
error::Error,
io::{Seek, Write},
};
use crate::{util, ElementType};
#[derive(Debug, thiserror::Error)]
pub enum SaveError<E: Error> {
#[error("non-specific I/O error")]
Io(#[from] std::io::Error),
#[error("invalid integer conversion")]
InvalidIntegerConversion(#[from] std::num::TryFromIntError),
#[error("implementation error")]
ImplementationError(#[source] E),
#[error("invariant broken: {0}")]
InvariantBroken(String),
}
pub trait SaveHandler<E: Error> {
fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), E>;
fn tensor_data(&mut self, tensor_name: &str) -> Result<TensorSaveInfo, E>;
}
#[derive(Clone, PartialEq, Debug)]
pub struct TensorSaveInfo {
pub n_dims: usize,
pub dims: [usize; 2],
pub element_type: ElementType,
pub data: Vec<u8>,
}
pub fn save<E: Error, W: Write + Seek>(
writer: &mut W,
handler: &mut dyn SaveHandler<E>,
vocabulary: &[(Vec<u8>, f32)],
tensor_names: &[String],
) -> Result<(), SaveError<E>> {
util::write_u32(writer, crate::FILE_MAGIC_GGJT)?;
util::write_u32(writer, crate::FORMAT_VERSION)?;
handler
.write_hyperparameters(writer)
.map_err(SaveError::ImplementationError)?;
for (token, score) in vocabulary {
util::write_u32(writer, token.len().try_into()?)?;
writer.write_all(token)?;
util::write_f32(writer, *score)?;
}
for name in tensor_names {
let TensorSaveInfo {
n_dims,
dims,
element_type,
data,
} = handler
.tensor_data(name)
.map_err(SaveError::ImplementationError)?;
match element_type {
ElementType::Q4_0 | ElementType::Q4_1 => {
if dims[0] % 64 != 0 {
return Err(SaveError::InvariantBroken(format!("{dims:?}[0] % 64 == 0")));
}
}
_ => {}
}
util::write_i32(writer, n_dims.try_into()?)?;
util::write_i32(writer, name.len().try_into()?)?;
util::write_u32(writer, element_type.into())?;
for &dim in &dims[0..n_dims] {
util::write_i32(writer, dim.try_into()?)?;
}
writer.write_all(name.as_bytes())?;
let offset_curr = writer.stream_position()?;
let offset_aligned = (offset_curr + 31) & !31;
let padding = usize::try_from(offset_aligned - offset_curr)?;
writer.write_all(&vec![0; padding])?;
writer.write_all(&data)?;
}
Ok(())
}