1use std::cmp::Ordering;
20use std::collections::{HashMap, HashSet};
21use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
22use std::time::{SystemTime, UNIX_EPOCH};
23
24use super::distance::{distance_simd, DistanceMetric};
25use super::hnsw::{HnswConfig, HnswIndex, NodeId};
26use super::vector_metadata::{MetadataEntry, MetadataFilter, MetadataStore};
27use crate::storage::index::{BloomSegment, HasBloom};
28
29pub type SegmentId = u64;
31
32pub type VectorId = u64;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum SegmentState {
38 Growing,
40 Sealed,
42 Flushed,
44}
45
46#[derive(Debug, Clone)]
48pub struct SegmentConfig {
49 pub max_vectors: usize,
51 pub hnsw_config: HnswConfig,
53}
54
55impl Default for SegmentConfig {
56 fn default() -> Self {
57 Self {
58 max_vectors: 10_000,
59 hnsw_config: HnswConfig::default(),
60 }
61 }
62}
63
64fn cmp_distance(a: f32, b: f32) -> Ordering {
65 match a.partial_cmp(&b) {
66 Some(order) => order,
67 None => {
68 if a.is_nan() && b.is_nan() {
69 Ordering::Equal
70 } else if a.is_nan() {
71 Ordering::Greater
72 } else {
73 Ordering::Less
74 }
75 }
76 }
77}
78
79pub struct VectorSegment {
81 pub id: SegmentId,
83 pub state: SegmentState,
85 pub dimension: usize,
87 pub metric: DistanceMetric,
89 vectors: HashMap<VectorId, Vec<f32>>,
91 metadata: MetadataStore,
93 hnsw_index: Option<HnswIndex>,
95 id_to_hnsw: HashMap<VectorId, NodeId>,
97 hnsw_to_id: HashMap<NodeId, VectorId>,
99 bloom: BloomSegment,
103 pub created_at: u64,
105 pub updated_at: u64,
107}
108
109impl HasBloom for VectorSegment {
110 fn bloom_segment(&self) -> Option<&BloomSegment> {
111 Some(&self.bloom)
112 }
113}
114
115impl VectorSegment {
116 pub fn new(id: SegmentId, dimension: usize, metric: DistanceMetric) -> Self {
118 let now = SystemTime::now()
119 .duration_since(UNIX_EPOCH)
120 .unwrap()
121 .as_secs();
122
123 Self {
124 id,
125 state: SegmentState::Growing,
126 dimension,
127 metric,
128 vectors: HashMap::new(),
129 metadata: MetadataStore::new(),
130 hnsw_index: None,
131 id_to_hnsw: HashMap::new(),
132 hnsw_to_id: HashMap::new(),
133 bloom: BloomSegment::with_capacity(10_000),
134 created_at: now,
135 updated_at: now,
136 }
137 }
138
139 pub fn len(&self) -> usize {
141 self.vectors.len()
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.vectors.is_empty()
147 }
148
149 pub fn can_write(&self) -> bool {
151 self.state == SegmentState::Growing
152 }
153
154 pub fn insert(
156 &mut self,
157 id: VectorId,
158 vector: Vec<f32>,
159 metadata: MetadataEntry,
160 ) -> Result<(), VectorStoreError> {
161 if !self.can_write() {
162 return Err(VectorStoreError::SegmentSealed);
163 }
164
165 if vector.len() != self.dimension {
166 return Err(VectorStoreError::DimensionMismatch {
167 expected: self.dimension,
168 got: vector.len(),
169 });
170 }
171
172 self.bloom.insert(&id.to_le_bytes());
173 self.vectors.insert(id, vector);
174 self.metadata.insert(id, metadata);
175 self.update_timestamp();
176
177 Ok(())
178 }
179
180 pub fn get_vector(&self, id: VectorId) -> Option<&Vec<f32>> {
182 if self.bloom.definitely_absent(&id.to_le_bytes()) {
183 return None;
184 }
185 self.vectors.get(&id)
186 }
187
188 pub fn get_metadata(&self, id: VectorId) -> Option<&MetadataEntry> {
190 if self.bloom.definitely_absent(&id.to_le_bytes()) {
191 return None;
192 }
193 self.metadata.get(id)
194 }
195
196 pub fn delete(&mut self, id: VectorId) -> Result<bool, VectorStoreError> {
198 if !self.can_write() {
199 return Err(VectorStoreError::SegmentSealed);
200 }
201
202 let existed = self.vectors.remove(&id).is_some();
203 if existed {
204 self.metadata.remove(id);
205 self.update_timestamp();
206 }
207
208 Ok(existed)
209 }
210
211 pub fn search(
213 &self,
214 query: &[f32],
215 k: usize,
216 filter: Option<&MetadataFilter>,
217 ) -> Vec<SearchResult> {
218 if query.len() != self.dimension {
219 return Vec::new();
220 }
221
222 match self.state {
223 SegmentState::Growing => self.brute_force_search(query, k, filter),
224 SegmentState::Sealed | SegmentState::Flushed => self.hnsw_search(query, k, filter),
225 }
226 }
227
228 pub fn seal(&mut self, config: &HnswConfig) {
230 if self.state != SegmentState::Growing {
231 return;
232 }
233
234 let mut hnsw = HnswIndex::new(self.dimension, config.clone());
236
237 for (&vector_id, vector) in &self.vectors {
238 let hnsw_id = hnsw.insert(vector.clone());
239 self.id_to_hnsw.insert(vector_id, hnsw_id);
240 self.hnsw_to_id.insert(hnsw_id, vector_id);
241 }
242
243 self.hnsw_index = Some(hnsw);
244 self.state = SegmentState::Sealed;
245 self.update_timestamp();
246 }
247
248 fn update_timestamp(&mut self) {
253 self.updated_at = SystemTime::now()
254 .duration_since(UNIX_EPOCH)
255 .unwrap()
256 .as_secs();
257 }
258
259 fn brute_force_search(
260 &self,
261 query: &[f32],
262 k: usize,
263 filter: Option<&MetadataFilter>,
264 ) -> Vec<SearchResult> {
265 let allowed: Option<HashSet<VectorId>> = filter.map(|f| self.metadata.filter(f));
267
268 let mut results: Vec<SearchResult> = self
270 .vectors
271 .iter()
272 .filter(|(id, _)| allowed.as_ref().map(|a| a.contains(id)).unwrap_or(true))
273 .map(|(&id, vector)| {
274 let dist = distance_simd(query, vector, self.metric);
275 SearchResult {
276 id,
277 distance: dist,
278 vector: Some(vector.clone()),
279 metadata: self.metadata.get(id).cloned(),
280 }
281 })
282 .collect();
283
284 results.sort_by(|a, b| cmp_distance(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
286 results.truncate(k);
287
288 results
289 }
290
291 fn hnsw_search(
292 &self,
293 query: &[f32],
294 k: usize,
295 filter: Option<&MetadataFilter>,
296 ) -> Vec<SearchResult> {
297 let hnsw = match &self.hnsw_index {
298 Some(h) => h,
299 None => return self.brute_force_search(query, k, filter),
300 };
301
302 let hnsw_results = if let Some(f) = filter {
303 let allowed_vector_ids = self.metadata.filter(f);
305 let allowed_hnsw_ids: HashSet<NodeId> = allowed_vector_ids
306 .iter()
307 .filter_map(|vid| self.id_to_hnsw.get(vid))
308 .copied()
309 .collect();
310
311 hnsw.search_filtered(query, k, &allowed_hnsw_ids)
312 } else {
313 hnsw.search(query, k)
314 };
315
316 hnsw_results
318 .into_iter()
319 .filter_map(|r| {
320 let vector_id = self.hnsw_to_id.get(&r.id)?;
321 Some(SearchResult {
322 id: *vector_id,
323 distance: r.distance,
324 vector: self.vectors.get(vector_id).cloned(),
325 metadata: self.metadata.get(*vector_id).cloned(),
326 })
327 })
328 .collect()
329 }
330}
331
332#[derive(Debug, Clone)]
334pub struct SearchResult {
335 pub id: VectorId,
337 pub distance: f32,
339 pub vector: Option<Vec<f32>>,
341 pub metadata: Option<MetadataEntry>,
343}
344
345pub struct VectorCollection {
347 pub name: String,
349 pub dimension: usize,
351 pub metric: DistanceMetric,
353 config: SegmentConfig,
355 segments: HashMap<SegmentId, VectorSegment>,
357 growing_segment: Option<SegmentId>,
359 next_segment_id: AtomicU64,
361 next_vector_id: AtomicU64,
363 vector_to_segment: HashMap<VectorId, SegmentId>,
365}
366
367impl VectorCollection {
368 pub fn new(name: impl Into<String>, dimension: usize) -> Self {
370 Self::with_config(name, dimension, SegmentConfig::default())
371 }
372
373 pub fn with_config(name: impl Into<String>, dimension: usize, config: SegmentConfig) -> Self {
375 let metric = config.hnsw_config.metric;
376
377 Self {
378 name: name.into(),
379 dimension,
380 metric,
381 config,
382 segments: HashMap::new(),
383 growing_segment: None,
384 next_segment_id: AtomicU64::new(0),
385 next_vector_id: AtomicU64::new(0),
386 vector_to_segment: HashMap::new(),
387 }
388 }
389
390 pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
392 self.metric = metric;
393 self.config.hnsw_config.metric = metric;
394 self
395 }
396
397 pub fn len(&self) -> usize {
399 self.segments.values().map(|s| s.len()).sum()
400 }
401
402 pub fn is_empty(&self) -> bool {
404 self.len() == 0
405 }
406
407 pub fn segment_count(&self) -> usize {
409 self.segments.len()
410 }
411
412 pub fn insert(
414 &mut self,
415 vector: Vec<f32>,
416 metadata: Option<MetadataEntry>,
417 ) -> Result<VectorId, VectorStoreError> {
418 if vector.len() != self.dimension {
419 return Err(VectorStoreError::DimensionMismatch {
420 expected: self.dimension,
421 got: vector.len(),
422 });
423 }
424
425 let segment_id = self.ensure_growing_segment();
427 let segment = self.segments.get_mut(&segment_id).unwrap();
428
429 let vector_id = self.next_vector_id.fetch_add(1, AtomicOrdering::SeqCst);
430 segment.insert(vector_id, vector, metadata.unwrap_or_default())?;
431
432 self.vector_to_segment.insert(vector_id, segment_id);
433
434 if segment.len() >= self.config.max_vectors {
436 self.seal_segment(segment_id);
437 }
438
439 Ok(vector_id)
440 }
441
442 pub fn insert_with_id(
444 &mut self,
445 id: VectorId,
446 vector: Vec<f32>,
447 metadata: Option<MetadataEntry>,
448 ) -> Result<(), VectorStoreError> {
449 if vector.len() != self.dimension {
450 return Err(VectorStoreError::DimensionMismatch {
451 expected: self.dimension,
452 got: vector.len(),
453 });
454 }
455
456 let segment_id = self.ensure_growing_segment();
457 let segment = self.segments.get_mut(&segment_id).unwrap();
458
459 segment.insert(id, vector, metadata.unwrap_or_default())?;
460 self.vector_to_segment.insert(id, segment_id);
461
462 let current_next = self.next_vector_id.load(AtomicOrdering::SeqCst);
464 if id >= current_next {
465 self.next_vector_id.store(id + 1, AtomicOrdering::SeqCst);
466 }
467
468 if segment.len() >= self.config.max_vectors {
470 self.seal_segment(segment_id);
471 }
472
473 Ok(())
474 }
475
476 pub fn get(&self, id: VectorId) -> Option<&Vec<f32>> {
478 let segment_id = self.vector_to_segment.get(&id)?;
479 self.segments.get(segment_id)?.get_vector(id)
480 }
481
482 pub fn get_metadata(&self, id: VectorId) -> Option<&MetadataEntry> {
484 let segment_id = self.vector_to_segment.get(&id)?;
485 self.segments.get(segment_id)?.get_metadata(id)
486 }
487
488 pub fn search(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
490 self.search_with_filter(query, k, None)
491 }
492
493 pub fn search_with_filter(
495 &self,
496 query: &[f32],
497 k: usize,
498 filter: Option<&MetadataFilter>,
499 ) -> Vec<SearchResult> {
500 if query.len() != self.dimension {
501 return Vec::new();
502 }
503
504 let mut all_results: Vec<SearchResult> = Vec::new();
506 for segment in self.segments.values() {
507 let segment_results = segment.search(query, k, filter);
508 all_results.extend(segment_results);
509 }
510
511 all_results
513 .sort_by(|a, b| cmp_distance(a.distance, b.distance).then_with(|| a.id.cmp(&b.id)));
514 all_results.truncate(k);
515
516 all_results
517 }
518
519 pub fn delete(&mut self, id: VectorId) -> Result<bool, VectorStoreError> {
521 let segment_id = match self.vector_to_segment.get(&id) {
522 Some(&sid) => sid,
523 None => return Ok(false),
524 };
525
526 let segment = match self.segments.get_mut(&segment_id) {
527 Some(s) => s,
528 None => return Ok(false),
529 };
530
531 if !segment.can_write() {
533 return Err(VectorStoreError::SegmentSealed);
534 }
535
536 let deleted = segment.delete(id)?;
537 if deleted {
538 self.vector_to_segment.remove(&id);
539 }
540
541 Ok(deleted)
542 }
543
544 pub fn seal_growing(&mut self) {
546 if let Some(segment_id) = self.growing_segment.take() {
547 self.seal_segment(segment_id);
548 }
549 }
550
551 fn seal_segment(&mut self, segment_id: SegmentId) {
553 if let Some(segment) = self.segments.get_mut(&segment_id) {
554 segment.seal(&self.config.hnsw_config);
555 }
556 if self.growing_segment == Some(segment_id) {
557 self.growing_segment = None;
558 }
559 }
560
561 fn ensure_growing_segment(&mut self) -> SegmentId {
563 if let Some(id) = self.growing_segment {
564 return id;
565 }
566
567 let segment_id = self.next_segment_id.fetch_add(1, AtomicOrdering::SeqCst);
568 let segment = VectorSegment::new(segment_id, self.dimension, self.metric);
569 self.segments.insert(segment_id, segment);
570 self.growing_segment = Some(segment_id);
571
572 segment_id
573 }
574}
575
576#[derive(Debug, Clone)]
578pub enum VectorStoreError {
579 DimensionMismatch { expected: usize, got: usize },
581 SegmentSealed,
583 VectorNotFound(VectorId),
585 CollectionNotFound(String),
587}
588
589impl std::fmt::Display for VectorStoreError {
590 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
591 match self {
592 Self::DimensionMismatch { expected, got } => {
593 write!(f, "Dimension mismatch: expected {}, got {}", expected, got)
594 }
595 Self::SegmentSealed => write!(f, "Segment is sealed"),
596 Self::VectorNotFound(id) => write!(f, "Vector not found: {}", id),
597 Self::CollectionNotFound(name) => write!(f, "Collection not found: {}", name),
598 }
599 }
600}
601
602impl std::error::Error for VectorStoreError {}
603
604pub struct VectorStore {
606 collections: HashMap<String, VectorCollection>,
608}
609
610impl VectorStore {
611 pub fn new() -> Self {
613 Self {
614 collections: HashMap::new(),
615 }
616 }
617
618 pub fn create_collection(
620 &mut self,
621 name: impl Into<String>,
622 dimension: usize,
623 ) -> &mut VectorCollection {
624 let name = name.into();
625 self.collections
626 .entry(name.clone())
627 .or_insert_with(|| VectorCollection::new(name.clone(), dimension))
628 }
629
630 pub fn get(&self, name: &str) -> Option<&VectorCollection> {
632 self.collections.get(name)
633 }
634
635 pub fn get_mut(&mut self, name: &str) -> Option<&mut VectorCollection> {
637 self.collections.get_mut(name)
638 }
639
640 pub fn drop_collection(&mut self, name: &str) -> bool {
642 self.collections.remove(name).is_some()
643 }
644
645 pub fn list_collections(&self) -> Vec<&str> {
647 self.collections.keys().map(|s| s.as_str()).collect()
648 }
649}
650
651impl Default for VectorStore {
652 fn default() -> Self {
653 Self::new()
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660 use crate::storage::engine::MetadataValue;
661
662 fn random_vector(dim: usize, seed: u64) -> Vec<f32> {
663 let mut state = seed;
664 (0..dim)
665 .map(|_| {
666 state ^= state << 13;
667 state ^= state >> 7;
668 state ^= state << 17;
669 (state as f32) / (u64::MAX as f32) * 2.0 - 1.0
670 })
671 .collect()
672 }
673
674 #[test]
675 fn test_collection_basic() {
676 let mut collection = VectorCollection::new("test", 3);
677
678 let id1 = collection.insert(vec![1.0, 0.0, 0.0], None).unwrap();
679 let id2 = collection.insert(vec![0.0, 1.0, 0.0], None).unwrap();
680 let id3 = collection.insert(vec![0.0, 0.0, 1.0], None).unwrap();
681
682 assert_eq!(collection.len(), 3);
683 assert!(collection.get(id1).is_some());
684 assert!(collection.get(id2).is_some());
685 assert!(collection.get(id3).is_some());
686 }
687
688 #[test]
689 fn test_collection_search() {
690 let mut collection = VectorCollection::new("test", 2);
691
692 collection.insert(vec![0.0, 0.0], None).unwrap();
693 collection.insert(vec![1.0, 0.0], None).unwrap();
694 collection.insert(vec![2.0, 0.0], None).unwrap();
695 collection.insert(vec![3.0, 0.0], None).unwrap();
696
697 let results = collection.search(&[0.9, 0.0], 2);
698 assert_eq!(results.len(), 2);
699 assert!(results[0].distance <= results[1].distance);
701 }
702
703 #[test]
704 fn test_collection_search_with_filter() {
705 let mut collection = VectorCollection::new("test", 2);
706
707 for i in 0..10 {
708 let mut metadata = MetadataEntry::new();
709 metadata.insert("index", MetadataValue::Integer(i));
710 metadata.insert("even", MetadataValue::Bool(i % 2 == 0));
711 collection
712 .insert(vec![i as f32, 0.0], Some(metadata))
713 .unwrap();
714 }
715
716 let filter = MetadataFilter::eq("even", true);
718 let results = collection.search_with_filter(&[5.0, 0.0], 3, Some(&filter));
719
720 assert_eq!(results.len(), 3);
721 for result in &results {
722 let meta = result.metadata.as_ref().unwrap();
723 assert_eq!(meta.get("even"), Some(MetadataValue::Bool(true)));
724 }
725 }
726
727 #[test]
728 fn test_segment_seal() {
729 let mut segment = VectorSegment::new(0, 3, DistanceMetric::L2);
730
731 for i in 0..100 {
732 segment
733 .insert(i, random_vector(3, i), MetadataEntry::new())
734 .unwrap();
735 }
736
737 assert!(segment.can_write());
738 assert_eq!(segment.state, SegmentState::Growing);
739
740 segment.seal(&HnswConfig::default());
741
742 assert!(!segment.can_write());
743 assert_eq!(segment.state, SegmentState::Sealed);
744 assert!(segment.hnsw_index.is_some());
745
746 let results = segment.search(&random_vector(3, 12345), 5, None);
748 assert_eq!(results.len(), 5);
749 }
750
751 #[test]
752 fn test_auto_seal() {
753 let config = SegmentConfig {
754 max_vectors: 10,
755 hnsw_config: HnswConfig::default(),
756 };
757 let mut collection = VectorCollection::with_config("test", 3, config);
758
759 for i in 0..15 {
761 collection.insert(random_vector(3, i), None).unwrap();
762 }
763
764 assert!(collection.segment_count() >= 1);
766
767 let sealed_count = collection
769 .segments
770 .values()
771 .filter(|s| s.state == SegmentState::Sealed)
772 .count();
773 assert!(sealed_count >= 1);
774 }
775
776 #[test]
777 fn test_vector_store() {
778 let mut store = VectorStore::new();
779
780 store.create_collection("hosts", 128);
781 store.create_collection("vulnerabilities", 256);
782
783 assert_eq!(store.list_collections().len(), 2);
784
785 let hosts = store.get_mut("hosts").unwrap();
786 hosts.insert(random_vector(128, 0), None).unwrap();
787 hosts.insert(random_vector(128, 1), None).unwrap();
788
789 assert_eq!(store.get("hosts").unwrap().len(), 2);
790 assert_eq!(store.get("vulnerabilities").unwrap().len(), 0);
791
792 store.drop_collection("vulnerabilities");
793 assert_eq!(store.list_collections().len(), 1);
794 }
795
796 #[test]
797 fn test_dimension_mismatch() {
798 let mut collection = VectorCollection::new("test", 3);
799
800 let result = collection.insert(vec![1.0, 2.0], None);
801 assert!(matches!(
802 result,
803 Err(VectorStoreError::DimensionMismatch { .. })
804 ));
805 }
806
807 #[test]
808 fn test_search_handles_nan() {
809 let mut collection = VectorCollection::new("test", 2);
810 collection.insert(vec![0.0, 0.0], None).unwrap();
811 collection.insert(vec![f32::NAN, 0.0], None).unwrap();
812
813 let results = collection.search(&[0.0, 0.0], 2);
814 assert_eq!(results.len(), 2);
815 }
816
817 #[test]
818 fn test_search_handles_nan_after_seal() {
819 let mut collection = VectorCollection::new("test", 2);
820 collection.insert(vec![0.0, 0.0], None).unwrap();
821 collection.insert(vec![f32::NAN, 0.0], None).unwrap();
822
823 collection.seal_growing();
824
825 let results = collection.search(&[0.0, 0.0], 2);
826 assert_eq!(results.len(), 2);
827 }
828
829 #[test]
830 fn test_delete() {
831 let mut collection = VectorCollection::new("test", 3);
832
833 let id1 = collection.insert(vec![1.0, 0.0, 0.0], None).unwrap();
834 let id2 = collection.insert(vec![0.0, 1.0, 0.0], None).unwrap();
835
836 assert_eq!(collection.len(), 2);
837
838 collection.delete(id1).unwrap();
839 assert_eq!(collection.len(), 1);
840 assert!(collection.get(id1).is_none());
841 assert!(collection.get(id2).is_some());
842 }
843
844 #[test]
845 fn test_cosine_metric() {
846 let mut collection = VectorCollection::new("test", 3).with_metric(DistanceMetric::Cosine);
847
848 collection.insert(vec![1.0, 0.0, 0.0], None).unwrap();
850 collection.insert(vec![0.0, 1.0, 0.0], None).unwrap();
851 collection.insert(vec![0.707, 0.707, 0.0], None).unwrap();
852
853 let results = collection.search(&[0.707, 0.707, 0.0], 1);
855 assert_eq!(results.len(), 1);
856 assert!(results[0].distance < 0.01); }
858
859 #[test]
860 fn test_metadata_complex_filter() {
861 let mut collection = VectorCollection::new("test", 2);
862
863 for i in 0..20 {
864 let mut metadata = MetadataEntry::new();
865 metadata.insert("score", MetadataValue::Integer(i));
866 metadata.insert(
867 "type",
868 MetadataValue::String(if i < 10 { "low" } else { "high" }.to_string()),
869 );
870 collection
871 .insert(vec![i as f32, 0.0], Some(metadata))
872 .unwrap();
873 }
874
875 let filter = MetadataFilter::and(vec![
877 MetadataFilter::eq("type", "high"),
878 MetadataFilter::gt("score", MetadataValue::Integer(15)),
879 ]);
880
881 let results = collection.search_with_filter(&[17.0, 0.0], 5, Some(&filter));
882
883 assert!(results.len() <= 4);
885 for result in &results {
886 let meta = result.metadata.as_ref().unwrap();
887 let score = match meta.get("score") {
888 Some(MetadataValue::Integer(s)) => s,
889 _ => panic!("Expected integer score"),
890 };
891 assert!(score > 15);
892 }
893 }
894}