1use super::DistanceMetric;
15
16pub struct QuantizedVectors {
18 pub dims: usize,
19 pub num_vectors: usize,
20 pub data: Vec<u8>,
22 pub mins: Vec<f32>,
24 pub scales: Vec<f32>,
26 pub norms: Vec<f32>,
28 pub metric: DistanceMetric,
30}
31
32impl QuantizedVectors {
33 pub fn quantize(vectors: &[Vec<f32>], metric: DistanceMetric) -> Self {
35 if vectors.is_empty() {
36 return Self {
37 dims: 0,
38 num_vectors: 0,
39 data: Vec::new(),
40 mins: Vec::new(),
41 scales: Vec::new(),
42 norms: Vec::new(),
43 metric,
44 };
45 }
46
47 let dims = vectors[0].len();
48 let num_vectors = vectors.len();
49
50 let mut mins = vec![f32::MAX; dims];
52 let mut maxs = vec![f32::MIN; dims];
53 for v in vectors {
54 for d in 0..dims {
55 if v[d] < mins[d] {
56 mins[d] = v[d];
57 }
58 if v[d] > maxs[d] {
59 maxs[d] = v[d];
60 }
61 }
62 }
63
64 let scales: Vec<f32> = (0..dims)
66 .map(|d| {
67 let range = maxs[d] - mins[d];
68 if range == 0.0 { 0.0 } else { range / 255.0 }
69 })
70 .collect();
71
72 let mut data = vec![0u8; num_vectors * dims];
74 let mut norms = vec![0.0f32; num_vectors];
75
76 for (i, v) in vectors.iter().enumerate() {
77 let offset = i * dims;
78 let mut norm_sq = 0.0f32;
79 for d in 0..dims {
80 let q = if scales[d] == 0.0 {
81 128u8 } else {
83 ((v[d] - mins[d]) / scales[d]).round().clamp(0.0, 255.0) as u8
84 };
85 data[offset + d] = q;
86 let dequant = mins[d] + q as f32 * scales[d];
88 norm_sq += dequant * dequant;
89 }
90 norms[i] = norm_sq.sqrt();
91 }
92
93 Self {
94 dims,
95 num_vectors,
96 data,
97 mins,
98 scales,
99 norms,
100 metric,
101 }
102 }
103
104 #[inline]
106 pub fn get(&self, idx: usize) -> &[u8] {
107 let start = idx * self.dims;
108 &self.data[start..start + self.dims]
109 }
110
111 #[inline]
120 pub fn asymmetric_distance(&self, idx: usize, query: &[f32]) -> f32 {
121 match self.metric {
122 DistanceMetric::Cosine => self.asymmetric_cosine(idx, query),
123 DistanceMetric::DotProduct => self.asymmetric_dot(idx, query),
124 DistanceMetric::L2 => self.asymmetric_l2(idx, query),
125 }
126 }
127
128 fn asymmetric_cosine(&self, idx: usize, query: &[f32]) -> f32 {
132 let quantized = self.get(idx);
133 let mut dot = 0.0f32;
134 for d in 0..self.dims {
135 let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
136 dot += dequant * query[d];
137 }
138 let stored_norm = self.norms[idx];
139 if stored_norm == 0.0 {
140 1.0
141 } else {
142 1.0 - dot / stored_norm
143 }
144 }
145
146 fn asymmetric_dot(&self, idx: usize, query: &[f32]) -> f32 {
148 let quantized = self.get(idx);
149 let mut dot = 0.0f32;
150 for d in 0..self.dims {
151 let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
152 dot += dequant * query[d];
153 }
154 -dot
155 }
156
157 fn asymmetric_l2(&self, idx: usize, query: &[f32]) -> f32 {
159 let quantized = self.get(idx);
160 let mut sum_sq = 0.0f32;
161 for d in 0..self.dims {
162 let dequant = self.mins[d] + quantized[d] as f32 * self.scales[d];
163 let diff = dequant - query[d];
164 sum_sq += diff * diff;
165 }
166 sum_sq.sqrt()
167 }
168
169 pub fn to_bytes(&self) -> Vec<u8> {
171 let mut buf = Vec::new();
172 buf.extend_from_slice(&(self.dims as u32).to_le_bytes());
173 buf.extend_from_slice(&(self.num_vectors as u32).to_le_bytes());
174 buf.push(self.metric as u8);
175 for &m in &self.mins {
177 buf.extend_from_slice(&m.to_le_bytes());
178 }
179 for &s in &self.scales {
181 buf.extend_from_slice(&s.to_le_bytes());
182 }
183 for &n in &self.norms {
185 buf.extend_from_slice(&n.to_le_bytes());
186 }
187 buf.extend_from_slice(&self.data);
189 buf
190 }
191
192 pub fn from_bytes(data: &[u8]) -> Self {
194 let dims = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
195 let num_vectors = u32::from_le_bytes(data[4..8].try_into().unwrap()) as usize;
196 let metric = DistanceMetric::from_byte(data[8]);
197 let mut pos = 9;
198
199 let mut mins = vec![0.0f32; dims];
200 for d in 0..dims {
201 mins[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
202 pos += 4;
203 }
204
205 let mut scales = vec![0.0f32; dims];
206 for d in 0..dims {
207 scales[d] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
208 pos += 4;
209 }
210
211 let mut norms = vec![0.0f32; num_vectors];
212 for i in 0..num_vectors {
213 norms[i] = f32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
214 pos += 4;
215 }
216
217 let qdata = data[pos..pos + num_vectors * dims].to_vec();
218
219 Self {
220 dims,
221 num_vectors,
222 data: qdata,
223 mins,
224 scales,
225 norms,
226 metric,
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn quantize_round_trip() {
237 let vectors = vec![
238 vec![1.0, 2.0, 3.0],
239 vec![4.0, 5.0, 6.0],
240 vec![7.0, 8.0, 9.0],
241 ];
242 let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);
243
244 assert_eq!(qv.dims, 3);
245 assert_eq!(qv.num_vectors, 3);
246 assert_eq!(qv.get(0), &[0, 0, 0]);
248 assert_eq!(qv.get(2), &[255, 255, 255]);
250 }
251
252 #[test]
253 fn asymmetric_cosine_close_to_exact() {
254 let vectors = vec![
257 vec![1.0, 0.0, 0.0],
258 vec![0.0, 1.0, 0.0],
259 vec![0.707, 0.707, 0.0],
260 ];
261 let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::Cosine);
262
263 let query = vec![1.0, 0.0, 0.0];
264
265 let d0 = qv.asymmetric_distance(0, &query);
266 let d1 = qv.asymmetric_distance(1, &query);
267 let d2 = qv.asymmetric_distance(2, &query);
268
269 assert!(d0 < d2, "d0={d0} should be < d2={d2}");
271 assert!(d2 < d1, "d2={d2} should be < d1={d1}");
272 }
273
274 #[test]
275 fn serialization_round_trip() {
276 let vectors = vec![vec![1.5, -2.3, 0.7, 4.1], vec![-0.5, 3.2, 1.1, -1.0]];
277 let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
278 let bytes = qv.to_bytes();
279 let qv2 = QuantizedVectors::from_bytes(&bytes);
280
281 assert_eq!(qv.dims, qv2.dims);
282 assert_eq!(qv.num_vectors, qv2.num_vectors);
283 assert_eq!(qv.data, qv2.data);
284 assert_eq!(qv.mins, qv2.mins);
285 assert_eq!(qv.scales, qv2.scales);
286 }
287
288 #[test]
289 fn empty_vectors() {
290 let qv = QuantizedVectors::quantize(&[], DistanceMetric::Cosine);
291 assert_eq!(qv.num_vectors, 0);
292 assert_eq!(qv.dims, 0);
293 }
294
295 #[test]
296 fn constant_dimension() {
297 let vectors = vec![vec![1.0, 5.0], vec![2.0, 5.0], vec![3.0, 5.0]];
299 let qv = QuantizedVectors::quantize(&vectors, DistanceMetric::L2);
300 assert_eq!(qv.get(0)[1], 128);
302 assert_eq!(qv.get(1)[1], 128);
303 assert_eq!(qv.get(2)[1], 128);
304 }
305}