1use crate::entry::{Entry, EntryType, Key, Value};
43use crate::error::{QdbError, Result};
44use num_complex::Complex64;
45use serde::{Deserialize, Serialize};
46use std::cmp::Ordering;
47use std::collections::{BinaryHeap, HashMap, HashSet};
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct Embedding {
56 pub data: Vec<f32>,
58 pub dim: usize,
60 #[serde(skip)]
62 norm: Option<f32>,
63}
64
65impl Embedding {
66 pub fn new(data: Vec<f32>) -> Self {
68 let dim = data.len();
69 Self {
70 data,
71 dim,
72 norm: None,
73 }
74 }
75
76 pub fn zeros(dim: usize) -> Self {
78 Self::new(vec![0.0; dim])
79 }
80
81 pub fn random(dim: usize) -> Self {
83 use std::collections::hash_map::DefaultHasher;
84 use std::hash::{Hash, Hasher};
85 use std::time::{SystemTime, UNIX_EPOCH};
86
87 let seed = SystemTime::now()
88 .duration_since(UNIX_EPOCH)
89 .unwrap_or_default()
90 .as_nanos() as u64;
91
92 let mut hasher = DefaultHasher::new();
93 let data: Vec<f32> = (0..dim)
94 .map(|i| {
95 (seed + i as u64).hash(&mut hasher);
96 let h = hasher.finish();
97 ((h as f64) / (u64::MAX as f64) * 2.0 - 1.0) as f32
98 })
99 .collect();
100
101 Self::new(data)
102 }
103
104 pub fn norm(&mut self) -> f32 {
106 if let Some(n) = self.norm {
107 return n;
108 }
109 let n = self.data.iter().map(|x| x * x).sum::<f32>().sqrt();
110 self.norm = Some(n);
111 n
112 }
113
114 pub fn normalize(&mut self) {
116 let n = self.norm();
117 if n > 1e-10 {
118 for x in &mut self.data {
119 *x /= n;
120 }
121 self.norm = Some(1.0);
122 }
123 }
124
125 pub fn normalized(&self) -> Self {
127 let mut copy = self.clone();
128 copy.normalize();
129 copy
130 }
131
132 pub fn dot(&self, other: &Embedding) -> f32 {
134 self.data
135 .iter()
136 .zip(other.data.iter())
137 .map(|(a, b)| a * b)
138 .sum()
139 }
140
141 pub fn add(&self, other: &Embedding) -> Self {
143 let data: Vec<f32> = self
144 .data
145 .iter()
146 .zip(other.data.iter())
147 .map(|(a, b)| a + b)
148 .collect();
149 Self::new(data)
150 }
151
152 pub fn scale(&self, s: f32) -> Self {
154 let data: Vec<f32> = self.data.iter().map(|x| x * s).collect();
155 Self::new(data)
156 }
157}
158
159impl From<Vec<f32>> for Embedding {
160 fn from(data: Vec<f32>) -> Self {
161 Self::new(data)
162 }
163}
164
165impl From<&[f32]> for Embedding {
166 fn from(data: &[f32]) -> Self {
167 Self::new(data.to_vec())
168 }
169}
170
171impl Embedding {
176 pub fn to_amplitude_encoding(&self) -> Self {
179 self.normalized()
180 }
181
182 pub fn from_complex_amplitudes(amplitudes: &[Complex64]) -> Self {
185 let data: Vec<f32> = amplitudes.iter().map(|c| c.re as f32).collect();
186 Self::new(data)
187 }
188
189 pub fn from_probabilities(probs: &[f64]) -> Self {
192 let data: Vec<f32> = probs.iter().map(|p| p.sqrt() as f32).collect();
193 Self::new(data)
194 }
195
196 pub fn fidelity(&self, other: &Embedding) -> f32 {
200 let inner_product: f32 = self
201 .data
202 .iter()
203 .zip(other.data.iter())
204 .map(|(a, b)| a * b)
205 .sum();
206 inner_product.powi(2)
207 }
208
209 pub fn trace_distance(&self, other: &Embedding) -> f32 {
212 (1.0 - self.fidelity(other)).sqrt()
214 }
215
216 pub fn quantize_binary(&self) -> Self {
218 let data: Vec<f32> = self.data.iter().map(|x| if *x > 0.0 { 1.0 } else { -1.0 }).collect();
219 Self::new(data)
220 }
221
222 pub fn product_quantize(&self, num_subvectors: usize) -> Vec<Vec<f32>> {
225 let subvec_dim = self.dim / num_subvectors;
226 (0..num_subvectors)
227 .map(|i| {
228 let start = i * subvec_dim;
229 let end = ((i + 1) * subvec_dim).min(self.dim);
230 self.data[start..end].to_vec()
231 })
232 .collect()
233 }
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct VectorEntry {
239 pub id: String,
241 pub embedding: Embedding,
243 pub metadata: HashMap<String, Value>,
245 pub content: Option<String>,
247 pub namespace: Option<String>,
249}
250
251impl VectorEntry {
252 pub fn new(id: impl Into<String>, embedding: Embedding) -> Self {
254 Self {
255 id: id.into(),
256 embedding,
257 metadata: HashMap::new(),
258 content: None,
259 namespace: None,
260 }
261 }
262
263 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
265 self.metadata.insert(key.into(), value.into());
266 self
267 }
268
269 pub fn with_content(mut self, content: impl Into<String>) -> Self {
271 self.content = Some(content.into());
272 self
273 }
274
275 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
277 self.namespace = Some(namespace.into());
278 self
279 }
280
281 pub fn to_entry(&self) -> Entry {
283 let key = Key::string(&self.id);
284 let value = Value::Json(serde_json::to_value(self).unwrap_or_default());
285 let mut entry = Entry::new(key, value);
286 entry.entry_type = EntryType::Document;
287 entry.metadata.insert("type".to_string(), "vector".to_string());
288 entry
289 .metadata
290 .insert("dim".to_string(), self.embedding.dim.to_string());
291 entry
292 }
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
301pub enum DistanceMetric {
302 #[default]
304 Cosine,
305 Euclidean,
307 DotProduct,
309 Manhattan,
311 Fidelity,
314 Hamming,
316}
317
318impl DistanceMetric {
319 pub fn distance(&self, a: &Embedding, b: &Embedding) -> f32 {
321 match self {
322 DistanceMetric::Cosine => {
323 let dot = a.dot(b);
324 let norm_a = a.data.iter().map(|x| x * x).sum::<f32>().sqrt();
325 let norm_b = b.data.iter().map(|x| x * x).sum::<f32>().sqrt();
326 if norm_a < 1e-10 || norm_b < 1e-10 {
327 1.0
328 } else {
329 1.0 - (dot / (norm_a * norm_b))
330 }
331 }
332 DistanceMetric::Euclidean => a
333 .data
334 .iter()
335 .zip(b.data.iter())
336 .map(|(x, y)| (x - y).powi(2))
337 .sum::<f32>()
338 .sqrt(),
339 DistanceMetric::DotProduct => -a.dot(b), DistanceMetric::Manhattan => a
341 .data
342 .iter()
343 .zip(b.data.iter())
344 .map(|(x, y)| (x - y).abs())
345 .sum(),
346 DistanceMetric::Fidelity => {
347 let inner_product: f32 = a.data.iter().zip(b.data.iter()).map(|(x, y)| x * y).sum();
350 1.0 - inner_product.powi(2)
351 }
352 DistanceMetric::Hamming => {
353 a.data
355 .iter()
356 .zip(b.data.iter())
357 .filter(|(x, y)| (**x > 0.0) != (**y > 0.0))
358 .count() as f32
359 }
360 }
361 }
362
363 pub fn similarity(&self, a: &Embedding, b: &Embedding) -> f32 {
365 match self {
366 DistanceMetric::Cosine => {
367 let dot = a.dot(b);
368 let norm_a = a.data.iter().map(|x| x * x).sum::<f32>().sqrt();
369 let norm_b = b.data.iter().map(|x| x * x).sum::<f32>().sqrt();
370 if norm_a < 1e-10 || norm_b < 1e-10 {
371 0.0
372 } else {
373 dot / (norm_a * norm_b)
374 }
375 }
376 DistanceMetric::DotProduct => a.dot(b),
377 DistanceMetric::Euclidean => 1.0 / (1.0 + self.distance(a, b)),
378 DistanceMetric::Manhattan => 1.0 / (1.0 + self.distance(a, b)),
379 DistanceMetric::Fidelity => {
380 let inner_product: f32 = a.data.iter().zip(b.data.iter()).map(|(x, y)| x * y).sum();
382 inner_product.powi(2)
383 }
384 DistanceMetric::Hamming => {
385 let matches = a
387 .data
388 .iter()
389 .zip(b.data.iter())
390 .filter(|(x, y)| (**x > 0.0) == (**y > 0.0))
391 .count();
392 matches as f32 / a.dim.max(1) as f32
393 }
394 }
395 }
396}
397
398#[inline]
404pub fn cosine_distance_fast(a: &[f32], b: &[f32]) -> f32 {
405 let mut dot = 0.0f32;
406 let mut norm_a = 0.0f32;
407 let mut norm_b = 0.0f32;
408
409 let chunks = a.len() / 4;
411 for i in 0..chunks {
412 let idx = i * 4;
413 dot += a[idx] * b[idx] + a[idx + 1] * b[idx + 1] + a[idx + 2] * b[idx + 2] + a[idx + 3] * b[idx + 3];
414 norm_a += a[idx] * a[idx] + a[idx + 1] * a[idx + 1] + a[idx + 2] * a[idx + 2] + a[idx + 3] * a[idx + 3];
415 norm_b += b[idx] * b[idx] + b[idx + 1] * b[idx + 1] + b[idx + 2] * b[idx + 2] + b[idx + 3] * b[idx + 3];
416 }
417
418 for i in (chunks * 4)..a.len() {
420 dot += a[i] * b[i];
421 norm_a += a[i] * a[i];
422 norm_b += b[i] * b[i];
423 }
424
425 let denom = (norm_a * norm_b).sqrt();
426 if denom < 1e-10 {
427 1.0
428 } else {
429 1.0 - (dot / denom)
430 }
431}
432
433#[inline]
435pub fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
436 let mut sum = 0.0f32;
437 let chunks = a.len() / 4;
438
439 for i in 0..chunks {
440 let idx = i * 4;
441 let d0 = a[idx] - b[idx];
442 let d1 = a[idx + 1] - b[idx + 1];
443 let d2 = a[idx + 2] - b[idx + 2];
444 let d3 = a[idx + 3] - b[idx + 3];
445 sum += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3;
446 }
447
448 for i in (chunks * 4)..a.len() {
449 let d = a[i] - b[i];
450 sum += d * d;
451 }
452
453 sum
454}
455
456#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct VectorIndexConfig {
463 pub m: usize,
465 pub m0: usize,
467 pub ef_construction: usize,
469 pub ef_search: usize,
471 pub metric: DistanceMetric,
473 pub max_layers: usize,
475}
476
477impl Default for VectorIndexConfig {
478 fn default() -> Self {
479 Self {
480 m: 16,
481 m0: 32,
482 ef_construction: 200,
483 ef_search: 50,
484 metric: DistanceMetric::Cosine,
485 max_layers: 16,
486 }
487 }
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize)]
492struct HnswNode {
493 id: String,
494 embedding: Embedding,
495 connections: Vec<Vec<String>>,
497}
498
499#[derive(Debug, Clone)]
501pub struct SearchResult {
502 pub id: String,
504 pub distance: f32,
506 pub score: f32,
508 pub entry: Option<VectorEntry>,
510}
511
512impl PartialEq for SearchResult {
513 fn eq(&self, other: &Self) -> bool {
514 self.id == other.id
515 }
516}
517
518impl Eq for SearchResult {}
519
520impl PartialOrd for SearchResult {
521 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
522 Some(self.cmp(other))
523 }
524}
525
526impl Ord for SearchResult {
527 fn cmp(&self, other: &Self) -> Ordering {
528 other
530 .distance
531 .partial_cmp(&self.distance)
532 .unwrap_or(Ordering::Equal)
533 }
534}
535
536#[derive(Debug, Clone, Serialize, Deserialize)]
538pub struct VectorIndex {
539 config: VectorIndexConfig,
541 nodes: HashMap<String, HnswNode>,
543 entry_point: Option<String>,
545 max_layer: usize,
547 dim: usize,
549}
550
551impl VectorIndex {
552 pub fn new(dim: usize) -> Self {
554 Self::with_config(dim, VectorIndexConfig::default())
555 }
556
557 pub fn with_config(dim: usize, config: VectorIndexConfig) -> Self {
559 Self {
560 config,
561 nodes: HashMap::new(),
562 entry_point: None,
563 max_layer: 0,
564 dim,
565 }
566 }
567
568 pub fn len(&self) -> usize {
570 self.nodes.len()
571 }
572
573 pub fn is_empty(&self) -> bool {
575 self.nodes.is_empty()
576 }
577
578 pub fn insert(&mut self, id: impl Into<String>, embedding: Embedding) -> Result<()> {
580 let id = id.into();
581
582 if embedding.dim != self.dim {
583 return Err(QdbError::InvalidInput(format!(
584 "Embedding dimension {} does not match index dimension {}",
585 embedding.dim, self.dim
586 )));
587 }
588
589 let layer = self.random_layer();
591
592 let mut node = HnswNode {
594 id: id.clone(),
595 embedding,
596 connections: vec![Vec::new(); layer + 1],
597 };
598
599 if self.entry_point.is_none() {
600 self.entry_point = Some(id.clone());
602 self.max_layer = layer;
603 self.nodes.insert(id, node);
604 return Ok(());
605 }
606
607 let mut current = self.entry_point.clone().unwrap();
609
610 for l in (layer + 1..=self.max_layer).rev() {
612 current = self.search_layer_greedy(&node.embedding, ¤t, l);
613 }
614
615 for l in (0..=layer.min(self.max_layer)).rev() {
617 let neighbors = self.search_layer(&node.embedding, ¤t, self.config.ef_construction, l);
618
619 let m = if l == 0 { self.config.m0 } else { self.config.m };
621 let selected: Vec<String> = neighbors.into_iter().take(m).map(|(id, _)| id).collect();
622
623 for neighbor_id in &selected {
625 if let Some(neighbor) = self.nodes.get_mut(neighbor_id) {
626 if neighbor.connections.len() > l {
627 neighbor.connections[l].push(id.clone());
628 if neighbor.connections[l].len() > m {
630 self.prune_connections(neighbor_id, l, m);
631 }
632 }
633 }
634 }
635
636 node.connections[l] = selected;
637
638 if !node.connections[l].is_empty() {
639 current = node.connections[l][0].clone();
640 }
641 }
642
643 if layer > self.max_layer {
645 self.entry_point = Some(id.clone());
646 self.max_layer = layer;
647 }
648
649 self.nodes.insert(id, node);
650 Ok(())
651 }
652
653 pub fn search(&self, query: &Embedding, k: usize) -> Vec<SearchResult> {
655 if self.entry_point.is_none() || k == 0 {
656 return Vec::new();
657 }
658
659 let mut current = self.entry_point.clone().unwrap();
660
661 for l in (1..=self.max_layer).rev() {
663 current = self.search_layer_greedy(query, ¤t, l);
664 }
665
666 let candidates = self.search_layer(query, ¤t, self.config.ef_search, 0);
668
669 candidates
671 .into_iter()
672 .take(k)
673 .map(|(id, dist)| SearchResult {
674 id: id.clone(),
675 distance: dist,
676 score: 1.0 - dist.min(1.0),
677 entry: None,
678 })
679 .collect()
680 }
681
682 pub fn search_with_filter<F>(&self, query: &Embedding, k: usize, filter: F) -> Vec<SearchResult>
684 where
685 F: Fn(&str) -> bool,
686 {
687 let candidates = self.search(query, k * 10);
689
690 candidates
691 .into_iter()
692 .filter(|r| filter(&r.id))
693 .take(k)
694 .collect()
695 }
696
697 pub fn delete(&mut self, id: &str) -> bool {
699 if let Some(node) = self.nodes.remove(id) {
700 for (layer, neighbors) in node.connections.iter().enumerate() {
702 for neighbor_id in neighbors {
703 if let Some(neighbor) = self.nodes.get_mut(neighbor_id) {
704 if neighbor.connections.len() > layer {
705 neighbor.connections[layer].retain(|x| x != id);
706 }
707 }
708 }
709 }
710
711 if self.entry_point.as_ref() == Some(&id.to_string()) {
713 self.entry_point = self.nodes.keys().next().cloned();
714 if let Some(ep) = &self.entry_point {
715 if let Some(node) = self.nodes.get(ep) {
716 self.max_layer = node.connections.len().saturating_sub(1);
717 }
718 }
719 }
720
721 true
722 } else {
723 false
724 }
725 }
726
727 fn random_layer(&self) -> usize {
729 use std::collections::hash_map::DefaultHasher;
730 use std::hash::{Hash, Hasher};
731 use std::time::{SystemTime, UNIX_EPOCH};
732
733 let seed = SystemTime::now()
734 .duration_since(UNIX_EPOCH)
735 .unwrap_or_default()
736 .as_nanos() as u64;
737
738 let mut hasher = DefaultHasher::new();
739 seed.hash(&mut hasher);
740 let r = (hasher.finish() as f64) / (u64::MAX as f64);
741
742 let ml = 1.0 / (self.config.m as f64).ln();
743 ((-r.ln() * ml) as usize).min(self.config.max_layers - 1)
744 }
745
746 fn search_layer_greedy(&self, query: &Embedding, start: &str, layer: usize) -> String {
748 let mut current = start.to_string();
749 let mut current_dist = self.distance_to(query, ¤t);
750
751 loop {
752 let mut changed = false;
753
754 if let Some(node) = self.nodes.get(¤t) {
755 if node.connections.len() > layer {
756 for neighbor_id in &node.connections[layer] {
757 let dist = self.distance_to(query, neighbor_id);
758 if dist < current_dist {
759 current = neighbor_id.clone();
760 current_dist = dist;
761 changed = true;
762 }
763 }
764 }
765 }
766
767 if !changed {
768 break;
769 }
770 }
771
772 current
773 }
774
775 fn search_layer(
777 &self,
778 query: &Embedding,
779 start: &str,
780 ef: usize,
781 layer: usize,
782 ) -> Vec<(String, f32)> {
783 let mut visited: HashSet<String> = HashSet::new();
784 let mut candidates: BinaryHeap<(OrderedFloat, String)> = BinaryHeap::new();
785 let mut results: BinaryHeap<(OrderedFloat, String)> = BinaryHeap::new();
786
787 let start_dist = self.distance_to(query, start);
788 candidates.push((OrderedFloat(-start_dist), start.to_string()));
789 results.push((OrderedFloat(start_dist), start.to_string()));
790 visited.insert(start.to_string());
791
792 while let Some((OrderedFloat(neg_dist), current)) = candidates.pop() {
793 let current_dist = -neg_dist;
794
795 if let Some((OrderedFloat(worst_dist), _)) = results.peek() {
797 if current_dist > *worst_dist && results.len() >= ef {
798 break;
799 }
800 }
801
802 if let Some(node) = self.nodes.get(¤t) {
803 if node.connections.len() > layer {
804 for neighbor_id in &node.connections[layer] {
805 if visited.contains(neighbor_id) {
806 continue;
807 }
808 visited.insert(neighbor_id.clone());
809
810 let dist = self.distance_to(query, neighbor_id);
811
812 let should_add = if results.len() < ef {
813 true
814 } else if let Some((OrderedFloat(worst_dist), _)) = results.peek() {
815 dist < *worst_dist
816 } else {
817 true
818 };
819
820 if should_add {
821 candidates.push((OrderedFloat(-dist), neighbor_id.clone()));
822 results.push((OrderedFloat(dist), neighbor_id.clone()));
823
824 if results.len() > ef {
825 results.pop();
826 }
827 }
828 }
829 }
830 }
831 }
832
833 let mut result_vec: Vec<(String, f32)> = results
835 .into_iter()
836 .map(|(OrderedFloat(d), id)| (id, d))
837 .collect();
838 result_vec.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
839 result_vec
840 }
841
842 fn distance_to(&self, query: &Embedding, id: &str) -> f32 {
844 if let Some(node) = self.nodes.get(id) {
845 self.config.metric.distance(query, &node.embedding)
846 } else {
847 f32::MAX
848 }
849 }
850
851 fn prune_connections(&mut self, id: &str, layer: usize, m: usize) {
853 if let Some(node) = self.nodes.get(id) {
854 let embedding = node.embedding.clone();
855 let connections = node.connections[layer].clone();
856
857 let mut scored: Vec<(String, f32)> = connections
859 .into_iter()
860 .filter_map(|neighbor_id| {
861 self.nodes
862 .get(&neighbor_id)
863 .map(|n| (neighbor_id, self.config.metric.distance(&embedding, &n.embedding)))
864 })
865 .collect();
866
867 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
868
869 if let Some(node) = self.nodes.get_mut(id) {
870 node.connections[layer] = scored.into_iter().take(m).map(|(id, _)| id).collect();
871 }
872 }
873 }
874}
875
876#[derive(Debug, Clone, Copy, PartialEq)]
878struct OrderedFloat(f32);
879
880impl Eq for OrderedFloat {}
881
882impl PartialOrd for OrderedFloat {
883 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
884 Some(self.cmp(other))
885 }
886}
887
888impl Ord for OrderedFloat {
889 fn cmp(&self, other: &Self) -> Ordering {
890 self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
891 }
892}
893
894#[derive(Debug, Clone, Serialize, Deserialize)]
900pub struct VectorStoreConfig {
901 pub dim: usize,
903 pub metric: DistanceMetric,
905 pub index_config: VectorIndexConfig,
907}
908
909impl VectorStoreConfig {
910 pub fn openai() -> Self {
912 Self {
913 dim: 1536,
914 metric: DistanceMetric::Cosine,
915 index_config: VectorIndexConfig::default(),
916 }
917 }
918
919 pub fn sentence_transformers() -> Self {
921 Self {
922 dim: 384,
923 metric: DistanceMetric::Cosine,
924 index_config: VectorIndexConfig::default(),
925 }
926 }
927
928 pub fn custom(dim: usize, metric: DistanceMetric) -> Self {
930 Self {
931 dim,
932 metric,
933 index_config: VectorIndexConfig {
934 metric,
935 ..Default::default()
936 },
937 }
938 }
939}
940
941#[derive(Debug)]
943pub struct VectorStore {
944 config: VectorStoreConfig,
946 index: VectorIndex,
948 entries: HashMap<String, VectorEntry>,
950}
951
952impl VectorStore {
953 pub fn new(config: VectorStoreConfig) -> Self {
955 let index = VectorIndex::with_config(config.dim, config.index_config.clone());
956 Self {
957 config,
958 index,
959 entries: HashMap::new(),
960 }
961 }
962
963 pub fn with_dim(dim: usize) -> Self {
965 Self::new(VectorStoreConfig::custom(dim, DistanceMetric::Cosine))
966 }
967
968 pub fn len(&self) -> usize {
970 self.entries.len()
971 }
972
973 pub fn is_empty(&self) -> bool {
975 self.entries.is_empty()
976 }
977
978 pub fn put(&mut self, entry: VectorEntry) -> Result<()> {
980 let id = entry.id.clone();
981 self.index.insert(&id, entry.embedding.clone())?;
982 self.entries.insert(id, entry);
983 Ok(())
984 }
985
986 pub fn put_embedding(&mut self, id: impl Into<String>, embedding: impl Into<Embedding>) -> Result<()> {
988 let entry = VectorEntry::new(id, embedding.into());
989 self.put(entry)
990 }
991
992 pub fn put_batch(&mut self, entries: Vec<VectorEntry>) -> Result<usize> {
994 let mut count = 0;
995 for entry in entries {
996 self.put(entry)?;
997 count += 1;
998 }
999 Ok(count)
1000 }
1001
1002 pub fn search(&self, query: &Embedding, k: usize) -> Vec<SearchResult> {
1004 let mut results = self.index.search(query, k);
1005
1006 for result in &mut results {
1008 result.entry = self.entries.get(&result.id).cloned();
1009 }
1010
1011 results
1012 }
1013
1014 pub fn search_with_filter<F>(&self, query: &Embedding, k: usize, filter: F) -> Vec<SearchResult>
1016 where
1017 F: Fn(&VectorEntry) -> bool,
1018 {
1019 let mut results = self.index.search(query, k * 10);
1020
1021 results
1023 .into_iter()
1024 .filter_map(|mut r| {
1025 if let Some(entry) = self.entries.get(&r.id) {
1026 if filter(entry) {
1027 r.entry = Some(entry.clone());
1028 return Some(r);
1029 }
1030 }
1031 None
1032 })
1033 .take(k)
1034 .collect()
1035 }
1036
1037 pub fn search_namespace(&self, query: &Embedding, k: usize, namespace: &str) -> Vec<SearchResult> {
1039 self.search_with_filter(query, k, |e| {
1040 e.namespace.as_ref().map(|n| n == namespace).unwrap_or(false)
1041 })
1042 }
1043
1044 pub fn get(&self, id: &str) -> Option<&VectorEntry> {
1046 self.entries.get(id)
1047 }
1048
1049 pub fn delete(&mut self, id: &str) -> bool {
1051 self.index.delete(id);
1052 self.entries.remove(id).is_some()
1053 }
1054
1055 pub fn ids(&self) -> Vec<&String> {
1057 self.entries.keys().collect()
1058 }
1059
1060 pub fn similarity(&self, id1: &str, id2: &str) -> Option<f32> {
1062 let e1 = self.entries.get(id1)?;
1063 let e2 = self.entries.get(id2)?;
1064 Some(self.config.metric.similarity(&e1.embedding, &e2.embedding))
1065 }
1066
1067 pub fn upsert(&mut self, entry: VectorEntry) -> Result<bool> {
1069 let id = entry.id.clone();
1070 let existed = self.entries.contains_key(&id);
1071
1072 if existed {
1073 self.index.delete(&id);
1074 }
1075
1076 self.index.insert(&id, entry.embedding.clone())?;
1077 self.entries.insert(id, entry);
1078 Ok(existed)
1079 }
1080
1081 pub fn clear(&mut self) {
1083 self.entries.clear();
1084 self.index = VectorIndex::with_config(self.config.dim, self.config.index_config.clone());
1085 }
1086
1087 pub fn stats(&self) -> VectorStoreStats {
1089 let total_vectors = self.entries.len();
1090 let total_dimensions = self.config.dim;
1091 let memory_estimate = total_vectors * total_dimensions * 4; let namespaces: HashSet<_> = self.entries
1094 .values()
1095 .filter_map(|e| e.namespace.as_ref())
1096 .collect();
1097
1098 VectorStoreStats {
1099 total_vectors,
1100 dimensions: total_dimensions,
1101 memory_bytes: memory_estimate,
1102 index_layers: self.index.max_layer,
1103 namespaces: namespaces.len(),
1104 metric: self.config.metric,
1105 }
1106 }
1107
1108 pub fn find_similar(&self, id: &str, k: usize) -> Vec<SearchResult> {
1110 if let Some(entry) = self.entries.get(id) {
1111 let mut results = self.search(&entry.embedding, k + 1);
1112 results.retain(|r| r.id != id);
1113 results.truncate(k);
1114 results
1115 } else {
1116 Vec::new()
1117 }
1118 }
1119
1120 pub fn contains(&self, id: &str) -> bool {
1122 self.entries.contains_key(id)
1123 }
1124
1125 pub fn config(&self) -> &VectorStoreConfig {
1127 &self.config
1128 }
1129
1130 pub fn update_metadata(&mut self, id: &str, key: impl Into<String>, value: impl Into<Value>) -> bool {
1132 if let Some(entry) = self.entries.get_mut(id) {
1133 entry.metadata.insert(key.into(), value.into());
1134 true
1135 } else {
1136 false
1137 }
1138 }
1139
1140 pub fn get_namespace(&self, namespace: &str) -> Vec<&VectorEntry> {
1142 self.entries
1143 .values()
1144 .filter(|e| e.namespace.as_ref().map(|n| n == namespace).unwrap_or(false))
1145 .collect()
1146 }
1147
1148 pub fn count_namespace(&self, namespace: &str) -> usize {
1150 self.entries
1151 .values()
1152 .filter(|e| e.namespace.as_ref().map(|n| n == namespace).unwrap_or(false))
1153 .count()
1154 }
1155
1156 pub fn delete_namespace(&mut self, namespace: &str) -> usize {
1158 let ids_to_delete: Vec<String> = self.entries
1159 .iter()
1160 .filter(|(_, e)| e.namespace.as_ref().map(|n| n == namespace).unwrap_or(false))
1161 .map(|(id, _)| id.clone())
1162 .collect();
1163
1164 let count = ids_to_delete.len();
1165 for id in ids_to_delete {
1166 self.delete(&id);
1167 }
1168 count
1169 }
1170
1171 pub fn centroid(&self, namespace: Option<&str>) -> Option<Embedding> {
1173 let vectors: Vec<&Embedding> = self.entries
1174 .values()
1175 .filter(|e| {
1176 namespace.map(|ns| e.namespace.as_ref().map(|n| n == ns).unwrap_or(false)).unwrap_or(true)
1177 })
1178 .map(|e| &e.embedding)
1179 .collect();
1180
1181 if vectors.is_empty() {
1182 return None;
1183 }
1184
1185 let dim = vectors[0].dim;
1186 let mut centroid = vec![0.0f32; dim];
1187
1188 for v in &vectors {
1189 for (i, val) in v.data.iter().enumerate() {
1190 centroid[i] += val;
1191 }
1192 }
1193
1194 let n = vectors.len() as f32;
1195 for val in &mut centroid {
1196 *val /= n;
1197 }
1198
1199 Some(Embedding::new(centroid))
1200 }
1201}
1202
1203#[derive(Debug, Clone)]
1205pub struct VectorStoreStats {
1206 pub total_vectors: usize,
1208 pub dimensions: usize,
1210 pub memory_bytes: usize,
1212 pub index_layers: usize,
1214 pub namespaces: usize,
1216 pub metric: DistanceMetric,
1218}
1219
1220impl std::fmt::Display for VectorStoreStats {
1221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1222 write!(
1223 f,
1224 "VectorStore: {} vectors, {} dims, {:.2} MB, {} layers, {} namespaces, {:?}",
1225 self.total_vectors,
1226 self.dimensions,
1227 self.memory_bytes as f64 / (1024.0 * 1024.0),
1228 self.index_layers,
1229 self.namespaces,
1230 self.metric
1231 )
1232 }
1233}
1234
1235#[cfg(test)]
1240mod tests {
1241 use super::*;
1242
1243 #[test]
1244 fn test_embedding_creation() {
1245 let emb = Embedding::new(vec![1.0, 2.0, 3.0]);
1246 assert_eq!(emb.dim, 3);
1247 }
1248
1249 #[test]
1250 fn test_embedding_dot_product() {
1251 let a = Embedding::new(vec![1.0, 0.0, 0.0]);
1252 let b = Embedding::new(vec![0.0, 1.0, 0.0]);
1253 let c = Embedding::new(vec![1.0, 0.0, 0.0]);
1254
1255 assert!((a.dot(&b) - 0.0).abs() < 1e-6);
1256 assert!((a.dot(&c) - 1.0).abs() < 1e-6);
1257 }
1258
1259 #[test]
1260 fn test_distance_metrics() {
1261 let a = Embedding::new(vec![1.0, 0.0]);
1262 let b = Embedding::new(vec![0.0, 1.0]);
1263
1264 let cosine_dist = DistanceMetric::Cosine.distance(&a, &b);
1265 assert!((cosine_dist - 1.0).abs() < 1e-6); let euclidean_dist = DistanceMetric::Euclidean.distance(&a, &b);
1268 assert!((euclidean_dist - 2.0_f32.sqrt()).abs() < 1e-6);
1269 }
1270
1271 #[test]
1272 fn test_vector_index_insert_search() {
1273 let mut index = VectorIndex::new(3);
1274
1275 index.insert("a", Embedding::new(vec![1.0, 0.0, 0.0])).unwrap();
1276 index.insert("b", Embedding::new(vec![0.9, 0.1, 0.0])).unwrap();
1277 index.insert("c", Embedding::new(vec![0.0, 1.0, 0.0])).unwrap();
1278
1279 let query = Embedding::new(vec![1.0, 0.0, 0.0]);
1280 let results = index.search(&query, 2);
1281
1282 assert_eq!(results.len(), 2);
1283 assert_eq!(results[0].id, "a"); }
1285
1286 #[test]
1287 fn test_vector_store() {
1288 let mut store = VectorStore::with_dim(3);
1289
1290 store
1291 .put(VectorEntry::new("doc1", Embedding::new(vec![1.0, 0.0, 0.0])).with_content("Hello"))
1292 .unwrap();
1293 store
1294 .put(VectorEntry::new("doc2", Embedding::new(vec![0.0, 1.0, 0.0])).with_content("World"))
1295 .unwrap();
1296
1297 let query = Embedding::new(vec![0.9, 0.1, 0.0]);
1298 let results = store.search(&query, 1);
1299
1300 assert_eq!(results.len(), 1);
1301 assert_eq!(results[0].id, "doc1");
1302 assert!(results[0].entry.is_some());
1303 }
1304
1305 #[test]
1306 fn test_vector_store_namespace() {
1307 let mut store = VectorStore::with_dim(2);
1308
1309 store
1310 .put(
1311 VectorEntry::new("a", Embedding::new(vec![1.0, 0.0]))
1312 .with_namespace("ns1"),
1313 )
1314 .unwrap();
1315 store
1316 .put(
1317 VectorEntry::new("b", Embedding::new(vec![0.9, 0.1]))
1318 .with_namespace("ns2"),
1319 )
1320 .unwrap();
1321
1322 let query = Embedding::new(vec![1.0, 0.0]);
1323 let results = store.search_namespace(&query, 10, "ns1");
1324
1325 assert_eq!(results.len(), 1);
1326 assert_eq!(results[0].id, "a");
1327 }
1328
1329 #[test]
1330 fn test_quantum_fidelity() {
1331 let bell = Embedding::new(vec![0.707, 0.0, 0.0, 0.707]);
1333 let same = Embedding::new(vec![0.707, 0.0, 0.0, 0.707]);
1334 let orthogonal = Embedding::new(vec![0.0, 0.707, 0.707, 0.0]);
1335
1336 let fid_same = bell.fidelity(&same);
1338 assert!((fid_same - 1.0).abs() < 0.01);
1339
1340 let fid_orth = bell.fidelity(&orthogonal);
1342 assert!(fid_orth < 0.01);
1343 }
1344
1345 #[test]
1346 fn test_fidelity_distance_metric() {
1347 let a = Embedding::new(vec![1.0, 0.0]).normalized();
1348 let b = Embedding::new(vec![0.707, 0.707]).normalized();
1349
1350 let fid_dist = DistanceMetric::Fidelity.distance(&a, &b);
1351 let fid_sim = DistanceMetric::Fidelity.similarity(&a, &b);
1352
1353 assert!((fid_sim - 0.5).abs() < 0.1);
1355 assert!((fid_dist - 0.5).abs() < 0.1);
1356 }
1357
1358 #[test]
1359 fn test_amplitude_encoding() {
1360 let emb = Embedding::new(vec![3.0, 4.0]);
1361 let amp = emb.to_amplitude_encoding();
1362
1363 let norm: f32 = amp.data.iter().map(|x| x * x).sum::<f32>().sqrt();
1365 assert!((norm - 1.0).abs() < 1e-6);
1366 }
1367
1368 #[test]
1369 fn test_upsert() {
1370 let mut store = VectorStore::with_dim(2);
1371
1372 let existed = store.upsert(VectorEntry::new("a", Embedding::new(vec![1.0, 0.0]))).unwrap();
1374 assert!(!existed);
1375
1376 let existed = store.upsert(VectorEntry::new("a", Embedding::new(vec![0.0, 1.0]))).unwrap();
1378 assert!(existed);
1379
1380 let entry = store.get("a").unwrap();
1382 assert!((entry.embedding.data[1] - 1.0).abs() < 1e-6);
1383 }
1384
1385 #[test]
1386 fn test_find_similar() {
1387 let mut store = VectorStore::with_dim(2);
1388
1389 store.put(VectorEntry::new("a", Embedding::new(vec![1.0, 0.0]))).unwrap();
1390 store.put(VectorEntry::new("b", Embedding::new(vec![0.9, 0.1]))).unwrap();
1391 store.put(VectorEntry::new("c", Embedding::new(vec![0.0, 1.0]))).unwrap();
1392
1393 let similar = store.find_similar("a", 2);
1394 assert_eq!(similar.len(), 2);
1395 assert_eq!(similar[0].id, "b"); }
1397
1398 #[test]
1399 fn test_stats() {
1400 let mut store = VectorStore::with_dim(128);
1401
1402 for i in 0..10 {
1403 store.put(
1404 VectorEntry::new(format!("v{}", i), Embedding::zeros(128))
1405 .with_namespace(if i < 5 { "ns1" } else { "ns2" })
1406 ).unwrap();
1407 }
1408
1409 let stats = store.stats();
1410 assert_eq!(stats.total_vectors, 10);
1411 assert_eq!(stats.dimensions, 128);
1412 assert_eq!(stats.namespaces, 2);
1413 }
1414
1415 #[test]
1416 fn test_centroid() {
1417 let mut store = VectorStore::with_dim(2);
1418
1419 store.put(VectorEntry::new("a", Embedding::new(vec![1.0, 0.0]))).unwrap();
1420 store.put(VectorEntry::new("b", Embedding::new(vec![0.0, 1.0]))).unwrap();
1421
1422 let centroid = store.centroid(None).unwrap();
1423 assert!((centroid.data[0] - 0.5).abs() < 1e-6);
1424 assert!((centroid.data[1] - 0.5).abs() < 1e-6);
1425 }
1426
1427 #[test]
1428 fn test_hamming_distance() {
1429 let a = Embedding::new(vec![1.0, -1.0, 1.0, -1.0]);
1430 let b = Embedding::new(vec![1.0, 1.0, 1.0, -1.0]);
1431
1432 let dist = DistanceMetric::Hamming.distance(&a, &b);
1433 assert!((dist - 1.0).abs() < 1e-6); }
1435
1436 #[test]
1437 fn test_simd_cosine() {
1438 let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
1439 let b: Vec<f32> = (0..128).map(|i| (i * 2) as f32).collect();
1440
1441 let fast = cosine_distance_fast(&a, &b);
1442 let slow = DistanceMetric::Cosine.distance(
1443 &Embedding::new(a.clone()),
1444 &Embedding::new(b.clone())
1445 );
1446
1447 assert!((fast - slow).abs() < 1e-4);
1448 }
1449}