Skip to main content

citadel_vector/vendored/prism/
quantize.rs

1use super::point::PointStore;
2
3/// Scalar-quantized (8-bit) vector store. 4x bandwidth reduction vs f32,
4/// identity quantization for native u8 data (SIFT, YFCC).
5pub struct SQ8Store {
6    codes: Vec<u8>,
7    mins: Vec<f32>,
8    scales: Vec<f32>,
9    dim: usize,
10}
11
12impl Drop for SQ8Store {
13    fn drop(&mut self) {
14        // Lossy 8-bit reconstructions of decrypted vectors; zero on drop alongside the
15        // full-precision PointStore so no vector residue outlives the region key.
16        use zeroize::Zeroize;
17        self.codes.zeroize();
18        self.mins.zeroize();
19        self.scales.zeroize();
20    }
21}
22
23impl SQ8Store {
24    /// Reassemble from persisted parts (the ANN segment decode path). Additive
25    /// to the vendored algorithm: construction semantics are untouched.
26    pub fn from_parts(codes: Vec<u8>, mins: Vec<f32>, scales: Vec<f32>, dim: usize) -> Self {
27        Self {
28            codes,
29            mins,
30            scales,
31            dim,
32        }
33    }
34
35    pub fn codes(&self) -> &[u8] {
36        &self.codes
37    }
38
39    pub fn mins(&self) -> &[f32] {
40        &self.mins
41    }
42
43    pub fn scales(&self) -> &[f32] {
44        &self.scales
45    }
46
47    pub fn dim(&self) -> usize {
48        self.dim
49    }
50
51    /// Build SQ8 codes. Uses identity quantization for integer [0,255] data.
52    pub fn build(store: &PointStore) -> Self {
53        let n = store.len;
54        let dim = store.dim;
55
56        let all_integer_byte = (0..n).all(|i| {
57            store
58                .vector(i as u32)
59                .iter()
60                .all(|&v| (0.0..=255.0).contains(&v) && v == v.round())
61        });
62
63        if all_integer_byte {
64            let mut codes = vec![0u8; n * dim];
65            for i in 0..n {
66                let vec = store.vector(i as u32);
67                let off = i * dim;
68                for d in 0..dim {
69                    codes[off + d] = vec[d] as u8;
70                }
71            }
72            return Self {
73                codes,
74                mins: vec![0.0; dim],
75                scales: vec![1.0; dim],
76                dim,
77            };
78        }
79
80        // Percentile-clipped quantization (p0.5..p99.5).
81        let sample_n = n.min(10_000);
82
83        let (mins, maxs) = if sample_n >= 200 {
84            let mut mins = vec![0.0f32; dim];
85            let mut maxs = vec![0.0f32; dim];
86            for d in 0..dim {
87                // Spread ids across the full range (a floored stride covers
88                // only a prefix when n is not a multiple of sample_n).
89                let mut sample: Vec<f32> = (0..sample_n)
90                    .map(|s| {
91                        let idx = ((s as u64 * n as u64) / sample_n as u64) as usize;
92                        store.vector(idx.min(n - 1) as u32)[d]
93                    })
94                    .collect();
95                sample.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
96                let lo = sample_n / 200;
97                let hi = sample_n.saturating_sub(1 + sample_n / 200);
98                mins[d] = sample[lo];
99                maxs[d] = sample[hi.max(lo + 1).min(sample_n - 1)];
100            }
101            (mins, maxs)
102        } else {
103            let mut mins = vec![f32::MAX; dim];
104            let mut maxs = vec![f32::MIN; dim];
105            for i in 0..n {
106                let vec = store.vector(i as u32);
107                for d in 0..dim {
108                    mins[d] = mins[d].min(vec[d]);
109                    maxs[d] = maxs[d].max(vec[d]);
110                }
111            }
112            (mins, maxs)
113        };
114
115        let scales: Vec<f32> = mins
116            .iter()
117            .zip(maxs.iter())
118            .map(|(&mn, &mx)| {
119                let range = mx - mn;
120                if range > 0.0 {
121                    range / 255.0
122                } else {
123                    1.0
124                }
125            })
126            .collect();
127
128        let mut codes = vec![0u8; n * dim];
129        for i in 0..n {
130            let vec = store.vector(i as u32);
131            let off = i * dim;
132            for d in 0..dim {
133                let val = (vec[d] - mins[d]) / scales[d];
134                codes[off + d] = val.round().clamp(0.0, 255.0) as u8;
135            }
136        }
137
138        Self {
139            codes,
140            mins,
141            scales,
142            dim,
143        }
144    }
145
146    /// Get the quantized code for point id.
147    #[inline]
148    pub fn code(&self, id: u32) -> &[u8] {
149        let start = id as usize * self.dim;
150        &self.codes[start..start + self.dim]
151    }
152
153    /// Quantize a f32 query vector to u8.
154    pub fn quantize_query(&self, query: &[f32]) -> Vec<u8> {
155        query
156            .iter()
157            .enumerate()
158            .map(|(d, &v)| {
159                let val = (v - self.mins[d]) / self.scales[d];
160                val.round().clamp(0.0, 255.0) as u8
161            })
162            .collect()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_sq8_roundtrip() {
172        let store = PointStore::from_parts(
173            vec![0.0, 0.0, 0.0, 255.0, 255.0, 255.0],
174            3,
175            vec![vec![0, 0]],
176        );
177        let sq8 = SQ8Store::build(&store);
178        assert_eq!(sq8.code(0), &[0, 0, 0]);
179        assert_eq!(sq8.code(1), &[255, 255, 255]);
180    }
181
182    #[test]
183    fn test_sq8_midpoint() {
184        let store = PointStore::from_parts(
185            vec![0.0, 0.0, 255.0, 255.0, 128.0, 128.0],
186            2,
187            vec![vec![0, 0, 0]],
188        );
189        let sq8 = SQ8Store::build(&store);
190        assert_eq!(sq8.code(2)[0], 128);
191    }
192
193    #[test]
194    fn test_sq8_identity_quantization() {
195        let store = PointStore::from_parts(
196            vec![10.0, 200.0, 50.0, 150.0, 0.0, 255.0],
197            2,
198            vec![vec![0, 0, 0]],
199        );
200        let sq8 = SQ8Store::build(&store);
201        assert_eq!(sq8.code(0), &[10, 200]);
202        assert_eq!(sq8.code(1), &[50, 150]);
203        assert_eq!(sq8.code(2), &[0, 255]);
204        assert_eq!(sq8.mins, vec![0.0, 0.0]);
205        assert_eq!(sq8.scales, vec![1.0, 1.0]);
206    }
207
208    #[test]
209    fn test_sq8_non_identity_for_float_data() {
210        let store = PointStore::from_parts(vec![0.0, 0.0, 1000.0, 500.5], 2, vec![vec![0, 0]]);
211        let sq8 = SQ8Store::build(&store);
212        assert_eq!(sq8.code(0), &[0, 0]);
213        assert_eq!(sq8.code(1), &[255, 255]);
214    }
215
216    #[test]
217    fn test_sq8_sampling_covers_tail_when_n_not_multiple_of_sample() {
218        // n = 12,500 (> 10k sample cap, not a multiple of it): the first 10k
219        // points sit in [0, 1], the last 2.5k at 100.5. A prefix-only sample
220        // would estimate the range from the old points alone.
221        let n = 12_500;
222        let mut vectors = Vec::with_capacity(n);
223        for i in 0..n {
224            if i < 10_000 {
225                vectors.push((i % 100) as f32 / 100.0 + 0.25);
226            } else {
227                vectors.push(100.5);
228            }
229        }
230        let store = PointStore::from_parts(vectors, 1, vec![vec![0; n]]);
231        let sq8 = SQ8Store::build(&store);
232        assert!(
233            sq8.scales[0] > 0.3,
234            "scale {} must reflect the tail range ~[0,100], not the prefix [0,1]",
235            sq8.scales[0]
236        );
237        assert_eq!(sq8.code((n - 1) as u32)[0], 255);
238    }
239
240    #[test]
241    fn test_sq8_distance_ranking() {
242        use super::super::distance;
243        let store = PointStore::from_parts(
244            vec![0.0, 0.0, 100.0, 100.0, 200.0, 200.0],
245            2,
246            vec![vec![0, 0, 0]],
247        );
248        let sq8 = SQ8Store::build(&store);
249        let q = sq8.quantize_query(&[90.0, 90.0]);
250        let d0 = distance::l2_sq8(&q, sq8.code(0));
251        let d1 = distance::l2_sq8(&q, sq8.code(1));
252        let d2 = distance::l2_sq8(&q, sq8.code(2));
253        assert!(d1 < d0, "point 1 should be closer than point 0");
254        assert!(d1 < d2, "point 1 should be closer than point 2");
255    }
256}