use super::GGufWriter;
use crate::{DEFAULT_ALIGNMENT, GGmlType, GGufFileHeader, GGufMetaDataValueType, pad};
use log::trace;
use std::{
borrow::Borrow,
io::{Result, Write},
time::Instant,
};
pub struct GGufFileWriter<T: Write> {
writer: GGufWriter<T>,
alignment: usize,
}
pub struct GGufTensorWriter<T: Write, U> {
writer: GGufWriter<T>,
alignment: usize,
data: Vec<U>,
offset: usize,
write_data: bool,
}
pub trait DataFuture {
fn get(&self) -> &[u8];
}
impl<T: Borrow<[u8]>> DataFuture for T {
#[inline]
fn get(&self) -> &[u8] {
self.borrow()
}
}
impl<T: Write> GGufFileWriter<T> {
#[inline]
pub fn new(writer: T, header: GGufFileHeader) -> Result<Self> {
let mut writer = GGufWriter::new(writer);
writer.write_header(header)?;
Ok(Self {
writer,
alignment: DEFAULT_ALIGNMENT,
})
}
#[inline]
pub fn with_alignment(writer: T, header: GGufFileHeader, alignment: usize) -> Result<Self> {
let mut ans = Self::new(writer, header)?;
ans.write_alignment(alignment)?;
Ok(ans)
}
#[inline]
pub fn write_alignment(&mut self, alignment: usize) -> Result<()> {
self.writer.write_alignment(alignment)?;
self.alignment = alignment;
Ok(())
}
#[inline]
pub fn write_meta_kv(
&mut self,
key: &str,
ty: GGufMetaDataValueType,
val: &[u8],
) -> Result<()> {
if let Some(alignment) = self.writer.write_meta_kv(key, ty, val)? {
self.alignment = alignment;
}
Ok(())
}
#[inline]
pub fn finish<U>(self, write_data: bool) -> GGufTensorWriter<T, U> {
GGufTensorWriter {
writer: self.writer,
alignment: self.alignment,
data: Vec::new(),
offset: 0,
write_data,
}
}
}
impl<T: Write, U: DataFuture> GGufTensorWriter<T, U> {
pub fn write_tensor(&mut self, name: &str, ty: GGmlType, shape: &[u64], data: U) -> Result<()> {
self.offset += pad(self.offset, self.alignment);
self.writer
.write_tensor_info(name, shape, ty, self.offset as _)
.unwrap();
let len = ty.size().elements_to_bytes(shape);
self.offset += len;
if self.write_data {
self.data.push(data)
}
Ok(())
}
pub fn finish(self) -> Result<usize> {
let Self {
mut writer,
alignment,
data,
..
} = self;
let total = data.len().to_string();
let width = total.len();
for (i, data) in data.into_iter().enumerate() {
let t0 = Instant::now();
let data = data.get();
let t1 = Instant::now();
writer.write_padding(alignment)?;
writer.write_data(data)?;
let t2 = Instant::now();
trace!(
"data {i:>width$}/{total} size = {} bytes, calculate in {:?}, write in {:?}",
data.len(),
t1 - t0,
t2 - t1,
)
}
Ok(writer.written_bytes())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::GGufFileHeader;
use std::io::Cursor;
fn create_test_header() -> GGufFileHeader {
GGufFileHeader::new(3, 0, 0)
}
#[test]
fn test_file_writer_new() {
let cursor = Cursor::new(Vec::new());
let writer = GGufFileWriter::new(cursor, create_test_header()).unwrap();
assert_eq!(writer.alignment, DEFAULT_ALIGNMENT);
}
#[test]
fn test_file_writer_with_alignment() {
let cursor = Cursor::new(Vec::new());
let alignment = 64;
let writer =
GGufFileWriter::with_alignment(cursor, create_test_header(), alignment).unwrap();
assert_eq!(writer.alignment, alignment);
}
#[test]
fn test_write_alignment() {
let cursor = Cursor::new(Vec::new());
let mut writer = GGufFileWriter::new(cursor, create_test_header()).unwrap();
let new_alignment = 128;
writer.write_alignment(new_alignment).unwrap();
assert_eq!(writer.alignment, new_alignment);
}
#[test]
fn test_write_meta_kv() {
use std::panic::{AssertUnwindSafe, catch_unwind};
let cursor = Cursor::new(Vec::new());
let mut writer = GGufFileWriter::new(cursor, create_test_header()).unwrap();
writer
.write_meta_kv("test.key", GGufMetaDataValueType::U32, &[1, 0, 0, 0])
.unwrap();
writer
.write_meta_kv(
"general.alignment",
GGufMetaDataValueType::U32,
&[64, 0, 0, 0],
)
.unwrap();
assert_eq!(writer.alignment, 64);
let result = catch_unwind(AssertUnwindSafe(|| {
writer
.write_meta_kv(
"general.alignment",
GGufMetaDataValueType::String,
b"test\0",
)
.unwrap();
}));
assert!(result.is_err(), "Expected panic for non-u32 value type");
}
#[test]
fn test_finish_and_tensor_writer() {
let cursor = Cursor::new(Vec::new());
let writer = GGufFileWriter::new(cursor, create_test_header()).unwrap();
let tensor_writer = writer.finish::<Vec<u8>>(false);
assert_eq!(tensor_writer.alignment, DEFAULT_ALIGNMENT);
assert_eq!(tensor_writer.offset, 0);
assert!(tensor_writer.data.is_empty());
assert!(!tensor_writer.write_data);
}
#[test]
fn test_tensor_writer_write_tensor() {
let cursor = Cursor::new(Vec::new());
let writer = GGufFileWriter::new(cursor, create_test_header()).unwrap();
let mut tensor_writer = writer.finish::<Vec<u8>>(true);
let shape = [2, 3];
let data = vec![0u8; 24]; tensor_writer
.write_tensor("test_tensor", GGmlType::F32, &shape, data.clone())
.unwrap();
assert_eq!(tensor_writer.offset, 24); assert_eq!(tensor_writer.data.len(), 1);
assert_eq!(tensor_writer.data[0].get(), data.as_slice());
}
#[test]
fn test_tensor_writer_multiple_tensors() {
let cursor = Cursor::new(Vec::new());
let writer = GGufFileWriter::new(cursor, create_test_header()).unwrap();
let mut tensor_writer = writer.finish::<Vec<u8>>(true);
let shape1 = [2, 3];
let data1 = vec![0u8; 24];
tensor_writer
.write_tensor("tensor1", GGmlType::F32, &shape1, data1)
.unwrap();
let shape2 = [4, 4];
let data2 = vec![0u8; 64]; tensor_writer
.write_tensor("tensor2", GGmlType::F16, &shape2, data2)
.unwrap();
assert_eq!(tensor_writer.data.len(), 2);
let expected_offset = 24 + pad(24, DEFAULT_ALIGNMENT) + 32;
assert_eq!(tensor_writer.offset, expected_offset);
}
#[test]
fn test_tensor_writer_finish() {
let cursor = Cursor::new(Vec::new());
let writer = GGufFileWriter::new(cursor, create_test_header()).unwrap();
let mut tensor_writer = writer.finish::<Vec<u8>>(true);
let shape = [2, 2];
let data = vec![0u8; 16]; tensor_writer
.write_tensor("test_tensor", GGmlType::F32, &shape, data)
.unwrap();
let bytes_written = tensor_writer.finish().unwrap();
assert!(bytes_written > 16); }
#[test]
fn test_end_to_end_write_process() {
let cursor = Cursor::new(Vec::new());
let header = GGufFileHeader::new(3, 0, 0);
let mut writer = GGufFileWriter::new(cursor, header).unwrap();
writer.write_alignment(64).unwrap();
writer
.write_meta_kv(
"general.architecture",
GGufMetaDataValueType::String,
b"llama\0",
)
.unwrap();
writer
.write_meta_kv(
"general.name",
GGufMetaDataValueType::String,
b"test_model\0",
)
.unwrap();
writer
.write_meta_kv(
"llm.context_length",
GGufMetaDataValueType::U32,
&4096u32.to_le_bytes(),
)
.unwrap();
let mut tensor_writer = writer.finish::<Vec<u8>>(true);
let shape1 = [5, 5];
let data1 = vec![0u8; 100]; tensor_writer
.write_tensor("embeddings", GGmlType::F32, &shape1, data1)
.unwrap();
let shape2 = [10, 20];
let data2 = vec![0u8; 400]; tensor_writer
.write_tensor("weights", GGmlType::F32, &shape2, data2)
.unwrap();
let total_bytes = tensor_writer.finish().unwrap();
assert!(total_bytes > 500); }
}