1use std::collections::{BTreeMap, HashMap, HashSet};
13
14use super::hnsw::NodeId;
15use crate::storage::schema::{CanonicalKey, CanonicalKeyFamily};
16
17pub use reddb_types::vector_metadata::{
24 metadata_value_to_canonical_key, MetadataEntry, MetadataFilter, MetadataValue,
25};
26
27#[derive(Debug, Clone, Default)]
29struct KeyIndex {
30 string_index: HashMap<String, HashSet<NodeId>>,
32 integer_index: HashMap<i64, HashSet<NodeId>>,
34 bool_index: HashMap<bool, HashSet<NodeId>>,
36 ordered_index: BTreeMap<CanonicalKey, HashSet<NodeId>>,
38 range_family: Option<CanonicalKeyFamily>,
40 has_mixed_families: bool,
41 all_ids: HashSet<NodeId>,
43}
44
45impl KeyIndex {
46 fn new() -> Self {
47 Self::default()
48 }
49
50 fn insert(&mut self, id: NodeId, value: &MetadataValue) {
51 self.all_ids.insert(id);
52 match value {
53 MetadataValue::String(s) => {
54 self.string_index.entry(s.clone()).or_default().insert(id);
55 }
56 MetadataValue::Integer(i) => {
57 self.integer_index.entry(*i).or_default().insert(id);
58 }
59 MetadataValue::Bool(b) => {
60 self.bool_index.entry(*b).or_default().insert(id);
61 }
62 MetadataValue::Float(_) | MetadataValue::Null => {}
63 }
64
65 if let Some(key) = metadata_value_to_canonical_key(value) {
66 match self.range_family {
67 Some(existing) if existing != key.family() => self.has_mixed_families = true,
68 None => self.range_family = Some(key.family()),
69 _ => {}
70 }
71 self.ordered_index.entry(key).or_default().insert(id);
72 }
73 }
74
75 fn remove(&mut self, id: NodeId, value: &MetadataValue) {
76 self.all_ids.remove(&id);
77 match value {
78 MetadataValue::String(s) => {
79 if let Some(ids) = self.string_index.get_mut(s) {
80 ids.remove(&id);
81 }
82 }
83 MetadataValue::Integer(i) => {
84 if let Some(ids) = self.integer_index.get_mut(i) {
85 ids.remove(&id);
86 }
87 }
88 MetadataValue::Bool(b) => {
89 if let Some(ids) = self.bool_index.get_mut(b) {
90 ids.remove(&id);
91 }
92 }
93 _ => {}
94 }
95
96 if let Some(key) = metadata_value_to_canonical_key(value) {
97 if let Some(ids) = self.ordered_index.get_mut(&key) {
98 ids.remove(&id);
99 if ids.is_empty() {
100 self.ordered_index.remove(&key);
101 }
102 }
103 }
104 }
105
106 fn exact_match_ids(&self, value: &MetadataValue) -> Option<HashSet<NodeId>> {
107 match value {
108 MetadataValue::String(s) => Some(self.string_index.get(s).cloned().unwrap_or_default()),
109 MetadataValue::Integer(i) => {
110 Some(self.integer_index.get(i).cloned().unwrap_or_default())
111 }
112 MetadataValue::Bool(b) => Some(self.bool_index.get(b).cloned().unwrap_or_default()),
113 MetadataValue::Null => Some(HashSet::new()),
114 MetadataValue::Float(f) if f.is_nan() => Some(HashSet::new()),
115 MetadataValue::Float(_) => metadata_value_to_canonical_key(value)
116 .map(|key| self.ordered_index.get(&key).cloned().unwrap_or_default()),
117 }
118 }
119
120 fn supports_range_key(&self, key: &CanonicalKey) -> bool {
121 !self.has_mixed_families && self.range_family == Some(key.family())
122 }
123
124 fn range_match_ids(
125 &self,
126 value: &MetadataValue,
127 op: MetadataRangeOp,
128 ) -> Option<HashSet<NodeId>> {
129 let key = metadata_value_to_canonical_key(value)?;
130 if !self.supports_range_key(&key) {
131 return None;
132 }
133
134 let mut out = HashSet::new();
135 match op {
136 MetadataRangeOp::Gt => {
137 for ids in self
138 .ordered_index
139 .range((std::ops::Bound::Excluded(key), std::ops::Bound::Unbounded))
140 .map(|(_, ids)| ids)
141 {
142 out.extend(ids.iter().copied());
143 }
144 }
145 MetadataRangeOp::Gte => {
146 for ids in self
147 .ordered_index
148 .range((std::ops::Bound::Included(key), std::ops::Bound::Unbounded))
149 .map(|(_, ids)| ids)
150 {
151 out.extend(ids.iter().copied());
152 }
153 }
154 MetadataRangeOp::Lt => {
155 for ids in self
156 .ordered_index
157 .range((std::ops::Bound::Unbounded, std::ops::Bound::Excluded(key)))
158 .map(|(_, ids)| ids)
159 {
160 out.extend(ids.iter().copied());
161 }
162 }
163 MetadataRangeOp::Lte => {
164 for ids in self
165 .ordered_index
166 .range((std::ops::Bound::Unbounded, std::ops::Bound::Included(key)))
167 .map(|(_, ids)| ids)
168 {
169 out.extend(ids.iter().copied());
170 }
171 }
172 }
173 Some(out)
174 }
175}
176
177#[derive(Debug, Clone, Copy)]
178enum MetadataRangeOp {
179 Gt,
180 Gte,
181 Lt,
182 Lte,
183}
184
185pub struct MetadataStore {
187 entries: HashMap<NodeId, MetadataEntry>,
189 indexes: HashMap<String, KeyIndex>,
191}
192
193impl MetadataStore {
194 pub fn new() -> Self {
196 Self {
197 entries: HashMap::new(),
198 indexes: HashMap::new(),
199 }
200 }
201
202 pub fn len(&self) -> usize {
204 self.entries.len()
205 }
206
207 pub fn is_empty(&self) -> bool {
209 self.entries.is_empty()
210 }
211
212 pub fn insert(&mut self, id: NodeId, entry: MetadataEntry) {
214 if let Some(old_entry) = self.entries.get(&id) {
216 for key in old_entry.keys() {
217 if let Some(value) = old_entry.get(&key) {
218 if let Some(index) = self.indexes.get_mut(&key) {
219 index.remove(id, &value);
220 }
221 }
222 }
223 }
224
225 for key in entry.keys() {
227 if let Some(value) = entry.get(&key) {
228 self.indexes
229 .entry(key.clone())
230 .or_default()
231 .insert(id, &value);
232 }
233 }
234
235 self.entries.insert(id, entry);
236 }
237
238 pub fn get(&self, id: NodeId) -> Option<&MetadataEntry> {
240 self.entries.get(&id)
241 }
242
243 pub fn remove(&mut self, id: NodeId) -> Option<MetadataEntry> {
245 if let Some(entry) = self.entries.remove(&id) {
246 for key in entry.keys() {
247 if let Some(value) = entry.get(&key) {
248 if let Some(index) = self.indexes.get_mut(&key) {
249 index.remove(id, &value);
250 }
251 }
252 }
253 Some(entry)
254 } else {
255 None
256 }
257 }
258
259 pub fn filter(&self, filter: &MetadataFilter) -> HashSet<NodeId> {
261 self.filter_internal(filter)
262 }
263
264 fn filter_internal(&self, filter: &MetadataFilter) -> HashSet<NodeId> {
265 match filter {
266 MetadataFilter::Eq(key, value) => self
267 .indexes
268 .get(key)
269 .and_then(|idx| idx.exact_match_ids(value))
270 .unwrap_or_else(|| {
271 self.entries
272 .iter()
273 .filter(|(_, entry)| {
274 entry
275 .get(key)
276 .map(|candidate| candidate.matches_eq(value))
277 .unwrap_or(false)
278 })
279 .map(|(id, _)| *id)
280 .collect()
281 }),
282 MetadataFilter::Ne(key, value) => {
283 let all: HashSet<_> = self.entries.keys().copied().collect();
284 if let Some(index) = self.indexes.get(key) {
285 if let Some(exact) = index.exact_match_ids(value) {
286 return all.difference(&exact).copied().collect();
287 }
288 }
289 self.entries
290 .iter()
291 .filter(|(_, entry)| {
292 entry
293 .get(key)
294 .map(|candidate| !candidate.matches_eq(value))
295 .unwrap_or(true)
296 })
297 .map(|(id, _)| *id)
298 .collect()
299 }
300 MetadataFilter::Gt(key, value) => self
301 .indexes
302 .get(key)
303 .and_then(|idx| idx.range_match_ids(value, MetadataRangeOp::Gt))
304 .unwrap_or_else(|| {
305 self.entries
306 .iter()
307 .filter(|(_, entry)| {
308 entry
309 .get(key)
310 .and_then(|candidate| candidate.compare(value))
311 .map(|ord| ord == std::cmp::Ordering::Greater)
312 .unwrap_or(false)
313 })
314 .map(|(id, _)| *id)
315 .collect()
316 }),
317 MetadataFilter::Gte(key, value) => self
318 .indexes
319 .get(key)
320 .and_then(|idx| idx.range_match_ids(value, MetadataRangeOp::Gte))
321 .unwrap_or_else(|| {
322 self.entries
323 .iter()
324 .filter(|(_, entry)| {
325 entry
326 .get(key)
327 .and_then(|candidate| candidate.compare(value))
328 .map(|ord| ord != std::cmp::Ordering::Less)
329 .unwrap_or(false)
330 })
331 .map(|(id, _)| *id)
332 .collect()
333 }),
334 MetadataFilter::Lt(key, value) => self
335 .indexes
336 .get(key)
337 .and_then(|idx| idx.range_match_ids(value, MetadataRangeOp::Lt))
338 .unwrap_or_else(|| {
339 self.entries
340 .iter()
341 .filter(|(_, entry)| {
342 entry
343 .get(key)
344 .and_then(|candidate| candidate.compare(value))
345 .map(|ord| ord == std::cmp::Ordering::Less)
346 .unwrap_or(false)
347 })
348 .map(|(id, _)| *id)
349 .collect()
350 }),
351 MetadataFilter::Lte(key, value) => self
352 .indexes
353 .get(key)
354 .and_then(|idx| idx.range_match_ids(value, MetadataRangeOp::Lte))
355 .unwrap_or_else(|| {
356 self.entries
357 .iter()
358 .filter(|(_, entry)| {
359 entry
360 .get(key)
361 .and_then(|candidate| candidate.compare(value))
362 .map(|ord| ord != std::cmp::Ordering::Greater)
363 .unwrap_or(false)
364 })
365 .map(|(id, _)| *id)
366 .collect()
367 }),
368 MetadataFilter::In(key, values) => {
369 if let Some(index) = self.indexes.get(key) {
370 if let Some(result) =
371 values.iter().try_fold(HashSet::new(), |mut acc, value| {
372 let ids = index.exact_match_ids(value)?;
373 acc.extend(ids);
374 Some(acc)
375 })
376 {
377 return result;
378 }
379 }
380 self.entries
381 .iter()
382 .filter(|(_, entry)| {
383 entry
384 .get(key)
385 .map(|candidate| values.iter().any(|value| candidate.matches_eq(value)))
386 .unwrap_or(false)
387 })
388 .map(|(id, _)| *id)
389 .collect()
390 }
391 MetadataFilter::NotIn(key, values) => {
392 let all: HashSet<_> = self.entries.keys().copied().collect();
393 if let Some(index) = self.indexes.get(key) {
394 if let Some(matched) =
395 values.iter().try_fold(HashSet::new(), |mut acc, value| {
396 let ids = index.exact_match_ids(value)?;
397 acc.extend(ids);
398 Some(acc)
399 })
400 {
401 return all.difference(&matched).copied().collect();
402 }
403 }
404 self.entries
405 .iter()
406 .filter(|(_, entry)| {
407 entry
408 .get(key)
409 .map(|candidate| {
410 !values.iter().any(|value| candidate.matches_eq(value))
411 })
412 .unwrap_or(true)
413 })
414 .map(|(id, _)| *id)
415 .collect()
416 }
417 MetadataFilter::Exists(key) => self
418 .indexes
419 .get(key)
420 .map(|idx| idx.all_ids.clone())
421 .unwrap_or_default(),
422 MetadataFilter::And(filters) => {
423 if filters.is_empty() {
424 return self.entries.keys().copied().collect();
425 }
426 let mut result = self.filter_internal(&filters[0]);
427 for filter in &filters[1..] {
428 let other = self.filter_internal(filter);
429 result = result.intersection(&other).copied().collect();
430 }
431 result
432 }
433 MetadataFilter::Or(filters) => {
434 let mut result = HashSet::new();
435 for filter in filters {
436 result.extend(self.filter_internal(filter));
437 }
438 result
439 }
440 MetadataFilter::Not(inner) => {
441 let all: HashSet<_> = self.entries.keys().copied().collect();
442 let matched = self.filter_internal(inner);
443 all.difference(&matched).copied().collect()
444 }
445 _ => self
447 .entries
448 .iter()
449 .filter(|(_, entry)| filter.matches(entry))
450 .map(|(id, _)| *id)
451 .collect(),
452 }
453 }
454}
455
456impl Default for MetadataStore {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_metadata_entry() {
468 let mut entry = MetadataEntry::new();
469 entry.insert("name", MetadataValue::String("test".to_string()));
470 entry.insert("count", MetadataValue::Integer(42));
471 entry.insert("score", MetadataValue::Float(2.5));
472 entry.insert("active", MetadataValue::Bool(true));
473
474 assert_eq!(
475 entry.get("name"),
476 Some(MetadataValue::String("test".to_string()))
477 );
478 assert_eq!(entry.get("count"), Some(MetadataValue::Integer(42)));
479 assert!(entry.get("score").is_some());
480 assert_eq!(entry.get("active"), Some(MetadataValue::Bool(true)));
481 assert!(entry.get("nonexistent").is_none());
482 }
483
484 #[test]
485 fn test_filter_eq() {
486 let mut store = MetadataStore::new();
487
488 let mut entry1 = MetadataEntry::new();
489 entry1.insert("type", MetadataValue::String("host".to_string()));
490
491 let mut entry2 = MetadataEntry::new();
492 entry2.insert("type", MetadataValue::String("service".to_string()));
493
494 store.insert(1, entry1);
495 store.insert(2, entry2);
496
497 let filter = MetadataFilter::eq("type", "host");
498 let results = store.filter(&filter);
499
500 assert_eq!(results.len(), 1);
501 assert!(results.contains(&1));
502 }
503
504 #[test]
505 fn test_filter_comparison() {
506 let mut store = MetadataStore::new();
507
508 for i in 0..10 {
509 let mut entry = MetadataEntry::new();
510 entry.insert("score", MetadataValue::Integer(i));
511 store.insert(i as u64, entry);
512 }
513
514 let filter = MetadataFilter::gt("score", MetadataValue::Integer(5));
516 let results = store.filter(&filter);
517 assert_eq!(results.len(), 4); let filter = MetadataFilter::gte("score", MetadataValue::Integer(5));
521 let results = store.filter(&filter);
522 assert_eq!(results.len(), 5); let filter = MetadataFilter::lt("score", MetadataValue::Integer(3));
526 let results = store.filter(&filter);
527 assert_eq!(results.len(), 3); }
529
530 #[test]
531 fn test_filter_and() {
532 let mut store = MetadataStore::new();
533
534 let mut entry1 = MetadataEntry::new();
535 entry1.insert("type", MetadataValue::String("host".to_string()));
536 entry1.insert("active", MetadataValue::Bool(true));
537
538 let mut entry2 = MetadataEntry::new();
539 entry2.insert("type", MetadataValue::String("host".to_string()));
540 entry2.insert("active", MetadataValue::Bool(false));
541
542 let mut entry3 = MetadataEntry::new();
543 entry3.insert("type", MetadataValue::String("service".to_string()));
544 entry3.insert("active", MetadataValue::Bool(true));
545
546 store.insert(1, entry1);
547 store.insert(2, entry2);
548 store.insert(3, entry3);
549
550 let filter = MetadataFilter::and(vec![
551 MetadataFilter::eq("type", "host"),
552 MetadataFilter::eq("active", true),
553 ]);
554 let results = store.filter(&filter);
555
556 assert_eq!(results.len(), 1);
557 assert!(results.contains(&1));
558 }
559
560 #[test]
561 fn test_filter_or() {
562 let mut store = MetadataStore::new();
563
564 let mut entry1 = MetadataEntry::new();
565 entry1.insert("type", MetadataValue::String("host".to_string()));
566
567 let mut entry2 = MetadataEntry::new();
568 entry2.insert("type", MetadataValue::String("service".to_string()));
569
570 let mut entry3 = MetadataEntry::new();
571 entry3.insert("type", MetadataValue::String("network".to_string()));
572
573 store.insert(1, entry1);
574 store.insert(2, entry2);
575 store.insert(3, entry3);
576
577 let filter = MetadataFilter::or(vec![
578 MetadataFilter::eq("type", "host"),
579 MetadataFilter::eq("type", "service"),
580 ]);
581 let results = store.filter(&filter);
582
583 assert_eq!(results.len(), 2);
584 assert!(results.contains(&1));
585 assert!(results.contains(&2));
586 }
587
588 #[test]
589 fn test_filter_contains() {
590 let mut store = MetadataStore::new();
591
592 let mut entry1 = MetadataEntry::new();
593 entry1.insert(
594 "description",
595 MetadataValue::String("SSH vulnerability".to_string()),
596 );
597
598 let mut entry2 = MetadataEntry::new();
599 entry2.insert(
600 "description",
601 MetadataValue::String("HTTP server".to_string()),
602 );
603
604 store.insert(1, entry1);
605 store.insert(2, entry2);
606
607 let filter =
608 MetadataFilter::Contains("description".to_string(), "vulnerability".to_string());
609 let results = store.filter(&filter);
610
611 assert_eq!(results.len(), 1);
612 assert!(results.contains(&1));
613 }
614
615 #[test]
616 fn test_filter_in() {
617 let mut store = MetadataStore::new();
618
619 for i in 0..5 {
620 let mut entry = MetadataEntry::new();
621 entry.insert("id", MetadataValue::Integer(i));
622 store.insert(i as u64, entry);
623 }
624
625 let filter = MetadataFilter::In(
626 "id".to_string(),
627 vec![MetadataValue::Integer(1), MetadataValue::Integer(3)],
628 );
629 let results = store.filter(&filter);
630
631 assert_eq!(results.len(), 2);
632 assert!(results.contains(&1));
633 assert!(results.contains(&3));
634 }
635
636 #[test]
637 fn test_remove_updates_index() {
638 let mut store = MetadataStore::new();
639
640 let mut entry = MetadataEntry::new();
641 entry.insert("type", MetadataValue::String("host".to_string()));
642 store.insert(1, entry);
643
644 assert_eq!(store.filter(&MetadataFilter::eq("type", "host")).len(), 1);
645
646 store.remove(1);
647
648 assert_eq!(store.filter(&MetadataFilter::eq("type", "host")).len(), 0);
649 }
650
651 #[test]
652 fn test_filter_float_eq_uses_canonical_index() {
653 let mut store = MetadataStore::new();
654
655 let mut entry1 = MetadataEntry::new();
656 entry1.insert("score", MetadataValue::Float(1.5));
657 store.insert(1, entry1);
658
659 let mut entry2 = MetadataEntry::new();
660 entry2.insert("score", MetadataValue::Float(2.5));
661 store.insert(2, entry2);
662
663 let results = store.filter(&MetadataFilter::eq("score", MetadataValue::Float(2.5)));
664 assert_eq!(results, HashSet::from([2]));
665 }
666
667 #[test]
668 fn test_filter_string_range_uses_ordered_index() {
669 let mut store = MetadataStore::new();
670
671 for (id, tier) in [(1, "alpha"), (2, "bravo"), (3, "delta")] {
672 let mut entry = MetadataEntry::new();
673 entry.insert("tier", MetadataValue::String(tier.to_string()));
674 store.insert(id, entry);
675 }
676
677 let results = store.filter(&MetadataFilter::gte(
678 "tier",
679 MetadataValue::String("bravo".to_string()),
680 ));
681 assert_eq!(results, HashSet::from([2, 3]));
682 }
683}