Skip to main content

batuta/oracle/rag/quantization/
embedding.rs

1//! Int8 quantized embedding with metadata
2//!
3//! Achieves 4x memory reduction compared to f32 embeddings.
4//! Following Gholami et al. (2022) survey on quantization methods.
5
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9use super::calibration::CalibrationStats;
10use super::error::{validate_embedding, QuantizationError};
11use super::params::QuantizationParams;
12
13/// Int8 quantized embedding with metadata
14///
15/// Achieves 4x memory reduction compared to f32 embeddings.
16/// Following Gholami et al. (2022) survey on quantization methods.
17#[derive(Debug, Clone)]
18pub struct QuantizedEmbedding {
19    /// Quantized values in [-128, 127]
20    pub values: Vec<i8>,
21    /// Quantization parameters for dequantization
22    pub params: QuantizationParams,
23    /// BLAKE3 content hash for integrity (Poka-Yoke)
24    pub hash: [u8; 32],
25}
26
27impl QuantizedEmbedding {
28    /// Quantize f32 embedding to int8 with calibration
29    ///
30    /// Implements symmetric quantization: Q(x) = round(x / scale)
31    ///
32    /// # Errors
33    /// - `EmptyEmbedding`: If input is empty
34    /// - `DimensionMismatch`: If dimensions don't match calibration
35    /// - `NonFiniteValue`: If NaN/Inf detected (Poka-Yoke)
36    pub fn from_f32(
37        embedding: &[f32],
38        calibration: &CalibrationStats,
39    ) -> Result<Self, QuantizationError> {
40        // Jidoka + Poka-Yoke: Validate embedding
41        validate_embedding(embedding, calibration.dims)?;
42
43        // Create quantization params from calibration
44        let params = calibration.to_quant_params()?;
45
46        // Quantize with rounding and clamping
47        let values: Vec<i8> = embedding.iter().map(|&v| params.quantize_value(v)).collect();
48
49        // Compute integrity hash (Poka-Yoke)
50        let hash = compute_hash(&values);
51
52        Ok(Self { values, params, hash })
53    }
54
55    /// Quantize f32 embedding using local statistics (no calibration)
56    ///
57    /// Computes absmax from the embedding itself.
58    pub fn from_f32_uncalibrated(embedding: &[f32]) -> Result<Self, QuantizationError> {
59        // Jidoka + Poka-Yoke: Validate embedding (dimension check is trivially satisfied)
60        validate_embedding(embedding, embedding.len())?;
61
62        // Compute absmax (safe since validate_embedding confirmed all values are finite)
63        let mut absmax: f32 = embedding.iter().fold(0.0f32, |acc, &v| acc.max(v.abs()));
64
65        // Handle zero vector
66        if absmax == 0.0 {
67            absmax = 1.0; // Avoid division by zero
68        }
69
70        let params = QuantizationParams::from_absmax(absmax, embedding.len())?;
71        let values: Vec<i8> = embedding.iter().map(|&v| params.quantize_value(v)).collect();
72        // Compute integrity hash (Poka-Yoke)
73        let hash = compute_hash(&values);
74
75        Ok(Self { values, params, hash })
76    }
77
78    /// Dequantize to f32 embedding
79    ///
80    /// Returns approximate original values within error bound.
81    pub fn dequantize(&self) -> Vec<f32> {
82        self.values.iter().map(|&v| self.params.dequantize_value(v)).collect()
83    }
84
85    /// Verify integrity hash (Poka-Yoke)
86    pub fn verify_integrity(&self) -> bool {
87        let computed = compute_hash(&self.values);
88        computed == self.hash
89    }
90
91    /// Get embedding dimensions
92    pub fn dims(&self) -> usize {
93        self.values.len()
94    }
95
96    /// Memory size in bytes (4x reduction from f32)
97    pub fn memory_size(&self) -> usize {
98        self.values.len() // 1 byte per element
99            + std::mem::size_of::<QuantizationParams>()
100            + 32 // hash
101    }
102}
103
104/// Compute content hash for integrity verification (Poka-Yoke)
105///
106/// Uses SipHash (DefaultHasher) which provides good collision resistance
107/// for integrity checking. Expanded to 32 bytes for consistency with
108/// BLAKE3-style hashes used in specification.
109pub fn compute_hash(values: &[i8]) -> [u8; 32] {
110    // h[0]: SipHash of raw values
111    let mut hasher = DefaultHasher::new();
112    values.hash(&mut hasher);
113    let mut hashes = [0u64; 4];
114    hashes[0] = hasher.finish();
115
116    // h[1]: chain previous hash + mix in length for extra entropy
117    let mut hasher = DefaultHasher::new();
118    hashes[0].hash(&mut hasher);
119    values.len().hash(&mut hasher);
120    hashes[1] = hasher.finish();
121
122    // h[2..3]: chain each subsequent hash
123    for i in 2..4 {
124        let mut hasher = DefaultHasher::new();
125        hashes[i - 1].hash(&mut hasher);
126        hashes[i] = hasher.finish();
127    }
128
129    let mut result = [0u8; 32];
130    for (i, &h) in hashes.iter().enumerate() {
131        result[i * 8..(i + 1) * 8].copy_from_slice(&h.to_le_bytes());
132    }
133    result
134}