Skip to main content

diskann_tools/utils/
compute_bitmap.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use bit_set::BitSet;
7use diskann_label_filter::attribute::AttributeValue;
8use diskann_label_filter::parser::format::Document;
9use diskann_label_filter::utils::flatten_utils::{
10    flatten_json_pointers_with_config, FlattenConfig,
11};
12use diskann_label_filter::{ASTExpr, CompareOp};
13use rayon::prelude::*;
14use std::any::Any;
15use std::cmp::Ordering;
16use std::collections::BTreeMap;
17use std::collections::HashMap;
18use std::mem::discriminant;
19use std::ops::Bound::{Excluded, Included, Unbounded};
20
21// In order to construct a B-Tree over floats, we need to create a total
22// ordering on the float values by excluding NaN values. This struct is
23// used to throw an error if a NaN value is encountered when constructing
24// the OrderedFloat type.
25struct NotNonNan;
26
27impl std::fmt::Display for NotNonNan {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "NotNonNan")
30    }
31}
32
33#[derive(Debug, Copy, Clone, PartialEq)]
34struct OrderedFloat(f64);
35
36impl OrderedFloat {
37    pub fn new(v: f64) -> Result<Self, NotNonNan> {
38        if v.is_nan() {
39            Err(NotNonNan)
40        } else {
41            Ok(Self(v))
42        }
43    }
44}
45
46impl Eq for OrderedFloat {}
47impl PartialOrd for OrderedFloat {
48    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
49        Some(self.cmp(other))
50    }
51}
52
53impl Ord for OrderedFloat {
54    fn cmp(&self, other: &Self) -> Ordering {
55        // By construction, we know the partial comparison will succeed.
56        // Return `Eq` if it doesn't for better code-gen.
57        self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
58    }
59}
60
61trait QueryAccelerator: Send + Sync {
62    fn eval(&self, op: &CompareOp) -> Result<BitSet, anyhow::Error>;
63
64    fn universe(&self) -> BitSet;
65
66    // method for testing
67    #[allow(dead_code)]
68    fn as_any(&self) -> &dyn Any;
69}
70
71struct InvertedIndexAccelerator {
72    map: HashMap<AttributeValue, BitSet>,
73}
74
75impl QueryAccelerator for InvertedIndexAccelerator {
76    fn as_any(&self) -> &dyn Any {
77        self
78    }
79
80    fn universe(&self) -> BitSet {
81        let mut result = BitSet::new();
82        for (_, bits) in self.map.iter() {
83            result.extend(bits);
84        }
85        result
86    }
87
88    fn eval(&self, op: &CompareOp) -> Result<BitSet, anyhow::Error> {
89        match op {
90            CompareOp::Eq(v) => {
91                let attr_val = AttributeValue::try_from(v)
92                    .map_err(|e| anyhow::anyhow!("Failed to convert value for Eq: {e}"))?;
93                Ok(self.map.get(&attr_val).cloned().unwrap_or_default())
94            }
95            CompareOp::Ne(v) => {
96                let attr_val = AttributeValue::try_from(v)
97                    .map_err(|e| anyhow::anyhow!("Failed to convert value for Ne: {e}"))?;
98                let mut result = BitSet::new();
99                for (val, bits) in self.map.iter() {
100                    if val != &attr_val {
101                        result.extend(bits);
102                    }
103                }
104                Ok(result)
105            }
106            _ => Err(anyhow::anyhow!(
107                "Only equality comparisons are supported with the inverted index accelerator"
108            )),
109        }
110    }
111}
112
113struct BTreeAccelerator {
114    map: BTreeMap<OrderedFloat, Vec<usize>>,
115}
116
117impl QueryAccelerator for BTreeAccelerator {
118    fn as_any(&self) -> &dyn Any {
119        self
120    }
121
122    fn universe(&self) -> BitSet {
123        let mut result = BitSet::new();
124        for (_, ids) in self.map.iter() {
125            result.extend(ids.iter().cloned());
126        }
127        result
128    }
129
130    fn eval(&self, op: &CompareOp) -> Result<BitSet, anyhow::Error> {
131        match op {
132            CompareOp::Eq(v) => {
133                let fval = v
134                    .as_f64()
135                    .ok_or_else(|| anyhow::anyhow!("Failed to convert value to f64 for Eq"))?;
136                let fval = OrderedFloat::new(fval)
137                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
138                if let Some(ids) = self.map.get(&fval) {
139                    Ok(insert_into_bitset(ids.to_vec()))
140                } else {
141                    Ok(BitSet::new())
142                }
143            }
144            CompareOp::Ne(v) => {
145                let fval = v
146                    .as_f64()
147                    .ok_or_else(|| anyhow::anyhow!("Failed to convert value to f64 for Ne"))?;
148                let fval = OrderedFloat::new(fval)
149                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
150                let mut bitset = BitSet::new();
151                for (val, ids) in self.map.iter() {
152                    if val != &fval {
153                        bitset.extend(ids.iter().cloned());
154                    }
155                }
156                Ok(bitset)
157            }
158            CompareOp::Lt(num) => {
159                let fval = OrderedFloat::new(*num)
160                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
161                let iter = self.map.range((Unbounded, Excluded(fval)));
162                Ok(insert_into_bitset(
163                    iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
164                ))
165            }
166            CompareOp::Lte(num) => {
167                let fval = OrderedFloat::new(*num)
168                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
169                let iter = self.map.range((Unbounded, Included(fval)));
170                Ok(insert_into_bitset(
171                    iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
172                ))
173            }
174            CompareOp::Gt(num) => {
175                let fval = OrderedFloat::new(*num)
176                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
177                let iter = self.map.range((Excluded(fval), Unbounded));
178                Ok(insert_into_bitset(
179                    iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
180                ))
181            }
182            CompareOp::Gte(num) => {
183                let fval = OrderedFloat::new(*num)
184                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
185                let iter = self.map.range((Included(fval), Unbounded));
186                Ok(insert_into_bitset(
187                    iter.flat_map(|(_, ids)| ids.iter().cloned()).collect(),
188                ))
189            }
190        }
191    }
192}
193
194// Helper to prepend the separator if not present
195fn prepend_separator(field: &str) -> String {
196    let separator = FlattenConfig::dot_notation().separator;
197    if !field.starts_with(&separator) {
198        format!("{}{}", separator, field)
199    } else {
200        field.to_string()
201    }
202}
203
204// Takes in an expression and returns a vector of all the labels used in the expression (raw field names, no separator prepending)
205fn compute_label_set(expr: &ASTExpr) -> Vec<String> {
206    match expr {
207        ASTExpr::Not(sub) => compute_label_set(sub),
208        ASTExpr::And(subs) => subs.iter().flat_map(compute_label_set).collect(),
209        ASTExpr::Or(subs) => subs.iter().flat_map(compute_label_set).collect(),
210        ASTExpr::Compare { field, .. } => vec![field.clone()],
211    }
212}
213
214// Takes in a set of labels and returns the universe of all possible values for those labels
215fn compute_universe(
216    universe_labels: Vec<String>,
217    query_accelerators: &HashMap<String, Box<dyn QueryAccelerator>>,
218) -> BitSet {
219    let mut universe_iter = universe_labels.iter();
220    // Initialize universe to the first accelerator's universe, then intersect with the rest
221    let mut universe = if let Some(first_label) = universe_iter.next() {
222        if let Some(accelerator) = query_accelerators.get(first_label) {
223            accelerator.universe()
224        } else {
225            BitSet::new()
226        }
227    } else {
228        BitSet::new()
229    };
230    for label in universe_iter {
231        if let Some(accelerator) = query_accelerators.get(label) {
232            universe = universe.intersection(&accelerator.universe()).collect();
233        }
234    }
235    universe
236}
237
238fn insert_into_bitset(ids: Vec<usize>) -> BitSet {
239    let mut bitset = BitSet::new();
240    bitset.extend(ids);
241    bitset
242}
243
244fn eval_query_using_accelerators(
245    query_expr: &ASTExpr,
246    query_accelerators: &HashMap<String, Box<dyn QueryAccelerator>>,
247) -> Result<BitSet, anyhow::Error> {
248    match query_expr {
249        ASTExpr::And(subs) => {
250            let mut acc: Option<BitSet> = None;
251            for e in subs {
252                let b = eval_query_using_accelerators(e, query_accelerators)?;
253                acc = Some(match acc {
254                    None => b,
255                    Some(acc_b) => acc_b.intersection(&b).collect(),
256                });
257            }
258            Ok(acc.unwrap_or_else(BitSet::new))
259        }
260        ASTExpr::Or(subs) => {
261            let mut acc: Option<BitSet> = None;
262            for e in subs {
263                let b = eval_query_using_accelerators(e, query_accelerators)?;
264                acc = Some(match acc {
265                    None => b,
266                    Some(acc_b) => acc_b.union(&b).collect(),
267                });
268            }
269            Ok(acc.unwrap_or_else(BitSet::new))
270        }
271        ASTExpr::Not(sub) => {
272            // compute the universe of all possible values
273            let universe_labels_raw = compute_label_set(query_expr);
274            let universe_labels: Vec<String> = universe_labels_raw
275                .iter()
276                .map(|f| prepend_separator(f))
277                .collect();
278            let universe = compute_universe(universe_labels, query_accelerators);
279
280            // Evaluate the sub-expression
281            let sub_result = eval_query_using_accelerators(sub, query_accelerators)?;
282
283            // Return the difference between the sub-expression result and the universe
284            Ok(universe.difference(&sub_result).collect())
285        }
286        ASTExpr::Compare { field, op } => {
287            let field = prepend_separator(field);
288            if let Some(accelerator) = query_accelerators.get(&field) {
289                accelerator.eval(op)
290            } else {
291                Ok(BitSet::new())
292            }
293        }
294    }
295}
296
297fn compute_inverted_index_accelerator(
298    key: &str,
299    doc_ids: &[usize],
300    labels: &[HashMap<String, AttributeValue>],
301) -> Result<HashMap<AttributeValue, BitSet>, anyhow::Error> {
302    let mut inverted_index: HashMap<AttributeValue, BitSet> = HashMap::new();
303    for (doc_id, label) in doc_ids.iter().zip(labels.iter()) {
304        if let Some(value) = label.get(key) {
305            inverted_index
306                .entry(value.clone())
307                .or_insert_with(BitSet::new)
308                .insert(*doc_id);
309        }
310    }
311    Ok(inverted_index)
312}
313
314fn compute_btree_accelerator(
315    key: &str,
316    labels: &[HashMap<String, AttributeValue>],
317    doc_ids: &[usize],
318) -> Result<BTreeMap<OrderedFloat, Vec<usize>>, anyhow::Error> {
319    // Implementation for computing BTree accelerator
320    let mut map: BTreeMap<OrderedFloat, Vec<usize>> = BTreeMap::new();
321    for (label, doc_id) in labels.iter().zip(doc_ids.iter().copied()) {
322        if let Some(value) = label.get(key) {
323            if let Some(f64_value) = value.as_float() {
324                let f64_value = OrderedFloat::new(f64_value)
325                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
326                map.entry(f64_value).or_default().push(doc_id);
327            } else if let Some(i64_value) = value.as_integer() {
328                // convert from i64 to f64
329                let f = i64_value as f64;
330                if f as i64 != i64_value {
331                    return Err(anyhow::anyhow!(
332                        "i64 value cannot be exactly represented as f64: {}",
333                        i64_value
334                    ));
335                }
336                let i64_value = OrderedFloat::new(f)
337                    .map_err(|e| anyhow::anyhow!("Failed to create OrderedFloat: {e}"))?;
338                map.entry(i64_value).or_default().push(doc_id);
339            } else {
340                // Error for other attribute values
341                return Err(anyhow::anyhow!(
342                    "Unsupported attribute value for key: {}",
343                    key
344                ));
345            }
346        }
347    }
348    Ok(map)
349}
350
351// Compute a global label set across all documents with a representative element
352// Make sure that each global label only maps to the same type of AttributeValue, and throw an error otherwise
353fn compute_global_label_set(
354    flattened_base_labels: &Vec<HashMap<std::string::String, AttributeValue>>,
355) -> Result<HashMap<String, AttributeValue>, anyhow::Error> {
356    let mut global_label_set = HashMap::new();
357    for labels in flattened_base_labels {
358        for (key, value) in labels {
359            if let Some(existing_value) = global_label_set.get(key) {
360                if discriminant(existing_value) != discriminant(value) {
361                    return Err(anyhow::anyhow!("Inconsistent types for key: {}", key));
362                }
363            }
364            global_label_set.insert(key.clone(), value.clone());
365        }
366    }
367    Ok(global_label_set)
368}
369
370fn compute_query_accelerator(
371    key: &str,
372    value: &AttributeValue,
373    doc_ids: &[usize],
374    flattened_base_labels: &[HashMap<String, AttributeValue>],
375) -> Result<Box<dyn QueryAccelerator>, anyhow::Error> {
376    match value {
377        AttributeValue::String(_) | AttributeValue::Bool(_) => {
378            let bitmap = compute_inverted_index_accelerator(key, doc_ids, flattened_base_labels)?;
379            Ok(Box::new(InvertedIndexAccelerator { map: bitmap }))
380        }
381        AttributeValue::Integer(_) | AttributeValue::Real(_) => {
382            let btree = compute_btree_accelerator(key, flattened_base_labels, doc_ids)?;
383            Ok(Box::new(BTreeAccelerator { map: btree }))
384        }
385        AttributeValue::Empty => Err(anyhow::anyhow!("Empty attribute value is not allowed")),
386    }
387}
388
389pub fn compute_query_bitmaps(
390    base_labels: Vec<Document>,
391    query_labels: Vec<(usize, ASTExpr)>,
392) -> Result<Vec<BitSet>, anyhow::Error> {
393    // Flatten base labels so that nested structures are converted to a flat list of key-value pairs
394    let flattened_base_labels: Vec<Vec<(std::string::String, AttributeValue)>> = base_labels
395        .iter()
396        .map(|base_label| {
397            flatten_json_pointers_with_config(&base_label.label, &FlattenConfig::dot_notation())
398        })
399        .collect();
400
401    let flattened_base_label_hashmaps: Result<Vec<HashMap<String, AttributeValue>>, anyhow::Error> =
402        flattened_base_labels
403            .iter()
404            .map(|labels| {
405                let mut map = HashMap::new();
406                for (key, value) in labels {
407                    // a base label may not have two values for the same key
408                    if let Some(_existing_value) = map.get(key) {
409                        return Err(anyhow::anyhow!(
410                            "Duplicate keys in the same document: {}",
411                            key
412                        ));
413                    }
414                    map.insert(key.clone(), value.clone());
415                }
416                Ok(map)
417            })
418            .collect();
419
420    let flattened_base_label_hashmaps = flattened_base_label_hashmaps?;
421    let base_doc_ids: Vec<usize> = base_labels
422        .iter()
423        .map(|base_label| base_label.doc_id)
424        .collect();
425
426    // compute the global set of labels ahead of time so that we can compute
427    // each accelerator in parallel
428    let global_label_set = compute_global_label_set(&flattened_base_label_hashmaps)?;
429
430    // Compute the accelerators for each label in the global set
431    #[allow(clippy::disallowed_methods)]
432    let query_accelerators: HashMap<String, Box<dyn QueryAccelerator>> = global_label_set
433        .par_iter()
434        .map(|(key, value)| {
435            compute_query_accelerator(key, value, &base_doc_ids, &flattened_base_label_hashmaps)
436                .map(|accel| (key.clone(), accel))
437        })
438        .collect::<Result<_, _>>()?;
439
440    // Evaluate each query using the precomputed accelerators
441    #[allow(clippy::disallowed_methods)]
442    let query_bitmaps: Result<Vec<BitSet>, anyhow::Error> = query_labels
443        .par_iter()
444        .map(|(_query_id, query_expr)| {
445            eval_query_using_accelerators(query_expr, &query_accelerators)
446        })
447        .collect();
448
449    let query_bitmaps = query_bitmaps?;
450
451    Ok(query_bitmaps)
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use diskann_label_filter::attribute::AttributeValue;
458    use diskann_label_filter::parser::format::Document;
459    use diskann_label_filter::{ASTExpr, CompareOp};
460    use serde_json::json;
461    use std::collections::HashMap;
462
463    #[test]
464    fn test_compute_query_bitmap_not_with_missing_field() {
465        // Three documents: two with "color", one without
466        let base_labels = vec![
467            Document {
468                doc_id: 0,
469                label: json!({"color": "red"}),
470            },
471            Document {
472                doc_id: 1,
473                label: json!({"color": "blue"}),
474            },
475            Document {
476                doc_id: 2,
477                label: json!({"shape": "circle"}), // no color field
478            },
479        ];
480
481        // Query: NOT color == "red"
482        let not_query = ASTExpr::Not(Box::new(ASTExpr::Compare {
483            field: "color".to_string(),
484            op: CompareOp::Eq(json!("red")),
485        }));
486        let queries = vec![(0, not_query)];
487        let bitmaps = compute_query_bitmaps(base_labels.clone(), queries).expect("Should succeed");
488        // Only doc 1 should match (has color and is not red)
489        assert!(bitmaps[0].contains(1));
490        assert!(!bitmaps[0].contains(0));
491        // Doc 2 does not have color, so should not be included in the NOT universe
492        assert!(!bitmaps[0].contains(2));
493    }
494
495    #[test]
496    fn test_compute_universe_function() {
497        // Sub-test 1: universe label not in query_accelerators, should return empty
498        let query_accelerators: HashMap<String, Box<dyn QueryAccelerator>> = HashMap::new();
499        let universe_labels = vec!["missing_label".to_string()];
500        let result = compute_universe(universe_labels, &query_accelerators);
501        assert!(
502            result.is_empty(),
503            "Universe should be empty if label is missing"
504        );
505
506        // Sub-test 2: both accelerator types, non-empty intersection
507        // InvertedIndexAccelerator for 'foo' with docs 1, 2
508        let mut inv_map = HashMap::new();
509        inv_map.insert(
510            AttributeValue::String("a".to_string()),
511            [1, 2].iter().cloned().collect(),
512        );
513        let inv_accel = Box::new(InvertedIndexAccelerator { map: inv_map });
514
515        // BTreeAccelerator for 'bar' with docs 2, 3
516        let mut btree_map = BTreeMap::new();
517        btree_map.insert(OrderedFloat(1.0), vec![2, 3]);
518        let btree_accel = Box::new(BTreeAccelerator { map: btree_map });
519
520        let mut query_accelerators: HashMap<String, Box<dyn QueryAccelerator>> = HashMap::new();
521        query_accelerators.insert("foo".to_string(), inv_accel);
522        query_accelerators.insert("bar".to_string(), btree_accel);
523
524        // The intersection of {1,2} and {2,3} is {2}
525        let universe_labels = vec!["foo".to_string(), "bar".to_string()];
526        let result = compute_universe(universe_labels, &query_accelerators);
527        let expected: BitSet = [2].iter().cloned().collect();
528        assert_eq!(
529            result, expected,
530            "Universe should be the intersection of both accelerator universes"
531        );
532    }
533
534    #[test]
535    fn test_compute_label_set() {
536        // OR expression: foo == 1 OR bar == 2
537        let expr_or = ASTExpr::Or(vec![
538            ASTExpr::Compare {
539                field: "foo".to_string(),
540                op: CompareOp::Eq(json!(1)),
541            },
542            ASTExpr::Compare {
543                field: "bar".to_string(),
544                op: CompareOp::Eq(json!(2)),
545            },
546        ]);
547        let mut result_or = compute_label_set(&expr_or);
548        result_or.sort();
549        assert_eq!(result_or, vec!["bar".to_string(), "foo".to_string()]);
550
551        // NOT expression: NOT (baz == 3)
552        let expr_not = ASTExpr::Not(Box::new(ASTExpr::Compare {
553            field: "baz".to_string(),
554            op: CompareOp::Eq(json!(3)),
555        }));
556        let result_not = compute_label_set(&expr_not);
557        assert_eq!(result_not, vec!["baz".to_string()]);
558    }
559
560    #[test]
561    fn test_compute_query_bitmap_duplicate_key_in_doc() {
562        // serde_json does not allow duplicate keys, but we can simulate this by flattening a document with a nested object that, when flattened, produces duplicate keys
563        // For this test, we will directly call compute_query_bitmaps with a document that, after flattening, would have duplicate keys
564        // This is a synthetic test: we create a document with a nested object and a top-level key that would flatten to the same key
565        let base_labels = vec![Document {
566            doc_id: 0,
567            label: json!({"color": {"color": "red"}, "color.color": "blue"}),
568        }];
569        // Query: color == "red"
570        let query = ASTExpr::Compare {
571            field: "color".to_string(),
572            op: CompareOp::Eq(json!("red")),
573        };
574        let result = compute_query_bitmaps(base_labels.clone(), vec![(0, query)]);
575        assert!(
576            result.is_err(),
577            "Should error on duplicate keys in the same document"
578        );
579    }
580
581    #[test]
582    fn test_compute_query_bitmap_inconsistent_types() {
583        // Two documents, same key, different value types
584        let base_labels = vec![
585            Document {
586                doc_id: 0,
587                label: json!({"foo": "bar"}),
588            },
589            Document {
590                doc_id: 1,
591                label: json!({"foo": 123}),
592            },
593        ];
594        // Query: foo == "bar"
595        let query = ASTExpr::Compare {
596            field: "foo".to_string(),
597            op: CompareOp::Eq(json!("bar")),
598        };
599        let result = compute_query_bitmaps(base_labels.clone(), vec![(0, query)]);
600        assert!(result.is_err(), "Should error on inconsistent value types");
601    }
602
603    #[test]
604    fn test_compute_query_bitmap_missing_field() {
605        // Three documents, one missing the 'color' field
606        let base_labels = vec![
607            Document {
608                doc_id: 0,
609                label: json!({"weight": 30}), // no color field
610            },
611            Document {
612                doc_id: 1,
613                label: json!({"color": "red", "weight": 10}),
614            },
615            Document {
616                doc_id: 2,
617                label: json!({"color": "blue", "weight": 20}),
618            },
619        ];
620
621        // Query: color == "red"
622        let query_color = ASTExpr::Compare {
623            field: "color".to_string(),
624            op: CompareOp::Eq(json!("red")),
625        };
626        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_color)])
627            .expect("should succeed");
628        assert!(!bitmaps[0].contains(0));
629        assert!(bitmaps[0].contains(1));
630        assert!(!bitmaps[0].contains(2));
631
632        // Query: weight >= 20
633        let query_weight = ASTExpr::Compare {
634            field: "weight".to_string(),
635            op: CompareOp::Gte(20.0),
636        };
637        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_weight)])
638            .expect("should succeed");
639        assert!(!bitmaps[0].contains(1));
640        assert!(bitmaps[0].contains(2));
641        assert!(bitmaps[0].contains(0));
642    }
643
644    #[test]
645    fn test_compute_query_bitmap_nested_value() {
646        // Two documents with nested car.color
647        let base_labels = vec![
648            Document {
649                doc_id: 0,
650                label: json!({"car": {"color": "red"}}),
651            },
652            Document {
653                doc_id: 1,
654                label: json!({"car": {"color": "blue"}}),
655            },
656        ];
657
658        // Query: car.color == "red"
659        let query_eq = ASTExpr::Compare {
660            field: "car.color".to_string(),
661            op: CompareOp::Eq(json!("red")),
662        };
663        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_eq)])
664            .expect("should succeed");
665        assert!(bitmaps[0].contains(0));
666        assert!(!bitmaps[0].contains(1));
667
668        // Query: NOT .car.color == "red" (should match blue)
669        let query_not = ASTExpr::Not(Box::new(ASTExpr::Compare {
670            field: ".car.color".to_string(),
671            op: CompareOp::Eq(json!("red")),
672        }));
673        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_not)])
674            .expect("should succeed");
675        assert!(bitmaps[0].contains(1));
676        assert!(!bitmaps[0].contains(0));
677    }
678
679    #[test]
680    fn test_compute_query_bitmap_floats() {
681        let base_labels = vec![
682            Document {
683                doc_id: 0,
684                label: json!({"score": 1.5}),
685            },
686            Document {
687                doc_id: 1,
688                label: json!({"score": 2.0}),
689            },
690            Document {
691                doc_id: 2,
692                label: json!({"score": 3.5}),
693            },
694        ];
695
696        // score < 2.0
697        let query_lt = ASTExpr::Compare {
698            field: "score".to_string(),
699            op: CompareOp::Lt(2.0),
700        };
701        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lt)])
702            .expect("should succeed");
703        assert!(bitmaps[0].contains(0));
704        assert!(!bitmaps[0].contains(1));
705        assert!(!bitmaps[0].contains(2));
706
707        // score > 2.0
708        let query_gt = ASTExpr::Compare {
709            field: "score".to_string(),
710            op: CompareOp::Gt(2.0),
711        };
712        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gt)])
713            .expect("should succeed");
714        assert!(bitmaps[0].contains(2));
715        assert!(!bitmaps[0].contains(0));
716        assert!(!bitmaps[0].contains(1));
717
718        // score <= 2.0
719        let query_lte = ASTExpr::Compare {
720            field: "score".to_string(),
721            op: CompareOp::Lte(2.0),
722        };
723        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lte)])
724            .expect("should succeed");
725        assert!(bitmaps[0].contains(0));
726        assert!(bitmaps[0].contains(1));
727        assert!(!bitmaps[0].contains(2));
728
729        // score >= 2.0
730        let query_gte = ASTExpr::Compare {
731            field: "score".to_string(),
732            op: CompareOp::Gte(2.0),
733        };
734        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gte)])
735            .expect("should succeed");
736        assert!(bitmaps[0].contains(1));
737        assert!(bitmaps[0].contains(2));
738        assert!(!bitmaps[0].contains(0));
739
740        // score >= 2.0 AND score <= 3.5 (range: [2.0, 3.5])
741        let query_range = ASTExpr::And(vec![
742            ASTExpr::Compare {
743                field: "score".to_string(),
744                op: CompareOp::Gte(2.0),
745            },
746            ASTExpr::Compare {
747                field: "score".to_string(),
748                op: CompareOp::Lte(3.5),
749            },
750        ]);
751        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_range)])
752            .expect("should succeed");
753        // Should match doc 1 (2.0) and doc 2 (3.5)
754        assert!(bitmaps[0].contains(1));
755        assert!(bitmaps[0].contains(2));
756        assert!(!bitmaps[0].contains(0));
757    }
758
759    #[test]
760    fn test_compute_query_bitmap_ints() {
761        let base_labels = vec![
762            Document {
763                doc_id: 0,
764                label: json!({"age": 10}),
765            },
766            Document {
767                doc_id: 1,
768                label: json!({"age": 20}),
769            },
770            Document {
771                doc_id: 2,
772                label: json!({"age": 30}),
773            },
774        ];
775
776        // age < 20
777        let query_lt = ASTExpr::Compare {
778            field: "age".to_string(),
779            op: CompareOp::Lt(20.0),
780        };
781        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lt)])
782            .expect("should succeed");
783        assert!(bitmaps[0].contains(0));
784        assert!(!bitmaps[0].contains(1));
785        assert!(!bitmaps[0].contains(2));
786
787        // age > 20
788        let query_gt = ASTExpr::Compare {
789            field: "age".to_string(),
790            op: CompareOp::Gt(20.0),
791        };
792        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gt)])
793            .expect("should succeed");
794        assert!(bitmaps[0].contains(2));
795        assert!(!bitmaps[0].contains(0));
796        assert!(!bitmaps[0].contains(1));
797
798        // age <= 20
799        let query_lte = ASTExpr::Compare {
800            field: "age".to_string(),
801            op: CompareOp::Lte(20.0),
802        };
803        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_lte)])
804            .expect("should succeed");
805        assert!(bitmaps[0].contains(0));
806        assert!(bitmaps[0].contains(1));
807        assert!(!bitmaps[0].contains(2));
808
809        // age >= 20
810        let query_gte = ASTExpr::Compare {
811            field: "age".to_string(),
812            op: CompareOp::Gte(20.0),
813        };
814        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_gte)])
815            .expect("should succeed");
816        assert!(bitmaps[0].contains(1));
817        assert!(bitmaps[0].contains(2));
818        assert!(!bitmaps[0].contains(0));
819
820        // age >= 20 AND age <= 30 (range: [20, 30])
821        let query_range = ASTExpr::And(vec![
822            ASTExpr::Compare {
823                field: "age".to_string(),
824                op: CompareOp::Gte(20.0),
825            },
826            ASTExpr::Compare {
827                field: "age".to_string(),
828                op: CompareOp::Lte(30.0),
829            },
830        ]);
831        let bitmaps = compute_query_bitmaps(base_labels.clone(), vec![(0, query_range)])
832            .expect("should succeed");
833        // Should match doc 1 (20) and doc 2 (30)
834        assert!(bitmaps[0].contains(1));
835        assert!(bitmaps[0].contains(2));
836        assert!(!bitmaps[0].contains(0));
837    }
838
839    #[test]
840    fn test_compute_query_bitmap_ints_uses_document_ids_in_accelerator() {
841        let base_labels = vec![
842            Document {
843                doc_id: 10,
844                label: json!({"age": 10}),
845            },
846            Document {
847                doc_id: 20,
848                label: json!({"age": 20}),
849            },
850            Document {
851                doc_id: 30,
852                label: json!({"age": 30}),
853            },
854        ];
855
856        let query_gte = ASTExpr::Compare {
857            field: "age".to_string(),
858            op: CompareOp::Gte(20.0),
859        };
860        let bitmaps =
861            compute_query_bitmaps(base_labels, vec![(0, query_gte)]).expect("should succeed");
862
863        assert!(bitmaps[0].contains(20));
864        assert!(bitmaps[0].contains(30));
865        assert!(!bitmaps[0].contains(10));
866        assert!(!bitmaps[0].contains(0));
867        assert!(!bitmaps[0].contains(1));
868        assert!(!bitmaps[0].contains(2));
869    }
870
871    #[test]
872    fn test_compute_query_bitmap_bools() {
873        // Two documents with a boolean field
874        let base_labels = vec![
875            Document {
876                doc_id: 0,
877                label: json!({"flag": true}),
878            },
879            Document {
880                doc_id: 1,
881                label: json!({"flag": false}),
882            },
883        ];
884
885        // Query: flag == true
886        let query = ASTExpr::Compare {
887            field: "flag".to_string(),
888            op: CompareOp::Eq(json!(true)),
889        };
890        let queries = vec![(0, query)];
891        let bitmaps = compute_query_bitmaps(base_labels.clone(), queries).expect("should succeed");
892        // Only doc 0 should match
893        assert!(bitmaps[0].contains(0));
894        assert!(!bitmaps[0].contains(1));
895    }
896
897    #[test]
898    fn test_compute_query_bitmaps_mixed_labels() {
899        let base_labels = vec![
900            Document {
901                doc_id: 0,
902                label: json!({"color": "red", "size": 10}),
903            },
904            Document {
905                doc_id: 1,
906                label: json!({"color": "blue", "size": 20}),
907            },
908            Document {
909                doc_id: 2,
910                label: json!({"color": "red", "size": 20}),
911            },
912        ];
913
914        // Query: color == "red"
915        let query1 = ASTExpr::Compare {
916            field: "color".to_string(),
917            op: CompareOp::Eq(serde_json::Value::String("red".to_string())),
918        };
919        // Query: size == 20
920        let query2 = ASTExpr::Compare {
921            field: "size".to_string(),
922            op: CompareOp::Eq(20.into()),
923        };
924        // Query: color == "red" AND size == 20
925        let query3 = ASTExpr::And(vec![
926            ASTExpr::Compare {
927                field: "color".to_string(),
928                op: CompareOp::Eq(serde_json::Value::String("red".to_string())),
929            },
930            ASTExpr::Compare {
931                field: "size".to_string(),
932                op: CompareOp::Eq(20.into()),
933            },
934        ]);
935        // Query: color == "red" OR size == 10
936        let query4 = ASTExpr::Or(vec![
937            ASTExpr::Compare {
938                field: "color".to_string(),
939                op: CompareOp::Eq(serde_json::Value::String("red".to_string())),
940            },
941            ASTExpr::Compare {
942                field: "size".to_string(),
943                op: CompareOp::Eq(10.into()),
944            },
945        ]);
946
947        let queries = vec![(0, query1), (1, query2), (2, query3), (3, query4)];
948
949        let bitmaps = compute_query_bitmaps(base_labels.clone(), queries).expect("should succeed");
950        // color == "red" => doc 0, 2
951        assert!(bitmaps[0].contains(0));
952        assert!(bitmaps[0].contains(2));
953        assert!(!bitmaps[0].contains(1));
954        // size == 20 => doc 1, 2
955        assert!(bitmaps[1].contains(1));
956        assert!(bitmaps[1].contains(2));
957        assert!(!bitmaps[1].contains(0));
958        // color == "red" AND size == 20 => doc 2
959        assert!(bitmaps[2].contains(2));
960        assert!(!bitmaps[2].contains(0));
961        assert!(!bitmaps[2].contains(1));
962        // color == "red" OR size == 10 => doc 0, 2
963        assert!(bitmaps[3].contains(0));
964        assert!(bitmaps[3].contains(2));
965        assert!(!bitmaps[3].contains(1));
966
967        // Query: NOT color == "red"
968        let not_query = ASTExpr::Not(Box::new(ASTExpr::Compare {
969            field: "color".to_string(),
970            op: CompareOp::Eq(serde_json::json!("red")),
971        }));
972        let queries_with_not = vec![(0, not_query)];
973        let bitmaps =
974            compute_query_bitmaps(base_labels.clone(), queries_with_not).expect("Should succeed");
975        // The result should be a bitmap with doc 1 (not red)
976        assert!(bitmaps[0].contains(1));
977        assert!(!bitmaps[0].contains(0));
978        assert!(!bitmaps[0].contains(2));
979    }
980
981    #[test]
982    fn test_compute_query_accelerator() {
983        // Prepare base labels
984        let mut doc1 = HashMap::new();
985        doc1.insert("foo".to_string(), AttributeValue::String("bar".to_string()));
986        doc1.insert("num".to_string(), AttributeValue::Integer(42));
987        doc1.insert("real".to_string(), AttributeValue::Real(3.13));
988        doc1.insert("flag".to_string(), AttributeValue::Bool(true));
989        let mut doc2 = HashMap::new();
990        doc2.insert("foo".to_string(), AttributeValue::String("baz".to_string()));
991        doc2.insert("num".to_string(), AttributeValue::Integer(7));
992        doc2.insert("real".to_string(), AttributeValue::Real(2.71));
993        doc2.insert("flag".to_string(), AttributeValue::Bool(false));
994        let base = vec![doc1, doc2];
995        let doc_ids = vec![10, 42];
996
997        // String
998        let accel = compute_query_accelerator(
999            "foo",
1000            &AttributeValue::String("bar".to_string()),
1001            &doc_ids,
1002            &base,
1003        )
1004        .expect("Should succeed for String");
1005        let accel = accel
1006            .as_any()
1007            .downcast_ref::<InvertedIndexAccelerator>()
1008            .expect("Expected InvertedIndexAccelerator");
1009        assert!(accel
1010            .map
1011            .contains_key(&AttributeValue::String("bar".to_string())));
1012        assert!(accel
1013            .map
1014            .contains_key(&AttributeValue::String("baz".to_string())));
1015        assert_eq!(
1016            accel
1017                .map
1018                .get(&AttributeValue::String("bar".to_string()))
1019                .expect("bar key should exist")
1020                .iter()
1021                .collect::<Vec<_>>(),
1022            vec![10]
1023        );
1024        assert_eq!(
1025            accel
1026                .map
1027                .get(&AttributeValue::String("baz".to_string()))
1028                .expect("baz key should exist")
1029                .iter()
1030                .collect::<Vec<_>>(),
1031            vec![42]
1032        );
1033
1034        // Bool
1035        let accel = compute_query_accelerator("flag", &AttributeValue::Bool(true), &doc_ids, &base)
1036            .expect("Should succeed for Bool");
1037        let accel = accel
1038            .as_any()
1039            .downcast_ref::<InvertedIndexAccelerator>()
1040            .expect("Expected InvertedIndexAccelerator");
1041        assert!(accel.map.contains_key(&AttributeValue::Bool(true)));
1042        assert!(accel.map.contains_key(&AttributeValue::Bool(false)));
1043
1044        // Integer
1045        let accel = compute_query_accelerator("num", &AttributeValue::Integer(42), &doc_ids, &base)
1046            .expect("Should succeed for Integer");
1047        let accel = accel
1048            .as_any()
1049            .downcast_ref::<BTreeAccelerator>()
1050            .expect("Expected BTreeAccelerator");
1051        assert!(accel.map.contains_key(&super::OrderedFloat(42.0)));
1052        assert!(accel.map.contains_key(&super::OrderedFloat(7.0)));
1053
1054        // Real
1055        let accel = compute_query_accelerator("real", &AttributeValue::Real(3.13), &doc_ids, &base)
1056            .expect("Should succeed for Real");
1057        let accel = accel
1058            .as_any()
1059            .downcast_ref::<BTreeAccelerator>()
1060            .expect("Expected BTreeAccelerator");
1061        assert!(accel.map.contains_key(&super::OrderedFloat(3.13)));
1062        assert!(accel.map.contains_key(&super::OrderedFloat(2.71)));
1063
1064        // Empty
1065        let err = compute_query_accelerator("none", &AttributeValue::Empty, &doc_ids, &base);
1066        assert!(err.is_err());
1067    }
1068}