use half::f16;
use oxillama_gguf::GgufTensorType;
use crate::dispatch::KernelDispatcher;
use crate::error::{QuantError, QuantResult};
const BLOCK_SIZE: usize = 32;
const Q4_0_BLOCK_BYTES: usize = 18;
const Q8_0_BLOCK_BYTES: usize = 34;
pub fn quantize_f32_to_q4_0(data: &[f32]) -> QuantResult<Vec<u8>> {
if data.len() % BLOCK_SIZE != 0 {
return Err(QuantError::DimensionMismatch {
expected: (data.len() / BLOCK_SIZE + 1) * BLOCK_SIZE,
got: data.len(),
});
}
let n_blocks = data.len() / BLOCK_SIZE;
let mut out = Vec::with_capacity(n_blocks * Q4_0_BLOCK_BYTES);
for blk_idx in 0..n_blocks {
let blk = &data[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
encode_q4_0_block(blk, &mut out);
}
Ok(out)
}
pub fn quantize_f32_to_q8_0(data: &[f32]) -> QuantResult<Vec<u8>> {
if data.len() % BLOCK_SIZE != 0 {
return Err(QuantError::DimensionMismatch {
expected: (data.len() / BLOCK_SIZE + 1) * BLOCK_SIZE,
got: data.len(),
});
}
let n_blocks = data.len() / BLOCK_SIZE;
let mut out = Vec::with_capacity(n_blocks * Q8_0_BLOCK_BYTES);
for blk_idx in 0..n_blocks {
let blk = &data[blk_idx * BLOCK_SIZE..(blk_idx + 1) * BLOCK_SIZE];
encode_q8_0_block(blk, &mut out);
}
Ok(out)
}
pub fn quantize_f16_to_q4_0(data: &[u16]) -> QuantResult<Vec<u8>> {
let f32_data: Vec<f32> = data.iter().map(|&b| f16::from_bits(b).to_f32()).collect();
quantize_f32_to_q4_0(&f32_data)
}
pub fn quantize_f16_to_q8_0(data: &[u16]) -> QuantResult<Vec<u8>> {
let f32_data: Vec<f32> = data.iter().map(|&b| f16::from_bits(b).to_f32()).collect();
quantize_f32_to_q8_0(&f32_data)
}
pub fn dequantize_to_f32(
data: &[u8],
tensor_type: GgufTensorType,
n_elements: usize,
) -> QuantResult<Vec<f32>> {
let dispatcher = KernelDispatcher::new();
let kernel = dispatcher.get_kernel(tensor_type)?;
let block_size = tensor_type.block_size();
let block_bytes = tensor_type.block_bytes();
if n_elements == 0 {
return Ok(Vec::new());
}
let n_blocks = n_elements.div_ceil(block_size);
let expected_bytes = n_blocks * block_bytes;
if data.len() < expected_bytes {
return Err(QuantError::BufferTooSmall {
needed: expected_bytes,
available: data.len(),
});
}
let mut output = vec![0.0f32; n_blocks * block_size];
for blk_idx in 0..n_blocks {
let byte_offset = blk_idx * block_bytes;
let block = &data[byte_offset..byte_offset + block_bytes];
let out_offset = blk_idx * block_size;
kernel.dequant_block(block, &mut output[out_offset..out_offset + block_size])?;
}
output.truncate(n_elements);
Ok(output)
}
fn encode_q4_0_block(values: &[f32], out: &mut Vec<u8>) {
debug_assert_eq!(values.len(), BLOCK_SIZE);
let max_val = values
.iter()
.copied()
.fold(0.0f32, |acc, v| if v.abs() > acc.abs() { v } else { acc });
let d = max_val / -8.0;
let id = if d != 0.0 { 1.0 / d } else { 0.0 };
let d_fp16 = f16::from_f32(d);
out.extend_from_slice(&d_fp16.to_bits().to_le_bytes());
for pair in 0..BLOCK_SIZE / 2 {
let x0 = values[pair * 2];
let x1 = values[pair * 2 + 1];
let q0 = ((x0 * id + 8.5) as i32).clamp(0, 15) as u8;
let q1 = ((x1 * id + 8.5) as i32).clamp(0, 15) as u8;
out.push(q0 | (q1 << 4));
}
}
fn encode_q8_0_block(values: &[f32], out: &mut Vec<u8>) {
debug_assert_eq!(values.len(), BLOCK_SIZE);
let amax = values.iter().fold(0.0f32, |acc, &v| acc.max(v.abs()));
let d = amax / 127.0;
let id = if d != 0.0 { 1.0 / d } else { 0.0 };
let d_fp16 = f16::from_f32(d);
out.extend_from_slice(&d_fp16.to_bits().to_le_bytes());
for &x in values {
let q = (x * id).round() as i32;
let q_clamped = q.clamp(-128, 127) as i8;
out.push(q_clamped as u8);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn max_abs_error(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.fold(0.0f32, |acc, (&x, &y)| acc.max((x - y).abs()))
}
#[test]
fn q4_0_round_trip_small_values() {
let data: Vec<f32> = (0..32).map(|i| (i as f32 / 15.5) - 1.0).collect();
let quantized = quantize_f32_to_q4_0(&data).expect("quantize failed");
assert_eq!(quantized.len(), Q4_0_BLOCK_BYTES);
let restored =
dequantize_to_f32(&quantized, GgufTensorType::Q4_0, 32).expect("dequantize failed");
assert_eq!(restored.len(), 32);
let err = max_abs_error(&data, &restored);
assert!(err < 0.15, "Q4_0 round-trip error too large: {err}");
}
#[test]
fn q4_0_round_trip_multiple_blocks() {
let data: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
let quantized = quantize_f32_to_q4_0(&data).expect("quantize failed");
assert_eq!(quantized.len(), 4 * Q4_0_BLOCK_BYTES);
let restored =
dequantize_to_f32(&quantized, GgufTensorType::Q4_0, 128).expect("dequantize failed");
assert_eq!(restored.len(), 128);
let err = max_abs_error(&data, &restored);
assert!(err < 0.15, "Q4_0 multi-block error: {err}");
}
#[test]
fn q8_0_round_trip_small_values() {
let data: Vec<f32> = (0..32).map(|i| (i as f32 / 15.5) - 1.0).collect();
let quantized = quantize_f32_to_q8_0(&data).expect("quantize failed");
assert_eq!(quantized.len(), Q8_0_BLOCK_BYTES);
let restored =
dequantize_to_f32(&quantized, GgufTensorType::Q8_0, 32).expect("dequantize failed");
assert_eq!(restored.len(), 32);
let err = max_abs_error(&data, &restored);
assert!(err < 0.02, "Q8_0 round-trip error too large: {err}");
}
#[test]
fn q8_0_round_trip_multiple_blocks() {
let data: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.01).collect();
let quantized = quantize_f32_to_q8_0(&data).expect("quantize failed");
assert_eq!(quantized.len(), 4 * Q8_0_BLOCK_BYTES);
let restored =
dequantize_to_f32(&quantized, GgufTensorType::Q8_0, 128).expect("dequantize failed");
assert_eq!(restored.len(), 128);
let err = max_abs_error(&data, &restored);
assert!(err < 0.01, "Q8_0 multi-block error: {err}");
}
#[test]
fn f16_to_q4_0_round_trip() {
let f32_data: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) * 0.1).collect();
let f16_data: Vec<u16> = f32_data
.iter()
.map(|&v| f16::from_f32(v).to_bits())
.collect();
let q_via_f16 = quantize_f16_to_q4_0(&f16_data).expect("f16 quantize failed");
let q_via_f32 = quantize_f32_to_q4_0(&f32_data).expect("f32 quantize failed");
let r16 = dequantize_to_f32(&q_via_f16, GgufTensorType::Q4_0, 32).expect("deq f16 failed");
let r32 = dequantize_to_f32(&q_via_f32, GgufTensorType::Q4_0, 32).expect("deq f32 failed");
let err = max_abs_error(&r16, &r32);
assert!(err < 0.25, "F16 vs F32 path divergence: {err}");
}
#[test]
fn f16_to_q8_0_round_trip() {
let f32_data: Vec<f32> = (0..32).map(|i| (i as f32 - 16.0) * 0.05).collect();
let f16_data: Vec<u16> = f32_data
.iter()
.map(|&v| f16::from_f32(v).to_bits())
.collect();
let q = quantize_f16_to_q8_0(&f16_data).expect("f16→q8_0 failed");
let restored = dequantize_to_f32(&q, GgufTensorType::Q8_0, 32).expect("deq failed");
let err = max_abs_error(&f32_data, &restored);
assert!(err < 0.03, "F16→Q8_0 error: {err}");
}
#[test]
fn q4_0_rejects_unaligned_input() {
let data = vec![0.0f32; 33]; let result = quantize_f32_to_q4_0(&data);
assert!(result.is_err());
match result {
Err(QuantError::DimensionMismatch { .. }) => {}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn q8_0_rejects_unaligned_input() {
let data = vec![0.0f32; 31];
let result = quantize_f32_to_q8_0(&data);
assert!(result.is_err());
match result {
Err(QuantError::DimensionMismatch { .. }) => {}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn q4_0_zero_input() {
let data = vec![0.0f32; 32];
let quantized = quantize_f32_to_q4_0(&data).expect("zero quantize failed");
let restored = dequantize_to_f32(&quantized, GgufTensorType::Q4_0, 32).expect("deq failed");
for &v in &restored {
assert!(v.abs() < f32::EPSILON, "expected 0.0, got {v}");
}
}
#[test]
fn q8_0_zero_input() {
let data = vec![0.0f32; 32];
let quantized = quantize_f32_to_q8_0(&data).expect("zero quantize failed");
let restored = dequantize_to_f32(&quantized, GgufTensorType::Q8_0, 32).expect("deq failed");
for &v in &restored {
assert!(v.abs() < f32::EPSILON, "expected 0.0, got {v}");
}
}
#[test]
fn q4_0_large_values() {
let mut data = vec![0.0f32; 32];
data[0] = 1000.0;
data[1] = -1000.0;
data[31] = 500.0;
let quantized = quantize_f32_to_q4_0(&data).expect("quantize failed");
let restored = dequantize_to_f32(&quantized, GgufTensorType::Q4_0, 32).expect("deq failed");
assert!(restored[0] > 0.0, "expected positive, got {}", restored[0]);
assert!(restored[1] < 0.0, "expected negative, got {}", restored[1]);
assert!(
restored[31] > 0.0,
"expected positive, got {}",
restored[31]
);
let rel_err_0 = (restored[0] - 1000.0).abs() / 1000.0;
assert!(
rel_err_0 < 0.15,
"Q4_0 large-value relative error: {rel_err_0}"
);
}
#[test]
fn q8_0_large_values() {
let mut data = vec![0.0f32; 32];
data[0] = 1000.0;
data[1] = -1000.0;
data[31] = 500.0;
let quantized = quantize_f32_to_q8_0(&data).expect("quantize failed");
let restored = dequantize_to_f32(&quantized, GgufTensorType::Q8_0, 32).expect("deq failed");
assert!(restored[0] > 0.0);
assert!(restored[1] < 0.0);
assert!(restored[31] > 0.0);
let rel_err_0 = (restored[0] - 1000.0).abs() / 1000.0;
assert!(
rel_err_0 < 0.02,
"Q8_0 large-value relative error: {rel_err_0}"
);
}
#[test]
fn empty_input_produces_empty_output() {
let empty: Vec<f32> = Vec::new();
let q4 = quantize_f32_to_q4_0(&empty).expect("empty q4_0 failed");
assert!(q4.is_empty());
let q8 = quantize_f32_to_q8_0(&empty).expect("empty q8_0 failed");
assert!(q8.is_empty());
let deq = dequantize_to_f32(&[], GgufTensorType::Q4_0, 0).expect("empty deq failed");
assert!(deq.is_empty());
}
}