use serde::{Deserialize, Serialize};
pub const BLOCK_SIZE: usize = 64;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Quantized4Bit {
pub scales: Vec<f32>,
pub data: Vec<u8>,
pub len: usize,
}
impl Quantized4Bit {
pub fn memory_bytes(&self) -> usize {
self.scales.len() * 4 + self.data.len()
}
#[provable_contracts_macros::contract("quantization-v1", equation = "compression_ratio")]
pub fn compression_ratio(&self) -> f32 {
let original_bytes = self.len * 4; let compressed_bytes = self.memory_bytes();
original_bytes as f32 / compressed_bytes as f32
}
}
pub fn quantize_4bit(values: &[f32]) -> Quantized4Bit {
let len = values.len();
let num_blocks = len.div_ceil(BLOCK_SIZE);
let mut scales = Vec::with_capacity(num_blocks);
let mut data = Vec::with_capacity(len.div_ceil(2));
for block_idx in 0..num_blocks {
let start = block_idx * BLOCK_SIZE;
let end = (start + BLOCK_SIZE).min(len);
let block = &values[start..end];
let max_abs = block.iter().map(|v| v.abs()).max_by(f32::total_cmp).unwrap_or(1e-8);
let scale = max_abs / 7.0;
scales.push(scale);
for (i, &val) in block.iter().enumerate() {
let quantized = quantize_value(val, scale);
if i.is_multiple_of(2) {
data.push(((quantized as u8) & 0x0F) << 4);
} else {
let last_idx = data.len() - 1;
data[last_idx] |= (quantized as u8) & 0x0F;
}
}
}
Quantized4Bit { scales, data, len }
}
pub fn dequantize_4bit(quantized: &Quantized4Bit) -> Vec<f32> {
let mut result = Vec::with_capacity(quantized.len);
let num_blocks = quantized.scales.len();
for block_idx in 0..num_blocks {
let scale = quantized.scales[block_idx];
let start = block_idx * BLOCK_SIZE;
let end = (start + BLOCK_SIZE).min(quantized.len);
let block_len = end - start;
for i in 0..block_len {
let byte_idx = usize::midpoint(start, i);
let byte = quantized.data[byte_idx];
let q_val = if (start + i).is_multiple_of(2) {
let nibble = (byte >> 4) & 0x0F;
if nibble & 0x08 != 0 {
(nibble | 0xF0) as i8
} else {
nibble as i8
}
} else {
let nibble = byte & 0x0F;
if nibble & 0x08 != 0 {
(nibble | 0xF0) as i8
} else {
nibble as i8
}
};
let deq_val = f32::from(q_val) * scale;
result.push(deq_val);
}
}
result
}
fn quantize_value(val: f32, scale: f32) -> i8 {
let normalized = val / scale;
let clamped = normalized.clamp(-7.0, 7.0);
clamped.round() as i8
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_quantize_dequantize_round_trip() {
let values = vec![1.0, -2.0, 3.5, -4.2, 0.5, -0.8, 2.1, -1.5];
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
assert_eq!(dequantized.len(), values.len());
for (original, deq) in values.iter().zip(dequantized.iter()) {
let error = (original - deq).abs();
let relative_error = error / original.abs().max(1e-6);
assert!(
relative_error < 0.3,
"Relative error too large: {original} vs {deq} (error: {error}, rel_error: {relative_error})"
);
}
}
#[test]
fn test_quantize_zeros() {
let values = vec![0.0; 64];
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
for val in dequantized {
assert_abs_diff_eq!(val, 0.0, epsilon = 1e-6);
}
}
#[test]
fn test_quantize_uniform() {
let values = vec![1.0; 64];
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
for val in dequantized {
assert_abs_diff_eq!(val, 1.0, epsilon = 0.2);
}
}
#[test]
fn test_quantize_range() {
let values: Vec<f32> = (-7..=7).map(|x| x as f32).collect();
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
for (original, deq) in values.iter().zip(dequantized.iter()) {
assert_abs_diff_eq!(original, deq, epsilon = 0.5);
}
}
#[test]
fn test_quantize_multiple_blocks() {
let values: Vec<f32> = (0..200).map(|i| (i as f32 * 0.1).sin()).collect();
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
assert_eq!(dequantized.len(), values.len());
let expected_blocks = 200_usize.div_ceil(BLOCK_SIZE);
assert_eq!(quantized.scales.len(), expected_blocks);
}
#[test]
fn test_memory_savings() {
let values = vec![1.0; 1024];
let quantized = quantize_4bit(&values);
let original_bytes = values.len() * 4; let compressed_bytes = quantized.memory_bytes();
let compression = original_bytes as f32 / compressed_bytes as f32;
assert!(compression > 6.0, "Compression ratio {compression} should be > 6.0");
}
#[test]
fn test_compression_ratio() {
let values = vec![1.5; 1024];
let quantized = quantize_4bit(&values);
let ratio = quantized.compression_ratio();
assert!(ratio > 6.0, "Compression ratio {ratio} should be > 6.0");
}
#[test]
fn test_quantize_small_values() {
let values = vec![0.001, 0.002, 0.003, 0.004];
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
for (original, deq) in values.iter().zip(dequantized.iter()) {
let error = (original - deq).abs();
assert!(error < 0.001, "Error {error} too large for small value");
}
}
#[test]
fn test_quantize_mixed_magnitudes() {
let values = vec![10.0, 1.0, -5.0, 0.5, 7.5, -2.0];
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
assert_eq!(dequantized.len(), values.len());
for (original, deq) in values.iter().zip(dequantized.iter()) {
let error = (original - deq).abs();
if original.abs() < 1.0 {
assert!(
error < 1.5,
"Absolute error {error} too large for small value {original} vs {deq}"
);
} else {
let relative_error = error / original.abs();
assert!(
relative_error < 0.5,
"Relative error {relative_error} too large for {original} vs {deq} (error: {error})"
);
}
}
}
#[test]
fn test_quantize_odd_length() {
let values: Vec<f32> = (0..77).map(|i| i as f32 * 0.5).collect();
let quantized = quantize_4bit(&values);
let dequantized = dequantize_4bit(&quantized);
assert_eq!(dequantized.len(), 77);
}
#[test]
fn test_quantized_data_size() {
let values = vec![1.0; 128];
let quantized = quantize_4bit(&values);
assert_eq!(quantized.data.len(), 64);
assert_eq!(quantized.scales.len(), 2);
}
}