Skip to main content

citadel_vector/vendored/prism/
quantize.rs

1use super::point::PointStore;
2
3/// Scalar-quantized (8-bit) vector store. 4x bandwidth reduction vs f32,
4/// identity quantization for native u8 data (SIFT, YFCC).
5pub struct SQ8Store {
6    codes: Vec<u8>,
7    mins: Vec<f32>,
8    scales: Vec<f32>,
9    dim: usize,
10}
11
12impl Drop for SQ8Store {
13    fn drop(&mut self) {
14        // Lossy 8-bit reconstructions of decrypted vectors; zero on drop alongside the
15        // full-precision PointStore so no vector residue outlives the region key.
16        use zeroize::Zeroize;
17        self.codes.zeroize();
18        self.mins.zeroize();
19        self.scales.zeroize();
20    }
21}
22
23impl SQ8Store {
24    /// Reassemble from persisted parts (the ANN segment decode path). Additive
25    /// to the vendored algorithm: construction semantics are untouched.
26    pub fn from_parts(codes: Vec<u8>, mins: Vec<f32>, scales: Vec<f32>, dim: usize) -> Self {
27        Self {
28            codes,
29            mins,
30            scales,
31            dim,
32        }
33    }
34
35    pub fn codes(&self) -> &[u8] {
36        &self.codes
37    }
38
39    pub fn mins(&self) -> &[f32] {
40        &self.mins
41    }
42
43    pub fn scales(&self) -> &[f32] {
44        &self.scales
45    }
46
47    pub fn dim(&self) -> usize {
48        self.dim
49    }
50
51    /// Build SQ8 codes. Uses identity quantization for integer [0,255] data.
52    pub fn build(store: &PointStore) -> Self {
53        let n = store.len;
54        let dim = store.dim;
55
56        let all_integer_byte = (0..n).all(|i| {
57            store
58                .vector(i as u32)
59                .iter()
60                .all(|&v| (0.0..=255.0).contains(&v) && v == v.round())
61        });
62
63        if all_integer_byte {
64            let mut codes = vec![0u8; n * dim];
65            for i in 0..n {
66                let vec = store.vector(i as u32);
67                let off = i * dim;
68                for d in 0..dim {
69                    codes[off + d] = vec[d] as u8;
70                }
71            }
72            return Self {
73                codes,
74                mins: vec![0.0; dim],
75                scales: vec![1.0; dim],
76                dim,
77            };
78        }
79
80        // Percentile-clipped quantization (p0.5..p99.5).
81        let sample_n = n.min(10_000);
82        let step = (n / sample_n).max(1);
83
84        let (mins, maxs) = if sample_n >= 200 {
85            let mut mins = vec![0.0f32; dim];
86            let mut maxs = vec![0.0f32; dim];
87            for d in 0..dim {
88                let mut sample: Vec<f32> = (0..sample_n)
89                    .map(|s| store.vector(((s * step).min(n - 1)) as u32)[d])
90                    .collect();
91                sample.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
92                let lo = sample_n / 200;
93                let hi = sample_n.saturating_sub(1 + sample_n / 200);
94                mins[d] = sample[lo];
95                maxs[d] = sample[hi.max(lo + 1).min(sample_n - 1)];
96            }
97            (mins, maxs)
98        } else {
99            let mut mins = vec![f32::MAX; dim];
100            let mut maxs = vec![f32::MIN; dim];
101            for i in 0..n {
102                let vec = store.vector(i as u32);
103                for d in 0..dim {
104                    mins[d] = mins[d].min(vec[d]);
105                    maxs[d] = maxs[d].max(vec[d]);
106                }
107            }
108            (mins, maxs)
109        };
110
111        let scales: Vec<f32> = mins
112            .iter()
113            .zip(maxs.iter())
114            .map(|(&mn, &mx)| {
115                let range = mx - mn;
116                if range > 0.0 {
117                    range / 255.0
118                } else {
119                    1.0
120                }
121            })
122            .collect();
123
124        let mut codes = vec![0u8; n * dim];
125        for i in 0..n {
126            let vec = store.vector(i as u32);
127            let off = i * dim;
128            for d in 0..dim {
129                let val = (vec[d] - mins[d]) / scales[d];
130                codes[off + d] = val.round().clamp(0.0, 255.0) as u8;
131            }
132        }
133
134        Self {
135            codes,
136            mins,
137            scales,
138            dim,
139        }
140    }
141
142    /// Get the quantized code for point id.
143    #[inline]
144    pub fn code(&self, id: u32) -> &[u8] {
145        let start = id as usize * self.dim;
146        &self.codes[start..start + self.dim]
147    }
148
149    /// Quantize a f32 query vector to u8.
150    pub fn quantize_query(&self, query: &[f32]) -> Vec<u8> {
151        query
152            .iter()
153            .enumerate()
154            .map(|(d, &v)| {
155                let val = (v - self.mins[d]) / self.scales[d];
156                val.round().clamp(0.0, 255.0) as u8
157            })
158            .collect()
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_sq8_roundtrip() {
168        let store = PointStore::from_parts(
169            vec![0.0, 0.0, 0.0, 255.0, 255.0, 255.0],
170            3,
171            vec![vec![0, 0]],
172        );
173        let sq8 = SQ8Store::build(&store);
174        assert_eq!(sq8.code(0), &[0, 0, 0]);
175        assert_eq!(sq8.code(1), &[255, 255, 255]);
176    }
177
178    #[test]
179    fn test_sq8_midpoint() {
180        let store = PointStore::from_parts(
181            vec![0.0, 0.0, 255.0, 255.0, 128.0, 128.0],
182            2,
183            vec![vec![0, 0, 0]],
184        );
185        let sq8 = SQ8Store::build(&store);
186        assert_eq!(sq8.code(2)[0], 128);
187    }
188
189    #[test]
190    fn test_sq8_identity_quantization() {
191        let store = PointStore::from_parts(
192            vec![10.0, 200.0, 50.0, 150.0, 0.0, 255.0],
193            2,
194            vec![vec![0, 0, 0]],
195        );
196        let sq8 = SQ8Store::build(&store);
197        assert_eq!(sq8.code(0), &[10, 200]);
198        assert_eq!(sq8.code(1), &[50, 150]);
199        assert_eq!(sq8.code(2), &[0, 255]);
200        assert_eq!(sq8.mins, vec![0.0, 0.0]);
201        assert_eq!(sq8.scales, vec![1.0, 1.0]);
202    }
203
204    #[test]
205    fn test_sq8_non_identity_for_float_data() {
206        let store = PointStore::from_parts(vec![0.0, 0.0, 1000.0, 500.5], 2, vec![vec![0, 0]]);
207        let sq8 = SQ8Store::build(&store);
208        assert_eq!(sq8.code(0), &[0, 0]);
209        assert_eq!(sq8.code(1), &[255, 255]);
210    }
211
212    #[test]
213    fn test_sq8_distance_ranking() {
214        use super::super::distance;
215        let store = PointStore::from_parts(
216            vec![0.0, 0.0, 100.0, 100.0, 200.0, 200.0],
217            2,
218            vec![vec![0, 0, 0]],
219        );
220        let sq8 = SQ8Store::build(&store);
221        let q = sq8.quantize_query(&[90.0, 90.0]);
222        let d0 = distance::l2_sq8(&q, sq8.code(0));
223        let d1 = distance::l2_sq8(&q, sq8.code(1));
224        let d2 = distance::l2_sq8(&q, sq8.code(2));
225        assert!(d1 < d0, "point 1 should be closer than point 0");
226        assert!(d1 < d2, "point 1 should be closer than point 2");
227    }
228}