use oxibonsai_core::gguf::writer::{GgufWriter, MetadataWriteValue, TensorEntry, TensorType};
use crate::quantize::{q1_0_g128_size_bytes, quantize_q1_0_g128};
use crate::quantize_int8::quantize_per_channel;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ExportFormat {
Float32,
Q1_0G128,
Int8PerChannel,
TernaryG128,
}
#[derive(Debug, Clone)]
pub struct ExportConfig {
pub format: ExportFormat,
pub model_name: String,
pub model_version: String,
pub description: Option<String>,
pub quantize_layers: Option<Vec<String>>,
pub fp32_layers: Vec<String>,
}
impl ExportConfig {
pub fn new(format: ExportFormat, model_name: &str) -> Self {
Self {
format,
model_name: model_name.to_string(),
model_version: "1.0.0".to_string(),
description: None,
quantize_layers: None,
fp32_layers: Vec::new(),
}
}
pub fn with_fp32_layers(mut self, layers: Vec<String>) -> Self {
self.fp32_layers = layers;
self
}
pub fn with_description(mut self, desc: &str) -> Self {
self.description = Some(desc.to_string());
self
}
pub fn default_fp32_exceptions() -> Vec<String> {
vec![
"token_embd.weight".to_string(),
"output_norm.weight".to_string(),
"output.weight".to_string(),
]
}
}
pub struct WeightTensor {
pub name: String,
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
impl WeightTensor {
pub fn new(name: &str, data: Vec<f32>, shape: Vec<usize>) -> Self {
Self {
name: name.to_string(),
data,
shape,
}
}
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
pub fn memory_bytes_f32(&self) -> usize {
self.data.len() * 4
}
}
#[derive(Debug, thiserror::Error)]
pub enum ExportError {
#[error("Quantization error for tensor '{name}': {reason}")]
QuantizeError { name: String, reason: String },
#[error("GGUF write error: {0}")]
WriteError(String),
#[error("No tensors to export")]
Empty,
}
fn should_keep_fp32(name: &str, config: &ExportConfig) -> bool {
if config.fp32_layers.iter().any(|exc| name == exc.as_str()) {
return true;
}
if let Some(ref allowed) = config.quantize_layers {
if !allowed.iter().any(|a| name == a.as_str()) {
return true;
}
}
false
}
fn encode_tensor(
tensor: &WeightTensor,
config: &ExportConfig,
) -> Result<(Vec<u8>, TensorType), ExportError> {
let effective_format = if should_keep_fp32(&tensor.name, config) {
ExportFormat::Float32
} else {
config.format
};
match effective_format {
ExportFormat::Float32 => {
let bytes: Vec<u8> = tensor.data.iter().flat_map(|f| f.to_le_bytes()).collect();
Ok((bytes, TensorType::F32))
}
ExportFormat::Q1_0G128 => {
use crate::quantize::GROUP_SIZE;
let remainder = tensor.data.len() % GROUP_SIZE;
let bytes = if remainder == 0 {
quantize_q1_0_g128(&tensor.data).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?
} else {
let mut padded = tensor.data.clone();
padded.resize(tensor.data.len() + GROUP_SIZE - remainder, 0.0);
quantize_q1_0_g128(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?
};
Ok((bytes, TensorType::Q1_0G128))
}
ExportFormat::Int8PerChannel => {
let num_channels = tensor.shape.first().copied().unwrap_or(1).max(1);
let int8 = quantize_per_channel(&tensor.data, num_channels).map_err(|e| {
ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
}
})?;
let mut bytes: Vec<u8> = Vec::with_capacity(int8.data.len() + int8.scales.len() * 4);
for &q in &int8.data {
bytes.push(q as u8);
}
for &s in &int8.scales {
bytes.extend_from_slice(&s.to_le_bytes());
}
Ok((bytes, TensorType::F32))
}
ExportFormat::TernaryG128 => {
let bytes =
crate::quantize_ternary::quantize_tq2_0_g128(&tensor.data).map_err(|e| {
ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
}
})?;
Ok((bytes, TensorType::TQ2_0_g128))
}
}
}
pub fn export_to_gguf(
tensors: &[WeightTensor],
config: &ExportConfig,
arch_metadata: &[(String, MetadataWriteValue)],
) -> Result<Vec<u8>, ExportError> {
if tensors.is_empty() {
return Err(ExportError::Empty);
}
let mut writer = GgufWriter::new();
writer.add_metadata(
"general.name",
MetadataWriteValue::Str(config.model_name.clone()),
);
writer.add_metadata(
"general.version",
MetadataWriteValue::Str(config.model_version.clone()),
);
if let Some(ref desc) = config.description {
writer.add_metadata("general.description", MetadataWriteValue::Str(desc.clone()));
}
let quant_str = match config.format {
ExportFormat::Float32 => "F32",
ExportFormat::Q1_0G128 => "Q1_0G128",
ExportFormat::Int8PerChannel => "INT8_PER_CHANNEL",
ExportFormat::TernaryG128 => "TQ2_0_g128",
};
writer.add_metadata(
"general.quantization_version",
MetadataWriteValue::Str(quant_str.to_string()),
);
for (key, val) in arch_metadata {
writer.add_metadata(key, val.clone());
}
for tensor in tensors {
if tensor.data.is_empty() {
continue;
}
let (bytes, tensor_type) = encode_tensor(tensor, config)?;
let shape: Vec<u64> = if config.format == ExportFormat::Int8PerChannel
&& !should_keep_fp32(&tensor.name, config)
{
vec![(bytes.len() / 4) as u64]
} else {
tensor.shape.iter().map(|&d| d as u64).collect()
};
writer.add_tensor(TensorEntry {
name: tensor.name.clone(),
shape,
tensor_type,
data: bytes,
});
}
writer
.to_bytes()
.map_err(|e| ExportError::WriteError(e.to_string()))
}
pub fn estimate_export_size(tensors: &[WeightTensor], config: &ExportConfig) -> usize {
tensors
.iter()
.map(|t| {
if t.data.is_empty() {
return 0;
}
let effective_format = if should_keep_fp32(&t.name, config) {
ExportFormat::Float32
} else {
config.format
};
match effective_format {
ExportFormat::Float32 => t.data.len() * 4,
ExportFormat::Q1_0G128 => q1_0_g128_size_bytes(t.data.len()),
ExportFormat::Int8PerChannel => {
let num_channels = t.shape.first().copied().unwrap_or(1).max(1);
t.data.len() + num_channels * 4
}
ExportFormat::TernaryG128 => {
crate::quantize_ternary::tq2_0_g128_size_bytes(t.data.len())
}
}
})
.sum()
}
#[derive(Debug, Clone)]
pub struct ExportStats {
pub num_tensors: usize,
pub quantized_tensors: usize,
pub fp32_tensors: usize,
pub original_bytes: usize,
pub exported_bytes: usize,
pub compression_ratio: f32,
}
pub fn export_stats(tensors: &[WeightTensor], config: &ExportConfig) -> ExportStats {
let mut quantized = 0usize;
let mut fp32_count = 0usize;
let mut original_bytes = 0usize;
for t in tensors {
original_bytes += t.data.len() * 4;
if should_keep_fp32(&t.name, config) || config.format == ExportFormat::Float32 {
fp32_count += 1;
} else {
quantized += 1;
}
}
let exported_bytes = estimate_export_size(tensors, config);
let compression_ratio = if exported_bytes == 0 {
1.0
} else {
original_bytes as f32 / exported_bytes as f32
};
ExportStats {
num_tensors: tensors.len(),
quantized_tensors: quantized,
fp32_tensors: fp32_count,
original_bytes,
exported_bytes,
compression_ratio,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_export_config_default_fp32_exceptions() {
let exceptions = ExportConfig::default_fp32_exceptions();
assert!(exceptions.contains(&"token_embd.weight".to_string()));
assert!(exceptions.contains(&"output_norm.weight".to_string()));
assert!(exceptions.contains(&"output.weight".to_string()));
assert_eq!(exceptions.len(), 3);
}
#[test]
fn test_weight_tensor_num_elements() {
let t = WeightTensor::new("test", vec![0.0; 256], vec![16, 16]);
assert_eq!(t.num_elements(), 256);
assert_eq!(t.memory_bytes_f32(), 1024);
}
#[test]
fn test_estimate_export_size_fp32() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::Float32, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(size, 256 * 4);
}
#[test]
fn test_estimate_export_size_q1_0() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::Q1_0G128, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size,
2 * 18,
"Q1_0 size for 256 weights should be {}",
2 * 18
);
}
#[test]
fn test_export_stats_compression_ratio() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 512], vec![512])];
let config = ExportConfig::new(ExportFormat::Q1_0G128, "m");
let stats = export_stats(&tensors, &config);
assert!(
stats.compression_ratio > 1.0,
"Q1_0 should compress better than FP32"
);
assert_eq!(stats.quantized_tensors, 1);
assert_eq!(stats.fp32_tensors, 0);
}
#[test]
fn test_export_to_gguf_basic() {
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
vec![1.0; 128],
vec![128],
)];
let config =
ExportConfig::new(ExportFormat::Q1_0G128, "test-model").with_description("unit test");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let magic = u32::from_le_bytes(bytes[0..4].try_into().expect("slice"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
}
#[test]
fn test_export_fp32_tensor_unchanged() {
let data: Vec<f32> = (0..4).map(|i| i as f32).collect();
let tensors = vec![WeightTensor::new("w", data.clone(), vec![4])];
let config = ExportConfig::new(ExportFormat::Float32, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let needle = 3.0_f32.to_le_bytes();
let found = bytes.windows(4).any(|w| w == needle.as_slice());
assert!(found, "float 3.0 should be present in the exported bytes");
}
#[test]
fn test_export_skips_empty_tensors() {
let tensors = vec![
WeightTensor::new("good", vec![1.0; 128], vec![128]),
WeightTensor::new("empty", vec![], vec![0]),
];
let config = ExportConfig::new(ExportFormat::Float32, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let tensor_count = u64::from_le_bytes(bytes[8..16].try_into().expect("slice"));
assert_eq!(tensor_count, 1, "empty tensor should be skipped");
}
#[test]
fn test_estimate_export_size_ternary_g128() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 128], vec![128])];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size, 34,
"128-weight tensor in TernaryG128 should be 34 bytes"
);
}
#[test]
fn test_estimate_export_size_ternary_g128_two_blocks() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size, 68,
"256-weight tensor in TernaryG128 should be 68 bytes"
);
}
#[test]
fn test_export_stats_ternary_g128_compression() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 512], vec![512])];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m");
let stats = export_stats(&tensors, &config);
assert!(
stats.compression_ratio > 1.0,
"TernaryG128 should compress better than FP32"
);
assert_eq!(stats.quantized_tensors, 1);
assert_eq!(stats.fp32_tensors, 0);
}
#[test]
fn test_export_to_gguf_ternary_g128_basic() {
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
vec![1.0; 128],
vec![128],
)];
let config = ExportConfig::new(ExportFormat::TernaryG128, "ternary-model");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let magic = u32::from_le_bytes(bytes[0..4].try_into().expect("slice"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
}
#[test]
fn test_ternary_g128_fp32_exception_tensors_stay_fp32() {
let tensors = vec![
WeightTensor::new("blk.0.attn_q.weight", vec![1.0; 128], vec![128]),
WeightTensor::new("output_norm.weight", vec![1.0; 128], vec![128]),
];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m")
.with_fp32_layers(vec!["output_norm.weight".to_string()]);
let stats = export_stats(&tensors, &config);
assert_eq!(stats.fp32_tensors, 1, "output_norm.weight should stay FP32");
assert_eq!(
stats.quantized_tensors, 1,
"attn_q.weight should be ternary-quantized"
);
}
}