1use bit_set::BitSet;
7use diskann_label_filter::attribute::AttributeValue;
8use diskann_label_filter::parser::format::Document;
9use diskann_label_filter::utils::flatten_utils::{
10 flatten_json_pointers_with_config, FlattenConfig,
11};
12use diskann_label_filter::{ASTExpr, CompareOp};
13use rayon::prelude::*;
14use std::any::Any;
15use std::cmp::Ordering;
16use std::collections::BTreeMap;
17use std::collections::HashMap;
18use std::mem::discriminant;
19use std::ops::Bound::{Excluded, Included, Unbounded};
20
21struct NotNonNan;
26
27impl std::fmt::Display for NotNonNan {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "NotNonNan")
30 }
31}
32
33#[derive(Debug, Copy, Clone, PartialEq)]
34struct OrderedFloat(f64);
35
36impl OrderedFloat {
37 pub fn new(v: f64) -> Result<Self, NotNonNan> {
38 if v.is_nan() {
39 Err(NotNonNan)
40 } else {
41 Ok(Self(v))
42 }
43 }
44}
45
46impl Eq for OrderedFloat {}
47impl PartialOrd for OrderedFloat {
48 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
49 Some(self.cmp(other))
50 }
51}
52
53impl Ord for OrderedFloat {
54 fn cmp(&self, other: &Self) -> Ordering {
55 self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
58 }
59}
60
61trait QueryAccelerator: Send + Sync {
62 fn eval(&self, op: &CompareOp) -> Result<BitSet, anyhow::Error>;
63
64 fn universe(&self) -> BitSet;
65
66 #[allow(dead_code)]
68 fn as_any(&self) -> &dyn Any;
69}
70
71struct InvertedIndexAccelerator {
72 map: HashMap<AttributeValue, BitSet>,
73}
74
75impl QueryAccelerator for InvertedIndexAccelerator {
76 fn as_any(&self) -> &dyn Any {
77 self
78 }
79
80 fn universe(&self) -> BitSet {
81 let mut result = BitSet::new();
82 for (_, bits) in self.map.iter() {
83 result.extend(bits);
84 }
85 result
86 }
87
88 fn eval(&self, op: &CompareOp) -> Result<BitSet, anyhow::Error> {
89 match op {
90 CompareOp::Eq(v) => {
91 let attr_val = AttributeValue::try_from(v)
92 .map_err(|e| anyhow::anyhow!("Failed to convert value for Eq: {e}"))?;
93 Ok(self.map.get(&attr_val).cloned().unwrap_or_default())
94 }
95 CompareOp::Ne(v) => {
96 let attr_val = AttributeValue::try_from(v)
97 .map_err(|e| anyhow::anyhow!("Failed to convert value for Ne: {e}"))?;
98 let mut result = BitSet::new();
99 for (val, bits) in self.map.iter() {
100 if val != &attr_val {
101 result.extend(bits);
102 }
103 }
104 Ok(result)
105 }
106 _ => Err(anyhow::anyhow!(
107 "Only equality comparisons are supported with the inverted index accelerator"
108 )),
109 }
110 }
111}
112
113struct BTreeAccelerator {
114 map: BTreeMap<OrderedFloat, Vec<usize>>,
115}
116
117impl QueryAccelerator for BTreeAccelerator {
118 fn as_any(&self) -> &dyn Any {
119 self
120 }
121
122 fn universe(&self) -> BitSet {
123 let mut result = BitSet::new();
124 for (_, ids) in self.map.iter() {
125 result.extend(ids.iter().cloned());
126 }
127 result
128 }
129
130 fn eval(&self, op: &CompareOp) -> Result<BitSet, anyhow::Error> {
131 match op {
132 CompareOp::Eq(v) => {
133 let fval = v
134 .as_f64()
135 .ok_or_else(|| anyhow::anyhow!("Failed to convert value to f64 for Eq"))?;
136 let fval = OrderedFloat::new(fval)
137 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
138 if let Some(ids) = self.map.get(&fval) {
139 Ok(insert_into_bitset(ids.to_vec()))
140 } else {
141 Ok(BitSet::new())
142 }
143 }
144 CompareOp::Ne(v) => {
145 let fval = v
146 .as_f64()
147 .ok_or_else(|| anyhow::anyhow!("Failed to convert value to f64 for Ne"))?;
148 let fval = OrderedFloat::new(fval)
149 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
150 let mut bitset = BitSet::new();
151 for (val, ids) in self.map.iter() {
152 if val != &fval {
153 bitset.extend(ids.iter().cloned());
154 }
155 }
156 Ok(bitset)
157 }
158 CompareOp::Lt(num) => {
159 let fval = OrderedFloat::new(*num)
160 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
161 let iter = self.map.range((Unbounded, Excluded(fval)));
162 Ok(insert_into_bitset(
163 iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
164 ))
165 }
166 CompareOp::Lte(num) => {
167 let fval = OrderedFloat::new(*num)
168 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
169 let iter = self.map.range((Unbounded, Included(fval)));
170 Ok(insert_into_bitset(
171 iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
172 ))
173 }
174 CompareOp::Gt(num) => {
175 let fval = OrderedFloat::new(*num)
176 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
177 let iter = self.map.range((Excluded(fval), Unbounded));
178 Ok(insert_into_bitset(
179 iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
180 ))
181 }
182 CompareOp::Gte(num) => {
183 let fval = OrderedFloat::new(*num)
184 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
185 let iter = self.map.range((Included(fval), Unbounded));
186 Ok(insert_into_bitset(
187 iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
188 ))
189 }
190 }
191 }
192}
193
194fn prepend_separator(field: &str) -> String {
196 let separator = FlattenConfig::dot_notation().separator;
197 if !field.starts_with(&separator) {
198 format!("{}{}", separator, field)
199 } else {
200 field.to_string()
201 }
202}
203
204fn compute_label_set(expr: &ASTExpr) -> Vec<String> {
206 match expr {
207 ASTExpr::Not(sub) => compute_label_set(sub),
208 ASTExpr::And(subs) => subs.iter().flat_map(compute_label_set).collect(),
209 ASTExpr::Or(subs) => subs.iter().flat_map(compute_label_set).collect(),
210 ASTExpr::Compare { field, .. } => vec![field.clone()],
211 }
212}
213
214fn compute_universe(
216 universe_labels: Vec<String>,
217 query_accelerators: &HashMap<String, Box<dyn QueryAccelerator>>,
218) -> BitSet {
219 let mut universe_iter = universe_labels.iter();
220 let mut universe = if let Some(first_label) = universe_iter.next() {
222 if let Some(accelerator) = query_accelerators.get(first_label) {
223 accelerator.universe()
224 } else {
225 BitSet::new()
226 }
227 } else {
228 BitSet::new()
229 };
230 for label in universe_iter {
231 if let Some(accelerator) = query_accelerators.get(label) {
232 universe = universe.intersection(&accelerator.universe()).collect();
233 }
234 }
235 universe
236}
237
238fn insert_into_bitset(ids: Vec<usize>) -> BitSet {
239 let mut bitset = BitSet::new();
240 bitset.extend(ids);
241 bitset
242}
243
244fn eval_query_using_accelerators(
245 query_expr: &ASTExpr,
246 query_accelerators: &HashMap<String, Box<dyn QueryAccelerator>>,
247) -> Result<BitSet, anyhow::Error> {
248 match query_expr {
249 ASTExpr::And(subs) => {
250 let mut acc: Option<BitSet> = None;
251 for e in subs {
252 let b = eval_query_using_accelerators(e, query_accelerators)?;
253 acc = Some(match acc {
254 None => b,
255 Some(acc_b) => acc_b.intersection(&b).collect(),
256 });
257 }
258 Ok(acc.unwrap_or_else(BitSet::new))
259 }
260 ASTExpr::Or(subs) => {
261 let mut acc: Option<BitSet> = None;
262 for e in subs {
263 let b = eval_query_using_accelerators(e, query_accelerators)?;
264 acc = Some(match acc {
265 None => b,
266 Some(acc_b) => acc_b.union(&b).collect(),
267 });
268 }
269 Ok(acc.unwrap_or_else(BitSet::new))
270 }
271 ASTExpr::Not(sub) => {
272 let universe_labels_raw = compute_label_set(query_expr);
274 let universe_labels: Vec<String> = universe_labels_raw
275 .iter()
276 .map(|f| prepend_separator(f))
277 .collect();
278 let universe = compute_universe(universe_labels, query_accelerators);
279
280 let sub_result = eval_query_using_accelerators(sub, query_accelerators)?;
282
283 Ok(universe.difference(&sub_result).collect())
285 }
286 ASTExpr::Compare { field, op } => {
287 let field = prepend_separator(field);
288 if let Some(accelerator) = query_accelerators.get(&field) {
289 accelerator.eval(op)
290 } else {
291 Ok(BitSet::new())
292 }
293 }
294 }
295}
296
297fn compute_inverted_index_accelerator(
298 key: &str,
299 doc_ids: &[usize],
300 labels: &[HashMap<String, AttributeValue>],
301) -> Result<HashMap<AttributeValue, BitSet>, anyhow::Error> {
302 let mut inverted_index: HashMap<AttributeValue, BitSet> = HashMap::new();
303 for (doc_id, label) in doc_ids.iter().zip(labels.iter()) {
304 if let Some(value) = label.get(key) {
305 inverted_index
306 .entry(value.clone())
307 .or_insert_with(BitSet::new)
308 .insert(*doc_id);
309 }
310 }
311 Ok(inverted_index)
312}
313
314fn compute_btree_accelerator(
315 key: &str,
316 labels: &[HashMap<String, AttributeValue>],
317 doc_ids: &[usize],
318) -> Result<BTreeMap<OrderedFloat, Vec<usize>>, anyhow::Error> {
319 let mut map: BTreeMap<OrderedFloat, Vec<usize>> = BTreeMap::new();
321 for (label, doc_id) in labels.iter().zip(doc_ids.iter().copied()) {
322 if let Some(value) = label.get(key) {
323 if let Some(f64_value) = value.as_float() {
324 let f64_value = OrderedFloat::new(f64_value)
325 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
326 map.entry(f64_value).or_default().push(doc_id);
327 } else if let Some(i64_value) = value.as_integer() {
328 let f = i64_value as f64;
330 if f as i64 != i64_value {
331 return Err(anyhow::anyhow!(
332 "i64 value cannot be exactly represented as f64: {}",
333 i64_value
334 ));
335 }
336 let i64_value = OrderedFloat::new(f)
337 .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
338 map.entry(i64_value).or_default().push(doc_id);
339 } else {
340 return Err(anyhow::anyhow!(
342 "Unsupported attribute value for key: {}",
343 key
344 ));
345 }
346 }
347 }
348 Ok(map)
349}
350
351fn compute_global_label_set(
354 flattened_base_labels: &Vec<HashMap<std::string::String, AttributeValue>>,
355) -> Result<HashMap<String, AttributeValue>, anyhow::Error> {
356 let mut global_label_set = HashMap::new();
357 for labels in flattened_base_labels {
358 for (key, value) in labels {
359 if let Some(existing_value) = global_label_set.get(key) {
360 if discriminant(existing_value) != discriminant(value) {
361 return Err(anyhow::anyhow!("Inconsistent types for key: {}", key));
362 }
363 }
364 global_label_set.insert(key.clone(), value.clone());
365 }
366 }
367 Ok(global_label_set)
368}
369
370fn compute_query_accelerator(
371 key: &str,
372 value: &AttributeValue,
373 doc_ids: &[usize],
374 flattened_base_labels: &[HashMap<String, AttributeValue>],
375) -> Result<Box<dyn QueryAccelerator>, anyhow::Error> {
376 match value {
377 AttributeValue::String(_) | AttributeValue::Bool(_) => {
378 let bitmap = compute_inverted_index_accelerator(key, doc_ids, flattened_base_labels)?;
379 Ok(Box::new(InvertedIndexAccelerator { map: bitmap }))
380 }
381 AttributeValue::Integer(_) | AttributeValue::Real(_) => {
382 let btree = compute_btree_accelerator(key, flattened_base_labels, doc_ids)?;
383 Ok(Box::new(BTreeAccelerator { map: btree }))
384 }
385 AttributeValue::Empty => Err(anyhow::anyhow!("Empty attribute value is not allowed")),
386 }
387}
388
389pub fn compute_query_bitmaps(
390 base_labels: Vec<Document>,
391 query_labels: Vec<(usize, ASTExpr)>,
392) -> Result<Vec<BitSet>, anyhow::Error> {
393 let flattened_base_labels: Vec<Vec<(std::string::String, AttributeValue)>> = base_labels
395 .iter()
396 .map(|base_label| {
397 flatten_json_pointers_with_config(&base_label.label, &FlattenConfig::dot_notation())
398 })
399 .collect();
400
401 let flattened_base_label_hashmaps: Result<Vec<HashMap<String, AttributeValue>>, anyhow::Error> =
402 flattened_base_labels
403 .iter()
404 .map(|labels| {
405 let mut map = HashMap::new();
406 for (key, value) in labels {
407 if let Some(_existing_value) = map.get(key) {
409 return Err(anyhow::anyhow!(
410 "Duplicate keys in the same document: {}",
411 key
412 ));
413 }
414 map.insert(key.clone(), value.clone());
415 }
416 Ok(map)
417 })
418 .collect();
419
420 let flattened_base_label_hashmaps = flattened_base_label_hashmaps?;
421 let base_doc_ids: Vec<usize> = base_labels
422 .iter()
423 .map(|base_label| base_label.doc_id)
424 .collect();
425
426 let global_label_set = compute_global_label_set(&flattened_base_label_hashmaps)?;
429
430 #[allow(clippy::disallowed_methods)]
432 let query_accelerators: HashMap<String, Box<dyn QueryAccelerator>> = global_label_set
433 .par_iter()
434 .map(|(key, value)| {
435 compute_query_accelerator(key, value, &base_doc_ids, &flattened_base_label_hashmaps)
436 .map(|accel| (key.clone(), accel))
437 })
438 .collect::<Result<_, _>>()?;
439
440 #[allow(clippy::disallowed_methods)]
442 let query_bitmaps: Result<Vec<BitSet>, anyhow::Error> = query_labels
443 .par_iter()
444 .map(|(_query_id, query_expr)| {
445 eval_query_using_accelerators(query_expr, &query_accelerators)
446 })
447 .collect();
448
449 let query_bitmaps = query_bitmaps?;
450
451 Ok(query_bitmaps)
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use diskann_label_filter::attribute::AttributeValue;
458 use diskann_label_filter::parser::format::Document;
459 use diskann_label_filter::{ASTExpr, CompareOp};
460 use serde_json::json;
461 use std::collections::HashMap;
462
463 #[test]
464 fn test_compute_query_bitmap_not_with_missing_field() {
465 let base_labels = vec![
467 Document {
468 doc_id: 0,
469 label: json!({"color": "red"}),
470 },
471 Document {
472 doc_id: 1,
473 label: json!({"color": "blue"}),
474 },
475 Document {
476 doc_id: 2,
477 label: json!({"shape": "circle"}), },
479 ];
480
481 let not_query = ASTExpr::Not(Box::new(ASTExpr::Compare {
483 field: "color".to_string(),
484 op: CompareOp::Eq(json!("red")),
485 }));
486 let queries = vec![(0, not_query)];
487 let bitmaps = compute_query_bitmaps(base_labels.clone(), queries).expect("Should succeed");
488 assert!(bitmaps[0].contains(1));
490 assert!(!bitmaps[0].contains(0));
491 assert!(!bitmaps[0].contains(2));
493 }
494
495 #[test]
496 fn test_compute_universe_function() {
497 let query_accelerators: HashMap<String, Box<dyn QueryAccelerator>> = HashMap::new();
499 let universe_labels = vec!["missing_label".to_string()];
500 let result = compute_universe(universe_labels, &query_accelerators);
501 assert!(
502 result.is_empty(),
503 "Universe should be empty if label is missing"
504 );
505
506 let mut inv_map = HashMap::new();
509 inv_map.insert(
510 AttributeValue::String("a".to_string()),
511 [1, 2].iter().cloned().collect(),
512 );
513 let inv_accel = Box::new(InvertedIndexAccelerator { map: inv_map });
514
515 let mut btree_map = BTreeMap::new();
517 btree_map.insert(OrderedFloat(1.0), vec![2, 3]);
518 let btree_accel = Box::new(BTreeAccelerator { map: btree_map });
519
520 let mut query_accelerators: HashMap<String, Box<dyn QueryAccelerator>> = HashMap::new();
521 query_accelerators.insert("foo".to_string(), inv_accel);
522 query_accelerators.insert("bar".to_string(), btree_accel);
523
524 let universe_labels = vec!["foo".to_string(), "bar".to_string()];
526 let result = compute_universe(universe_labels, &query_accelerators);
527 let expected: BitSet = [2].iter().cloned().collect();
528 assert_eq!(
529 result, expected,
530 "Universe should be the intersection of both accelerator universes"
531 );
532 }
533
534 #[test]
535 fn test_compute_label_set() {
536 let expr_or = ASTExpr::Or(vec![
538 ASTExpr::Compare {
539 field: "foo".to_string(),
540 op: CompareOp::Eq(json!(1)),
541 },
542 ASTExpr::Compare {
543 field: "bar".to_string(),
544 op: CompareOp::Eq(json!(2)),
545 },
546 ]);
547 let mut result_or = compute_label_set(&expr_or);
548 result_or.sort();
549 assert_eq!(result_or, vec!["bar".to_string(), "foo".to_string()]);
550
551 let expr_not = ASTExpr::Not(Box::new(ASTExpr::Compare {
553 field: "baz".to_string(),
554 op: CompareOp::Eq(json!(3)),
555 }));
556 let result_not = compute_label_set(&expr_not);
557 assert_eq!(result_not, vec!["baz".to_string()]);
558 }
559
560 #[test]
561 fn test_compute_query_bitmap_duplicate_key_in_doc() {
562 let base_labels = vec![Document {
566 doc_id: 0,
567 label: json!({"color": {"color": "red"}, "color.color": "blue"}),
568 }];
569 let query = ASTExpr::Compare {
571 field: "color".to_string(),
572 op: CompareOp::Eq(json!("red")),
573 };
574 let result = compute_query_bitmaps(base_labels.clone(), vec![(0, query)]);
575 assert!(
576 result.is_err(),
577 "Should error on duplicate keys in the same document"
578 );
579 }
580
581 #[test]
582 fn test_compute_query_bitmap_inconsistent_types() {
583 let base_labels = vec![
585 Document {
586 doc_id: 0,
587 label: json!({"foo": "bar"}),
588 },
589 Document {
590 doc_id: 1,
591 label: json!({"foo": 123}),
592 },
593 ];
594 let query = ASTExpr::Compare {
596 field: "foo".to_string(),
597 op: CompareOp::Eq(json!("bar")),
598 };
599 let result = compute_query_bitmaps(base_labels.clone(), vec![(0, query)]);
600 assert!(result.is_err(), "Should error on inconsistent value types");
601 }
602
603 #[test]
604 fn test_compute_query_bitmap_missing_field() {
605 let base_labels = vec![
607 Document {
608 doc_id: 0,
609 label: json!({"weight": 30}), },
611 Document {
612 doc_id: 1,
613 label: json!({"color": "red", "weight": 10}),
614 },
615 Document {
616 doc_id: 2,
617 label: json!({"color": "blue", "weight": 20}),
618 },
619 ];
620
621 let query_color = ASTExpr::Compare {
623 field: "color".to_string(),
624 op: CompareOp::Eq(json!("red")),
625 };
626 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_color)])
627 .expect("should succeed");
628 assert!(!bitmaps[0].contains(0));
629 assert!(bitmaps[0].contains(1));
630 assert!(!bitmaps[0].contains(2));
631
632 let query_weight = ASTExpr::Compare {
634 field: "weight".to_string(),
635 op: CompareOp::Gte(20.0),
636 };
637 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_weight)])
638 .expect("should succeed");
639 assert!(!bitmaps[0].contains(1));
640 assert!(bitmaps[0].contains(2));
641 assert!(bitmaps[0].contains(0));
642 }
643
644 #[test]
645 fn test_compute_query_bitmap_nested_value() {
646 let base_labels = vec![
648 Document {
649 doc_id: 0,
650 label: json!({"car": {"color": "red"}}),
651 },
652 Document {
653 doc_id: 1,
654 label: json!({"car": {"color": "blue"}}),
655 },
656 ];
657
658 let query_eq = ASTExpr::Compare {
660 field: "car.color".to_string(),
661 op: CompareOp::Eq(json!("red")),
662 };
663 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_eq)])
664 .expect("should succeed");
665 assert!(bitmaps[0].contains(0));
666 assert!(!bitmaps[0].contains(1));
667
668 let query_not = ASTExpr::Not(Box::new(ASTExpr::Compare {
670 field: ".car.color".to_string(),
671 op: CompareOp::Eq(json!("red")),
672 }));
673 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_not)])
674 .expect("should succeed");
675 assert!(bitmaps[0].contains(1));
676 assert!(!bitmaps[0].contains(0));
677 }
678
679 #[test]
680 fn test_compute_query_bitmap_floats() {
681 let base_labels = vec![
682 Document {
683 doc_id: 0,
684 label: json!({"score": 1.5}),
685 },
686 Document {
687 doc_id: 1,
688 label: json!({"score": 2.0}),
689 },
690 Document {
691 doc_id: 2,
692 label: json!({"score": 3.5}),
693 },
694 ];
695
696 let query_lt = ASTExpr::Compare {
698 field: "score".to_string(),
699 op: CompareOp::Lt(2.0),
700 };
701 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lt)])
702 .expect("should succeed");
703 assert!(bitmaps[0].contains(0));
704 assert!(!bitmaps[0].contains(1));
705 assert!(!bitmaps[0].contains(2));
706
707 let query_gt = ASTExpr::Compare {
709 field: "score".to_string(),
710 op: CompareOp::Gt(2.0),
711 };
712 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gt)])
713 .expect("should succeed");
714 assert!(bitmaps[0].contains(2));
715 assert!(!bitmaps[0].contains(0));
716 assert!(!bitmaps[0].contains(1));
717
718 let query_lte = ASTExpr::Compare {
720 field: "score".to_string(),
721 op: CompareOp::Lte(2.0),
722 };
723 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lte)])
724 .expect("should succeed");
725 assert!(bitmaps[0].contains(0));
726 assert!(bitmaps[0].contains(1));
727 assert!(!bitmaps[0].contains(2));
728
729 let query_gte = ASTExpr::Compare {
731 field: "score".to_string(),
732 op: CompareOp::Gte(2.0),
733 };
734 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gte)])
735 .expect("should succeed");
736 assert!(bitmaps[0].contains(1));
737 assert!(bitmaps[0].contains(2));
738 assert!(!bitmaps[0].contains(0));
739
740 let query_range = ASTExpr::And(vec![
742 ASTExpr::Compare {
743 field: "score".to_string(),
744 op: CompareOp::Gte(2.0),
745 },
746 ASTExpr::Compare {
747 field: "score".to_string(),
748 op: CompareOp::Lte(3.5),
749 },
750 ]);
751 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_range)])
752 .expect("should succeed");
753 assert!(bitmaps[0].contains(1));
755 assert!(bitmaps[0].contains(2));
756 assert!(!bitmaps[0].contains(0));
757 }
758
759 #[test]
760 fn test_compute_query_bitmap_ints() {
761 let base_labels = vec![
762 Document {
763 doc_id: 0,
764 label: json!({"age": 10}),
765 },
766 Document {
767 doc_id: 1,
768 label: json!({"age": 20}),
769 },
770 Document {
771 doc_id: 2,
772 label: json!({"age": 30}),
773 },
774 ];
775
776 let query_lt = ASTExpr::Compare {
778 field: "age".to_string(),
779 op: CompareOp::Lt(20.0),
780 };
781 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lt)])
782 .expect("should succeed");
783 assert!(bitmaps[0].contains(0));
784 assert!(!bitmaps[0].contains(1));
785 assert!(!bitmaps[0].contains(2));
786
787 let query_gt = ASTExpr::Compare {
789 field: "age".to_string(),
790 op: CompareOp::Gt(20.0),
791 };
792 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gt)])
793 .expect("should succeed");
794 assert!(bitmaps[0].contains(2));
795 assert!(!bitmaps[0].contains(0));
796 assert!(!bitmaps[0].contains(1));
797
798 let query_lte = ASTExpr::Compare {
800 field: "age".to_string(),
801 op: CompareOp::Lte(20.0),
802 };
803 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lte)])
804 .expect("should succeed");
805 assert!(bitmaps[0].contains(0));
806 assert!(bitmaps[0].contains(1));
807 assert!(!bitmaps[0].contains(2));
808
809 let query_gte = ASTExpr::Compare {
811 field: "age".to_string(),
812 op: CompareOp::Gte(20.0),
813 };
814 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gte)])
815 .expect("should succeed");
816 assert!(bitmaps[0].contains(1));
817 assert!(bitmaps[0].contains(2));
818 assert!(!bitmaps[0].contains(0));
819
820 let query_range = ASTExpr::And(vec![
822 ASTExpr::Compare {
823 field: "age".to_string(),
824 op: CompareOp::Gte(20.0),
825 },
826 ASTExpr::Compare {
827 field: "age".to_string(),
828 op: CompareOp::Lte(30.0),
829 },
830 ]);
831 let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_range)])
832 .expect("should succeed");
833 assert!(bitmaps[0].contains(1));
835 assert!(bitmaps[0].contains(2));
836 assert!(!bitmaps[0].contains(0));
837 }
838
839 #[test]
840 fn test_compute_query_bitmap_ints_uses_document_ids_in_accelerator() {
841 let base_labels = vec![
842 Document {
843 doc_id: 10,
844 label: json!({"age": 10}),
845 },
846 Document {
847 doc_id: 20,
848 label: json!({"age": 20}),
849 },
850 Document {
851 doc_id: 30,
852 label: json!({"age": 30}),
853 },
854 ];
855
856 let query_gte = ASTExpr::Compare {
857 field: "age".to_string(),
858 op: CompareOp::Gte(20.0),
859 };
860 let bitmaps =
861 compute_query_bitmaps(base_labels, vec![(0, query_gte)]).expect("should succeed");
862
863 assert!(bitmaps[0].contains(20));
864 assert!(bitmaps[0].contains(30));
865 assert!(!bitmaps[0].contains(10));
866 assert!(!bitmaps[0].contains(0));
867 assert!(!bitmaps[0].contains(1));
868 assert!(!bitmaps[0].contains(2));
869 }
870
871 #[test]
872 fn test_compute_query_bitmap_bools() {
873 let base_labels = vec![
875 Document {
876 doc_id: 0,
877 label: json!({"flag": true}),
878 },
879 Document {
880 doc_id: 1,
881 label: json!({"flag": false}),
882 },
883 ];
884
885 let query = ASTExpr::Compare {
887 field: "flag".to_string(),
888 op: CompareOp::Eq(json!(true)),
889 };
890 let queries = vec![(0, query)];
891 let bitmaps = compute_query_bitmaps(base_labels.clone(), queries).expect("should succeed");
892 assert!(bitmaps[0].contains(0));
894 assert!(!bitmaps[0].contains(1));
895 }
896
897 #[test]
898 fn test_compute_query_bitmaps_mixed_labels() {
899 let base_labels = vec![
900 Document {
901 doc_id: 0,
902 label: json!({"color": "red", "size": 10}),
903 },
904 Document {
905 doc_id: 1,
906 label: json!({"color": "blue", "size": 20}),
907 },
908 Document {
909 doc_id: 2,
910 label: json!({"color": "red", "size": 20}),
911 },
912 ];
913
914 let query1 = ASTExpr::Compare {
916 field: "color".to_string(),
917 op: CompareOp::Eq(serde_json::Value::String("red".to_string())),
918 };
919 let query2 = ASTExpr::Compare {
921 field: "size".to_string(),
922 op: CompareOp::Eq(20.into()),
923 };
924 let query3 = ASTExpr::And(vec![
926 ASTExpr::Compare {
927 field: "color".to_string(),
928 op: CompareOp::Eq(serde_json::Value::String("red".to_string())),
929 },
930 ASTExpr::Compare {
931 field: "size".to_string(),
932 op: CompareOp::Eq(20.into()),
933 },
934 ]);
935 let query4 = ASTExpr::Or(vec![
937 ASTExpr::Compare {
938 field: "color".to_string(),
939 op: CompareOp::Eq(serde_json::Value::String("red".to_string())),
940 },
941 ASTExpr::Compare {
942 field: "size".to_string(),
943 op: CompareOp::Eq(10.into()),
944 },
945 ]);
946
947 let queries = vec![(0, query1), (1, query2), (2, query3), (3, query4)];
948
949 let bitmaps = compute_query_bitmaps(base_labels.clone(), queries).expect("should succeed");
950 assert!(bitmaps[0].contains(0));
952 assert!(bitmaps[0].contains(2));
953 assert!(!bitmaps[0].contains(1));
954 assert!(bitmaps[1].contains(1));
956 assert!(bitmaps[1].contains(2));
957 assert!(!bitmaps[1].contains(0));
958 assert!(bitmaps[2].contains(2));
960 assert!(!bitmaps[2].contains(0));
961 assert!(!bitmaps[2].contains(1));
962 assert!(bitmaps[3].contains(0));
964 assert!(bitmaps[3].contains(2));
965 assert!(!bitmaps[3].contains(1));
966
967 let not_query = ASTExpr::Not(Box::new(ASTExpr::Compare {
969 field: "color".to_string(),
970 op: CompareOp::Eq(serde_json::json!("red")),
971 }));
972 let queries_with_not = vec![(0, not_query)];
973 let bitmaps =
974 compute_query_bitmaps(base_labels.clone(), queries_with_not).expect("Should succeed");
975 assert!(bitmaps[0].contains(1));
977 assert!(!bitmaps[0].contains(0));
978 assert!(!bitmaps[0].contains(2));
979 }
980
981 #[test]
982 fn test_compute_query_accelerator() {
983 let mut doc1 = HashMap::new();
985 doc1.insert("foo".to_string(), AttributeValue::String("bar".to_string()));
986 doc1.insert("num".to_string(), AttributeValue::Integer(42));
987 doc1.insert("real".to_string(), AttributeValue::Real(3.13));
988 doc1.insert("flag".to_string(), AttributeValue::Bool(true));
989 let mut doc2 = HashMap::new();
990 doc2.insert("foo".to_string(), AttributeValue::String("baz".to_string()));
991 doc2.insert("num".to_string(), AttributeValue::Integer(7));
992 doc2.insert("real".to_string(), AttributeValue::Real(2.71));
993 doc2.insert("flag".to_string(), AttributeValue::Bool(false));
994 let base = vec![doc1, doc2];
995 let doc_ids = vec![10, 42];
996
997 let accel = compute_query_accelerator(
999 "foo",
1000 &AttributeValue::String("bar".to_string()),
1001 &doc_ids,
1002 &base,
1003 )
1004 .expect("Should succeed for String");
1005 let accel = accel
1006 .as_any()
1007 .downcast_ref::<InvertedIndexAccelerator>()
1008 .expect("Expected InvertedIndexAccelerator");
1009 assert!(accel
1010 .map
1011 .contains_key(&AttributeValue::String("bar".to_string())));
1012 assert!(accel
1013 .map
1014 .contains_key(&AttributeValue::String("baz".to_string())));
1015 assert_eq!(
1016 accel
1017 .map
1018 .get(&AttributeValue::String("bar".to_string()))
1019 .expect("bar key should exist")
1020 .iter()
1021 .collect::<Vec<_>>(),
1022 vec![10]
1023 );
1024 assert_eq!(
1025 accel
1026 .map
1027 .get(&AttributeValue::String("baz".to_string()))
1028 .expect("baz key should exist")
1029 .iter()
1030 .collect::<Vec<_>>(),
1031 vec![42]
1032 );
1033
1034 let accel = compute_query_accelerator("flag", &AttributeValue::Bool(true), &doc_ids, &base)
1036 .expect("Should succeed for Bool");
1037 let accel = accel
1038 .as_any()
1039 .downcast_ref::<InvertedIndexAccelerator>()
1040 .expect("Expected InvertedIndexAccelerator");
1041 assert!(accel.map.contains_key(&AttributeValue::Bool(true)));
1042 assert!(accel.map.contains_key(&AttributeValue::Bool(false)));
1043
1044 let accel = compute_query_accelerator("num", &AttributeValue::Integer(42), &doc_ids, &base)
1046 .expect("Should succeed for Integer");
1047 let accel = accel
1048 .as_any()
1049 .downcast_ref::<BTreeAccelerator>()
1050 .expect("Expected BTreeAccelerator");
1051 assert!(accel.map.contains_key(&super::OrderedFloat(42.0)));
1052 assert!(accel.map.contains_key(&super::OrderedFloat(7.0)));
1053
1054 let accel = compute_query_accelerator("real", &AttributeValue::Real(3.13), &doc_ids, &base)
1056 .expect("Should succeed for Real");
1057 let accel = accel
1058 .as_any()
1059 .downcast_ref::<BTreeAccelerator>()
1060 .expect("Expected BTreeAccelerator");
1061 assert!(accel.map.contains_key(&super::OrderedFloat(3.13)));
1062 assert!(accel.map.contains_key(&super::OrderedFloat(2.71)));
1063
1064 let err = compute_query_accelerator("none", &AttributeValue::Empty, &doc_ids, &base);
1066 assert!(err.is_err());
1067 }
1068}