1use anyhow::{anyhow, Result};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SqConfig {
19 pub bits: u8,
21 pub mode: QuantizationMode,
23 pub normalize: bool,
25 pub training_samples: usize,
27}
28
29impl Default for SqConfig {
30 fn default() -> Self {
31 Self {
32 bits: 8,
33 mode: QuantizationMode::Uniform,
34 normalize: false,
35 training_samples: 10_000,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
42pub enum QuantizationMode {
43 Uniform,
45 PerDimension,
47 MeanStd,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct QuantizationParams {
54 pub min: f32,
56 pub max: f32,
58 pub scale: f32,
60 pub offset: f32,
62}
63
64impl QuantizationParams {
65 pub fn from_range(min: f32, max: f32, bits: u8) -> Self {
67 let levels = (1 << bits) - 1;
68 let range = max - min;
69 let scale = if range > 1e-8 {
70 levels as f32 / range
71 } else {
72 1.0
73 };
74
75 Self {
76 min,
77 max,
78 scale,
79 offset: min,
80 }
81 }
82
83 pub fn from_mean_std(mean: f32, std: f32, bits: u8) -> Self {
85 let min = mean - 3.0 * std;
86 let max = mean + 3.0 * std;
87 Self::from_range(min, max, bits)
88 }
89
90 pub fn quantize(&self, value: f32) -> u8 {
92 let normalized = (value - self.offset) * self.scale;
93 normalized.clamp(0.0, 255.0) as u8
94 }
95
96 pub fn dequantize(&self, quantized: u8) -> f32 {
98 (quantized as f32 / self.scale) + self.offset
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SqStats {
105 pub vector_count: usize,
107 pub dimensions: usize,
109 pub bits: u8,
111 pub compression_ratio: f32,
113 pub memory_bytes: usize,
115 pub avg_quantization_error: f32,
117}
118
119pub struct SqIndex {
121 config: SqConfig,
122 dimensions: usize,
123 quantization_params: Vec<QuantizationParams>,
124 quantized_vectors: Vec<Vec<u8>>,
125 uri_to_id: HashMap<String, usize>,
126 id_to_uri: Vec<String>,
127}
128
129impl SqIndex {
130 pub fn new(config: SqConfig, dimensions: usize) -> Self {
132 Self {
133 config,
134 dimensions,
135 quantization_params: Vec::new(),
136 quantized_vectors: Vec::new(),
137 uri_to_id: HashMap::new(),
138 id_to_uri: Vec::new(),
139 }
140 }
141
142 pub fn train(&mut self, training_vectors: &[Vec<f32>]) -> Result<()> {
144 if training_vectors.is_empty() {
145 return Err(anyhow!("No training vectors provided"));
146 }
147
148 let dim = training_vectors[0].len();
149 if dim != self.dimensions {
150 return Err(anyhow!(
151 "Training vector dimensions ({}) don't match index dimensions ({})",
152 dim,
153 self.dimensions
154 ));
155 }
156
157 let sample_count = training_vectors.len().min(self.config.training_samples);
159 let samples = &training_vectors[..sample_count];
160
161 match self.config.mode {
162 QuantizationMode::Uniform => {
163 self.train_uniform(samples)?;
164 }
165 QuantizationMode::PerDimension => {
166 self.train_per_dimension(samples)?;
167 }
168 QuantizationMode::MeanStd => {
169 self.train_mean_std(samples)?;
170 }
171 }
172
173 tracing::info!(
174 "Trained SQ index: mode={:?}, bits={}, samples={}, dimensions={}",
175 self.config.mode,
176 self.config.bits,
177 sample_count,
178 self.dimensions
179 );
180
181 Ok(())
182 }
183
184 fn train_uniform(&mut self, samples: &[Vec<f32>]) -> Result<()> {
186 let mut global_min = f32::INFINITY;
187 let mut global_max = f32::NEG_INFINITY;
188
189 for vector in samples {
190 for &value in vector {
191 global_min = global_min.min(value);
192 global_max = global_max.max(value);
193 }
194 }
195
196 let params = QuantizationParams::from_range(global_min, global_max, self.config.bits);
197 self.quantization_params = vec![params; self.dimensions];
198
199 Ok(())
200 }
201
202 fn train_per_dimension(&mut self, samples: &[Vec<f32>]) -> Result<()> {
204 let mut dim_mins = vec![f32::INFINITY; self.dimensions];
205 let mut dim_maxs = vec![f32::NEG_INFINITY; self.dimensions];
206
207 for vector in samples {
208 for (d, &value) in vector.iter().enumerate() {
209 dim_mins[d] = dim_mins[d].min(value);
210 dim_maxs[d] = dim_maxs[d].max(value);
211 }
212 }
213
214 self.quantization_params = dim_mins
215 .into_iter()
216 .zip(dim_maxs)
217 .map(|(min, max)| QuantizationParams::from_range(min, max, self.config.bits))
218 .collect();
219
220 Ok(())
221 }
222
223 fn train_mean_std(&mut self, samples: &[Vec<f32>]) -> Result<()> {
225 let n = samples.len() as f32;
226 let mut dim_means = vec![0.0; self.dimensions];
227 let mut dim_stds = vec![0.0; self.dimensions];
228
229 for vector in samples {
231 for (d, &value) in vector.iter().enumerate() {
232 dim_means[d] += value;
233 }
234 }
235 for mean in &mut dim_means {
236 *mean /= n;
237 }
238
239 for vector in samples {
241 for (d, &value) in vector.iter().enumerate() {
242 let diff = value - dim_means[d];
243 dim_stds[d] += diff * diff;
244 }
245 }
246 for std in &mut dim_stds {
247 *std = (*std / n).sqrt();
248 }
249
250 self.quantization_params = dim_means
251 .into_iter()
252 .zip(dim_stds)
253 .map(|(mean, std)| QuantizationParams::from_mean_std(mean, std, self.config.bits))
254 .collect();
255
256 Ok(())
257 }
258
259 pub fn add(&mut self, uri: String, vector: Vec<f32>) -> Result<()> {
261 if vector.len() != self.dimensions {
262 return Err(anyhow!(
263 "Vector dimensions ({}) don't match index dimensions ({})",
264 vector.len(),
265 self.dimensions
266 ));
267 }
268
269 if self.quantization_params.is_empty() {
270 return Err(anyhow!(
271 "Index not trained. Call train() before adding vectors."
272 ));
273 }
274
275 let quantized = self.quantize_vector(&vector);
276 let id = self.quantized_vectors.len();
277
278 self.uri_to_id.insert(uri.clone(), id);
279 self.id_to_uri.push(uri);
280 self.quantized_vectors.push(quantized);
281
282 Ok(())
283 }
284
285 fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
287 vector
288 .iter()
289 .zip(&self.quantization_params)
290 .map(|(&value, params)| params.quantize(value))
291 .collect()
292 }
293
294 fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
296 quantized
297 .iter()
298 .zip(&self.quantization_params)
299 .map(|(&q, params)| params.dequantize(q))
300 .collect()
301 }
302
303 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
305 if query.len() != self.dimensions {
306 return Err(anyhow!(
307 "Query dimensions ({}) don't match index dimensions ({})",
308 query.len(),
309 self.dimensions
310 ));
311 }
312
313 if self.quantized_vectors.is_empty() {
314 return Ok(Vec::new());
315 }
316
317 let query_quantized = self.quantize_vector(query);
319
320 let mut distances: Vec<(usize, f32)> = self
322 .quantized_vectors
323 .iter()
324 .enumerate()
325 .map(|(id, vec)| {
326 let dist = self.asymmetric_distance(&query_quantized, vec);
327 (id, dist)
328 })
329 .collect();
330
331 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
333
334 Ok(distances
336 .into_iter()
337 .take(k)
338 .map(|(id, dist)| (self.id_to_uri[id].clone(), dist))
339 .collect())
340 }
341
342 fn asymmetric_distance(&self, query_quantized: &[u8], db_quantized: &[u8]) -> f32 {
345 query_quantized
346 .iter()
347 .zip(db_quantized)
348 .zip(&self.quantization_params)
349 .map(|((&q1, &q2), params)| {
350 let v1 = params.dequantize(q1);
351 let v2 = params.dequantize(q2);
352 let diff = v1 - v2;
353 diff * diff
354 })
355 .sum::<f32>()
356 .sqrt()
357 }
358
359 pub fn stats(&self) -> SqStats {
361 let vector_count = self.quantized_vectors.len();
362 let bits_per_vector = self.dimensions * self.config.bits as usize;
363 let bytes_per_vector = (bits_per_vector + 7) / 8;
364 let memory_bytes = vector_count * bytes_per_vector;
365
366 let original_bytes = vector_count * self.dimensions * 4; let compression_ratio = if memory_bytes > 0 {
368 original_bytes as f32 / memory_bytes as f32
369 } else {
370 0.0
371 };
372
373 SqStats {
374 vector_count,
375 dimensions: self.dimensions,
376 bits: self.config.bits,
377 compression_ratio,
378 memory_bytes,
379 avg_quantization_error: self.estimate_quantization_error(),
380 }
381 }
382
383 fn estimate_quantization_error(&self) -> f32 {
385 if self.quantized_vectors.is_empty() {
386 return 0.0;
387 }
388
389 let sample_size = self.quantized_vectors.len().min(100);
390 let mut total_error = 0.0;
391
392 for quantized in self.quantized_vectors.iter().take(sample_size) {
393 let dequantized = self.dequantize_vector(quantized);
394 let reconstructed_quantized = self.quantize_vector(&dequantized);
395
396 let error: f32 = quantized
398 .iter()
399 .zip(&reconstructed_quantized)
400 .map(|(&a, &b)| (a as f32 - b as f32).abs())
401 .sum();
402
403 total_error += error / self.dimensions as f32;
404 }
405
406 total_error / sample_size as f32
407 }
408
409 pub fn get(&self, uri: &str) -> Option<Vec<f32>> {
411 self.uri_to_id
412 .get(uri)
413 .and_then(|&id| self.quantized_vectors.get(id))
414 .map(|q| self.dequantize_vector(q))
415 }
416
417 pub fn len(&self) -> usize {
419 self.quantized_vectors.len()
420 }
421
422 pub fn is_empty(&self) -> bool {
424 self.quantized_vectors.is_empty()
425 }
426
427 pub fn config(&self) -> &SqConfig {
429 &self.config
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_quantization_params() {
439 let params = QuantizationParams::from_range(0.0, 1.0, 8);
440 assert_eq!(params.quantize(0.0), 0);
441 assert_eq!(params.quantize(1.0), 255);
442 assert_eq!(params.quantize(0.5), 127);
443
444 let dequantized = params.dequantize(127);
445 assert!((dequantized - 0.5).abs() < 0.01);
446 }
447
448 #[test]
449 fn test_sq_index_creation() {
450 let config = SqConfig::default();
451 let index = SqIndex::new(config, 128);
452 assert_eq!(index.dimensions, 128);
453 assert!(index.is_empty());
454 }
455
456 #[test]
457 fn test_sq_training() {
458 let config = SqConfig {
459 bits: 8,
460 mode: QuantizationMode::PerDimension,
461 ..Default::default()
462 };
463
464 let mut index = SqIndex::new(config, 4);
465
466 let training_data = vec![
467 vec![0.0, 1.0, 2.0, 3.0],
468 vec![1.0, 2.0, 3.0, 4.0],
469 vec![2.0, 3.0, 4.0, 5.0],
470 ];
471
472 assert!(index.train(&training_data).is_ok());
473 assert_eq!(index.quantization_params.len(), 4);
474 }
475
476 #[test]
477 fn test_sq_add_and_search() {
478 let config = SqConfig::default();
479 let mut index = SqIndex::new(config, 4);
480
481 let training_data = vec![
482 vec![0.0, 0.0, 0.0, 0.0],
483 vec![1.0, 1.0, 1.0, 1.0],
484 vec![2.0, 2.0, 2.0, 2.0],
485 ];
486
487 index.train(&training_data).unwrap();
488
489 index
490 .add("vec1".to_string(), vec![0.1, 0.1, 0.1, 0.1])
491 .unwrap();
492 index
493 .add("vec2".to_string(), vec![0.9, 0.9, 0.9, 0.9])
494 .unwrap();
495 index
496 .add("vec3".to_string(), vec![1.8, 1.8, 1.8, 1.8])
497 .unwrap();
498
499 let query = vec![0.0, 0.0, 0.0, 0.0];
500 let results = index.search(&query, 2).unwrap();
501
502 assert_eq!(results.len(), 2);
503 assert_eq!(results[0].0, "vec1");
504 }
505
506 #[test]
507 fn test_sq_stats() {
508 let config = SqConfig {
509 bits: 4,
510 ..Default::default()
511 };
512 let mut index = SqIndex::new(config, 128);
513
514 let training_data: Vec<Vec<f32>> =
515 (0..100).map(|_| (0..128).map(|_| 0.5).collect()).collect();
516
517 index.train(&training_data).unwrap();
518
519 for i in 0..10 {
520 index.add(format!("vec{}", i), vec![0.5; 128]).unwrap();
521 }
522
523 let stats = index.stats();
524 assert_eq!(stats.vector_count, 10);
525 assert_eq!(stats.dimensions, 128);
526 assert_eq!(stats.bits, 4);
527 assert!(stats.compression_ratio > 1.0);
528 }
529
530 #[test]
531 fn test_different_quantization_modes() {
532 let dimensions = 4;
533 let training_data = vec![
534 vec![0.0, 1.0, 2.0, 3.0],
535 vec![1.0, 2.0, 3.0, 4.0],
536 vec![2.0, 3.0, 4.0, 5.0],
537 ];
538
539 let mut index_uniform = SqIndex::new(
541 SqConfig {
542 mode: QuantizationMode::Uniform,
543 ..Default::default()
544 },
545 dimensions,
546 );
547 assert!(index_uniform.train(&training_data).is_ok());
548
549 let mut index_per_dim = SqIndex::new(
551 SqConfig {
552 mode: QuantizationMode::PerDimension,
553 ..Default::default()
554 },
555 dimensions,
556 );
557 assert!(index_per_dim.train(&training_data).is_ok());
558
559 let mut index_mean_std = SqIndex::new(
561 SqConfig {
562 mode: QuantizationMode::MeanStd,
563 ..Default::default()
564 },
565 dimensions,
566 );
567 assert!(index_mean_std.train(&training_data).is_ok());
568 }
569}