citadel_vector/vendored/prism/
quantize.rs1use super::point::PointStore;
2
3pub struct SQ8Store {
6 codes: Vec<u8>,
7 mins: Vec<f32>,
8 scales: Vec<f32>,
9 dim: usize,
10}
11
12impl Drop for SQ8Store {
13 fn drop(&mut self) {
14 use zeroize::Zeroize;
17 self.codes.zeroize();
18 self.mins.zeroize();
19 self.scales.zeroize();
20 }
21}
22
23impl SQ8Store {
24 pub fn from_parts(codes: Vec<u8>, mins: Vec<f32>, scales: Vec<f32>, dim: usize) -> Self {
27 Self {
28 codes,
29 mins,
30 scales,
31 dim,
32 }
33 }
34
35 pub fn codes(&self) -> &[u8] {
36 &self.codes
37 }
38
39 pub fn mins(&self) -> &[f32] {
40 &self.mins
41 }
42
43 pub fn scales(&self) -> &[f32] {
44 &self.scales
45 }
46
47 pub fn dim(&self) -> usize {
48 self.dim
49 }
50
51 pub fn build(store: &PointStore) -> Self {
53 let n = store.len;
54 let dim = store.dim;
55
56 let all_integer_byte = (0..n).all(|i| {
57 store
58 .vector(i as u32)
59 .iter()
60 .all(|&v| (0.0..=255.0).contains(&v) && v == v.round())
61 });
62
63 if all_integer_byte {
64 let mut codes = vec![0u8; n * dim];
65 for i in 0..n {
66 let vec = store.vector(i as u32);
67 let off = i * dim;
68 for d in 0..dim {
69 codes[off + d] = vec[d] as u8;
70 }
71 }
72 return Self {
73 codes,
74 mins: vec![0.0; dim],
75 scales: vec![1.0; dim],
76 dim,
77 };
78 }
79
80 let sample_n = n.min(10_000);
82 let step = (n / sample_n).max(1);
83
84 let (mins, maxs) = if sample_n >= 200 {
85 let mut mins = vec![0.0f32; dim];
86 let mut maxs = vec![0.0f32; dim];
87 for d in 0..dim {
88 let mut sample: Vec<f32> = (0..sample_n)
89 .map(|s| store.vector(((s * step).min(n - 1)) as u32)[d])
90 .collect();
91 sample.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
92 let lo = sample_n / 200;
93 let hi = sample_n.saturating_sub(1 + sample_n / 200);
94 mins[d] = sample[lo];
95 maxs[d] = sample[hi.max(lo + 1).min(sample_n - 1)];
96 }
97 (mins, maxs)
98 } else {
99 let mut mins = vec![f32::MAX; dim];
100 let mut maxs = vec![f32::MIN; dim];
101 for i in 0..n {
102 let vec = store.vector(i as u32);
103 for d in 0..dim {
104 mins[d] = mins[d].min(vec[d]);
105 maxs[d] = maxs[d].max(vec[d]);
106 }
107 }
108 (mins, maxs)
109 };
110
111 let scales: Vec<f32> = mins
112 .iter()
113 .zip(maxs.iter())
114 .map(|(&mn, &mx)| {
115 let range = mx - mn;
116 if range > 0.0 {
117 range / 255.0
118 } else {
119 1.0
120 }
121 })
122 .collect();
123
124 let mut codes = vec![0u8; n * dim];
125 for i in 0..n {
126 let vec = store.vector(i as u32);
127 let off = i * dim;
128 for d in 0..dim {
129 let val = (vec[d] - mins[d]) / scales[d];
130 codes[off + d] = val.round().clamp(0.0, 255.0) as u8;
131 }
132 }
133
134 Self {
135 codes,
136 mins,
137 scales,
138 dim,
139 }
140 }
141
142 #[inline]
144 pub fn code(&self, id: u32) -> &[u8] {
145 let start = id as usize * self.dim;
146 &self.codes[start..start + self.dim]
147 }
148
149 pub fn quantize_query(&self, query: &[f32]) -> Vec<u8> {
151 query
152 .iter()
153 .enumerate()
154 .map(|(d, &v)| {
155 let val = (v - self.mins[d]) / self.scales[d];
156 val.round().clamp(0.0, 255.0) as u8
157 })
158 .collect()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_sq8_roundtrip() {
168 let store = PointStore::from_parts(
169 vec![0.0, 0.0, 0.0, 255.0, 255.0, 255.0],
170 3,
171 vec![vec![0, 0]],
172 );
173 let sq8 = SQ8Store::build(&store);
174 assert_eq!(sq8.code(0), &[0, 0, 0]);
175 assert_eq!(sq8.code(1), &[255, 255, 255]);
176 }
177
178 #[test]
179 fn test_sq8_midpoint() {
180 let store = PointStore::from_parts(
181 vec![0.0, 0.0, 255.0, 255.0, 128.0, 128.0],
182 2,
183 vec![vec![0, 0, 0]],
184 );
185 let sq8 = SQ8Store::build(&store);
186 assert_eq!(sq8.code(2)[0], 128);
187 }
188
189 #[test]
190 fn test_sq8_identity_quantization() {
191 let store = PointStore::from_parts(
192 vec![10.0, 200.0, 50.0, 150.0, 0.0, 255.0],
193 2,
194 vec![vec![0, 0, 0]],
195 );
196 let sq8 = SQ8Store::build(&store);
197 assert_eq!(sq8.code(0), &[10, 200]);
198 assert_eq!(sq8.code(1), &[50, 150]);
199 assert_eq!(sq8.code(2), &[0, 255]);
200 assert_eq!(sq8.mins, vec![0.0, 0.0]);
201 assert_eq!(sq8.scales, vec![1.0, 1.0]);
202 }
203
204 #[test]
205 fn test_sq8_non_identity_for_float_data() {
206 let store = PointStore::from_parts(vec![0.0, 0.0, 1000.0, 500.5], 2, vec![vec![0, 0]]);
207 let sq8 = SQ8Store::build(&store);
208 assert_eq!(sq8.code(0), &[0, 0]);
209 assert_eq!(sq8.code(1), &[255, 255]);
210 }
211
212 #[test]
213 fn test_sq8_distance_ranking() {
214 use super::super::distance;
215 let store = PointStore::from_parts(
216 vec![0.0, 0.0, 100.0, 100.0, 200.0, 200.0],
217 2,
218 vec![vec![0, 0, 0]],
219 );
220 let sq8 = SQ8Store::build(&store);
221 let q = sq8.quantize_query(&[90.0, 90.0]);
222 let d0 = distance::l2_sq8(&q, sq8.code(0));
223 let d1 = distance::l2_sq8(&q, sq8.code(1));
224 let d2 = distance::l2_sq8(&q, sq8.code(2));
225 assert!(d1 < d0, "point 1 should be closer than point 0");
226 assert!(d1 < d2, "point 1 should be closer than point 2");
227 }
228}