use aprender::format::gguf::GgmlType;
use crate::quant::{GGUF_BLOCK_SIZE, Q4_0, Q8_0};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgufQuantization {
None,
Q4_0,
Q8_0,
}
pub fn quantize_to_gguf_bytes(data: &[f32], quant: GgufQuantization) -> (Vec<u8>, GgmlType) {
match quant {
GgufQuantization::None => {
let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
(bytes, GgmlType::F32)
}
GgufQuantization::Q4_0 => {
let quantized = Q4_0::quantize(data);
(encode_q4_0_blocks(&quantized), GgmlType::Q4_0)
}
GgufQuantization::Q8_0 => {
let quantized = Q8_0::quantize(data);
(encode_q8_0_blocks(&quantized), GgmlType::Q8_0)
}
}
}
fn encode_q4_0_blocks(q: &Q4_0) -> Vec<u8> {
let num_blocks = q.num_blocks();
let mut bytes = Vec::with_capacity(num_blocks * 18);
for block_idx in 0..num_blocks {
let scale_f16 = half::f16::from_f32(q.scales[block_idx]);
bytes.extend_from_slice(&scale_f16.to_le_bytes());
let data_start = block_idx * 16;
let data_end = (data_start + 16).min(q.data.len());
bytes.extend_from_slice(&q.data[data_start..data_end]);
let pad = 16 - (data_end - data_start);
bytes.extend(std::iter::repeat_n(0u8, pad));
}
bytes
}
fn encode_q8_0_blocks(q: &Q8_0) -> Vec<u8> {
let num_blocks = q.num_blocks();
let mut bytes = Vec::with_capacity(num_blocks * 34);
for block_idx in 0..num_blocks {
let scale_f16 = half::f16::from_f32(q.scales[block_idx]);
bytes.extend_from_slice(&scale_f16.to_le_bytes());
let data_start = block_idx * GGUF_BLOCK_SIZE;
let data_end = (data_start + GGUF_BLOCK_SIZE).min(q.data.len());
let block_data = &q.data[data_start..data_end];
for &val in block_data {
bytes.push(val as u8);
}
let pad = GGUF_BLOCK_SIZE - block_data.len();
bytes.extend(std::iter::repeat_n(0u8, pad));
}
bytes
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_encode_q4_0_block_size() {
let values: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
let q = Q4_0::quantize(&values);
let bytes = encode_q4_0_blocks(&q);
assert_eq!(bytes.len(), 18);
}
#[test]
fn test_encode_q8_0_block_size() {
let values: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
let q = Q8_0::quantize(&values);
let bytes = encode_q8_0_blocks(&q);
assert_eq!(bytes.len(), 34);
}
#[test]
fn test_quantize_to_gguf_bytes_none() {
let data = [1.0f32, 2.0, 3.0, 4.0];
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
assert_eq!(dtype, GgmlType::F32);
assert_eq!(bytes.len(), 16); let val = f32::from_le_bytes(bytes[0..4].try_into().expect("conversion should succeed"));
assert!((val - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_quantize_to_gguf_bytes_q4_0() {
let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
assert_eq!(dtype, GgmlType::Q4_0);
assert_eq!(bytes.len(), 18); }
#[test]
fn test_quantize_to_gguf_bytes_q8_0() {
let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
assert_eq!(dtype, GgmlType::Q8_0);
assert_eq!(bytes.len(), 34); }
#[test]
fn test_falsify_quantize_empty_data_none() {
let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::None);
assert_eq!(dtype, GgmlType::F32);
assert!(bytes.is_empty());
}
#[test]
fn test_falsify_quantize_empty_data_q4_0() {
let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::Q4_0);
assert_eq!(dtype, GgmlType::Q4_0);
assert!(
bytes.is_empty(),
"empty input must produce empty output, got {} bytes",
bytes.len()
);
}
#[test]
fn test_falsify_quantize_empty_data_q8_0() {
let (bytes, dtype) = quantize_to_gguf_bytes(&[], GgufQuantization::Q8_0);
assert_eq!(dtype, GgmlType::Q8_0);
assert!(
bytes.is_empty(),
"empty input must produce empty output, got {} bytes",
bytes.len()
);
}
#[test]
fn test_falsify_quantize_single_element_q4_0() {
let (bytes, dtype) = quantize_to_gguf_bytes(&[42.0], GgufQuantization::Q4_0);
assert_eq!(dtype, GgmlType::Q4_0);
assert_eq!(bytes.len(), 18);
}
#[test]
fn test_falsify_quantize_single_element_q8_0() {
let (bytes, dtype) = quantize_to_gguf_bytes(&[42.0], GgufQuantization::Q8_0);
assert_eq!(dtype, GgmlType::Q8_0);
assert_eq!(bytes.len(), 34);
}
#[test]
fn test_falsify_quantize_33_elements_q4_0() {
let data: Vec<f32> = (0..33).map(|i| i as f32 * 0.1).collect();
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
assert_eq!(dtype, GgmlType::Q4_0);
assert_eq!(bytes.len(), 2 * 18); }
#[test]
fn test_falsify_quantize_33_elements_q8_0() {
let data: Vec<f32> = (0..33).map(|i| i as f32 * 0.1).collect();
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
assert_eq!(dtype, GgmlType::Q8_0);
assert_eq!(bytes.len(), 2 * 34); }
#[test]
fn test_falsify_quantize_63_elements_q4_0() {
let data: Vec<f32> = (0..63).map(|i| i as f32 * 0.01).collect();
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
assert_eq!(dtype, GgmlType::Q4_0);
assert_eq!(bytes.len(), 2 * 18);
}
#[test]
fn test_falsify_quantize_all_zeros_q4_0() {
let data = [0.0f32; 64];
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
assert_eq!(dtype, GgmlType::Q4_0);
assert_eq!(bytes.len(), 2 * 18);
let scale0 = half::f16::from_le_bytes([bytes[0], bytes[1]]);
assert_eq!(scale0.to_f32(), 0.0, "scale for zero data must be 0");
}
#[test]
fn test_falsify_quantize_all_zeros_q8_0() {
let data = [0.0f32; 32];
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
assert_eq!(dtype, GgmlType::Q8_0);
assert_eq!(bytes.len(), 34);
let scale0 = half::f16::from_le_bytes([bytes[0], bytes[1]]);
assert_eq!(scale0.to_f32(), 0.0, "scale for zero data must be 0");
}
#[test]
fn test_falsify_quantize_extreme_range_q4_0() {
let mut data = vec![0.0f32; 32];
data[0] = 1e30;
data[1] = -1e30;
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
assert_eq!(dtype, GgmlType::Q4_0);
assert_eq!(bytes.len(), 18);
let scale = half::f16::from_le_bytes([bytes[0], bytes[1]]);
assert!(!scale.to_f32().is_nan(), "scale must not be NaN for extreme values");
}
#[test]
fn test_falsify_quantize_extreme_range_q8_0() {
let mut data = vec![0.0f32; 32];
data[0] = 1e30;
data[1] = -1e30;
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
assert_eq!(dtype, GgmlType::Q8_0);
assert_eq!(bytes.len(), 34);
let scale = half::f16::from_le_bytes([bytes[0], bytes[1]]);
assert!(!scale.to_f32().is_nan(), "scale must not be NaN for extreme values");
}
#[test]
fn test_falsify_quantize_f32_exact_byte_layout() {
let data = [std::f32::consts::PI, std::f32::consts::E];
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
assert_eq!(dtype, GgmlType::F32);
assert_eq!(bytes.len(), 8);
let pi_bytes = std::f32::consts::PI.to_le_bytes();
let e_bytes = std::f32::consts::E.to_le_bytes();
assert_eq!(&bytes[0..4], &pi_bytes);
assert_eq!(&bytes[4..8], &e_bytes);
}
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(50))]
#[test]
fn prop_q4_0_encode_correct_block_count(
n_elements in 1usize..256,
) {
let data: Vec<f32> = vec![1.0; n_elements];
let q = Q4_0::quantize(&data);
let bytes = encode_q4_0_blocks(&q);
let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
prop_assert_eq!(bytes.len(), expected_blocks * 18);
}
#[test]
fn prop_q8_0_encode_correct_block_count(
n_elements in 1usize..256,
) {
let data: Vec<f32> = vec![1.0; n_elements];
let q = Q8_0::quantize(&data);
let bytes = encode_q8_0_blocks(&q);
let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
prop_assert_eq!(bytes.len(), expected_blocks * 34);
}
#[test]
fn prop_falsify_quantize_none_preserves_all_bytes(
n_elements in 1usize..128,
) {
let data: Vec<f32> = (0..n_elements).map(|i| i as f32 * 0.7 - 50.0).collect();
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::None);
prop_assert_eq!(dtype, GgmlType::F32);
prop_assert_eq!(bytes.len(), n_elements * 4);
for (i, &expected) in data.iter().enumerate() {
let actual = f32::from_le_bytes(bytes[i*4..(i+1)*4].try_into().expect("conversion should succeed"));
prop_assert!(
(actual - expected).abs() < f32::EPSILON,
"element {i}: expected {expected}, got {actual}"
);
}
}
#[test]
fn prop_falsify_quantize_q4_0_byte_size_invariant(
n_elements in 1usize..512,
) {
let data: Vec<f32> = vec![0.5; n_elements];
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q4_0);
prop_assert_eq!(dtype, GgmlType::Q4_0);
let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
prop_assert_eq!(bytes.len(), expected_blocks * 18);
}
#[test]
fn prop_falsify_quantize_q8_0_byte_size_invariant(
n_elements in 1usize..512,
) {
let data: Vec<f32> = vec![0.5; n_elements];
let (bytes, dtype) = quantize_to_gguf_bytes(&data, GgufQuantization::Q8_0);
prop_assert_eq!(dtype, GgmlType::Q8_0);
let expected_blocks = n_elements.div_ceil(GGUF_BLOCK_SIZE);
prop_assert_eq!(bytes.len(), expected_blocks * 34);
}
}
}