Skip to main content

citadel_vector/vendored/prism/
quantize.rs

1use super::point::PointStore;
2
3/// Scalar-quantized (8-bit) vector store. 4x bandwidth reduction vs f32,
4/// identity quantization for native u8 data (SIFT, YFCC).
5pub struct SQ8Store {
6    codes: Vec<u8>,
7    mins: Vec<f32>,
8    scales: Vec<f32>,
9    dim: usize,
10}
11
12impl Drop for SQ8Store {
13    fn drop(&mut self) {
14        // Lossy 8-bit reconstructions of decrypted vectors; zero on drop alongside the
15        // full-precision PointStore so no vector residue outlives the region key.
16        use zeroize::Zeroize;
17        self.codes.zeroize();
18        self.mins.zeroize();
19        self.scales.zeroize();
20    }
21}
22
23impl SQ8Store {
24    /// Build SQ8 codes. Uses identity quantization for integer [0,255] data.
25    pub fn build(store: &PointStore) -> Self {
26        let n = store.len;
27        let dim = store.dim;
28
29        // Identity path for u8 data
30        let all_integer_byte = (0..n).all(|i| {
31            store
32                .vector(i as u32)
33                .iter()
34                .all(|&v| (0.0..=255.0).contains(&v) && v == v.round())
35        });
36
37        // Lossless: direct cast
38        if all_integer_byte {
39            let mut codes = vec![0u8; n * dim];
40            for i in 0..n {
41                let vec = store.vector(i as u32);
42                let off = i * dim;
43                for d in 0..dim {
44                    codes[off + d] = vec[d] as u8;
45                }
46            }
47            return Self {
48                codes,
49                mins: vec![0.0; dim],
50                scales: vec![1.0; dim],
51                dim,
52            };
53        }
54
55        // Percentile-clipped quantization (p0.5–p99.5)
56        let sample_n = n.min(10_000);
57        let step = (n / sample_n).max(1);
58
59        let (mins, maxs) = if sample_n >= 200 {
60            let mut mins = vec![0.0f32; dim];
61            let mut maxs = vec![0.0f32; dim];
62            for d in 0..dim {
63                let mut sample: Vec<f32> = (0..sample_n)
64                    .map(|s| store.vector(((s * step).min(n - 1)) as u32)[d])
65                    .collect();
66                sample.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
67                let lo = sample_n / 200;
68                let hi = sample_n.saturating_sub(1 + sample_n / 200);
69                mins[d] = sample[lo];
70                maxs[d] = sample[hi.max(lo + 1).min(sample_n - 1)];
71            }
72            (mins, maxs)
73        } else {
74            let mut mins = vec![f32::MAX; dim];
75            let mut maxs = vec![f32::MIN; dim];
76            for i in 0..n {
77                let vec = store.vector(i as u32);
78                for d in 0..dim {
79                    mins[d] = mins[d].min(vec[d]);
80                    maxs[d] = maxs[d].max(vec[d]);
81                }
82            }
83            (mins, maxs)
84        };
85
86        let scales: Vec<f32> = mins
87            .iter()
88            .zip(maxs.iter())
89            .map(|(&mn, &mx)| {
90                let range = mx - mn;
91                if range > 0.0 {
92                    range / 255.0
93                } else {
94                    1.0
95                }
96            })
97            .collect();
98
99        let mut codes = vec![0u8; n * dim];
100        for i in 0..n {
101            let vec = store.vector(i as u32);
102            let off = i * dim;
103            for d in 0..dim {
104                let val = (vec[d] - mins[d]) / scales[d];
105                codes[off + d] = val.round().clamp(0.0, 255.0) as u8;
106            }
107        }
108
109        Self {
110            codes,
111            mins,
112            scales,
113            dim,
114        }
115    }
116
117    /// Get the quantized code for point id.
118    #[inline]
119    pub fn code(&self, id: u32) -> &[u8] {
120        let start = id as usize * self.dim;
121        &self.codes[start..start + self.dim]
122    }
123
124    /// Quantize a f32 query vector to u8.
125    pub fn quantize_query(&self, query: &[f32]) -> Vec<u8> {
126        query
127            .iter()
128            .enumerate()
129            .map(|(d, &v)| {
130                let val = (v - self.mins[d]) / self.scales[d];
131                val.round().clamp(0.0, 255.0) as u8
132            })
133            .collect()
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_sq8_roundtrip() {
143        // Two points spanning [0,255] on all dims → near-identity quantization
144        let store = PointStore::from_parts(
145            vec![0.0, 0.0, 0.0, 255.0, 255.0, 255.0],
146            3,
147            vec![vec![0, 0]],
148        );
149        let sq8 = SQ8Store::build(&store);
150        assert_eq!(sq8.code(0), &[0, 0, 0]);
151        assert_eq!(sq8.code(1), &[255, 255, 255]);
152    }
153
154    #[test]
155    fn test_sq8_midpoint() {
156        // Three points: min=0, mid=128, max=255 per dim
157        let store = PointStore::from_parts(
158            vec![0.0, 0.0, 255.0, 255.0, 128.0, 128.0],
159            2,
160            vec![vec![0, 0, 0]],
161        );
162        let sq8 = SQ8Store::build(&store);
163        // dim 0: min=0, max=255 → 128 maps to round(128/1) = 128
164        assert_eq!(sq8.code(2)[0], 128);
165    }
166
167    #[test]
168    fn test_sq8_identity_quantization() {
169        // Integer [0,255] data → identity quantization (lossless)
170        let store = PointStore::from_parts(
171            vec![10.0, 200.0, 50.0, 150.0, 0.0, 255.0],
172            2,
173            vec![vec![0, 0, 0]],
174        );
175        let sq8 = SQ8Store::build(&store);
176        // Identity: codes should be exact integer values
177        assert_eq!(sq8.code(0), &[10, 200]);
178        assert_eq!(sq8.code(1), &[50, 150]);
179        assert_eq!(sq8.code(2), &[0, 255]);
180        // mins=0, scales=1 for identity
181        assert_eq!(sq8.mins, vec![0.0, 0.0]);
182        assert_eq!(sq8.scales, vec![1.0, 1.0]);
183    }
184
185    #[test]
186    fn test_sq8_non_identity_for_float_data() {
187        // Float data outside [0,255] → per-dim normalization
188        let store = PointStore::from_parts(vec![0.0, 0.0, 1000.0, 500.5], 2, vec![vec![0, 0]]);
189        let sq8 = SQ8Store::build(&store);
190        // Should NOT use identity (500.5 is not an integer, 1000 > 255)
191        assert_eq!(sq8.code(0), &[0, 0]);
192        assert_eq!(sq8.code(1), &[255, 255]);
193    }
194
195    #[test]
196    fn test_sq8_distance_ranking() {
197        use super::super::distance;
198        // Verify SQ8 distances preserve ranking
199        let store = PointStore::from_parts(
200            vec![0.0, 0.0, 100.0, 100.0, 200.0, 200.0],
201            2,
202            vec![vec![0, 0, 0]],
203        );
204        let sq8 = SQ8Store::build(&store);
205        let q = sq8.quantize_query(&[90.0, 90.0]);
206        let d0 = distance::l2_sq8(&q, sq8.code(0));
207        let d1 = distance::l2_sq8(&q, sq8.code(1));
208        let d2 = distance::l2_sq8(&q, sq8.code(2));
209        // Point 1 (100,100) is closest to query (90,90)
210        assert!(d1 < d0, "point 1 should be closer than point 0");
211        assert!(d1 < d2, "point 1 should be closer than point 2");
212    }
213}