a_tree/
atree.rs

1use crate::{
2    ast::*,
3    error::ATreeError,
4    evaluation::EvaluationResult,
5    events::{AttributeDefinition, AttributeTable, Event, EventBuilder},
6    parser,
7    predicates::Predicate,
8    strings::StringTable,
9};
10use slab::Slab;
11use std::{collections::HashMap, fmt::Debug, hash::Hash};
12
13type NodeId = usize;
14type ExpressionId = u64;
15
16/// The A-Tree data structure as described by the paper
17///
18/// See the [module documentation] for more details.
19///
20/// [module documentation]: index.html
21#[derive(Clone, Debug)]
22pub struct ATree<T> {
23    nodes: Slab<Entry<T>>,
24    strings: StringTable,
25    attributes: AttributeTable,
26    roots: Vec<NodeId>,
27    max_level: usize,
28    predicates: Vec<NodeId>,
29    expression_to_node: HashMap<ExpressionId, NodeId>,
30    nodes_by_ids: HashMap<T, NodeId>,
31}
32
33impl<T: Eq + Hash + Clone + Debug> ATree<T> {
34    const DEFAULT_PREDICATES: usize = 1000;
35    const DEFAULT_NODES: usize = 2000;
36    const DEFAULT_ROOTS: usize = 50;
37
38    /// Create a new [`ATree`] with the attributes that can be used by the inserted arbitrary
39    /// boolean expressions along with their types.
40    ///
41    /// # Examples
42    ///
43    /// ```rust
44    /// use a_tree::{ATree, AttributeDefinition};
45    ///
46    /// let definitions = [
47    ///     AttributeDefinition::boolean("private"),
48    ///     AttributeDefinition::integer("exchange_id")
49    /// ];
50    /// let result = ATree::<u64>::new(&definitions);
51    /// assert!(result.is_ok());
52    /// ```
53    ///
54    /// Duplicate attributes are not allowed and the [`ATree::new()`] function will return an error if there are some:
55    ///
56    /// ```rust
57    /// use a_tree::{ATree, AttributeDefinition};
58    ///
59    /// let definitions = [
60    ///     AttributeDefinition::boolean("private"),
61    ///     AttributeDefinition::boolean("private"),
62    /// ];
63    /// let result = ATree::<u64>::new(&definitions);
64    /// assert!(result.is_err());
65    /// ```
66    pub fn new(definitions: &[AttributeDefinition]) -> Result<Self, ATreeError> {
67        let attributes = AttributeTable::new(definitions).map_err(ATreeError::Event)?;
68        let strings = StringTable::new();
69        Ok(Self {
70            attributes,
71            strings,
72            max_level: 1,
73            roots: Vec::with_capacity(Self::DEFAULT_ROOTS),
74            predicates: Vec::with_capacity(Self::DEFAULT_PREDICATES),
75            nodes: Slab::with_capacity(Self::DEFAULT_NODES),
76            expression_to_node: HashMap::new(),
77            nodes_by_ids: HashMap::new(),
78        })
79    }
80
81    /// Insert an arbitrary boolean expression inside the [`ATree`].
82    ///
83    /// # Examples
84    ///
85    /// ```rust
86    /// use a_tree::{ATree, AttributeDefinition};
87    ///
88    /// let definitions = [
89    ///     AttributeDefinition::boolean("private"),
90    ///     AttributeDefinition::integer("exchange_id")
91    /// ];
92    /// let mut atree = ATree::new(&definitions).unwrap();
93    /// assert!(atree.insert(&1u64, "exchange_id = 5").is_ok());
94    /// assert!(atree.insert(&2u64, "private").is_ok());
95    /// ```
96    #[inline]
97    pub fn insert<'a>(
98        &'a mut self,
99        subscription_id: &T,
100        expression: &'a str,
101    ) -> Result<(), ATreeError<'a>> {
102        let ast = parser::parse(expression, &self.attributes, &mut self.strings)
103            .map_err(ATreeError::ParseError)?;
104        let ast = ast.optimize();
105        self.insert_root(subscription_id, ast);
106        Ok(())
107    }
108
109    fn insert_root(&mut self, subscription_id: &T, root: OptimizedNode) {
110        let expression_id = root.id();
111        if let Some(node_id) = self.expression_to_node.get(&expression_id) {
112            add_subscription_id(
113                subscription_id,
114                *node_id,
115                &mut self.nodes,
116                &mut self.nodes_by_ids,
117            );
118            increment_use_count(*node_id, &mut self.nodes);
119            return;
120        }
121
122        let is_and = matches!(&root, OptimizedNode::And(_, _));
123        let cost = root.cost();
124        let node_id = match root {
125            OptimizedNode::And(left, right) | OptimizedNode::Or(left, right) => {
126                let left_id = self.insert_node(*left);
127                let right_id = self.insert_node(*right);
128                let left_entry = &self.nodes[left_id];
129                let right_entry = &self.nodes[right_id];
130                let rnode = ATreeNode::RNode(RNode {
131                    level: 1 + std::cmp::max(left_entry.node.level(), right_entry.node.level()),
132                    operator: if is_and { Operator::And } else { Operator::Or },
133                    children: if left_entry.cost > right_entry.cost {
134                        vec![right_id, left_id]
135                    } else {
136                        vec![left_id, right_id]
137                    },
138                });
139                let node_id = insert_node(
140                    &mut self.expression_to_node,
141                    &mut self.nodes,
142                    &expression_id,
143                    rnode,
144                    Some(subscription_id.clone()),
145                    cost,
146                );
147                if is_and {
148                    choose_access_child(
149                        left_id,
150                        right_id,
151                        node_id,
152                        &mut self.nodes,
153                        &mut self.predicates,
154                    );
155                } else {
156                    add_parent(&mut self.nodes[left_id], node_id);
157                    add_parent(&mut self.nodes[right_id], node_id);
158                    add_predicate(left_id, &self.nodes, &mut self.predicates);
159                    add_predicate(right_id, &self.nodes, &mut self.predicates);
160                }
161                node_id
162            }
163            OptimizedNode::Value(value) => {
164                let lnode = ATreeNode::lnode(&value);
165                let node_id = insert_node(
166                    &mut self.expression_to_node,
167                    &mut self.nodes,
168                    &expression_id,
169                    lnode,
170                    Some(subscription_id.clone()),
171                    cost,
172                );
173                self.predicates.push(node_id);
174                node_id
175            }
176        };
177        self.nodes_by_ids.insert(subscription_id.clone(), node_id);
178        self.roots.push(node_id);
179        self.max_level = get_max_level(&self.roots, &self.nodes);
180    }
181
182    fn insert_node(&mut self, node: OptimizedNode) -> NodeId {
183        let expression_id = node.id();
184        if let Some(node_id) = self.expression_to_node.get(&expression_id) {
185            change_rnode_to_inode(*node_id, &mut self.nodes);
186            increment_use_count(*node_id, &mut self.nodes);
187            return *node_id;
188        }
189
190        let is_and = matches!(node, OptimizedNode::And(_, _));
191        let cost = node.cost();
192        match node {
193            OptimizedNode::And(left, right) | OptimizedNode::Or(left, right) => {
194                let left_id = self.insert_node(*left);
195                let right_id = self.insert_node(*right);
196                let left_entry = &self.nodes[left_id];
197                let right_entry = &self.nodes[right_id];
198                let inode = INode {
199                    parents: vec![],
200                    level: 1 + std::cmp::max(left_entry.node.level(), right_entry.node.level()),
201                    operator: if is_and { Operator::And } else { Operator::Or },
202                    children: if left_entry.cost > right_entry.cost {
203                        vec![right_id, left_id]
204                    } else {
205                        vec![left_id, right_id]
206                    },
207                };
208                let inode = ATreeNode::INode(inode);
209                let node_id = insert_node(
210                    &mut self.expression_to_node,
211                    &mut self.nodes,
212                    &expression_id,
213                    inode,
214                    None,
215                    cost,
216                );
217                if is_and {
218                    choose_access_child(
219                        left_id,
220                        right_id,
221                        node_id,
222                        &mut self.nodes,
223                        &mut self.predicates,
224                    );
225                } else {
226                    add_parent(&mut self.nodes[left_id], node_id);
227                    add_parent(&mut self.nodes[right_id], node_id);
228                    add_predicate(left_id, &self.nodes, &mut self.predicates);
229                    add_predicate(right_id, &self.nodes, &mut self.predicates);
230                }
231                node_id
232            }
233            OptimizedNode::Value(node) => {
234                let lnode = ATreeNode::lnode(&node);
235                insert_node(
236                    &mut self.expression_to_node,
237                    &mut self.nodes,
238                    &expression_id,
239                    lnode,
240                    None,
241                    cost,
242                )
243            }
244        }
245    }
246
247    /// Create a new [`EventBuilder`] to be able to generate an [`Event`] that will be usable for
248    /// finding the matching arbitrary boolean expressions inside the [`ATree`] via the
249    /// [`ATree::search()`] function.
250    #[inline]
251    pub fn make_event(&self) -> EventBuilder {
252        EventBuilder::new(&self.attributes, &self.strings)
253    }
254
255    /// Search the [`ATree`] for arbitrary boolean expressions that match the [`Event`].
256    pub fn search(&self, event: &Event) -> Result<Report<T>, ATreeError> {
257        let mut results = EvaluationResult::new(self.nodes.len());
258        let mut matches = Vec::with_capacity(50);
259
260        // Since the predicates will already be evaluated and their parents will be put into the
261        // queues, then there is no need to keep a queue for them.
262        let mut queues = vec![Vec::with_capacity(50); self.max_level - 1];
263        process_predicates(
264            &self.predicates,
265            &self.nodes,
266            event,
267            &mut matches,
268            &mut results,
269            &mut queues,
270        );
271
272        for level in 0..queues.len() {
273            while let Some((node_id, node)) = queues[level].pop() {
274                if results.is_evaluated(node_id) {
275                    continue;
276                }
277
278                let result = evaluate_node(
279                    node_id,
280                    event,
281                    node,
282                    &self.nodes,
283                    &mut results,
284                    &mut matches,
285                );
286                add_matches(result, node, &mut matches);
287
288                if node.is_root() {
289                    continue;
290                }
291
292                for parent_id in node.parents() {
293                    let entry = &self.nodes[*parent_id];
294                    let is_evaluated = results.is_evaluated(*parent_id);
295                    if !is_evaluated
296                        && matches!(entry.operator(), Operator::And)
297                        && !result.unwrap_or(true)
298                    {
299                        results.set_result(*parent_id, Some(false));
300                        continue;
301                    }
302
303                    if !is_evaluated {
304                        queues[entry.level() - 2].push((*parent_id, entry));
305                    }
306                }
307            }
308        }
309
310        Ok(Report::new(matches))
311    }
312
313    #[inline]
314    /// Delete the specified expression
315    pub fn delete(&mut self, subscription_id: &T) {
316        if let Some(node_id) = self.nodes_by_ids.get(subscription_id) {
317            self.delete_node(subscription_id, *node_id);
318        }
319    }
320
321    #[inline]
322    fn delete_node(&mut self, subscription_id: &T, node_id: NodeId) {
323        let children = decrement_use_count(
324            subscription_id,
325            node_id,
326            &mut self.nodes,
327            &mut self.expression_to_node,
328            &mut self.roots,
329            &mut self.predicates,
330            &mut self.nodes_by_ids,
331            &mut self.max_level,
332        );
333
334        if let Some(children) = children {
335            for child in children {
336                self.delete_node(subscription_id, child);
337            }
338        }
339    }
340
341    /// Export the [`ATree`] to the Graphviz format.
342    pub fn to_graphviz(&self) -> String {
343        const DEFAULT_CAPACITY: usize = 100_000;
344        let mut builder = String::with_capacity(DEFAULT_CAPACITY);
345        builder.push_str("digraph {\n");
346        builder.push_str("rankdir = TB;\n");
347        builder.push_str(r#"node [shape = "record"];"#);
348        builder.push('\n');
349        let mut relations = Vec::with_capacity(DEFAULT_CAPACITY);
350        let mut levels = vec![vec![]; self.max_level];
351        for (id, entry) in &self.nodes {
352            match &entry.node {
353                ATreeNode::LNode(LNode {
354                    parents, predicate, ..
355                }) => {
356                    let node = format!(
357                        r#"node_{id} [label = "{{{id} | level: {} | {predicate} | subscriptions: {:?} | l-node}}", style = "rounded"];"#,
358                        entry.level(),
359                        entry.subscription_ids
360                    );
361                    levels[entry.level() - 1].push((id, node));
362
363                    for parent_id in parents {
364                        relations.push(format!("node_{id} -> node_{parent_id};"));
365                    }
366                }
367                ATreeNode::INode(INode {
368                    children,
369                    parents,
370                    operator,
371                    ..
372                }) => {
373                    let node = format!(
374                        r#"node_{id} [label = "{{{id} | level: {} | {operator:#?} | subscriptions: {:?} | i-node}}"];"#,
375                        entry.level(),
376                        entry.subscription_ids
377                    );
378                    levels[entry.level() - 1].push((id, node));
379
380                    for parent_id in parents {
381                        relations.push(format!("node_{id} -> node_{parent_id};"));
382                    }
383
384                    for child_id in children {
385                        relations.push(format!("node_{id} -> node_{child_id};"));
386                    }
387                }
388                ATreeNode::RNode(RNode {
389                    children, operator, ..
390                }) => {
391                    let node = format!(
392                        r#"node_{id} [label = "{{{id} | level: {} | {operator:#?} | subscriptions: {:?} | r-node}}"];"#,
393                        entry.level(),
394                        entry.subscription_ids
395                    );
396                    levels[entry.level() - 1].push((id, node));
397
398                    for child_id in children {
399                        relations.push(format!("node_{id} -> node_{child_id};"));
400                    }
401                }
402            }
403        }
404
405        builder.push_str("\n// nodes\n");
406        for entries in levels.into_iter().rev() {
407            for (_, node) in entries.iter() {
408                builder.push_str(node);
409                builder.push('\n');
410            }
411
412            builder.push_str("{rank = same; ");
413            for (id, _) in entries {
414                builder.push_str(&format!("node_{id}; "));
415            }
416            builder.push_str("};\n");
417        }
418
419        builder.push_str("\n// edges\n");
420        for relation in relations {
421            builder.push_str(&relation);
422            builder.push('\n');
423        }
424
425        builder.push('}');
426        builder
427    }
428}
429
430#[inline]
431#[allow(clippy::too_many_arguments)]
432fn decrement_use_count<T: Eq + Hash>(
433    subscription_id: &T,
434    node_id: NodeId,
435    nodes: &mut Slab<Entry<T>>,
436    expression_to_node: &mut HashMap<ExpressionId, NodeId>,
437    roots: &mut Vec<NodeId>,
438    predicates: &mut Vec<NodeId>,
439    nodes_by_ids: &mut HashMap<T, NodeId>,
440    max_level: &mut usize,
441) -> Option<Vec<NodeId>> {
442    let node = &mut nodes[node_id];
443    node.use_count -= 1;
444    let mut children = None;
445    node.subscription_ids.retain(|x| *x != *subscription_id);
446    nodes_by_ids.remove(subscription_id);
447    if node.use_count == 0 {
448        if !node.is_leaf() {
449            children = Some(node.children().to_vec());
450        }
451        let expression_id = node.id;
452        roots.retain(|x| *x != node_id);
453        predicates.retain(|x| *x != node_id);
454        *max_level = get_max_level(roots, nodes);
455        expression_to_node.remove(&expression_id);
456        nodes.remove(node_id);
457    }
458
459    children
460}
461
462#[inline]
463fn insert_node<T>(
464    expression_to_node: &mut HashMap<ExpressionId, NodeId>,
465    nodes: &mut Slab<Entry<T>>,
466    expression_id: &ExpressionId,
467    node: ATreeNode,
468    subscription_id: Option<T>,
469    cost: u64,
470) -> NodeId {
471    let entry = Entry::new(*expression_id, node, subscription_id, cost);
472    let node_id = nodes.insert(entry);
473    if expression_to_node.insert(*expression_id, node_id).is_some() {
474        unreachable!("{expression_id} is already present; this is a bug");
475    }
476    node_id
477}
478
479#[inline]
480fn add_parent<T>(entry: &mut Entry<T>, node_id: NodeId) {
481    entry.node.add_parent(node_id);
482}
483
484#[inline]
485fn add_subscription_id<T: Eq + Hash + Clone>(
486    subscription_id: &T,
487    node_id: NodeId,
488    nodes: &mut Slab<Entry<T>>,
489    nodes_by_ids: &mut HashMap<T, NodeId>,
490) {
491    nodes[node_id]
492        .subscription_ids
493        .push(subscription_id.clone());
494    nodes_by_ids.insert(subscription_id.clone(), node_id);
495}
496
497#[inline]
498fn increment_use_count<T>(node_id: NodeId, nodes: &mut Slab<Entry<T>>) {
499    nodes[node_id].use_count += 1;
500}
501
502#[inline]
503fn get_max_level<T>(roots: &[NodeId], nodes: &Slab<Entry<T>>) -> usize {
504    roots
505        .iter()
506        .map(|root_id| nodes[*root_id].level())
507        .max()
508        .unwrap_or(1)
509}
510
511#[inline]
512fn change_rnode_to_inode<T>(node_id: NodeId, nodes: &mut Slab<Entry<T>>) {
513    let entry = &mut nodes[node_id];
514    if let ATreeNode::RNode(RNode {
515        children,
516        level,
517        operator,
518    }) = &entry.node
519    {
520        let inode = ATreeNode::INode(INode {
521            parents: vec![],
522            children: children.to_vec(),
523            level: *level,
524            operator: operator.clone(),
525        });
526        entry.node = inode;
527    }
528}
529
530#[inline]
531fn choose_access_child<T>(
532    left_id: NodeId,
533    right_id: NodeId,
534    parent_id: NodeId,
535    nodes: &mut Slab<Entry<T>>,
536    predicates: &mut Vec<NodeId>,
537) {
538    let left_entry = &nodes[left_id];
539    let right_entry = &nodes[right_id];
540    let accessor_id = if left_entry.cost < right_entry.cost {
541        left_id
542    } else {
543        right_id
544    };
545    add_parent(&mut nodes[accessor_id], parent_id);
546    add_predicate(accessor_id, nodes, predicates);
547}
548
549#[inline]
550fn add_predicate<T>(node_id: NodeId, nodes: &Slab<Entry<T>>, predicates: &mut Vec<NodeId>) {
551    let entry = &nodes[node_id];
552    if entry.is_leaf() && !predicates.contains(&node_id) {
553        predicates.push(node_id);
554    }
555}
556
557#[inline]
558fn process_predicates<'a, T>(
559    predicates: &[NodeId],
560    nodes: &'a Slab<Entry<T>>,
561    event: &Event,
562    matches: &mut Vec<&'a T>,
563    results: &mut EvaluationResult,
564    queues: &mut [Vec<(NodeId, &'a Entry<T>)>],
565) {
566    for predicate_id in predicates {
567        let node = &nodes[*predicate_id];
568        // The evaluation is delayed as much as possible; if the predicate has no
569        // subscribers and no parents, there is no point in evaluating eagerly and
570        // it should only be evaluated if there is a need for it.
571        let delay_evaluation = node.subscription_ids.is_empty() && node.parents().is_empty();
572        if delay_evaluation || results.is_evaluated(*predicate_id) {
573            continue;
574        }
575
576        let result = node.evaluate(event);
577        results.set_result(*predicate_id, result);
578        add_matches(result, node, matches);
579
580        node.parents()
581            .iter()
582            .map(|parent_id| (*parent_id, &nodes[*parent_id]))
583            .for_each(|(parent_id, parent)| {
584                if matches!(parent.operator(), Operator::And) && !result.unwrap_or(true) {
585                    results.set_result(parent_id, Some(false));
586                } else {
587                    queues[parent.level() - 2].push((parent_id, parent));
588                }
589            })
590    }
591}
592
593#[inline]
594fn evaluate_node<'a, T>(
595    node_id: NodeId,
596    event: &Event,
597    node: &'a Entry<T>,
598    nodes: &'a Slab<Entry<T>>,
599    results: &mut EvaluationResult,
600    matches: &mut Vec<&'a T>,
601) -> Option<bool> {
602    let operator = node.operator();
603    let result = match operator {
604        Operator::And => evaluate_and(node.children(), event, nodes, results, matches),
605        Operator::Or => evaluate_or(node.children(), event, nodes, results, matches),
606    };
607    results.set_result(node_id, result);
608    result
609}
610
611#[inline]
612fn evaluate_and<'a, T>(
613    children: &[NodeId],
614    event: &Event,
615    nodes: &'a Slab<Entry<T>>,
616    results: &mut EvaluationResult,
617    matches: &mut Vec<&'a T>,
618) -> Option<bool> {
619    let mut acc = Some(true);
620    for child_id in children {
621        let result = lazy_evaluate(*child_id, event, nodes, results, matches);
622        match (acc, result) {
623            (Some(false), _) => {
624                acc = Some(false);
625                break;
626            }
627            (_, Some(false)) => {
628                acc = Some(false);
629                break;
630            }
631            (Some(a), Some(b)) => {
632                acc = Some(a && b);
633            }
634            (_, _) => {
635                acc = None;
636            }
637        }
638    }
639    acc
640}
641
642#[inline]
643fn evaluate_or<'a, T>(
644    children: &[NodeId],
645    event: &Event,
646    nodes: &'a Slab<Entry<T>>,
647    results: &mut EvaluationResult,
648    matches: &mut Vec<&'a T>,
649) -> Option<bool> {
650    let mut acc = Some(false);
651    for child_id in children {
652        let result = lazy_evaluate(*child_id, event, nodes, results, matches);
653        match (acc, result) {
654            (Some(true), _) => {
655                acc = Some(true);
656                break;
657            }
658            (_, Some(true)) => {
659                acc = Some(true);
660                break;
661            }
662            (Some(a), Some(b)) => {
663                acc = Some(a || b);
664            }
665            (_, _) => {
666                acc = None;
667            }
668        }
669    }
670
671    acc
672}
673
674#[inline]
675fn lazy_evaluate<'a, T>(
676    node_id: NodeId,
677    event: &Event,
678    nodes: &'a Slab<Entry<T>>,
679    results: &mut EvaluationResult,
680    matches: &mut Vec<&'a T>,
681) -> Option<bool> {
682    if results.is_evaluated(node_id) {
683        return results.get_result(node_id);
684    }
685    let node = &nodes[node_id];
686    let result = if node.is_leaf() {
687        let result = node.evaluate(event);
688        results.set_result(node_id, result);
689        result
690    } else {
691        evaluate_node(node_id, event, node, nodes, results, matches)
692    };
693    add_matches(result, node, matches);
694    result
695}
696
697#[inline]
698fn add_matches<'a, T>(result: Option<bool>, node: &'a Entry<T>, matches: &mut Vec<&'a T>) {
699    if !node.subscription_ids.is_empty() {
700        if let Some(true) = result {
701            for subscription_id in &node.subscription_ids {
702                matches.push(subscription_id);
703            }
704        }
705    }
706}
707
708#[derive(Clone, Debug)]
709struct Entry<T> {
710    id: ExpressionId,
711    subscription_ids: Vec<T>,
712    node: ATreeNode,
713    use_count: usize,
714    cost: u64,
715}
716
717impl<T> Entry<T> {
718    fn new(id: ExpressionId, node: ATreeNode, subscription_id: Option<T>, cost: u64) -> Self {
719        Self {
720            id,
721            node,
722            use_count: 1,
723            subscription_ids: subscription_id
724                .map_or_else(Vec::new, |subscription_id| vec![subscription_id]),
725            cost,
726        }
727    }
728
729    #[inline]
730    const fn is_leaf(&self) -> bool {
731        matches!(self.node, ATreeNode::LNode(_))
732    }
733
734    #[inline]
735    const fn is_root(&self) -> bool {
736        matches!(self.node, ATreeNode::RNode(_))
737    }
738
739    #[inline]
740    const fn level(&self) -> usize {
741        self.node.level()
742    }
743
744    #[inline]
745    fn evaluate(&self, event: &Event) -> Option<bool> {
746        self.node.evaluate(event)
747    }
748
749    #[inline]
750    fn operator(&self) -> Operator {
751        self.node.operator()
752    }
753
754    #[inline]
755    fn children(&self) -> &[NodeId] {
756        self.node.children()
757    }
758
759    #[inline]
760    fn parents(&self) -> &[NodeId] {
761        self.node.parents()
762    }
763}
764
765#[derive(Clone, Debug)]
766#[allow(clippy::enum_variant_names)]
767enum ATreeNode {
768    LNode(LNode),
769    INode(INode),
770    RNode(RNode),
771}
772
773impl ATreeNode {
774    #[inline]
775    fn lnode(predicate: &Predicate) -> Self {
776        Self::LNode(LNode {
777            level: 1,
778            parents: vec![],
779            predicate: predicate.clone(),
780        })
781    }
782
783    #[inline]
784    const fn level(&self) -> usize {
785        match self {
786            Self::RNode(node) => node.level,
787            Self::LNode(node) => node.level,
788            Self::INode(node) => node.level,
789        }
790    }
791
792    #[inline]
793    fn evaluate(&self, event: &Event) -> Option<bool> {
794        match self {
795            Self::LNode(node) => node.predicate.evaluate(event),
796            node => unreachable!("evaluating {node:?} which is not a predicate; this is a bug."),
797        }
798    }
799
800    #[inline]
801    fn operator(&self) -> Operator {
802        match self {
803            Self::LNode(_) => {
804                unreachable!("trying to get the operator of leaf node; this is a bug");
805            }
806            Self::RNode(RNode { operator, .. }) | Self::INode(INode { operator, .. }) => {
807                operator.clone()
808            }
809        }
810    }
811
812    #[inline]
813    fn children(&self) -> &[NodeId] {
814        match self {
815            Self::INode(INode { children, .. }) | Self::RNode(RNode { children, .. }) => children,
816            Self::LNode(_) => unreachable!("cannot get children for l-node; this is a bug"),
817        }
818    }
819
820    #[inline]
821    fn parents(&self) -> &[NodeId] {
822        match self {
823            Self::INode(INode { parents, .. }) | Self::LNode(LNode { parents, .. }) => parents,
824            Self::RNode(_) => unreachable!("cannot get parents for r-node; this is a bug"),
825        }
826    }
827
828    #[inline]
829    fn add_parent(&mut self, parent_id: NodeId) {
830        match self {
831            ATreeNode::INode(node) => {
832                node.parents.push(parent_id);
833            }
834            ATreeNode::LNode(node) => {
835                node.parents.push(parent_id);
836            }
837            ATreeNode::RNode(node) => {
838                unreachable!("trying to insert parents to r-node {node:?} which cannot have any parents; this is a bug");
839            }
840        }
841    }
842}
843
844#[derive(Clone, Debug)]
845struct LNode {
846    parents: Vec<NodeId>,
847    level: usize,
848    predicate: Predicate,
849}
850
851#[derive(Clone, Debug)]
852struct INode {
853    parents: Vec<NodeId>,
854    children: Vec<NodeId>,
855    level: usize,
856    operator: Operator,
857}
858
859#[derive(Clone, Debug)]
860struct RNode {
861    children: Vec<NodeId>,
862    level: usize,
863    operator: Operator,
864}
865
866#[derive(Debug)]
867/// Structure that holds the search results from the [`ATree::search()`] function
868pub struct Report<'a, T> {
869    matches: Vec<&'a T>,
870}
871
872impl<'a, T> Report<'a, T> {
873    const fn new(matches: Vec<&'a T>) -> Self {
874        Self { matches }
875    }
876
877    #[inline]
878    /// Get the search matches
879    pub fn matches(&self) -> &[&'a T] {
880        &self.matches
881    }
882}
883
884#[cfg(test)]
885mod tests {
886    use super::*;
887
888    const AN_INVALID_BOOLEAN_EXPRESSION: &str = "invalid in (1, 2, 3 and";
889    const AN_EXPRESSION: &str = "exchange_id = 1";
890    const A_NOT_EXPRESSION: &str = "not private";
891    const AN_EXPRESSION_WITH_AND_OPERATORS: &str =
892        r#"exchange_id = 1 and deals one of ["deal-1", "deal-2"]"#;
893    const AN_EXPRESSION_WITH_OR_OPERATORS: &str =
894        r#"exchange_id = 1 or deals one of ["deal-1", "deal-2"]"#;
895    const A_COMPLEX_EXPRESSION: &str = r#"exchange_id = 1 and not private and deal_ids one of ["deal-1", "deal-2"] and segment_ids one of [1, 2, 3] and country = 'CA' and city in ['QC'] or country = 'US' and city in ['AZ']"#;
896    const ANOTHER_COMPLEX_EXPRESSION: &str = r#"exchange_id = 1 and not private and deal_ids one of ["deal-1", "deal-2"] and segment_ids one of [1, 2, 3] and country in ['FR', 'GB']"#;
897
898    fn is_sync_and_send<T: Send + Sync>() {}
899
900    #[test]
901    fn support_sync_and_send_traits() {
902        is_sync_and_send::<ATree<u64>>();
903    }
904
905    #[test]
906    fn can_build_an_atree() {
907        let definitions = [
908            AttributeDefinition::boolean("private"),
909            AttributeDefinition::string_list("deals"),
910            AttributeDefinition::integer("exchange_id"),
911            AttributeDefinition::float("bidfloor"),
912            AttributeDefinition::string("country"),
913            AttributeDefinition::integer_list("segment_ids"),
914        ];
915
916        let result = ATree::<u64>::new(&definitions);
917
918        assert!(result.is_ok());
919    }
920
921    #[test]
922    fn return_an_error_on_duplicate_definitions() {
923        let definitions = [
924            AttributeDefinition::boolean("private"),
925            AttributeDefinition::string("country"),
926            AttributeDefinition::string_list("deals"),
927            AttributeDefinition::integer("exchange_id"),
928            AttributeDefinition::float("bidfloor"),
929            AttributeDefinition::integer("country"),
930            AttributeDefinition::integer_list("segment_ids"),
931        ];
932
933        let result = ATree::<u64>::new(&definitions);
934
935        assert!(result.is_err());
936    }
937
938    #[test]
939    fn return_an_error_on_invalid_boolean_expression() {
940        let definitions = [
941            AttributeDefinition::boolean("private"),
942            AttributeDefinition::string("country"),
943            AttributeDefinition::string_list("deals"),
944            AttributeDefinition::integer("exchange_id"),
945            AttributeDefinition::integer_list("segment_ids"),
946        ];
947        let mut atree = ATree::new(&definitions).unwrap();
948
949        let result = atree.insert(&1u64, AN_INVALID_BOOLEAN_EXPRESSION);
950
951        assert!(result.is_err());
952    }
953
954    #[test]
955    fn return_an_error_on_empty_boolean_expression() {
956        let definitions = [
957            AttributeDefinition::boolean("private"),
958            AttributeDefinition::string("country"),
959            AttributeDefinition::string_list("deals"),
960            AttributeDefinition::integer("exchange_id"),
961            AttributeDefinition::integer_list("segment_ids"),
962        ];
963        let mut atree = ATree::new(&definitions).unwrap();
964
965        let result = atree.insert(&1u64, "");
966
967        assert!(result.is_err());
968    }
969
970    #[test]
971    fn can_insert_a_simple_expression() {
972        let definitions = [
973            AttributeDefinition::boolean("private"),
974            AttributeDefinition::string("country"),
975            AttributeDefinition::string_list("deals"),
976            AttributeDefinition::integer("exchange_id"),
977            AttributeDefinition::integer_list("segment_ids"),
978        ];
979        let mut atree = ATree::new(&definitions).unwrap();
980
981        let result = atree.insert(&1u64, AN_EXPRESSION);
982
983        assert!(result.is_ok());
984    }
985
986    #[test]
987    fn can_insert_an_expression_that_refers_to_a_rnode() {
988        let definitions = [
989            AttributeDefinition::boolean("private"),
990            AttributeDefinition::integer("exchange_id"),
991            AttributeDefinition::string_list("deal_ids"),
992        ];
993        let an_expression = "private or exchange_id = 1";
994        let another_expression =
995            r#"private or exchange_id = 1 or deal_ids one of ["deal-1", "deal-2"]"#;
996        let mut atree = ATree::new(&definitions).unwrap();
997        assert!(atree.insert(&1u64, an_expression).is_ok());
998        assert!(atree.insert(&2u64, another_expression).is_ok());
999    }
1000
1001    #[test]
1002    fn can_insert_the_same_expression_multiple_times() {
1003        let definitions = [
1004            AttributeDefinition::boolean("private"),
1005            AttributeDefinition::string("country"),
1006            AttributeDefinition::string_list("deals"),
1007            AttributeDefinition::integer("exchange_id"),
1008            AttributeDefinition::integer_list("segment_ids"),
1009        ];
1010        let mut atree = ATree::new(&definitions).unwrap();
1011
1012        assert!(atree.insert(&1u64, AN_EXPRESSION).is_ok());
1013        assert!(atree.insert(&2u64, AN_EXPRESSION).is_ok());
1014    }
1015
1016    #[test]
1017    fn can_insert_a_negative_expression() {
1018        let definitions = [
1019            AttributeDefinition::boolean("private"),
1020            AttributeDefinition::string("country"),
1021            AttributeDefinition::string_list("deals"),
1022            AttributeDefinition::integer("exchange_id"),
1023            AttributeDefinition::integer_list("segment_ids"),
1024        ];
1025        let mut atree = ATree::new(&definitions).unwrap();
1026
1027        let result = atree.insert(&1u64, A_NOT_EXPRESSION);
1028
1029        assert!(result.is_ok());
1030    }
1031
1032    #[test]
1033    fn can_insert_an_expression_with_and_operators() {
1034        let definitions = [
1035            AttributeDefinition::boolean("private"),
1036            AttributeDefinition::string("country"),
1037            AttributeDefinition::string_list("deals"),
1038            AttributeDefinition::integer("exchange_id"),
1039            AttributeDefinition::integer_list("segment_ids"),
1040        ];
1041        let mut atree = ATree::new(&definitions).unwrap();
1042
1043        let result = atree.insert(&1u64, AN_EXPRESSION_WITH_AND_OPERATORS);
1044
1045        assert!(result.is_ok());
1046    }
1047
1048    #[test]
1049    fn can_insert_an_expression_with_or_operators() {
1050        let definitions = [
1051            AttributeDefinition::boolean("private"),
1052            AttributeDefinition::string("country"),
1053            AttributeDefinition::string_list("deals"),
1054            AttributeDefinition::integer("exchange_id"),
1055            AttributeDefinition::integer_list("segment_ids"),
1056        ];
1057        let mut atree = ATree::new(&definitions).unwrap();
1058
1059        let result = atree.insert(&1u64, AN_EXPRESSION_WITH_OR_OPERATORS);
1060
1061        assert!(result.is_ok());
1062    }
1063
1064    #[test]
1065    fn can_insert_an_expression_with_mixed_operators() {
1066        let definitions = [
1067            AttributeDefinition::boolean("private"),
1068            AttributeDefinition::integer("exchange_id"),
1069            AttributeDefinition::string_list("deal_ids"),
1070            AttributeDefinition::integer_list("segment_ids"),
1071            AttributeDefinition::string("country"),
1072            AttributeDefinition::string("city"),
1073        ];
1074        let mut atree = ATree::new(&definitions).unwrap();
1075
1076        let result = atree.insert(&1u64, A_COMPLEX_EXPRESSION);
1077
1078        assert!(result.is_ok());
1079    }
1080
1081    #[test]
1082    fn can_insert_multiple_expressions_with_mixed_operators() {
1083        let definitions = [
1084            AttributeDefinition::boolean("private"),
1085            AttributeDefinition::integer("exchange_id"),
1086            AttributeDefinition::string_list("deal_ids"),
1087            AttributeDefinition::integer_list("segment_ids"),
1088            AttributeDefinition::string("country"),
1089            AttributeDefinition::string("city"),
1090        ];
1091        let mut atree = ATree::new(&definitions).unwrap();
1092
1093        assert!(atree.insert(&1u64, A_COMPLEX_EXPRESSION).is_ok());
1094        assert!(atree.insert(&2u64, ANOTHER_COMPLEX_EXPRESSION).is_ok());
1095    }
1096
1097    #[test]
1098    fn can_search_an_empty_tree() {
1099        let definitions = [
1100            AttributeDefinition::boolean("private"),
1101            AttributeDefinition::integer("exchange_id"),
1102            AttributeDefinition::string_list("deal_ids"),
1103            AttributeDefinition::string_list("deals"),
1104            AttributeDefinition::integer_list("segment_ids"),
1105            AttributeDefinition::string("country"),
1106            AttributeDefinition::string("city"),
1107        ];
1108        let atree = ATree::new(&definitions).unwrap();
1109        let mut builder = atree.make_event();
1110        builder.with_boolean("private", false).unwrap();
1111        let event = builder.build().unwrap();
1112
1113        let expected: Vec<&u64> = vec![];
1114        let actual = atree.search(&event).unwrap().matches().to_vec();
1115        assert_eq!(expected, actual);
1116    }
1117
1118    #[test]
1119    fn can_search_a_single_predicate() {
1120        let definitions = [
1121            AttributeDefinition::boolean("private"),
1122            AttributeDefinition::integer("exchange_id"),
1123            AttributeDefinition::string_list("deal_ids"),
1124            AttributeDefinition::string_list("deals"),
1125            AttributeDefinition::integer_list("segment_ids"),
1126            AttributeDefinition::string("country"),
1127            AttributeDefinition::string("city"),
1128        ];
1129        let mut atree = ATree::new(&definitions).unwrap();
1130        atree.insert(&1u64, "private").unwrap();
1131        let mut builder = atree.make_event();
1132        builder.with_boolean("private", true).unwrap();
1133        let event = builder.build().unwrap();
1134
1135        let expected = vec![&1u64];
1136        let actual = atree.search(&event).unwrap().matches().to_vec();
1137        assert_eq!(expected, actual);
1138    }
1139
1140    #[test]
1141    fn ignore_results_that_are_not_matched() {
1142        let definitions = [
1143            AttributeDefinition::boolean("private"),
1144            AttributeDefinition::integer("exchange_id"),
1145            AttributeDefinition::string_list("deal_ids"),
1146            AttributeDefinition::string_list("deals"),
1147            AttributeDefinition::integer_list("segment_ids"),
1148            AttributeDefinition::string("country"),
1149            AttributeDefinition::string("city"),
1150        ];
1151        let mut atree = ATree::new(&definitions).unwrap();
1152        atree.insert(&1u64, "private").unwrap();
1153        atree.insert(&2u64, A_COMPLEX_EXPRESSION).unwrap();
1154        let mut builder = atree.make_event();
1155        builder.with_boolean("private", false).unwrap();
1156        let event = builder.build().unwrap();
1157
1158        let expected: Vec<&u64> = vec![];
1159        let actual = atree.search(&event).unwrap().matches().to_vec();
1160        assert_eq!(expected, actual);
1161    }
1162
1163    #[test]
1164    fn can_search_simple_expressions() {
1165        let definitions = [
1166            AttributeDefinition::boolean("private"),
1167            AttributeDefinition::integer("exchange_id"),
1168            AttributeDefinition::string_list("deal_ids"),
1169            AttributeDefinition::string_list("deals"),
1170            AttributeDefinition::integer_list("segment_ids"),
1171            AttributeDefinition::string("country"),
1172            AttributeDefinition::string("city"),
1173        ];
1174        let mut atree = ATree::new(&definitions).unwrap();
1175        atree.insert(&1u64, "private").unwrap();
1176        atree.insert(&2u64, "not private").unwrap();
1177        let mut builder = atree.make_event();
1178        builder.with_boolean("private", true).unwrap();
1179        let event = builder.build().unwrap();
1180
1181        let expected = vec![&1u64];
1182        let mut actual = atree.search(&event).unwrap().matches().to_vec();
1183        actual.sort();
1184        assert_eq!(expected, actual);
1185    }
1186
1187    #[test]
1188    fn can_search_complex_expressions() {
1189        let definitions = [
1190            AttributeDefinition::boolean("private"),
1191            AttributeDefinition::integer("exchange_id"),
1192            AttributeDefinition::string_list("deal_ids"),
1193            AttributeDefinition::string_list("deals"),
1194            AttributeDefinition::integer_list("segment_ids"),
1195            AttributeDefinition::string("country"),
1196            AttributeDefinition::string("city"),
1197        ];
1198        let mut atree = ATree::new(&definitions).unwrap();
1199
1200        atree.insert(&1, A_COMPLEX_EXPRESSION).unwrap();
1201        atree.insert(&2, AN_EXPRESSION_WITH_AND_OPERATORS).unwrap();
1202        atree.insert(&3, AN_EXPRESSION_WITH_OR_OPERATORS).unwrap();
1203        let mut builder = atree.make_event();
1204        builder.with_integer("exchange_id", 1).unwrap();
1205        builder.with_boolean("private", true).unwrap();
1206        builder
1207            .with_string_list("deal_ids", &["deal-1", "deal-2"])
1208            .unwrap();
1209        builder
1210            .with_string_list("deals", &["deal-1", "deal-2"])
1211            .unwrap();
1212        builder.with_integer_list("segment_ids", &[2, 3]).unwrap();
1213        builder.with_string("country", "FR").unwrap();
1214        let event = builder.build().unwrap();
1215
1216        let expected = vec![&2, &3];
1217        let mut actual = atree.search(&event).unwrap().matches().to_vec();
1218        actual.sort();
1219        assert_eq!(expected, actual);
1220    }
1221
1222    #[test]
1223    fn can_search_a_tree_with_multiple_shared_sub_expressions() {
1224        let definitions = [
1225            AttributeDefinition::boolean("private"),
1226            AttributeDefinition::integer("exchange_id"),
1227            AttributeDefinition::string_list("deals"),
1228            AttributeDefinition::integer_list("segment_ids"),
1229            AttributeDefinition::string("country"),
1230            AttributeDefinition::string("city"),
1231        ];
1232        let mut atree = ATree::new(&definitions).unwrap();
1233        [
1234            (
1235                1,
1236                r#"exchange_id = 1 and not private and deals one of ["deal-1", "deal-2"]"#,
1237            ),
1238            (
1239                2,
1240                r#"exchange_id = 1 and not private and deals one of ["deal-2", "deal-3"]"#,
1241            ),
1242            (
1243                3,
1244                r#"exchange_id = 1 and not private and deals one of ["deal-2", "deal-3"] and segment_ids one of [1, 2, 3, 4]"#,
1245            ),
1246            (
1247                4,
1248                r#"exchange_id = 1 and not private and deals one of ["deal-2", "deal-3"] and segment_ids one of [5, 6, 7, 8] and country in ["CA", "US"]"#,
1249            ),
1250        ].into_iter().for_each(|(id, expression)| {
1251                atree.insert(&id, expression).unwrap()
1252        });
1253
1254        let mut builder = atree.make_event();
1255        builder.with_boolean("private", false).unwrap();
1256        builder.with_integer("exchange_id", 1).unwrap();
1257        builder
1258            .with_string_list("deals", &["deal-1", "deal-3"])
1259            .unwrap();
1260        builder.with_integer_list("segment_ids", &[2, 3]).unwrap();
1261        builder.with_string("country", "CA").unwrap();
1262        let event = builder.build().unwrap();
1263
1264        let mut matches = atree.search(&event).unwrap().matches().to_vec();
1265        matches.sort();
1266        assert_eq!(vec![&1, &2, &3], matches);
1267    }
1268
1269    #[test]
1270    fn can_delete_a_single_predicate() {
1271        let definitions = [AttributeDefinition::boolean("private")];
1272        let mut atree = ATree::new(&definitions).unwrap();
1273        atree.insert(&1u64, "private").unwrap();
1274        let mut builder = atree.make_event();
1275        builder.with_boolean("private", true).unwrap();
1276        let event = builder.build().unwrap();
1277
1278        let results = atree.search(&event).unwrap().matches().to_vec();
1279        assert_eq!(vec![&1u64], results);
1280
1281        atree.delete(&1u64);
1282        let mut builder = atree.make_event();
1283        builder.with_boolean("private", true).unwrap();
1284        let event = builder.build().unwrap();
1285        let results = atree.search(&event).unwrap().matches().to_vec();
1286        assert!(results.is_empty());
1287    }
1288
1289    #[test]
1290    fn deleting_an_expression_only_removes_the_id_not_the_expression_if_it_is_still_referenced() {
1291        let definitions = [
1292            AttributeDefinition::boolean("private"),
1293            AttributeDefinition::integer("exchange_id"),
1294            AttributeDefinition::string_list("deal_ids"),
1295            AttributeDefinition::string_list("deals"),
1296            AttributeDefinition::integer_list("segment_ids"),
1297            AttributeDefinition::string("country"),
1298            AttributeDefinition::string("city"),
1299        ];
1300        let an_expression = "private or exchange_id = 1";
1301        let another_expression =
1302            r#"private or exchange_id = 1 or deal_ids one of ["deal-1", "deal-2"]"#;
1303        let mut atree = ATree::new(&definitions).unwrap();
1304        atree.insert(&1u64, an_expression).unwrap();
1305        atree.insert(&2u64, another_expression).unwrap();
1306        let mut builder = atree.make_event();
1307        builder.with_integer("exchange_id", 1).unwrap();
1308        let event = builder.build().unwrap();
1309
1310        let results = atree.search(&event).unwrap().matches().to_vec();
1311        assert_eq!(vec![&1u64, &2u64], results);
1312
1313        atree.delete(&1u64);
1314        let mut builder = atree.make_event();
1315        builder.with_integer("exchange_id", 1).unwrap();
1316        let event = builder.build().unwrap();
1317        let results = atree.search(&event).unwrap().matches().to_vec();
1318        assert_eq!(vec![&2u64], results);
1319    }
1320
1321    #[test]
1322    fn deleting_an_expression_only_removes_the_id_not_the_expression_if_it_has_multiple_subscription_ids(
1323    ) {
1324        let definitions = [
1325            AttributeDefinition::boolean("private"),
1326            AttributeDefinition::integer("exchange_id"),
1327        ];
1328        let an_expression = "private or exchange_id = 1";
1329        let mut atree = ATree::new(&definitions).unwrap();
1330        atree.insert(&1u64, an_expression).unwrap();
1331        atree.insert(&2u64, an_expression).unwrap();
1332        let mut builder = atree.make_event();
1333        builder.with_integer("exchange_id", 1).unwrap();
1334        let event = builder.build().unwrap();
1335
1336        let results = atree.search(&event).unwrap().matches().to_vec();
1337        assert_eq!(vec![&1u64, &2u64], results);
1338
1339        atree.delete(&1u64);
1340        let mut builder = atree.make_event();
1341        builder.with_integer("exchange_id", 1).unwrap();
1342        let event = builder.build().unwrap();
1343        let results = atree.search(&event).unwrap().matches().to_vec();
1344        assert_eq!(vec![&2u64], results);
1345    }
1346
1347    #[test]
1348    fn can_delete_root_node_when_all_references_are_deleted() {
1349        let definitions = [
1350            AttributeDefinition::boolean("private"),
1351            AttributeDefinition::integer("exchange_id"),
1352        ];
1353        let an_expression = "private or exchange_id = 1";
1354        let mut atree = ATree::new(&definitions).unwrap();
1355        atree.insert(&1u64, an_expression).unwrap();
1356        atree.insert(&2u64, an_expression).unwrap();
1357        let mut builder = atree.make_event();
1358        builder.with_integer("exchange_id", 1).unwrap();
1359        let event = builder.build().unwrap();
1360
1361        let results = atree.search(&event).unwrap().matches().to_vec();
1362        assert_eq!(vec![&1u64, &2u64], results);
1363
1364        atree.delete(&1u64);
1365        atree.delete(&2u64);
1366        let mut builder = atree.make_event();
1367        builder.with_integer("exchange_id", 1).unwrap();
1368        let event = builder.build().unwrap();
1369        let results = atree.search(&event).unwrap().matches().to_vec();
1370        assert!(results.is_empty());
1371    }
1372
1373    #[test]
1374    fn can_render_to_graphviz() {
1375        let definitions = [
1376            AttributeDefinition::boolean("private"),
1377            AttributeDefinition::integer("exchange_id"),
1378            AttributeDefinition::string_list("deal_ids"),
1379            AttributeDefinition::string_list("deals"),
1380            AttributeDefinition::integer_list("segment_ids"),
1381            AttributeDefinition::string("country"),
1382            AttributeDefinition::string("city"),
1383        ];
1384        let an_expression = "private or exchange_id = 1";
1385        let another_expression =
1386            r#"private or exchange_id = 1 or deal_ids one of ["deal-1", "deal-2"]"#;
1387        let mut atree = ATree::new(&definitions).unwrap();
1388        atree.insert(&1u64, an_expression).unwrap();
1389        atree.insert(&2u64, another_expression).unwrap();
1390
1391        assert!(!atree.to_graphviz().is_empty());
1392    }
1393}