use oxibonsai_core::gguf::writer::{GgufWriter, MetadataWriteValue, TensorEntry, TensorType};
use crate::quantize::{q1_0_g128_size_bytes, quantize_q1_0_g128};
use crate::quantize_int8::quantize_per_channel;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ExportFormat {
Float32,
Q1_0G128,
Int8PerChannel,
TernaryG128,
FP8E4M3,
FP8E5M2,
Q4_0,
Q8_0,
Q4K,
Q5K,
Q6K,
}
#[derive(Debug, Clone)]
pub struct ExportConfig {
pub format: ExportFormat,
pub model_name: String,
pub model_version: String,
pub description: Option<String>,
pub quantize_layers: Option<Vec<String>>,
pub fp32_layers: Vec<String>,
}
impl ExportConfig {
pub fn new(format: ExportFormat, model_name: &str) -> Self {
Self {
format,
model_name: model_name.to_string(),
model_version: "1.0.0".to_string(),
description: None,
quantize_layers: None,
fp32_layers: Vec::new(),
}
}
pub fn with_fp32_layers(mut self, layers: Vec<String>) -> Self {
self.fp32_layers = layers;
self
}
pub fn with_description(mut self, desc: &str) -> Self {
self.description = Some(desc.to_string());
self
}
pub fn default_fp32_exceptions() -> Vec<String> {
vec![
"token_embd.weight".to_string(),
"output_norm.weight".to_string(),
"output.weight".to_string(),
]
}
}
pub struct WeightTensor {
pub name: String,
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
impl WeightTensor {
pub fn new(name: &str, data: Vec<f32>, shape: Vec<usize>) -> Self {
Self {
name: name.to_string(),
data,
shape,
}
}
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
pub fn memory_bytes_f32(&self) -> usize {
self.data.len() * 4
}
}
#[derive(Debug, thiserror::Error)]
pub enum ExportError {
#[error("Quantization error for tensor '{name}': {reason}")]
QuantizeError { name: String, reason: String },
#[error("GGUF write error: {0}")]
WriteError(String),
#[error("No tensors to export")]
Empty,
}
fn should_keep_fp32(name: &str, config: &ExportConfig) -> bool {
if config.fp32_layers.iter().any(|exc| name == exc.as_str()) {
return true;
}
if let Some(ref allowed) = config.quantize_layers {
if !allowed.iter().any(|a| name == a.as_str()) {
return true;
}
}
false
}
fn encode_tensor(
tensor: &WeightTensor,
config: &ExportConfig,
) -> Result<(Vec<u8>, TensorType), ExportError> {
let effective_format = if should_keep_fp32(&tensor.name, config) {
ExportFormat::Float32
} else {
config.format
};
match effective_format {
ExportFormat::Float32 => {
let bytes: Vec<u8> = tensor.data.iter().flat_map(|f| f.to_le_bytes()).collect();
Ok((bytes, TensorType::F32))
}
ExportFormat::Q1_0G128 => {
use crate::quantize::GROUP_SIZE;
let remainder = tensor.data.len() % GROUP_SIZE;
let bytes = if remainder == 0 {
quantize_q1_0_g128(&tensor.data).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?
} else {
let mut padded = tensor.data.clone();
padded.resize(tensor.data.len() + GROUP_SIZE - remainder, 0.0);
quantize_q1_0_g128(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?
};
Ok((bytes, TensorType::Q1_0G128))
}
ExportFormat::Int8PerChannel => {
let num_channels = tensor.shape.first().copied().unwrap_or(1).max(1);
let int8 = quantize_per_channel(&tensor.data, num_channels).map_err(|e| {
ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
}
})?;
let mut bytes: Vec<u8> = Vec::with_capacity(int8.data.len() + int8.scales.len() * 4);
for &q in &int8.data {
bytes.push(q as u8);
}
for &s in &int8.scales {
bytes.extend_from_slice(&s.to_le_bytes());
}
Ok((bytes, TensorType::F32))
}
ExportFormat::TernaryG128 => {
let bytes =
crate::quantize_ternary::quantize_tq2_0_g128(&tensor.data).map_err(|e| {
ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
}
})?;
Ok((bytes, TensorType::TQ2_0_g128))
}
ExportFormat::FP8E4M3 => {
use oxibonsai_core::quant_fp8::{BlockFP8E4M3, QK_FP8};
let remainder = tensor.data.len() % QK_FP8;
let padded: std::borrow::Cow<[f32]> = if remainder == 0 {
std::borrow::Cow::Borrowed(&tensor.data)
} else {
let pad = QK_FP8 - remainder;
let mut v = tensor.data.clone();
v.resize(tensor.data.len() + pad, 0.0_f32);
std::borrow::Cow::Owned(v)
};
let blocks =
BlockFP8E4M3::quantize(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?;
let byte_len = blocks.len() * oxibonsai_core::quant_fp8::BLOCK_FP8_BYTES;
let block_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, byte_len) };
Ok((block_bytes.to_vec(), TensorType::F8_E4M3))
}
ExportFormat::FP8E5M2 => {
use oxibonsai_core::quant_fp8::{BlockFP8E5M2, QK_FP8};
let remainder = tensor.data.len() % QK_FP8;
let padded: std::borrow::Cow<[f32]> = if remainder == 0 {
std::borrow::Cow::Borrowed(&tensor.data)
} else {
let pad = QK_FP8 - remainder;
let mut v = tensor.data.clone();
v.resize(tensor.data.len() + pad, 0.0_f32);
std::borrow::Cow::Owned(v)
};
let blocks =
BlockFP8E5M2::quantize(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?;
let byte_len = blocks.len() * oxibonsai_core::quant_fp8::BLOCK_FP8_BYTES;
let block_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, byte_len) };
Ok((block_bytes.to_vec(), TensorType::F8_E5M2))
}
ExportFormat::Q4_0 => {
use oxibonsai_core::quant_std::{BlockQ4_0, BLOCK_Q4_0_BYTES, QK_Q4_0};
let remainder = tensor.data.len() % QK_Q4_0;
let padded: std::borrow::Cow<[f32]> = if remainder == 0 {
std::borrow::Cow::Borrowed(&tensor.data)
} else {
let pad = QK_Q4_0 - remainder;
let mut v = tensor.data.clone();
v.resize(tensor.data.len() + pad, 0.0_f32);
std::borrow::Cow::Owned(v)
};
let blocks = BlockQ4_0::quantize(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?;
let byte_len = blocks.len() * BLOCK_Q4_0_BYTES;
let block_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, byte_len) };
Ok((block_bytes.to_vec(), TensorType::Q4_0))
}
ExportFormat::Q8_0 => {
use oxibonsai_core::quant_std::{BlockQ8_0, BLOCK_Q8_0_BYTES, QK_Q8_0};
let remainder = tensor.data.len() % QK_Q8_0;
let padded: std::borrow::Cow<[f32]> = if remainder == 0 {
std::borrow::Cow::Borrowed(&tensor.data)
} else {
let pad = QK_Q8_0 - remainder;
let mut v = tensor.data.clone();
v.resize(tensor.data.len() + pad, 0.0_f32);
std::borrow::Cow::Owned(v)
};
let blocks = BlockQ8_0::quantize(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?;
let byte_len = blocks.len() * BLOCK_Q8_0_BYTES;
let block_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, byte_len) };
Ok((block_bytes.to_vec(), TensorType::Q8_0))
}
ExportFormat::Q4K => {
use oxibonsai_core::quant_k::{BlockQ4K, BLOCK_Q4_K_BYTES, QK_K};
let remainder = tensor.data.len() % QK_K;
let padded: std::borrow::Cow<[f32]> = if remainder == 0 {
std::borrow::Cow::Borrowed(&tensor.data)
} else {
let pad = QK_K - remainder;
let mut v = tensor.data.clone();
v.resize(tensor.data.len() + pad, 0.0_f32);
std::borrow::Cow::Owned(v)
};
let blocks = BlockQ4K::quantize(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?;
let byte_len = blocks.len() * BLOCK_Q4_K_BYTES;
let block_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, byte_len) };
Ok((block_bytes.to_vec(), TensorType::Q4_K))
}
ExportFormat::Q5K => {
use oxibonsai_core::quant_k::QK_K;
use oxibonsai_core::quant_k_ext::{BlockQ5K, BLOCK_Q5K_BYTES};
let remainder = tensor.data.len() % QK_K;
let padded: std::borrow::Cow<[f32]> = if remainder == 0 {
std::borrow::Cow::Borrowed(&tensor.data)
} else {
let pad = QK_K - remainder;
let mut v = tensor.data.clone();
v.resize(tensor.data.len() + pad, 0.0_f32);
std::borrow::Cow::Owned(v)
};
let blocks = BlockQ5K::quantize(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?;
let byte_len = blocks.len() * BLOCK_Q5K_BYTES;
let block_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, byte_len) };
Ok((block_bytes.to_vec(), TensorType::Q5_K))
}
ExportFormat::Q6K => {
use oxibonsai_core::quant_k::QK_K;
use oxibonsai_core::quant_k_ext::{BlockQ6K, BLOCK_Q6K_BYTES};
let remainder = tensor.data.len() % QK_K;
let padded: std::borrow::Cow<[f32]> = if remainder == 0 {
std::borrow::Cow::Borrowed(&tensor.data)
} else {
let pad = QK_K - remainder;
let mut v = tensor.data.clone();
v.resize(tensor.data.len() + pad, 0.0_f32);
std::borrow::Cow::Owned(v)
};
let blocks = BlockQ6K::quantize(&padded).map_err(|e| ExportError::QuantizeError {
name: tensor.name.clone(),
reason: e.to_string(),
})?;
let byte_len = blocks.len() * BLOCK_Q6K_BYTES;
let block_bytes: &[u8] =
unsafe { std::slice::from_raw_parts(blocks.as_ptr() as *const u8, byte_len) };
Ok((block_bytes.to_vec(), TensorType::Q6_K))
}
}
}
pub fn export_to_gguf(
tensors: &[WeightTensor],
config: &ExportConfig,
arch_metadata: &[(String, MetadataWriteValue)],
) -> Result<Vec<u8>, ExportError> {
if tensors.is_empty() {
return Err(ExportError::Empty);
}
let mut writer = GgufWriter::new();
writer.add_metadata(
"general.name",
MetadataWriteValue::Str(config.model_name.clone()),
);
writer.add_metadata(
"general.version",
MetadataWriteValue::Str(config.model_version.clone()),
);
if let Some(ref desc) = config.description {
writer.add_metadata("general.description", MetadataWriteValue::Str(desc.clone()));
}
let quant_str = match config.format {
ExportFormat::Float32 => "F32",
ExportFormat::Q1_0G128 => "Q1_0G128",
ExportFormat::Int8PerChannel => "INT8_PER_CHANNEL",
ExportFormat::TernaryG128 => "TQ2_0_g128",
ExportFormat::FP8E4M3 => "F8_E4M3",
ExportFormat::FP8E5M2 => "F8_E5M2",
ExportFormat::Q4_0 => "Q4_0",
ExportFormat::Q8_0 => "Q8_0",
ExportFormat::Q4K => "Q4_K",
ExportFormat::Q5K => "Q5_K",
ExportFormat::Q6K => "Q6_K",
};
writer.add_metadata(
"general.quantization_version",
MetadataWriteValue::Str(quant_str.to_string()),
);
for (key, val) in arch_metadata {
writer.add_metadata(key, val.clone());
}
for tensor in tensors {
if tensor.data.is_empty() {
continue;
}
let (bytes, tensor_type) = encode_tensor(tensor, config)?;
let shape: Vec<u64> = if config.format == ExportFormat::Int8PerChannel
&& !should_keep_fp32(&tensor.name, config)
{
vec![(bytes.len() / 4) as u64]
} else {
tensor.shape.iter().map(|&d| d as u64).collect()
};
writer.add_tensor(TensorEntry {
name: tensor.name.clone(),
shape,
tensor_type,
data: bytes,
});
}
writer
.to_bytes()
.map_err(|e| ExportError::WriteError(e.to_string()))
}
pub fn estimate_export_size(tensors: &[WeightTensor], config: &ExportConfig) -> usize {
tensors
.iter()
.map(|t| {
if t.data.is_empty() {
return 0;
}
let effective_format = if should_keep_fp32(&t.name, config) {
ExportFormat::Float32
} else {
config.format
};
match effective_format {
ExportFormat::Float32 => t.data.len() * 4,
ExportFormat::Q1_0G128 => q1_0_g128_size_bytes(t.data.len()),
ExportFormat::Int8PerChannel => {
let num_channels = t.shape.first().copied().unwrap_or(1).max(1);
t.data.len() + num_channels * 4
}
ExportFormat::TernaryG128 => {
crate::quantize_ternary::tq2_0_g128_size_bytes(t.data.len())
}
ExportFormat::FP8E4M3 | ExportFormat::FP8E5M2 => {
let num_blocks = t.data.len().div_ceil(oxibonsai_core::quant_fp8::QK_FP8);
num_blocks * oxibonsai_core::quant_fp8::BLOCK_FP8_BYTES
}
ExportFormat::Q4_0 => {
let num_blocks = t.data.len().div_ceil(oxibonsai_core::quant_std::QK_Q4_0);
num_blocks * oxibonsai_core::quant_std::BLOCK_Q4_0_BYTES
}
ExportFormat::Q8_0 => {
let num_blocks = t.data.len().div_ceil(oxibonsai_core::quant_std::QK_Q8_0);
num_blocks * oxibonsai_core::quant_std::BLOCK_Q8_0_BYTES
}
ExportFormat::Q4K => {
let num_blocks = t.data.len().div_ceil(oxibonsai_core::quant_k::QK_K);
num_blocks * oxibonsai_core::quant_k::BLOCK_Q4_K_BYTES
}
ExportFormat::Q5K => {
let num_blocks = t.data.len().div_ceil(oxibonsai_core::quant_k::QK_K);
num_blocks * oxibonsai_core::quant_k_ext::BLOCK_Q5K_BYTES
}
ExportFormat::Q6K => {
let num_blocks = t.data.len().div_ceil(oxibonsai_core::quant_k::QK_K);
num_blocks * oxibonsai_core::quant_k_ext::BLOCK_Q6K_BYTES
}
}
})
.sum()
}
#[derive(Debug, Clone)]
pub struct ExportStats {
pub num_tensors: usize,
pub quantized_tensors: usize,
pub fp32_tensors: usize,
pub original_bytes: usize,
pub exported_bytes: usize,
pub compression_ratio: f32,
}
pub fn export_stats(tensors: &[WeightTensor], config: &ExportConfig) -> ExportStats {
let mut quantized = 0usize;
let mut fp32_count = 0usize;
let mut original_bytes = 0usize;
for t in tensors {
original_bytes += t.data.len() * 4;
if should_keep_fp32(&t.name, config) || config.format == ExportFormat::Float32 {
fp32_count += 1;
} else {
quantized += 1;
}
}
let exported_bytes = estimate_export_size(tensors, config);
let compression_ratio = if exported_bytes == 0 {
1.0
} else {
original_bytes as f32 / exported_bytes as f32
};
ExportStats {
num_tensors: tensors.len(),
quantized_tensors: quantized,
fp32_tensors: fp32_count,
original_bytes,
exported_bytes,
compression_ratio,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_export_config_default_fp32_exceptions() {
let exceptions = ExportConfig::default_fp32_exceptions();
assert!(exceptions.contains(&"token_embd.weight".to_string()));
assert!(exceptions.contains(&"output_norm.weight".to_string()));
assert!(exceptions.contains(&"output.weight".to_string()));
assert_eq!(exceptions.len(), 3);
}
#[test]
fn test_weight_tensor_num_elements() {
let t = WeightTensor::new("test", vec![0.0; 256], vec![16, 16]);
assert_eq!(t.num_elements(), 256);
assert_eq!(t.memory_bytes_f32(), 1024);
}
#[test]
fn test_estimate_export_size_fp32() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::Float32, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(size, 256 * 4);
}
#[test]
fn test_estimate_export_size_q1_0() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::Q1_0G128, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size,
2 * 18,
"Q1_0 size for 256 weights should be {}",
2 * 18
);
}
#[test]
fn test_export_stats_compression_ratio() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 512], vec![512])];
let config = ExportConfig::new(ExportFormat::Q1_0G128, "m");
let stats = export_stats(&tensors, &config);
assert!(
stats.compression_ratio > 1.0,
"Q1_0 should compress better than FP32"
);
assert_eq!(stats.quantized_tensors, 1);
assert_eq!(stats.fp32_tensors, 0);
}
#[test]
fn test_export_to_gguf_basic() {
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
vec![1.0; 128],
vec![128],
)];
let config =
ExportConfig::new(ExportFormat::Q1_0G128, "test-model").with_description("unit test");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let magic = u32::from_le_bytes(bytes[0..4].try_into().expect("slice"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
}
#[test]
fn test_export_fp32_tensor_unchanged() {
let data: Vec<f32> = (0..4).map(|i| i as f32).collect();
let tensors = vec![WeightTensor::new("w", data.clone(), vec![4])];
let config = ExportConfig::new(ExportFormat::Float32, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let needle = 3.0_f32.to_le_bytes();
let found = bytes.windows(4).any(|w| w == needle.as_slice());
assert!(found, "float 3.0 should be present in the exported bytes");
}
#[test]
fn test_export_skips_empty_tensors() {
let tensors = vec![
WeightTensor::new("good", vec![1.0; 128], vec![128]),
WeightTensor::new("empty", vec![], vec![0]),
];
let config = ExportConfig::new(ExportFormat::Float32, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let tensor_count = u64::from_le_bytes(bytes[8..16].try_into().expect("slice"));
assert_eq!(tensor_count, 1, "empty tensor should be skipped");
}
#[test]
fn test_estimate_export_size_ternary_g128() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 128], vec![128])];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size, 34,
"128-weight tensor in TernaryG128 should be 34 bytes"
);
}
#[test]
fn test_estimate_export_size_ternary_g128_two_blocks() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size, 68,
"256-weight tensor in TernaryG128 should be 68 bytes"
);
}
#[test]
fn test_export_stats_ternary_g128_compression() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 512], vec![512])];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m");
let stats = export_stats(&tensors, &config);
assert!(
stats.compression_ratio > 1.0,
"TernaryG128 should compress better than FP32"
);
assert_eq!(stats.quantized_tensors, 1);
assert_eq!(stats.fp32_tensors, 0);
}
#[test]
fn test_export_to_gguf_ternary_g128_basic() {
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
vec![1.0; 128],
vec![128],
)];
let config = ExportConfig::new(ExportFormat::TernaryG128, "ternary-model");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("export");
let magic = u32::from_le_bytes(bytes[0..4].try_into().expect("slice"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
}
#[test]
fn test_ternary_g128_fp32_exception_tensors_stay_fp32() {
let tensors = vec![
WeightTensor::new("blk.0.attn_q.weight", vec![1.0; 128], vec![128]),
WeightTensor::new("output_norm.weight", vec![1.0; 128], vec![128]),
];
let config = ExportConfig::new(ExportFormat::TernaryG128, "m")
.with_fp32_layers(vec!["output_norm.weight".to_string()]);
let stats = export_stats(&tensors, &config);
assert_eq!(stats.fp32_tensors, 1, "output_norm.weight should stay FP32");
assert_eq!(
stats.quantized_tensors, 1,
"attn_q.weight should be ternary-quantized"
);
}
#[test]
fn test_export_fp8_e4m3_roundtrip() {
let n_weights = 128usize;
let n_blocks = n_weights / oxibonsai_core::quant_fp8::QK_FP8;
let expected_bytes = n_blocks * oxibonsai_core::quant_fp8::BLOCK_FP8_BYTES;
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
vec![1.0; n_weights],
vec![n_weights],
)];
let config = ExportConfig::new(ExportFormat::FP8E4M3, "fp8-e4m3-model");
let gguf_bytes = export_to_gguf(&tensors, &config, &[]).expect("FP8E4M3 export");
let magic = u32::from_le_bytes(gguf_bytes[0..4].try_into().expect("magic slice"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
assert!(
gguf_bytes.len() >= expected_bytes,
"GGUF file too small: {} < {}",
gguf_bytes.len(),
expected_bytes,
);
}
#[test]
fn test_export_fp8_e5m2_roundtrip() {
let n_weights = 64usize;
let n_blocks = n_weights / oxibonsai_core::quant_fp8::QK_FP8;
let expected_bytes = n_blocks * oxibonsai_core::quant_fp8::BLOCK_FP8_BYTES;
let tensors = vec![WeightTensor::new(
"blk.0.ffn_gate.weight",
vec![2.0; n_weights],
vec![n_weights],
)];
let config = ExportConfig::new(ExportFormat::FP8E5M2, "fp8-e5m2-model");
let gguf_bytes = export_to_gguf(&tensors, &config, &[]).expect("FP8E5M2 export");
let magic = u32::from_le_bytes(gguf_bytes[0..4].try_into().expect("magic slice"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
assert!(
gguf_bytes.len() >= expected_bytes,
"GGUF file too small: {} < {}",
gguf_bytes.len(),
expected_bytes,
);
}
#[test]
fn test_export_fp8_size_estimate() {
let tensors_32 = vec![WeightTensor::new("w", vec![1.0; 32], vec![32])];
let config_e4m3 = ExportConfig::new(ExportFormat::FP8E4M3, "m");
let config_e5m2 = ExportConfig::new(ExportFormat::FP8E5M2, "m");
assert_eq!(
estimate_export_size(&tensors_32, &config_e4m3),
34,
"32 weights in FP8E4M3 → 1 block → 34 bytes"
);
assert_eq!(
estimate_export_size(&tensors_32, &config_e5m2),
34,
"32 weights in FP8E5M2 → 1 block → 34 bytes"
);
let tensors_256 = vec![WeightTensor::new("w", vec![1.0; 256], vec![256])];
assert_eq!(
estimate_export_size(&tensors_256, &config_e4m3),
8 * 34,
"256 weights → 8 blocks → 272 bytes"
);
let stats = export_stats(&tensors_256, &config_e4m3);
assert!(
stats.compression_ratio > 1.0,
"FP8E4M3 should compress better than FP32"
);
assert!(
stats.compression_ratio > 3.0,
"FP8E4M3 compression ratio should be > 3.0, got {}",
stats.compression_ratio
);
assert_eq!(stats.quantized_tensors, 1);
assert_eq!(stats.fp32_tensors, 0);
}
#[test]
fn test_fp8_fp32_exception_tensors_stay_fp32() {
let tensors = vec![
WeightTensor::new("blk.0.attn_q.weight", vec![1.0; 64], vec![64]),
WeightTensor::new("output_norm.weight", vec![1.0; 64], vec![64]),
];
let config = ExportConfig::new(ExportFormat::FP8E4M3, "m")
.with_fp32_layers(vec!["output_norm.weight".to_string()]);
let stats = export_stats(&tensors, &config);
assert_eq!(stats.fp32_tensors, 1, "output_norm.weight should stay FP32");
assert_eq!(
stats.quantized_tensors, 1,
"attn_q.weight should be FP8-quantized"
);
}
#[test]
fn test_export_q4_0_roundtrip() {
use oxibonsai_core::quant_std::{BlockQ4_0, BLOCK_Q4_0_BYTES, QK_Q4_0};
let n = 64usize;
let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.25 - 8.0).collect();
let config = ExportConfig::new(ExportFormat::Q4_0, "q4-0-model");
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
input.clone(),
vec![n],
)];
let gguf_bytes = export_to_gguf(&tensors, &config, &[]).expect("Q4_0 export");
let magic = u32::from_le_bytes(gguf_bytes[0..4].try_into().expect("magic"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
let expected_raw = (n / QK_Q4_0) * BLOCK_Q4_0_BYTES;
assert!(
gguf_bytes.len() >= expected_raw,
"GGUF file ({} bytes) must cover at least {} raw data bytes",
gguf_bytes.len(),
expected_raw,
);
let blocks = BlockQ4_0::quantize(&input).expect("Q4_0 quantize");
assert_eq!(blocks.len(), n / QK_Q4_0, "block count matches");
let mut output = vec![0.0f32; n];
BlockQ4_0::dequant(&blocks, &mut output).expect("Q4_0 dequant");
let max_range = input.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let max_err = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let threshold = max_range * 0.15;
assert!(
max_err <= threshold,
"Q4_0 roundtrip max error {max_err} > threshold {threshold} (max_range={max_range})"
);
}
#[test]
fn test_export_q8_0_roundtrip() {
use oxibonsai_core::quant_std::{BlockQ8_0, BLOCK_Q8_0_BYTES, QK_Q8_0};
let n = 64usize;
let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5 - 16.0).collect();
let config = ExportConfig::new(ExportFormat::Q8_0, "q8-0-model");
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
input.clone(),
vec![n],
)];
let gguf_bytes = export_to_gguf(&tensors, &config, &[]).expect("Q8_0 export");
let magic = u32::from_le_bytes(gguf_bytes[0..4].try_into().expect("magic"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
let expected_raw = (n / QK_Q8_0) * BLOCK_Q8_0_BYTES;
assert!(
gguf_bytes.len() >= expected_raw,
"GGUF file ({} bytes) must cover at least {} raw Q8_0 data bytes",
gguf_bytes.len(),
expected_raw,
);
let blocks = BlockQ8_0::quantize(&input).expect("Q8_0 quantize");
assert_eq!(blocks.len(), n / QK_Q8_0, "block count matches");
let mut output = vec![0.0f32; n];
BlockQ8_0::dequant(&blocks, &mut output).expect("Q8_0 dequant");
let max_range = input.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let max_err = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let threshold = max_range * 0.01;
assert!(
max_err <= threshold,
"Q8_0 roundtrip max error {max_err} > threshold {threshold} (max_range={max_range})"
);
}
#[test]
fn test_export_q4k_roundtrip() {
use oxibonsai_core::quant_k::{BlockQ4K, BLOCK_Q4_K_BYTES, QK_K};
let n = 512usize;
let input: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.1 - 25.6).sin()).collect();
let config = ExportConfig::new(ExportFormat::Q4K, "q4k-model");
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
input.clone(),
vec![n],
)];
let gguf_bytes = export_to_gguf(&tensors, &config, &[]).expect("Q4K export");
let magic = u32::from_le_bytes(gguf_bytes[0..4].try_into().expect("magic"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
let expected_raw = (n / QK_K) * BLOCK_Q4_K_BYTES;
assert!(
gguf_bytes.len() >= expected_raw,
"GGUF file ({} bytes) must cover at least {} raw Q4K data bytes",
gguf_bytes.len(),
expected_raw,
);
let blocks = BlockQ4K::quantize(&input).expect("Q4K quantize");
assert_eq!(blocks.len(), n / QK_K, "Q4K block count matches");
let mut output = vec![0.0f32; n];
BlockQ4K::dequant(&blocks, &mut output).expect("Q4K dequant");
let max_range = input.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let max_err = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let threshold = (max_range * 0.08).max(0.1);
assert!(
max_err <= threshold,
"Q4K roundtrip max error {max_err} > threshold {threshold}"
);
}
#[test]
fn test_export_q5k_roundtrip() {
use oxibonsai_core::quant_k::QK_K;
use oxibonsai_core::quant_k_ext::{BlockQ5K, BLOCK_Q5K_BYTES};
let n = 512usize;
let input: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.07 - 17.9).cos()).collect();
let config = ExportConfig::new(ExportFormat::Q5K, "q5k-model");
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
input.clone(),
vec![n],
)];
let gguf_bytes = export_to_gguf(&tensors, &config, &[]).expect("Q5K export");
let magic = u32::from_le_bytes(gguf_bytes[0..4].try_into().expect("magic"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
let expected_raw = (n / QK_K) * BLOCK_Q5K_BYTES;
assert!(
gguf_bytes.len() >= expected_raw,
"GGUF file ({} bytes) must cover at least {} raw Q5K data bytes",
gguf_bytes.len(),
expected_raw,
);
let blocks = BlockQ5K::quantize(&input).expect("Q5K quantize");
assert_eq!(blocks.len(), n / QK_K, "Q5K block count matches");
let mut output = vec![0.0f32; n];
BlockQ5K::dequant(&blocks, &mut output).expect("Q5K dequant");
let max_range = input.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let max_err = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let threshold = (max_range * 0.08).max(0.1);
assert!(
max_err <= threshold,
"Q5K roundtrip max error {max_err} > threshold {threshold}"
);
}
#[test]
fn test_export_q6k_roundtrip() {
use oxibonsai_core::quant_k::QK_K;
use oxibonsai_core::quant_k_ext::{BlockQ6K, BLOCK_Q6K_BYTES};
let n = 512usize;
let input: Vec<f32> = (0..n).map(|i| (i as f32) * 0.05 - 12.8).collect();
let config = ExportConfig::new(ExportFormat::Q6K, "q6k-model");
let tensors = vec![WeightTensor::new(
"blk.0.attn_q.weight",
input.clone(),
vec![n],
)];
let gguf_bytes = export_to_gguf(&tensors, &config, &[]).expect("Q6K export");
let magic = u32::from_le_bytes(gguf_bytes[0..4].try_into().expect("magic"));
assert_eq!(magic, 0x4655_4747, "expected GGUF magic");
let expected_raw = (n / QK_K) * BLOCK_Q6K_BYTES;
assert!(
gguf_bytes.len() >= expected_raw,
"GGUF file ({} bytes) must cover at least {} raw Q6K data bytes",
gguf_bytes.len(),
expected_raw,
);
let blocks = BlockQ6K::quantize(&input).expect("Q6K quantize");
assert_eq!(blocks.len(), n / QK_K, "Q6K block count matches");
let mut output = vec![0.0f32; n];
BlockQ6K::dequant(&blocks, &mut output).expect("Q6K dequant");
let max_range = input.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let max_err = input
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let threshold = (max_range * 0.05).max(0.1);
assert!(
max_err <= threshold,
"Q6K roundtrip max error {max_err} > threshold {threshold}"
);
}
#[test]
fn test_estimate_export_size_q4_0() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 64], vec![64])];
let config = ExportConfig::new(ExportFormat::Q4_0, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(size, 2 * 18, "Q4_0: 64 weights → 2 blocks → 36 bytes");
}
#[test]
fn test_estimate_export_size_q8_0() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 64], vec![64])];
let config = ExportConfig::new(ExportFormat::Q8_0, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(size, 2 * 34, "Q8_0: 64 weights → 2 blocks → 68 bytes");
}
#[test]
fn test_estimate_export_size_q4k() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 512], vec![512])];
let config = ExportConfig::new(ExportFormat::Q4K, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size,
2 * 144,
"Q4K: 512 weights → 2 super-blocks → 288 bytes"
);
}
#[test]
fn test_estimate_export_size_q5k() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 512], vec![512])];
let config = ExportConfig::new(ExportFormat::Q5K, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size,
2 * 176,
"Q5K: 512 weights → 2 super-blocks → 352 bytes"
);
}
#[test]
fn test_estimate_export_size_q6k() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 512], vec![512])];
let config = ExportConfig::new(ExportFormat::Q6K, "m");
let size = estimate_export_size(&tensors, &config);
assert_eq!(
size,
2 * 210,
"Q6K: 512 weights → 2 super-blocks → 420 bytes"
);
}
#[test]
fn test_export_format_type_name_q4_0() {
let tensors = vec![WeightTensor::new("blk.0.w", vec![1.0; 64], vec![64])];
let config = ExportConfig::new(ExportFormat::Q4_0, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("Q4_0 export");
let needle = b"Q4_0";
let found = bytes.windows(needle.len()).any(|w| w == needle);
assert!(
found,
"GGUF metadata should contain \"Q4_0\" quantization string"
);
}
#[test]
fn test_export_format_type_name_q4k() {
let tensors = vec![WeightTensor::new("blk.0.w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::Q4K, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("Q4K export");
let needle = b"Q4_K";
let found = bytes.windows(needle.len()).any(|w| w == needle);
assert!(
found,
"GGUF metadata should contain \"Q4_K\" quantization string"
);
}
#[test]
fn test_export_format_type_name_q5k() {
let tensors = vec![WeightTensor::new("blk.0.w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::Q5K, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("Q5K export");
let needle = b"Q5_K";
let found = bytes.windows(needle.len()).any(|w| w == needle);
assert!(
found,
"GGUF metadata should contain \"Q5_K\" quantization string"
);
}
#[test]
fn test_export_format_type_name_q6k() {
let tensors = vec![WeightTensor::new("blk.0.w", vec![1.0; 256], vec![256])];
let config = ExportConfig::new(ExportFormat::Q6K, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("Q6K export");
let needle = b"Q6_K";
let found = bytes.windows(needle.len()).any(|w| w == needle);
assert!(
found,
"GGUF metadata should contain \"Q6_K\" quantization string"
);
}
#[test]
fn test_export_format_type_name_q8_0() {
let tensors = vec![WeightTensor::new("blk.0.w", vec![1.0; 64], vec![64])];
let config = ExportConfig::new(ExportFormat::Q8_0, "m");
let bytes = export_to_gguf(&tensors, &config, &[]).expect("Q8_0 export");
let needle = b"Q8_0";
let found = bytes.windows(needle.len()).any(|w| w == needle);
assert!(
found,
"GGUF metadata should contain \"Q8_0\" quantization string"
);
}
#[test]
fn test_q4_0_produces_smaller_output_than_float32() {
let tensors = vec![WeightTensor::new("w", vec![1.0; 32], vec![32])];
let config_q4 = ExportConfig::new(ExportFormat::Q4_0, "m");
let config_f32 = ExportConfig::new(ExportFormat::Float32, "m");
let q4_size = estimate_export_size(&tensors, &config_q4);
let f32_size = estimate_export_size(&tensors, &config_f32);
assert_eq!(q4_size, 18, "Q4_0 32 weights = 18 bytes");
assert_eq!(f32_size, 128, "Float32 32 weights = 128 bytes");
assert!(
q4_size < f32_size,
"Q4_0 ({q4_size} bytes) must be smaller than Float32 ({f32_size} bytes)"
);
}
#[test]
fn test_q8_0_compression_vs_float32() {
let tensors = vec![WeightTensor::new("w", vec![0.5f32; 32], vec![32])];
let config_q8 = ExportConfig::new(ExportFormat::Q8_0, "m");
let config_f32 = ExportConfig::new(ExportFormat::Float32, "m");
let q8_size = estimate_export_size(&tensors, &config_q8);
let f32_size = estimate_export_size(&tensors, &config_f32);
assert!(
q8_size < f32_size,
"Q8_0 ({q8_size} bytes) must be smaller than Float32 ({f32_size} bytes)"
);
}
#[test]
fn test_new_formats_fp32_exception_respected() {
let tensors = vec![
WeightTensor::new("blk.0.attn_q.weight", vec![1.0; 64], vec![64]),
WeightTensor::new("output_norm.weight", vec![1.0; 64], vec![64]),
];
let fp32_exceptions = vec!["output_norm.weight".to_string()];
for fmt in &[
ExportFormat::Q4_0,
ExportFormat::Q8_0,
ExportFormat::Q4K,
ExportFormat::Q5K,
ExportFormat::Q6K,
] {
let config = ExportConfig::new(*fmt, "m").with_fp32_layers(fp32_exceptions.clone());
let stats = export_stats(&tensors, &config);
assert_eq!(
stats.fp32_tensors, 1,
"output_norm.weight must stay FP32 for format {fmt:?}"
);
assert_eq!(
stats.quantized_tensors, 1,
"attn_q.weight must be quantized for format {fmt:?}"
);
}
}
}