Skip to main content

luci/vector/
quantize.rs

1//! Scalar quantization for vector search.
2//!
3//! Per-dimension linear quantization: maps float32 values to a reduced
4//! integer representation using per-dimension min/max computed across
5//! the dataset. Distances are computed asymmetrically: stored vectors
6//! at reduced precision, query vector at float32.
7//!
8//! Currently implements `Int8` only. The user-facing [`QuantizationType`]
9//! enum lives in `luci-mapping` and includes recognized-but-unimplemented
10//! variants (`Int4`, `Bbq`); those are rejected at mapping parse time so
11//! they cannot reach this layer. See [[code-must-not-lie]],
12//! [[optimization-knn-int8-quantization]], and [[quantization]].
13
14use super::DistanceMetric;
15
16/// Int8 quantized vectors with per-dimension calibration data.
17pub struct QuantizedVectors {
18    pub dims: usize,
19    pub num_vectors: usize,
20    /// Flat quantized data: `data[i * dims .. (i+1) * dims]` = vector i.
21    pub data: Vec<u8>,
22    /// Per-dimension minimum values.
23    pub mins: Vec<f32>,
24    /// Per-dimension scale: `(max - min) / 255.0`.
25    pub scales: Vec<f32>,
26    /// Per-vector precomputed norms (approximate, for cosine distance).
27    pub norms: Vec<f32>,
28    /// Distance metric.
29    pub metric: DistanceMetric,
30}
31
32impl QuantizedVectors {
33    /// Quantize a set of float32 vectors to int8.
34    pub fn quantize(vectors: &[Vec<f32>], metric: DistanceMetric) -> Self {
35        if vectors.is_empty() {
36            return Self {
37                dims: 0,
38                num_vectors: 0,
39                data: Vec::new(),
40                mins: Vec::new(),
41                scales: Vec::new(),
42                norms: Vec::new(),
43                metric,
44            };
45        }
46
47        let dims = vectors[0].len();
48        let num_vectors = vectors.len();
49
50        // Compute per-dimension min and max
51        let mut mins = vec![f32::MAX; dims];
52        let mut maxs = vec![f32::MIN; dims];
53        for v in vectors {
54            for d in 0..dims {
55                if v[d] < mins[d] {
56                    mins[d] = v[d];
57                }
58                if v[d] > maxs[d] {
59                    maxs[d] = v[d];
60                }
61            }
62        }
63
64        // Compute scales (avoid division by zero for constant dimensions)
65        let scales: Vec<f32> = (0..dims)
66            .map(|d| {
67                let range = maxs[d] - mins[d];
68                if range == 0.0 { 0.0 } else { range / 255.0 }
69            })
70            .collect();
71
72        // Quantize all vectors
73        let mut data = vec![0u8; num_vectors * dims];
74        let mut norms = vec![0.0f32; num_vectors];
75
76        for (i, v) in vectors.iter().enumerate() {
77            let offset = i * dims;
78            let mut norm_sq = 0.0f32;
79            for d in 0..dims {
80                let q = if scales[d] == 0.0 {
81                    128u8 // midpoint for constant dimensions
82                } else {
83                    ((v[d] - mins[d]) / scales[d]).round().clamp(0.0, 255.0) as u8
84                };
85                data[offset + d] = q;
86                // Approximate dequantized value for norm computation
87                let dequant = mins[d] + q as f32 * scales[d];
88                norm_sq += dequant * dequant;
89            }
90            norms[i] = norm_sq.sqrt();
91        }
92
93        Self {
94            dims,
95            num_vectors,
96            data,
97            mins,
98            scales,
99            norms,
100            metric,
101        }
102    }
103
104    /// Get the quantized vector for index `idx`.
105    #[inline]
106    pub fn get(&self, idx: usize) -> &[u8] {
107        let start = idx * self.dims;
108        &self.data[start..start + self.dims]
109    }
110
111    /// Compute asymmetric distance between quantized stored vector and
112    /// float32 query vector.
113    ///
114    /// For cosine, the query is expected to be unit-length (caller
115    /// normalizes at entry per [[optimize-cosine-norm-precompute]]).
116    /// `stored_norm` is the norm of the *dequantized* stored vector,
117    /// which approximates 1.0 but drifts due to int8 quantization
118    /// rounding — keeping it cancels that drift in the score.
119    #[inline]
120    pub fn asymmetric_distance(&self, idx: usize, query: &[f32]) -> f32 {
121        match self.metric {
122            DistanceMetric::Cosine => self.asymmetric_cosine(idx, query),
123            DistanceMetric::DotProduct => self.asymmetric_dot(idx, query),
124            DistanceMetric::L2 => self.asymmetric_l2(idx, query),
125        }
126    }
127
128    /// Asymmetric cosine distance: quantized stored × float32 query.
129    /// Caller guarantees the query is unit-length, so the cosine
130    /// denominator collapses to `stored_norm`.
131    fn asymmetric_cosine(&self, idx: usize, query: &[f32]) -> f32 {
132        let quantized = self.get(idx);
133        let mut dot = 0.0f32;
134        for d in 0..self.dims {
135            let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
136            dot += dequant * query[d];
137        }
138        let stored_norm = self.norms[idx];
139        if stored_norm == 0.0 {
140            1.0
141        } else {
142            1.0 - dot / stored_norm
143        }
144    }
145
146    /// Asymmetric negative dot product: quantized stored × float32 query.
147    fn asymmetric_dot(&self, idx: usize, query: &[f32]) -> f32 {
148        let quantized = self.get(idx);
149        let mut dot = 0.0f32;
150        for d in 0..self.dims {
151            let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
152            dot += dequant * query[d];
153        }
154        -dot
155    }
156
157    /// Asymmetric L2 distance: quantized stored × float32 query.
158    fn asymmetric_l2(&self, idx: usize, query: &[f32]) -> f32 {
159        let quantized = self.get(idx);
160        let mut sum_sq = 0.0f32;
161        for d in 0..self.dims {
162            let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
163            let diff = dequant - query[d];
164            sum_sq += diff * diff;
165        }
166        sum_sq.sqrt()
167    }
168
169    /// Serialize to bytes.
170    pub fn to_bytes(&self) -> Vec<u8> {
171        let mut buf = Vec::new();
172        buf.extend_from_slice(&(self.dims as u32).to_le_bytes());
173        buf.extend_from_slice(&(self.num_vectors as u32).to_le_bytes());
174        buf.push(self.metric as u8);
175        // Mins
176        for &m in &self.mins {
177            buf.extend_from_slice(&m.to_le_bytes());
178        }
179        // Scales
180        for &s in &self.scales {
181            buf.extend_from_slice(&s.to_le_bytes());
182        }
183        // Norms
184        for &n in &self.norms {
185            buf.extend_from_slice(&n.to_le_bytes());
186        }
187        // Quantized data
188        buf.extend_from_slice(&self.data);
189        buf
190    }
191
192    /// Deserialize from bytes.
193    pub fn from_bytes(data: &[u8]) -> Self {
194        let dims = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
195        let num_vectors = u32::from_le_bytes(data[4..8].try_into().unwrap()) as usize;
196        let metric = DistanceMetric::from_byte(data[8]);
197        let mut pos = 9;
198
199        let mut mins = vec![0.0f32; dims];
200        for d in 0..dims {
201            mins[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
202            pos += 4;
203        }
204
205        let mut scales = vec![0.0f32; dims];
206        for d in 0..dims {
207            scales[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
208            pos += 4;
209        }
210
211        let mut norms = vec![0.0f32; num_vectors];
212        for i in 0..num_vectors {
213            norms[i] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
214            pos += 4;
215        }
216
217        let qdata = data[pos..pos + num_vectors * dims].to_vec();
218
219        Self {
220            dims,
221            num_vectors,
222            data: qdata,
223            mins,
224            scales,
225            norms,
226            metric,
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn quantize_round_trip() {
237        let vectors = vec![
238            vec![1.0, 2.0, 3.0],
239            vec![4.0, 5.0, 6.0],
240            vec![7.0, 8.0, 9.0],
241        ];
242        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);
243
244        assert_eq!(qv.dims, 3);
245        assert_eq!(qv.num_vectors, 3);
246        // First vector should quantize to low values (near min)
247        assert_eq!(qv.get(0), &[0, 0, 0]);
248        // Last vector should quantize to 255 (at max)
249        assert_eq!(qv.get(2), &[255, 255, 255]);
250    }
251
252    #[test]
253    fn asymmetric_cosine_close_to_exact() {
254        // Inputs are unit-length — production code normalizes the query
255        // before reaching this kernel under the v0.7.2 invariant.
256        let vectors = vec![
257            vec![1.0, 0.0, 0.0],
258            vec![0.0, 1.0, 0.0],
259            vec![0.707, 0.707, 0.0],
260        ];
261        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);
262
263        let query = vec![1.0, 0.0, 0.0];
264
265        let d0 = qv.asymmetric_distance(0, &query);
266        let d1 = qv.asymmetric_distance(1, &query);
267        let d2 = qv.asymmetric_distance(2, &query);
268
269        // Vector 0 is closest to query (same direction)
270        assert!(d0 < d2, "d0={d0} should be < d2={d2}");
271        assert!(d2 < d1, "d2={d2} should be < d1={d1}");
272    }
273
274    #[test]
275    fn serialization_round_trip() {
276        let vectors = vec![vec![1.5, -2.3, 0.7, 4.1], vec![-0.5, 3.2, 1.1, -1.0]];
277        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
278        let bytes = qv.to_bytes();
279        let qv2 = QuantizedVectors::from_bytes(&bytes);
280
281        assert_eq!(qv.dims, qv2.dims);
282        assert_eq!(qv.num_vectors, qv2.num_vectors);
283        assert_eq!(qv.data, qv2.data);
284        assert_eq!(qv.mins, qv2.mins);
285        assert_eq!(qv.scales, qv2.scales);
286    }
287
288    #[test]
289    fn empty_vectors() {
290        let qv = QuantizedVectors::quantize(&[], DistanceMetric::Cosine);
291        assert_eq!(qv.num_vectors, 0);
292        assert_eq!(qv.dims, 0);
293    }
294
295    #[test]
296    fn constant_dimension() {
297        // All vectors have same value in one dimension
298        let vectors = vec![vec![1.0, 5.0], vec![2.0, 5.0], vec![3.0, 5.0]];
299        let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
300        // Second dimension is constant — should get midpoint (128)
301        assert_eq!(qv.get(0)[1], 128);
302        assert_eq!(qv.get(1)[1], 128);
303        assert_eq!(qv.get(2)[1], 128);
304    }
305}