1use serde::{Deserialize, Serialize};
19
20#[cfg(target_arch = "x86_64")]
21#[allow(clippy::wildcard_imports)]
22use std::arch::x86_64::*;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ScalarParams {
30 pub scale: f32,
32 pub offset: f32,
34 pub dimensions: usize,
36}
37
38#[derive(Debug, Clone)]
40pub struct QuantizedVector {
41 pub data: Vec<u8>,
43 pub sum: i32,
45 pub norm_sq: f32,
47}
48
49#[derive(Debug, Clone)]
51pub struct QueryPrep {
52 pub quantized: Vec<u8>,
54 pub norm_sq: f32,
56 pub sum: i32,
58}
59
60impl ScalarParams {
61 #[must_use]
65 pub fn uninitialized(dimensions: usize) -> Self {
66 Self {
67 scale: 1.0 / 255.0,
68 offset: 0.0,
69 dimensions,
70 }
71 }
72
73 pub fn train(vectors: &[&[f32]]) -> Result<Self, &'static str> {
80 Self::train_with_percentiles(vectors, 0.01, 0.99)
81 }
82
83 pub fn train_with_percentiles(
85 vectors: &[&[f32]],
86 lower_percentile: f32,
87 upper_percentile: f32,
88 ) -> Result<Self, &'static str> {
89 if vectors.is_empty() {
90 return Err("Need at least one vector to train");
91 }
92 let dimensions = vectors[0].len();
93 if !vectors.iter().all(|v| v.len() == dimensions) {
94 return Err("All vectors must have same dimensions");
95 }
96
97 let mut all_values: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
99 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100
101 let n = all_values.len();
102 let lower_idx = ((n as f32 * lower_percentile) as usize).min(n - 1);
103 let upper_idx = ((n as f32 * upper_percentile) as usize).min(n - 1);
104
105 let min_val = all_values[lower_idx];
106 let max_val = all_values[upper_idx];
107
108 let range = max_val - min_val;
109 let (offset, scale) = if range < 1e-7 {
110 (min_val - 0.5, 1.0 / 255.0)
111 } else {
112 (min_val, range / 255.0)
113 };
114
115 Ok(Self {
116 scale,
117 offset,
118 dimensions,
119 })
120 }
121
122 #[must_use]
124 pub fn quantize(&self, vector: &[f32]) -> QuantizedVector {
125 debug_assert_eq!(vector.len(), self.dimensions);
126
127 let inv_scale = 1.0 / self.scale;
128 let data: Vec<u8> = vector
129 .iter()
130 .map(|&v| ((v - self.offset) * inv_scale).clamp(0.0, 255.0).round() as u8)
131 .collect();
132
133 let sum: i32 = data.iter().map(|&x| x as i32).sum();
134
135 let norm_sq: f32 = data
137 .iter()
138 .map(|&x| {
139 let dequant = x as f32 * self.scale + self.offset;
140 dequant * dequant
141 })
142 .sum();
143
144 QuantizedVector { data, sum, norm_sq }
145 }
146
147 #[must_use]
149 pub fn quantize_to_bytes(&self, vector: &[f32]) -> Vec<u8> {
150 debug_assert_eq!(vector.len(), self.dimensions);
151
152 let inv_scale = 1.0 / self.scale;
153 vector
154 .iter()
155 .map(|&v| ((v - self.offset) * inv_scale).clamp(0.0, 255.0).round() as u8)
156 .collect()
157 }
158
159 #[must_use]
161 pub fn dequantize(&self, quantized: &[u8]) -> Vec<f32> {
162 quantized
163 .iter()
164 .map(|&q| q as f32 * self.scale + self.offset)
165 .collect()
166 }
167
168 #[must_use]
170 pub fn dequantized_norm_squared(&self, quantized: &[u8]) -> f32 {
171 quantized
172 .iter()
173 .map(|&q| {
174 let dequant = q as f32 * self.scale + self.offset;
175 dequant * dequant
176 })
177 .sum()
178 }
179
180 #[must_use]
182 pub fn quantized_sum(&self, quantized: &[u8]) -> i32 {
183 quantized.iter().map(|&x| x as i32).sum()
184 }
185
186 #[must_use]
188 pub fn prepare_query(&self, query: &[f32]) -> QueryPrep {
189 debug_assert_eq!(query.len(), self.dimensions);
190
191 let inv_scale = 1.0 / self.scale;
192 let quantized: Vec<u8> = query
193 .iter()
194 .map(|&v| ((v - self.offset) * inv_scale).clamp(0.0, 255.0).round() as u8)
195 .collect();
196
197 let norm_sq: f32 = query.iter().map(|x| x * x).sum();
198 let sum: i32 = quantized.iter().map(|&x| x as i32).sum();
199
200 QueryPrep {
201 quantized,
202 norm_sq,
203 sum,
204 }
205 }
206
207 #[inline(always)]
212 #[must_use]
213 pub fn distance_l2_squared(&self, query_prep: &QueryPrep, vec: &QuantizedVector) -> f32 {
214 let int_dot = self.int_dot_product(&query_prep.quantized, &vec.data);
216
217 let scale_sq = self.scale * self.scale;
222 let dot = scale_sq * int_dot as f32
223 + self.scale * self.offset * (query_prep.sum + vec.sum) as f32
224 + self.offset * self.offset * self.dimensions as f32;
225
226 query_prep.norm_sq + vec.norm_sq - 2.0 * dot
228 }
229
230 #[inline(always)]
234 #[must_use]
235 pub fn distance_l2_squared_raw(
236 &self,
237 query_prep: &QueryPrep,
238 vec_data: &[u8],
239 vec_sum: i32,
240 vec_norm_sq: f32,
241 ) -> f32 {
242 let int_dot = self.int_dot_product(&query_prep.quantized, vec_data);
243
244 let scale_sq = self.scale * self.scale;
245 let dot = scale_sq * int_dot as f32
246 + self.scale * self.offset * (query_prep.sum + vec_sum) as f32
247 + self.offset * self.offset * self.dimensions as f32;
248
249 query_prep.norm_sq + vec_norm_sq - 2.0 * dot
250 }
251
252 #[inline(always)]
254 fn int_dot_product(&self, query: &[u8], vec: &[u8]) -> u32 {
255 debug_assert_eq!(query.len(), vec.len());
256
257 #[cfg(target_arch = "x86_64")]
258 {
259 if is_x86_feature_detected!("avx2") {
260 return unsafe { self.int_dot_product_avx2(query, vec) };
261 }
262 Self::int_dot_product_scalar(query, vec)
263 }
264
265 #[cfg(target_arch = "aarch64")]
266 {
267 unsafe { self.int_dot_product_neon(query, vec) }
268 }
269
270 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
271 {
272 Self::int_dot_product_scalar(query, vec)
273 }
274 }
275
276 #[inline]
277 #[allow(dead_code)]
278 fn int_dot_product_scalar(query: &[u8], vec: &[u8]) -> u32 {
279 query
280 .iter()
281 .zip(vec.iter())
282 .map(|(&q, &v)| q as u32 * v as u32)
283 .sum()
284 }
285
286 #[cfg(target_arch = "x86_64")]
287 #[target_feature(enable = "avx2")]
288 #[allow(clippy::unused_self)]
289 unsafe fn int_dot_product_avx2(&self, query: &[u8], vec: &[u8]) -> u32 {
290 let mut sum = _mm256_setzero_si256();
291 let mut i = 0;
292
293 while i + 32 <= query.len() {
294 let q = _mm256_loadu_si256(query.as_ptr().add(i).cast());
295 let v = _mm256_loadu_si256(vec.as_ptr().add(i).cast());
296
297 let q_lo = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q, 0));
298 let q_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q, 1));
299 let v_lo = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v, 0));
300 let v_hi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v, 1));
301
302 let prod_lo = _mm256_madd_epi16(q_lo, v_lo);
303 let prod_hi = _mm256_madd_epi16(q_hi, v_hi);
304 sum = _mm256_add_epi32(sum, prod_lo);
305 sum = _mm256_add_epi32(sum, prod_hi);
306
307 i += 32;
308 }
309
310 while i + 16 <= query.len() {
311 let q = _mm256_cvtepu8_epi16(_mm_loadu_si128(query.as_ptr().add(i).cast()));
312 let v = _mm256_cvtepu8_epi16(_mm_loadu_si128(vec.as_ptr().add(i).cast()));
313 let prod = _mm256_madd_epi16(q, v);
314 sum = _mm256_add_epi32(sum, prod);
315 i += 16;
316 }
317
318 let sum128 = _mm_add_epi32(
319 _mm256_extracti128_si256(sum, 0),
320 _mm256_extracti128_si256(sum, 1),
321 );
322 let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
323 let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
324 let mut result = _mm_cvtsi128_si32(sum32) as u32;
325
326 for j in i..query.len() {
327 result += query[j] as u32 * vec[j] as u32;
328 }
329
330 result
331 }
332
333 #[cfg(target_arch = "aarch64")]
334 #[inline(always)]
335 #[allow(clippy::unused_self)]
336 unsafe fn int_dot_product_neon(&self, query: &[u8], vec: &[u8]) -> u32 {
337 use std::arch::aarch64::{
338 vaddq_u32, vaddvq_u32, vdupq_n_u32, vget_low_u8, vld1q_u8, vmull_high_u8, vmull_u8,
339 vpadalq_u16,
340 };
341
342 let mut sum0 = vdupq_n_u32(0);
344 let mut sum1 = vdupq_n_u32(0);
345 let mut sum2 = vdupq_n_u32(0);
346 let mut sum3 = vdupq_n_u32(0);
347 let mut i = 0;
348
349 while i + 64 <= query.len() {
351 let q0 = vld1q_u8(query.as_ptr().add(i));
352 let v0 = vld1q_u8(vec.as_ptr().add(i));
353 let prod0_lo = vmull_u8(vget_low_u8(q0), vget_low_u8(v0));
354 let prod0_hi = vmull_high_u8(q0, v0);
355 sum0 = vpadalq_u16(sum0, prod0_lo);
356 sum0 = vpadalq_u16(sum0, prod0_hi);
357
358 let q1 = vld1q_u8(query.as_ptr().add(i + 16));
359 let v1 = vld1q_u8(vec.as_ptr().add(i + 16));
360 let prod1_lo = vmull_u8(vget_low_u8(q1), vget_low_u8(v1));
361 let prod1_hi = vmull_high_u8(q1, v1);
362 sum1 = vpadalq_u16(sum1, prod1_lo);
363 sum1 = vpadalq_u16(sum1, prod1_hi);
364
365 let q2 = vld1q_u8(query.as_ptr().add(i + 32));
366 let v2 = vld1q_u8(vec.as_ptr().add(i + 32));
367 let prod2_lo = vmull_u8(vget_low_u8(q2), vget_low_u8(v2));
368 let prod2_hi = vmull_high_u8(q2, v2);
369 sum2 = vpadalq_u16(sum2, prod2_lo);
370 sum2 = vpadalq_u16(sum2, prod2_hi);
371
372 let q3 = vld1q_u8(query.as_ptr().add(i + 48));
373 let v3 = vld1q_u8(vec.as_ptr().add(i + 48));
374 let prod3_lo = vmull_u8(vget_low_u8(q3), vget_low_u8(v3));
375 let prod3_hi = vmull_high_u8(q3, v3);
376 sum3 = vpadalq_u16(sum3, prod3_lo);
377 sum3 = vpadalq_u16(sum3, prod3_hi);
378
379 i += 64;
380 }
381
382 while i + 16 <= query.len() {
383 let q = vld1q_u8(query.as_ptr().add(i));
384 let v = vld1q_u8(vec.as_ptr().add(i));
385 let prod_lo = vmull_u8(vget_low_u8(q), vget_low_u8(v));
386 let prod_hi = vmull_high_u8(q, v);
387 sum0 = vpadalq_u16(sum0, prod_lo);
388 sum0 = vpadalq_u16(sum0, prod_hi);
389 i += 16;
390 }
391
392 let sum01 = vaddq_u32(sum0, sum1);
393 let sum23 = vaddq_u32(sum2, sum3);
394 let sum_all = vaddq_u32(sum01, sum23);
395 let mut result = vaddvq_u32(sum_all);
396
397 for j in i..query.len() {
398 result += query[j] as u32 * vec[j] as u32;
399 }
400
401 result
402 }
403}
404
405#[inline]
407#[must_use]
408pub fn symmetric_l2_squared_u8(a: &[u8], b: &[u8]) -> u32 {
409 a.iter()
410 .zip(b.iter())
411 .map(|(&x, &y)| {
412 let diff = (i16::from(x) - i16::from(y)) as i32;
413 (diff * diff) as u32
414 })
415 .sum()
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_train_and_quantize() {
424 let vectors: Vec<Vec<f32>> = vec![
425 vec![0.0, 0.5, 1.0, 0.3],
426 vec![0.1, 0.6, 0.9, 0.4],
427 vec![0.2, 0.4, 0.8, 0.5],
428 ];
429 let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
430
431 let params = ScalarParams::train(&refs).unwrap();
432
433 let quantized = params.quantize(&vectors[0]);
434 assert_eq!(quantized.data.len(), 4);
435 assert!(quantized.sum > 0);
436 assert!(quantized.norm_sq > 0.0);
437 }
438
439 #[test]
440 fn test_distance_accuracy() {
441 use rand::Rng;
442
443 let dim = 128;
444 let n_vectors = 100;
445 let mut rng = rand::thread_rng();
446
447 let vectors: Vec<Vec<f32>> = (0..n_vectors)
449 .map(|_| {
450 let v: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
451 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
452 v.iter().map(|x| x / norm).collect()
453 })
454 .collect();
455
456 let refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
457 let params = ScalarParams::train(&refs).unwrap();
458
459 let quantized: Vec<_> = vectors.iter().map(|v| params.quantize(v)).collect();
460
461 let query = &vectors[0];
462 let query_prep = params.prepare_query(query);
463
464 let mut max_rel_error = 0.0f32;
465
466 for (i, (orig, quant)) in vectors.iter().zip(quantized.iter()).enumerate() {
467 if i == 0 {
468 continue;
469 }
470
471 let true_dist: f32 = query
472 .iter()
473 .zip(orig.iter())
474 .map(|(a, b)| (a - b).powi(2))
475 .sum();
476
477 let quant_dist = params.distance_l2_squared(&query_prep, quant);
478
479 let rel_error = (true_dist - quant_dist).abs() / true_dist.max(1e-6);
480 max_rel_error = max_rel_error.max(rel_error);
481 }
482
483 println!(
484 "SQ8 max relative distance error: {:.2}%",
485 max_rel_error * 100.0
486 );
487 assert!(
488 max_rel_error < 0.15,
489 "Distance error too large: {max_rel_error:.4}"
490 );
491 }
492
493 #[test]
494 fn test_int_dot_product() {
495 let vectors: Vec<Vec<f32>> = vec![vec![0.5; 768], vec![0.3; 768]];
496 let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
497
498 let params = ScalarParams::train(&refs).unwrap();
499 let query_prep = params.prepare_query(&vectors[0]);
500 let quantized = params.quantize(&vectors[1]);
501
502 let dist = params.distance_l2_squared(&query_prep, &quantized);
503 assert!(dist >= 0.0);
504 assert!(!dist.is_nan());
505 }
506
507 #[test]
508 fn test_dequantize_roundtrip() {
509 let vectors: Vec<Vec<f32>> = vec![
510 vec![0.0, 0.5, 1.0],
511 vec![0.1, 0.6, 0.9],
512 vec![0.2, 0.4, 0.8],
513 ];
514 let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
515
516 let params = ScalarParams::train(&refs).unwrap();
517 let quantized = params.quantize(&vectors[0]);
518 let dequantized = params.dequantize(&quantized.data);
519
520 for (orig, deq) in vectors[0].iter().zip(dequantized.iter()) {
521 assert!(
522 (orig - deq).abs() < 0.05,
523 "Roundtrip error too large: {} vs {}",
524 orig,
525 deq
526 );
527 }
528 }
529
530 #[test]
531 fn test_symmetric_distance() {
532 let a: Vec<u8> = vec![0, 100, 200, 255];
533 let b: Vec<u8> = vec![0, 100, 200, 255];
534 let dist = symmetric_l2_squared_u8(&a, &b);
535 assert_eq!(dist, 0);
536
537 let c: Vec<u8> = vec![10, 110, 210, 245];
538 let dist2 = symmetric_l2_squared_u8(&a, &c);
539 assert!(dist2 > 0);
540 }
541
542 #[test]
543 fn test_compression_ratio() {
544 let dims = 768;
545 let original_size = dims * 4; let quantized_size = dims; assert_eq!(original_size / quantized_size, 4);
549 }
550}