citadel_vector/vendored/prism/
quantize.rs1use super::point::PointStore;
2
3pub struct SQ8Store {
6 codes: Vec<u8>,
7 mins: Vec<f32>,
8 scales: Vec<f32>,
9 dim: usize,
10}
11
12impl Drop for SQ8Store {
13 fn drop(&mut self) {
14 use zeroize::Zeroize;
17 self.codes.zeroize();
18 self.mins.zeroize();
19 self.scales.zeroize();
20 }
21}
22
23impl SQ8Store {
24 pub fn build(store: &PointStore) -> Self {
26 let n = store.len;
27 let dim = store.dim;
28
29 let all_integer_byte = (0..n).all(|i| {
31 store
32 .vector(i as u32)
33 .iter()
34 .all(|&v| (0.0..=255.0).contains(&v) && v == v.round())
35 });
36
37 if all_integer_byte {
39 let mut codes = vec![0u8; n * dim];
40 for i in 0..n {
41 let vec = store.vector(i as u32);
42 let off = i * dim;
43 for d in 0..dim {
44 codes[off + d] = vec[d] as u8;
45 }
46 }
47 return Self {
48 codes,
49 mins: vec![0.0; dim],
50 scales: vec![1.0; dim],
51 dim,
52 };
53 }
54
55 let sample_n = n.min(10_000);
57 let step = (n / sample_n).max(1);
58
59 let (mins, maxs) = if sample_n >= 200 {
60 let mut mins = vec![0.0f32; dim];
61 let mut maxs = vec![0.0f32; dim];
62 for d in 0..dim {
63 let mut sample: Vec<f32> = (0..sample_n)
64 .map(|s| store.vector(((s * step).min(n - 1)) as u32)[d])
65 .collect();
66 sample.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
67 let lo = sample_n / 200;
68 let hi = sample_n.saturating_sub(1 + sample_n / 200);
69 mins[d] = sample[lo];
70 maxs[d] = sample[hi.max(lo + 1).min(sample_n - 1)];
71 }
72 (mins, maxs)
73 } else {
74 let mut mins = vec![f32::MAX; dim];
75 let mut maxs = vec![f32::MIN; dim];
76 for i in 0..n {
77 let vec = store.vector(i as u32);
78 for d in 0..dim {
79 mins[d] = mins[d].min(vec[d]);
80 maxs[d] = maxs[d].max(vec[d]);
81 }
82 }
83 (mins, maxs)
84 };
85
86 let scales: Vec<f32> = mins
87 .iter()
88 .zip(maxs.iter())
89 .map(|(&mn, &mx)| {
90 let range = mx - mn;
91 if range > 0.0 {
92 range / 255.0
93 } else {
94 1.0
95 }
96 })
97 .collect();
98
99 let mut codes = vec![0u8; n * dim];
100 for i in 0..n {
101 let vec = store.vector(i as u32);
102 let off = i * dim;
103 for d in 0..dim {
104 let val = (vec[d] - mins[d]) / scales[d];
105 codes[off + d] = val.round().clamp(0.0, 255.0) as u8;
106 }
107 }
108
109 Self {
110 codes,
111 mins,
112 scales,
113 dim,
114 }
115 }
116
117 #[inline]
119 pub fn code(&self, id: u32) -> &[u8] {
120 let start = id as usize * self.dim;
121 &self.codes[start..start + self.dim]
122 }
123
124 pub fn quantize_query(&self, query: &[f32]) -> Vec<u8> {
126 query
127 .iter()
128 .enumerate()
129 .map(|(d, &v)| {
130 let val = (v - self.mins[d]) / self.scales[d];
131 val.round().clamp(0.0, 255.0) as u8
132 })
133 .collect()
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_sq8_roundtrip() {
143 let store = PointStore::from_parts(
145 vec![0.0, 0.0, 0.0, 255.0, 255.0, 255.0],
146 3,
147 vec![vec![0, 0]],
148 );
149 let sq8 = SQ8Store::build(&store);
150 assert_eq!(sq8.code(0), &[0, 0, 0]);
151 assert_eq!(sq8.code(1), &[255, 255, 255]);
152 }
153
154 #[test]
155 fn test_sq8_midpoint() {
156 let store = PointStore::from_parts(
158 vec![0.0, 0.0, 255.0, 255.0, 128.0, 128.0],
159 2,
160 vec![vec![0, 0, 0]],
161 );
162 let sq8 = SQ8Store::build(&store);
163 assert_eq!(sq8.code(2)[0], 128);
165 }
166
167 #[test]
168 fn test_sq8_identity_quantization() {
169 let store = PointStore::from_parts(
171 vec![10.0, 200.0, 50.0, 150.0, 0.0, 255.0],
172 2,
173 vec![vec![0, 0, 0]],
174 );
175 let sq8 = SQ8Store::build(&store);
176 assert_eq!(sq8.code(0), &[10, 200]);
178 assert_eq!(sq8.code(1), &[50, 150]);
179 assert_eq!(sq8.code(2), &[0, 255]);
180 assert_eq!(sq8.mins, vec![0.0, 0.0]);
182 assert_eq!(sq8.scales, vec![1.0, 1.0]);
183 }
184
185 #[test]
186 fn test_sq8_non_identity_for_float_data() {
187 let store = PointStore::from_parts(vec![0.0, 0.0, 1000.0, 500.5], 2, vec![vec![0, 0]]);
189 let sq8 = SQ8Store::build(&store);
190 assert_eq!(sq8.code(0), &[0, 0]);
192 assert_eq!(sq8.code(1), &[255, 255]);
193 }
194
195 #[test]
196 fn test_sq8_distance_ranking() {
197 use super::super::distance;
198 let store = PointStore::from_parts(
200 vec![0.0, 0.0, 100.0, 100.0, 200.0, 200.0],
201 2,
202 vec![vec![0, 0, 0]],
203 );
204 let sq8 = SQ8Store::build(&store);
205 let q = sq8.quantize_query(&[90.0, 90.0]);
206 let d0 = distance::l2_sq8(&q, sq8.code(0));
207 let d1 = distance::l2_sq8(&q, sq8.code(1));
208 let d2 = distance::l2_sq8(&q, sq8.code(2));
209 assert!(d1 < d0, "point 1 should be closer than point 0");
211 assert!(d1 < d2, "point 1 should be closer than point 2");
212 }
213}