use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Quantization {
None,
E8 {
bits_per_block: u8,
use_hadamard: bool,
random_seed: u64,
},
H4 {
use_hadamard: bool,
random_seed: u64,
},
}
impl Default for Quantization {
fn default() -> Self {
Quantization::None
}
}
impl Quantization {
pub fn e8_default() -> Self {
Quantization::E8 {
bits_per_block: 10,
use_hadamard: true,
random_seed: 0xcafef00d,
}
}
pub fn e8(bits_per_block: u8, use_hadamard: bool, random_seed: u64) -> Self {
assert!(
bits_per_block == 8 || bits_per_block == 10 || bits_per_block == 12,
"bits_per_block must be 8, 10, or 12, got {}",
bits_per_block
);
Quantization::E8 {
bits_per_block,
use_hadamard,
random_seed,
}
}
pub fn e8_checked(bits_per_block: u8, use_hadamard: bool, random_seed: u64) -> Result<Self, &'static str> {
if bits_per_block != 8 && bits_per_block != 10 && bits_per_block != 12 {
return Err("bits_per_block must be 8, 10, or 12");
}
Ok(Quantization::E8 {
bits_per_block,
use_hadamard,
random_seed,
})
}
pub fn h4_default() -> Self {
Quantization::H4 {
use_hadamard: true,
random_seed: 0xdeadbeef,
}
}
pub fn h4(use_hadamard: bool, random_seed: u64) -> Self {
Quantization::H4 { use_hadamard, random_seed }
}
pub fn is_enabled(&self) -> bool {
!matches!(self, Quantization::None)
}
pub fn bits_per_dim(&self) -> f32 {
match self {
Quantization::None => 32.0,
Quantization::E8 { bits_per_block, .. } => *bits_per_block as f32 / 8.0,
Quantization::H4 { .. } => (120.0f32).log2() / 4.0,
}
}
pub fn bytes_per_vector(&self, dimension: usize) -> usize {
match self {
Quantization::None => dimension * 4,
Quantization::E8 { bits_per_block, .. } => {
let num_blocks = (dimension + 7) / 8;
let code_bytes = (num_blocks * (*bits_per_block as usize) + 7) / 8;
code_bytes + 4 }
Quantization::H4 { .. } => {
let num_blocks = (dimension + 3) / 4;
num_blocks + 4
}
}
}
pub fn compression_ratio(&self, dimension: usize) -> f32 {
let f32_bytes = dimension * 4;
let quant_bytes = self.bytes_per_vector(dimension);
f32_bytes as f32 / quant_bytes as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_default() {
let q = Quantization::default();
assert!(!q.is_enabled());
assert_eq!(q.bits_per_dim(), 32.0);
}
#[test]
fn test_e8_compression() {
let q = Quantization::e8_default();
assert!(q.is_enabled());
let f32_bytes = 768 * 4; let e8_bytes = q.bytes_per_vector(768);
let ratio = f32_bytes as f32 / e8_bytes as f32;
assert!(ratio > 2.0 && ratio < 30.0, "E8 compression ratio: {}", ratio);
}
#[test]
fn test_h4_compression() {
let q = Quantization::h4_default();
assert!(q.is_enabled());
let f32_bytes = 768 * 4;
let h4_bytes = q.bytes_per_vector(768);
let ratio = f32_bytes as f32 / h4_bytes as f32;
assert!(ratio > 10.0 && ratio < 25.0, "H4 compression ratio: {:.1}×", ratio);
}
#[test]
fn test_bits_per_dim() {
let q8 = Quantization::e8(8, true, 0);
let q10 = Quantization::e8(10, true, 0);
let q12 = Quantization::e8(12, true, 0);
let qh4 = Quantization::h4_default();
assert_eq!(q8.bits_per_dim(), 1.0);
assert_eq!(q10.bits_per_dim(), 1.25);
assert_eq!(q12.bits_per_dim(), 1.5);
assert!((qh4.bits_per_dim() - 1.727).abs() < 0.01, "H4 bits/dim = {}", qh4.bits_per_dim());
}
#[test]
fn test_h4_is_enabled() {
assert!(Quantization::h4_default().is_enabled());
assert!(Quantization::h4(false, 0).is_enabled());
}
}