1use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15use common::DistanceMetric;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
19pub enum QuantizationType {
20 SQ4,
22 #[default]
24 SQ8,
25 SQ16,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct SQConfig {
32 pub quantization_type: QuantizationType,
34 pub dimensions: usize,
36 pub metric: DistanceMetric,
38 pub store_originals: bool,
40}
41
42impl Default for SQConfig {
43 fn default() -> Self {
44 Self {
45 quantization_type: QuantizationType::SQ8,
46 dimensions: 0,
47 metric: DistanceMetric::Cosine,
48 store_originals: false,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct SQStats {
56 pub num_vectors: usize,
58 pub original_memory_bytes: usize,
60 pub quantized_memory_bytes: usize,
62 pub compression_ratio: f32,
64 pub quantization_type: QuantizationType,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70struct DimensionParams {
71 min_val: f32,
72 max_val: f32,
73 scale: f32,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct SQIndex {
79 config: SQConfig,
80 dimension_params: Vec<DimensionParams>,
82 quantized_vectors: Vec<Vec<u8>>,
84 ids: Vec<String>,
86 original_vectors: Option<Vec<Vec<f32>>>,
88 id_to_index: HashMap<String, usize>,
90 trained: bool,
92}
93
94#[derive(Debug, Clone)]
96pub struct SQSearchResult {
97 pub id: String,
98 pub score: f32,
99 pub quantized_score: f32,
100}
101
102impl SQIndex {
103 pub fn new(config: SQConfig) -> Self {
105 Self {
106 dimension_params: Vec::new(),
107 quantized_vectors: Vec::new(),
108 ids: Vec::new(),
109 original_vectors: if config.store_originals {
110 Some(Vec::new())
111 } else {
112 None
113 },
114 id_to_index: HashMap::new(),
115 trained: false,
116 config,
117 }
118 }
119
120 pub fn train(&mut self, vectors: &[Vec<f32>]) -> Result<(), String> {
122 if vectors.is_empty() {
123 return Err("Cannot train on empty vector set".to_string());
124 }
125
126 let dimensions = vectors[0].len();
127 if self.config.dimensions == 0 {
128 self.config.dimensions = dimensions;
129 } else if self.config.dimensions != dimensions {
130 return Err(format!(
131 "Dimension mismatch: expected {}, got {}",
132 self.config.dimensions, dimensions
133 ));
134 }
135
136 let mut dimension_params = Vec::with_capacity(dimensions);
138
139 for dim in 0..dimensions {
140 let mut min_val = f32::MAX;
141 let mut max_val = f32::MIN;
142
143 for vector in vectors {
144 let val = vector[dim];
145 min_val = min_val.min(val);
146 max_val = max_val.max(val);
147 }
148
149 let range = (max_val - min_val).max(1e-10);
151 let scale = self.get_max_quantized_value() / range;
152
153 dimension_params.push(DimensionParams {
154 min_val,
155 max_val,
156 scale,
157 });
158 }
159
160 self.dimension_params = dimension_params;
161 self.trained = true;
162 Ok(())
163 }
164
165 fn get_max_quantized_value(&self) -> f32 {
167 match self.config.quantization_type {
168 QuantizationType::SQ4 => 15.0,
169 QuantizationType::SQ8 => 255.0,
170 QuantizationType::SQ16 => 65535.0,
171 }
172 }
173
174 fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
176 match self.config.quantization_type {
177 QuantizationType::SQ8 => self.quantize_sq8(vector),
178 QuantizationType::SQ4 => self.quantize_sq4(vector),
179 QuantizationType::SQ16 => self.quantize_sq16(vector),
180 }
181 }
182
183 fn quantize_sq8(&self, vector: &[f32]) -> Vec<u8> {
185 vector
186 .iter()
187 .enumerate()
188 .map(|(i, &val)| {
189 let params = &self.dimension_params[i];
190 let normalized = (val - params.min_val) * params.scale;
191 normalized.clamp(0.0, 255.0) as u8
192 })
193 .collect()
194 }
195
196 fn quantize_sq4(&self, vector: &[f32]) -> Vec<u8> {
198 let mut result = Vec::with_capacity(vector.len().div_ceil(2));
199
200 for chunk in vector.chunks(2) {
201 let low = {
202 let params = &self.dimension_params[result.len() * 2];
203 let normalized = (chunk[0] - params.min_val) * params.scale;
204 (normalized.clamp(0.0, 15.0) as u8) & 0x0F
205 };
206
207 let high = if chunk.len() > 1 {
208 let params = &self.dimension_params[result.len() * 2 + 1];
209 let normalized = (chunk[1] - params.min_val) * params.scale;
210 ((normalized.clamp(0.0, 15.0) as u8) & 0x0F) << 4
211 } else {
212 0
213 };
214
215 result.push(low | high);
216 }
217
218 result
219 }
220
221 fn quantize_sq16(&self, vector: &[f32]) -> Vec<u8> {
223 let mut result = Vec::with_capacity(vector.len() * 2);
224
225 for (i, &val) in vector.iter().enumerate() {
226 let params = &self.dimension_params[i];
227 let normalized = (val - params.min_val) * params.scale;
228 let quantized = normalized.clamp(0.0, 65535.0) as u16;
229 result.extend_from_slice(&quantized.to_le_bytes());
230 }
231
232 result
233 }
234
235 pub fn dequantize_vector(&self, quantized: &[u8]) -> Vec<f32> {
237 match self.config.quantization_type {
238 QuantizationType::SQ8 => self.dequantize_sq8(quantized),
239 QuantizationType::SQ4 => self.dequantize_sq4(quantized),
240 QuantizationType::SQ16 => self.dequantize_sq16(quantized),
241 }
242 }
243
244 fn dequantize_sq8(&self, quantized: &[u8]) -> Vec<f32> {
245 quantized
246 .iter()
247 .enumerate()
248 .map(|(i, &val)| {
249 let params = &self.dimension_params[i];
250 params.min_val + (val as f32 / params.scale)
251 })
252 .collect()
253 }
254
255 fn dequantize_sq4(&self, quantized: &[u8]) -> Vec<f32> {
256 let mut result = Vec::with_capacity(self.config.dimensions);
257
258 for (byte_idx, &byte) in quantized.iter().enumerate() {
259 let dim_idx = byte_idx * 2;
260 if dim_idx < self.config.dimensions {
261 let low = byte & 0x0F;
262 let params = &self.dimension_params[dim_idx];
263 result.push(params.min_val + (low as f32 / params.scale));
264 }
265
266 if dim_idx + 1 < self.config.dimensions {
267 let high = (byte >> 4) & 0x0F;
268 let params = &self.dimension_params[dim_idx + 1];
269 result.push(params.min_val + (high as f32 / params.scale));
270 }
271 }
272
273 result
274 }
275
276 fn dequantize_sq16(&self, quantized: &[u8]) -> Vec<f32> {
277 quantized
278 .chunks(2)
279 .enumerate()
280 .map(|(i, bytes)| {
281 let val = u16::from_le_bytes([bytes[0], bytes[1]]);
282 let params = &self.dimension_params[i];
283 params.min_val + (val as f32 / params.scale)
284 })
285 .collect()
286 }
287
288 pub fn add(&mut self, ids: &[String], vectors: &[Vec<f32>]) -> Result<(), String> {
290 if !self.trained {
291 self.train(vectors)?;
293 }
294
295 for (id, vector) in ids.iter().zip(vectors.iter()) {
296 if vector.len() != self.config.dimensions {
297 return Err(format!(
298 "Dimension mismatch for {}: expected {}, got {}",
299 id,
300 self.config.dimensions,
301 vector.len()
302 ));
303 }
304
305 if let Some(&existing_idx) = self.id_to_index.get(id) {
307 self.quantized_vectors[existing_idx] = self.quantize_vector(vector);
309 if let Some(ref mut originals) = self.original_vectors {
310 originals[existing_idx] = vector.clone();
311 }
312 } else {
313 let idx = self.quantized_vectors.len();
315 self.quantized_vectors.push(self.quantize_vector(vector));
316 self.ids.push(id.clone());
317 self.id_to_index.insert(id.clone(), idx);
318
319 if let Some(ref mut originals) = self.original_vectors {
320 originals.push(vector.clone());
321 }
322 }
323 }
324
325 Ok(())
326 }
327
328 pub fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<SQSearchResult>, String> {
330 if !self.trained {
331 return Err("Index not trained".to_string());
332 }
333
334 if query.len() != self.config.dimensions {
335 return Err(format!(
336 "Query dimension mismatch: expected {}, got {}",
337 self.config.dimensions,
338 query.len()
339 ));
340 }
341
342 let quantized_query = self.quantize_vector(query);
344
345 let mut scores: Vec<(usize, f32)> = self
347 .quantized_vectors
348 .iter()
349 .enumerate()
350 .map(|(idx, qv)| {
351 let score = self.quantized_distance(&quantized_query, qv);
352 (idx, score)
353 })
354 .collect();
355
356 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
358
359 let results: Vec<SQSearchResult> = scores
361 .into_iter()
362 .take(top_k)
363 .map(|(idx, quantized_score)| {
364 let final_score = if let Some(ref originals) = self.original_vectors {
365 self.float_similarity(query, &originals[idx])
366 } else {
367 quantized_score
368 };
369
370 SQSearchResult {
371 id: self.ids[idx].clone(),
372 score: final_score,
373 quantized_score,
374 }
375 })
376 .collect();
377
378 Ok(results)
379 }
380
381 fn quantized_distance(&self, a: &[u8], b: &[u8]) -> f32 {
383 match self.config.quantization_type {
384 QuantizationType::SQ8 => self.sq8_distance(a, b),
385 QuantizationType::SQ4 => self.sq4_distance(a, b),
386 QuantizationType::SQ16 => self.sq16_distance(a, b),
387 }
388 }
389
390 fn sq8_distance(&self, a: &[u8], b: &[u8]) -> f32 {
392 match self.config.metric {
393 DistanceMetric::Cosine | DistanceMetric::DotProduct => {
394 let dot: i32 = a
396 .iter()
397 .zip(b.iter())
398 .map(|(&x, &y)| x as i32 * y as i32)
399 .sum();
400
401 let norm_a: i32 = a.iter().map(|&x| x as i32 * x as i32).sum();
403 let norm_b: i32 = b.iter().map(|&x| x as i32 * x as i32).sum();
404
405 let denom = ((norm_a as f32).sqrt() * (norm_b as f32).sqrt()).max(1e-10);
406 dot as f32 / denom
407 }
408 DistanceMetric::Euclidean => {
409 let dist_sq: i32 = a
411 .iter()
412 .zip(b.iter())
413 .map(|(&x, &y)| {
414 let diff = x as i32 - y as i32;
415 diff * diff
416 })
417 .sum();
418 -(dist_sq as f32).sqrt()
419 }
420 }
421 }
422
423 fn sq4_distance(&self, a: &[u8], b: &[u8]) -> f32 {
425 let a_unpacked = self.unpack_sq4(a);
427 let b_unpacked = self.unpack_sq4(b);
428 self.sq8_distance(&a_unpacked, &b_unpacked)
429 }
430
431 fn unpack_sq4(&self, packed: &[u8]) -> Vec<u8> {
432 let mut result = Vec::with_capacity(self.config.dimensions);
433 for &byte in packed {
434 result.push(byte & 0x0F);
435 if result.len() < self.config.dimensions {
436 result.push((byte >> 4) & 0x0F);
437 }
438 }
439 result
440 }
441
442 fn sq16_distance(&self, a: &[u8], b: &[u8]) -> f32 {
444 match self.config.metric {
445 DistanceMetric::Cosine | DistanceMetric::DotProduct => {
446 let mut dot: i64 = 0;
447 let mut norm_a: i64 = 0;
448 let mut norm_b: i64 = 0;
449
450 for i in (0..a.len()).step_by(2) {
451 let va = u16::from_le_bytes([a[i], a[i + 1]]) as i64;
452 let vb = u16::from_le_bytes([b[i], b[i + 1]]) as i64;
453 dot += va * vb;
454 norm_a += va * va;
455 norm_b += vb * vb;
456 }
457
458 let denom = ((norm_a as f64).sqrt() * (norm_b as f64).sqrt()).max(1e-10);
459 (dot as f64 / denom) as f32
460 }
461 DistanceMetric::Euclidean => {
462 let mut dist_sq: i64 = 0;
463 for i in (0..a.len()).step_by(2) {
464 let va = u16::from_le_bytes([a[i], a[i + 1]]) as i64;
465 let vb = u16::from_le_bytes([b[i], b[i + 1]]) as i64;
466 let diff = va - vb;
467 dist_sq += diff * diff;
468 }
469 -((dist_sq as f64).sqrt() as f32)
470 }
471 }
472 }
473
474 fn float_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
476 match self.config.metric {
477 DistanceMetric::Cosine => {
478 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
479 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
480 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
481 dot / (norm_a * norm_b).max(1e-10)
482 }
483 DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
484 DistanceMetric::Euclidean => {
485 let dist_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
486 -dist_sq.sqrt()
487 }
488 }
489 }
490
491 pub fn delete(&mut self, ids: &[String]) -> usize {
493 let mut deleted = 0;
494
495 for id in ids {
496 if let Some(idx) = self.id_to_index.remove(id) {
497 let last_idx = self.quantized_vectors.len() - 1;
500
501 if idx != last_idx {
502 self.quantized_vectors.swap(idx, last_idx);
504 self.ids.swap(idx, last_idx);
505 if let Some(ref mut originals) = self.original_vectors {
506 originals.swap(idx, last_idx);
507 }
508 self.id_to_index.insert(self.ids[idx].clone(), idx);
510 }
511
512 self.quantized_vectors.pop();
513 self.ids.pop();
514 if let Some(ref mut originals) = self.original_vectors {
515 originals.pop();
516 }
517
518 deleted += 1;
519 }
520 }
521
522 deleted
523 }
524
525 pub fn stats(&self) -> SQStats {
527 let bytes_per_quantized = match self.config.quantization_type {
528 QuantizationType::SQ4 => self.config.dimensions.div_ceil(2),
529 QuantizationType::SQ8 => self.config.dimensions,
530 QuantizationType::SQ16 => self.config.dimensions * 2,
531 };
532
533 let original_memory = self.quantized_vectors.len() * self.config.dimensions * 4;
534 let quantized_memory = self.quantized_vectors.len() * bytes_per_quantized;
535
536 SQStats {
537 num_vectors: self.quantized_vectors.len(),
538 original_memory_bytes: original_memory,
539 quantized_memory_bytes: quantized_memory,
540 compression_ratio: if quantized_memory > 0 {
541 original_memory as f32 / quantized_memory as f32
542 } else {
543 0.0
544 },
545 quantization_type: self.config.quantization_type,
546 }
547 }
548
549 pub fn len(&self) -> usize {
551 self.quantized_vectors.len()
552 }
553
554 pub fn is_empty(&self) -> bool {
556 self.quantized_vectors.is_empty()
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 fn create_test_vectors() -> Vec<Vec<f32>> {
565 vec![
566 vec![1.0, 0.0, 0.0, 0.0],
567 vec![0.0, 1.0, 0.0, 0.0],
568 vec![0.0, 0.0, 1.0, 0.0],
569 vec![0.5, 0.5, 0.0, 0.0],
570 vec![0.0, 0.5, 0.5, 0.0],
571 ]
572 }
573
574 #[test]
575 fn test_sq8_basic() {
576 let config = SQConfig {
577 quantization_type: QuantizationType::SQ8,
578 dimensions: 4,
579 metric: DistanceMetric::Cosine,
580 store_originals: false,
581 };
582
583 let mut index = SQIndex::new(config);
584 let vectors = create_test_vectors();
585 let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
586
587 index.add(&ids, &vectors).unwrap();
588
589 assert_eq!(index.len(), 5);
590
591 let results = index.search(&vectors[0], 3).unwrap();
593 assert_eq!(results.len(), 3);
594 assert_eq!(results[0].id, "v0"); }
596
597 #[test]
598 fn test_sq4_compression() {
599 let config = SQConfig {
600 quantization_type: QuantizationType::SQ4,
601 dimensions: 8,
602 metric: DistanceMetric::Cosine,
603 store_originals: false,
604 };
605
606 let mut index = SQIndex::new(config);
607 let vectors = vec![
608 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
609 vec![0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
610 ];
611 let ids = vec!["a".to_string(), "b".to_string()];
612
613 index.add(&ids, &vectors).unwrap();
614
615 let stats = index.stats();
616 assert!(stats.compression_ratio > 6.0);
618 }
619
620 #[test]
621 fn test_sq16_accuracy() {
622 let config = SQConfig {
623 quantization_type: QuantizationType::SQ16,
624 dimensions: 4,
625 metric: DistanceMetric::Cosine,
626 store_originals: true,
627 };
628
629 let mut index = SQIndex::new(config);
630 let vectors = create_test_vectors();
631 let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
632
633 index.add(&ids, &vectors).unwrap();
634
635 let results = index.search(&vectors[0], 2).unwrap();
637 assert!(results[0].score > 0.99); }
639
640 #[test]
641 fn test_delete() {
642 let config = SQConfig {
643 quantization_type: QuantizationType::SQ8,
644 dimensions: 4,
645 metric: DistanceMetric::Cosine,
646 store_originals: false,
647 };
648
649 let mut index = SQIndex::new(config);
650 let vectors = create_test_vectors();
651 let ids: Vec<String> = (0..vectors.len()).map(|i| format!("v{}", i)).collect();
652
653 index.add(&ids, &vectors).unwrap();
654 assert_eq!(index.len(), 5);
655
656 let deleted = index.delete(&["v0".to_string(), "v2".to_string()]);
657 assert_eq!(deleted, 2);
658 assert_eq!(index.len(), 3);
659 }
660
661 #[test]
662 fn test_dequantize_roundtrip() {
663 let config = SQConfig {
664 quantization_type: QuantizationType::SQ8,
665 dimensions: 4,
666 metric: DistanceMetric::Cosine,
667 store_originals: false,
668 };
669
670 let mut index = SQIndex::new(config);
671 let vectors = vec![vec![0.1, 0.5, 0.3, 0.9]];
672 let _ids = vec!["test".to_string()];
673
674 index.train(&vectors).unwrap();
675 let quantized = index.quantize_vector(&vectors[0]);
676 let dequantized = index.dequantize_vector(&quantized);
677
678 for (orig, deq) in vectors[0].iter().zip(dequantized.iter()) {
680 assert!((orig - deq).abs() < 0.05, "Dequantized value too different");
681 }
682 }
683
684 #[test]
685 fn test_update_existing() {
686 let config = SQConfig {
687 quantization_type: QuantizationType::SQ8,
688 dimensions: 4,
689 metric: DistanceMetric::Cosine,
690 store_originals: false,
691 };
692
693 let mut index = SQIndex::new(config);
694 let vectors = vec![vec![1.0, 0.0, 0.0, 0.0]];
695 let ids = vec!["v1".to_string()];
696
697 index.add(&ids, &vectors).unwrap();
698 assert_eq!(index.len(), 1);
699
700 let new_vectors = vec![vec![0.0, 1.0, 0.0, 0.0]];
702 index.add(&ids, &new_vectors).unwrap();
703 assert_eq!(index.len(), 1); let results = index.search(&[0.0, 1.0, 0.0, 0.0], 1).unwrap();
707 assert_eq!(results[0].id, "v1");
708 }
709}