citadel_vector/vendored/prism/
quantize.rs1use super::point::PointStore;
2
3pub struct SQ8Store {
6 codes: Vec<u8>,
7 mins: Vec<f32>,
8 scales: Vec<f32>,
9 dim: usize,
10}
11
12impl Drop for SQ8Store {
13 fn drop(&mut self) {
14 use zeroize::Zeroize;
17 self.codes.zeroize();
18 self.mins.zeroize();
19 self.scales.zeroize();
20 }
21}
22
23impl SQ8Store {
24 pub fn from_parts(codes: Vec<u8>, mins: Vec<f32>, scales: Vec<f32>, dim: usize) -> Self {
27 Self {
28 codes,
29 mins,
30 scales,
31 dim,
32 }
33 }
34
35 pub fn codes(&self) -> &[u8] {
36 &self.codes
37 }
38
39 pub fn mins(&self) -> &[f32] {
40 &self.mins
41 }
42
43 pub fn scales(&self) -> &[f32] {
44 &self.scales
45 }
46
47 pub fn dim(&self) -> usize {
48 self.dim
49 }
50
51 pub fn build(store: &PointStore) -> Self {
53 let n = store.len;
54 let dim = store.dim;
55
56 let all_integer_byte = (0..n).all(|i| {
57 store
58 .vector(i as u32)
59 .iter()
60 .all(|&v| (0.0..=255.0).contains(&v) && v == v.round())
61 });
62
63 if all_integer_byte {
64 let mut codes = vec![0u8; n * dim];
65 for i in 0..n {
66 let vec = store.vector(i as u32);
67 let off = i * dim;
68 for d in 0..dim {
69 codes[off + d] = vec[d] as u8;
70 }
71 }
72 return Self {
73 codes,
74 mins: vec![0.0; dim],
75 scales: vec![1.0; dim],
76 dim,
77 };
78 }
79
80 let sample_n = n.min(10_000);
82
83 let (mins, maxs) = if sample_n >= 200 {
84 let mut mins = vec![0.0f32; dim];
85 let mut maxs = vec![0.0f32; dim];
86 for d in 0..dim {
87 let mut sample: Vec<f32> = (0..sample_n)
90 .map(|s| {
91 let idx = ((s as u64 * n as u64) / sample_n as u64) as usize;
92 store.vector(idx.min(n - 1) as u32)[d]
93 })
94 .collect();
95 sample.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
96 let lo = sample_n / 200;
97 let hi = sample_n.saturating_sub(1 + sample_n / 200);
98 mins[d] = sample[lo];
99 maxs[d] = sample[hi.max(lo + 1).min(sample_n - 1)];
100 }
101 (mins, maxs)
102 } else {
103 let mut mins = vec![f32::MAX; dim];
104 let mut maxs = vec![f32::MIN; dim];
105 for i in 0..n {
106 let vec = store.vector(i as u32);
107 for d in 0..dim {
108 mins[d] = mins[d].min(vec[d]);
109 maxs[d] = maxs[d].max(vec[d]);
110 }
111 }
112 (mins, maxs)
113 };
114
115 let scales: Vec<f32> = mins
116 .iter()
117 .zip(maxs.iter())
118 .map(|(&mn, &mx)| {
119 let range = mx - mn;
120 if range > 0.0 {
121 range / 255.0
122 } else {
123 1.0
124 }
125 })
126 .collect();
127
128 let mut codes = vec![0u8; n * dim];
129 for i in 0..n {
130 let vec = store.vector(i as u32);
131 let off = i * dim;
132 for d in 0..dim {
133 let val = (vec[d] - mins[d]) / scales[d];
134 codes[off + d] = val.round().clamp(0.0, 255.0) as u8;
135 }
136 }
137
138 Self {
139 codes,
140 mins,
141 scales,
142 dim,
143 }
144 }
145
146 #[inline]
148 pub fn code(&self, id: u32) -> &[u8] {
149 let start = id as usize * self.dim;
150 &self.codes[start..start + self.dim]
151 }
152
153 pub fn quantize_query(&self, query: &[f32]) -> Vec<u8> {
155 query
156 .iter()
157 .enumerate()
158 .map(|(d, &v)| {
159 let val = (v - self.mins[d]) / self.scales[d];
160 val.round().clamp(0.0, 255.0) as u8
161 })
162 .collect()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_sq8_roundtrip() {
172 let store = PointStore::from_parts(
173 vec![0.0, 0.0, 0.0, 255.0, 255.0, 255.0],
174 3,
175 vec![vec![0, 0]],
176 );
177 let sq8 = SQ8Store::build(&store);
178 assert_eq!(sq8.code(0), &[0, 0, 0]);
179 assert_eq!(sq8.code(1), &[255, 255, 255]);
180 }
181
182 #[test]
183 fn test_sq8_midpoint() {
184 let store = PointStore::from_parts(
185 vec![0.0, 0.0, 255.0, 255.0, 128.0, 128.0],
186 2,
187 vec![vec![0, 0, 0]],
188 );
189 let sq8 = SQ8Store::build(&store);
190 assert_eq!(sq8.code(2)[0], 128);
191 }
192
193 #[test]
194 fn test_sq8_identity_quantization() {
195 let store = PointStore::from_parts(
196 vec![10.0, 200.0, 50.0, 150.0, 0.0, 255.0],
197 2,
198 vec![vec![0, 0, 0]],
199 );
200 let sq8 = SQ8Store::build(&store);
201 assert_eq!(sq8.code(0), &[10, 200]);
202 assert_eq!(sq8.code(1), &[50, 150]);
203 assert_eq!(sq8.code(2), &[0, 255]);
204 assert_eq!(sq8.mins, vec![0.0, 0.0]);
205 assert_eq!(sq8.scales, vec![1.0, 1.0]);
206 }
207
208 #[test]
209 fn test_sq8_non_identity_for_float_data() {
210 let store = PointStore::from_parts(vec![0.0, 0.0, 1000.0, 500.5], 2, vec![vec![0, 0]]);
211 let sq8 = SQ8Store::build(&store);
212 assert_eq!(sq8.code(0), &[0, 0]);
213 assert_eq!(sq8.code(1), &[255, 255]);
214 }
215
216 #[test]
217 fn test_sq8_sampling_covers_tail_when_n_not_multiple_of_sample() {
218 let n = 12_500;
222 let mut vectors = Vec::with_capacity(n);
223 for i in 0..n {
224 if i < 10_000 {
225 vectors.push((i % 100) as f32 / 100.0 + 0.25);
226 } else {
227 vectors.push(100.5);
228 }
229 }
230 let store = PointStore::from_parts(vectors, 1, vec![vec![0; n]]);
231 let sq8 = SQ8Store::build(&store);
232 assert!(
233 sq8.scales[0] > 0.3,
234 "scale {} must reflect the tail range ~[0,100], not the prefix [0,1]",
235 sq8.scales[0]
236 );
237 assert_eq!(sq8.code((n - 1) as u32)[0], 255);
238 }
239
240 #[test]
241 fn test_sq8_distance_ranking() {
242 use super::super::distance;
243 let store = PointStore::from_parts(
244 vec![0.0, 0.0, 100.0, 100.0, 200.0, 200.0],
245 2,
246 vec![vec![0, 0, 0]],
247 );
248 let sq8 = SQ8Store::build(&store);
249 let q = sq8.quantize_query(&[90.0, 90.0]);
250 let d0 = distance::l2_sq8(&q, sq8.code(0));
251 let d1 = distance::l2_sq8(&q, sq8.code(1));
252 let d2 = distance::l2_sq8(&q, sq8.code(2));
253 assert!(d1 < d0, "point 1 should be closer than point 0");
254 assert!(d1 < d2, "point 1 should be closer than point 2");
255 }
256}