use std::io::Write;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum GgufType {
Uint8 = 0,
Int8 = 1,
Uint16 = 2,
Int16 = 3,
Uint32 = 4,
Int32 = 5,
Float32 = 6,
Bool = 7,
String = 8,
Array = 9,
Uint64 = 10,
Int64 = 11,
Float64 = 12,
}
#[derive(Debug, Clone)]
pub enum MetadataWriteValue {
U32(u32),
I32(i32),
F32(f32),
F64(f64),
U64(u64),
Bool(bool),
Str(String),
ArrayStr(Vec<String>),
ArrayF32(Vec<f32>),
ArrayU32(Vec<u32>),
}
pub use MetadataWriteValue as MetadataValue;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
#[allow(non_camel_case_types)]
pub enum TensorType {
F32 = 0,
F16 = 1,
Q4_0 = 2,
TQ2_0 = 35,
Q1_0G128 = 41,
TQ2_0_g128 = 42,
}
impl TensorType {
pub fn block_size(self) -> usize {
match self {
Self::F32 | Self::F16 => 1,
Self::Q4_0 => 32,
Self::Q1_0G128 => 128,
Self::TQ2_0_g128 => 128,
Self::TQ2_0 => 256,
}
}
pub fn block_bytes(self) -> usize {
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::Q4_0 => 18,
Self::Q1_0G128 => 18, Self::TQ2_0_g128 => 34, Self::TQ2_0 => 66, }
}
pub fn expected_bytes(self, element_count: u64) -> u64 {
let block_size = self.block_size() as u64;
let block_bytes = self.block_bytes() as u64;
let num_blocks = element_count.div_ceil(block_size);
num_blocks * block_bytes
}
}
pub struct TensorEntry {
pub name: String,
pub shape: Vec<u64>,
pub tensor_type: TensorType,
pub data: Vec<u8>,
}
pub struct GgufWriter {
metadata: Vec<(String, MetadataWriteValue)>,
tensors: Vec<TensorEntry>,
alignment: usize,
}
impl GgufWriter {
pub fn new() -> Self {
Self {
metadata: Vec::new(),
tensors: Vec::new(),
alignment: 32,
}
}
pub fn add_metadata(&mut self, key: &str, value: MetadataWriteValue) -> &mut Self {
self.metadata.push((key.to_string(), value));
self
}
pub fn add_tensor(&mut self, entry: TensorEntry) -> &mut Self {
self.tensors.push(entry);
self
}
pub fn set_alignment(&mut self, alignment: usize) -> &mut Self {
self.alignment = alignment;
self
}
pub fn write<W: Write>(&self, out: &mut W) -> Result<usize, WriteError> {
let mut pos: usize = 0;
const DEFAULT_ALIGNMENT: usize = 32;
let has_alignment = self.metadata.iter().any(|(k, _)| k == "general.alignment");
let alignment_entry: Option<(String, MetadataWriteValue)> =
if !has_alignment && self.alignment != DEFAULT_ALIGNMENT {
Some((
"general.alignment".to_string(),
MetadataWriteValue::U32(self.alignment as u32),
))
} else {
None
};
let effective_kv_count =
self.metadata.len() + if alignment_entry.is_some() { 1 } else { 0 };
const GGUF_MAGIC: u32 = 0x4655_4747;
Self::write_le_u32(out, GGUF_MAGIC)?;
pos += 4;
Self::write_le_u32(out, 3)?;
pos += 4;
Self::write_le_u64(out, self.tensors.len() as u64)?;
pos += 8;
Self::write_le_u64(out, effective_kv_count as u64)?;
pos += 8;
if let Some((ref key, ref value)) = alignment_entry {
pos += Self::write_string(out, key)?;
pos += Self::write_metadata_value(out, value)?;
}
for (key, value) in &self.metadata {
pos += Self::write_string(out, key)?;
pos += Self::write_metadata_value(out, value)?;
}
let mut data_offsets: Vec<u64> = Vec::with_capacity(self.tensors.len());
let mut running_offset: u64 = 0;
for entry in &self.tensors {
data_offsets.push(running_offset);
let element_count: u64 = entry.shape.iter().product();
let expected = entry.tensor_type.expected_bytes(element_count);
running_offset += expected;
}
for (idx, entry) in self.tensors.iter().enumerate() {
let element_count: u64 = entry.shape.iter().product();
let expected = entry.tensor_type.expected_bytes(element_count) as usize;
if entry.data.len() != expected {
return Err(WriteError::DataSizeMismatch {
name: entry.name.clone(),
expected,
got: entry.data.len(),
});
}
pos += Self::write_string(out, &entry.name)?;
let n_dims = entry.shape.len() as u32;
Self::write_le_u32(out, n_dims)?;
pos += 4;
for &dim in &entry.shape {
Self::write_le_u64(out, dim)?;
pos += 8;
}
Self::write_le_u32(out, entry.tensor_type as u32)?;
pos += 4;
Self::write_le_u64(out, data_offsets[idx])?;
pos += 8;
}
let pad = Self::pad_to_alignment(out, pos, self.alignment)?;
pos += pad;
for entry in &self.tensors {
out.write_all(&entry.data)
.map_err(|e| WriteError::Io(e.to_string()))?;
pos += entry.data.len();
}
Ok(pos)
}
pub fn to_bytes(&self) -> Result<Vec<u8>, WriteError> {
let mut buf: Vec<u8> = Vec::new();
self.write(&mut buf)?;
Ok(buf)
}
fn write_string<W: Write>(out: &mut W, s: &str) -> Result<usize, WriteError> {
let bytes = s.as_bytes();
Self::write_le_u64(out, bytes.len() as u64)?;
out.write_all(bytes)
.map_err(|e| WriteError::Io(e.to_string()))?;
Ok(8 + bytes.len())
}
fn write_metadata_value<W: Write>(
out: &mut W,
val: &MetadataWriteValue,
) -> Result<usize, WriteError> {
let mut n: usize = 0;
match val {
MetadataWriteValue::U32(v) => {
Self::write_le_u32(out, GgufType::Uint32 as u32)?;
Self::write_le_u32(out, *v)?;
n += 8;
}
MetadataWriteValue::I32(v) => {
Self::write_le_u32(out, GgufType::Int32 as u32)?;
out.write_all(&v.to_le_bytes())
.map_err(|e| WriteError::Io(e.to_string()))?;
n += 8;
}
MetadataWriteValue::F32(v) => {
Self::write_le_u32(out, GgufType::Float32 as u32)?;
Self::write_le_f32(out, *v)?;
n += 8;
}
MetadataWriteValue::F64(v) => {
Self::write_le_u32(out, GgufType::Float64 as u32)?;
out.write_all(&v.to_le_bytes())
.map_err(|e| WriteError::Io(e.to_string()))?;
n += 12;
}
MetadataWriteValue::U64(v) => {
Self::write_le_u32(out, GgufType::Uint64 as u32)?;
Self::write_le_u64(out, *v)?;
n += 12;
}
MetadataWriteValue::Bool(v) => {
Self::write_le_u32(out, GgufType::Bool as u32)?;
out.write_all(&[if *v { 1u8 } else { 0u8 }])
.map_err(|e| WriteError::Io(e.to_string()))?;
n += 5;
}
MetadataWriteValue::Str(s) => {
Self::write_le_u32(out, GgufType::String as u32)?;
n += 4;
n += Self::write_string(out, s)?;
}
MetadataWriteValue::ArrayStr(items) => {
Self::write_le_u32(out, GgufType::Array as u32)?;
Self::write_le_u32(out, GgufType::String as u32)?;
Self::write_le_u64(out, items.len() as u64)?;
n += 16;
for s in items {
n += Self::write_string(out, s)?;
}
}
MetadataWriteValue::ArrayF32(items) => {
Self::write_le_u32(out, GgufType::Array as u32)?;
Self::write_le_u32(out, GgufType::Float32 as u32)?;
Self::write_le_u64(out, items.len() as u64)?;
n += 16;
for &v in items {
Self::write_le_f32(out, v)?;
n += 4;
}
}
MetadataWriteValue::ArrayU32(items) => {
Self::write_le_u32(out, GgufType::Array as u32)?;
Self::write_le_u32(out, GgufType::Uint32 as u32)?;
Self::write_le_u64(out, items.len() as u64)?;
n += 16;
for &v in items {
Self::write_le_u32(out, v)?;
n += 4;
}
}
}
Ok(n)
}
fn write_le_u32<W: Write>(out: &mut W, v: u32) -> Result<(), WriteError> {
out.write_all(&v.to_le_bytes())
.map_err(|e| WriteError::Io(e.to_string()))
}
fn write_le_u64<W: Write>(out: &mut W, v: u64) -> Result<(), WriteError> {
out.write_all(&v.to_le_bytes())
.map_err(|e| WriteError::Io(e.to_string()))
}
fn write_le_f32<W: Write>(out: &mut W, v: f32) -> Result<(), WriteError> {
out.write_all(&v.to_le_bytes())
.map_err(|e| WriteError::Io(e.to_string()))
}
fn pad_to_alignment<W: Write>(
out: &mut W,
pos: usize,
alignment: usize,
) -> Result<usize, WriteError> {
if alignment == 0 {
return Ok(0);
}
let remainder = pos % alignment;
if remainder == 0 {
return Ok(0);
}
let pad = alignment - remainder;
let zeros = vec![0u8; pad];
out.write_all(&zeros)
.map_err(|e| WriteError::Io(e.to_string()))?;
Ok(pad)
}
}
impl Default for GgufWriter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum WriteError {
#[error("I/O error: {0}")]
Io(String),
#[error("Tensor data size mismatch for {name}: expected {expected}, got {got}")]
DataSizeMismatch {
name: String,
expected: usize,
got: usize,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_alignment_is_32() {
let w = GgufWriter::new();
assert_eq!(w.alignment, 32);
}
#[test]
fn set_alignment_changes_value() {
let mut w = GgufWriter::new();
w.set_alignment(64);
assert_eq!(w.alignment, 64);
}
#[test]
fn empty_file_has_correct_header() {
let w = GgufWriter::new();
let bytes = w.to_bytes().expect("write failed");
assert_eq!(
u32::from_le_bytes(bytes[0..4].try_into().expect("slice")),
0x4655_4747
);
assert_eq!(
u32::from_le_bytes(bytes[4..8].try_into().expect("slice")),
3
);
assert_eq!(
u64::from_le_bytes(bytes[8..16].try_into().expect("slice")),
0
);
assert_eq!(
u64::from_le_bytes(bytes[16..24].try_into().expect("slice")),
0
);
}
#[test]
fn data_size_mismatch_returns_error() {
let mut w = GgufWriter::new();
w.add_tensor(TensorEntry {
name: "bad".to_string(),
shape: vec![4],
tensor_type: TensorType::F32,
data: vec![0u8; 8], });
assert!(matches!(
w.to_bytes(),
Err(WriteError::DataSizeMismatch { .. })
));
}
}