logify/
eval.rs

1use crate::expr::{Expression, Node};
2
3mod bitwise_eval;
4pub use bitwise_eval::BitwiseEval;
5mod bool_eval;
6pub use bool_eval::BoolEval;
7use serde::{Deserialize, Serialize};
8
9/// Defines how to resolve abstract logic into concrete results.
10///
11/// To run an [`Expression`], you must implement this trait. It acts as the bridge
12/// between the boolean logic graph and your specific domain (e.g., SQL generation, bitmask operations,
13/// search engine query execution).
14///
15/// # Type Parameters
16/// * `T`: The **Term** type used in the expression (e.g., `String` for tags, `u32` for IDs).
17/// * `R`: The **Result** type produced by the evaluation (e.g., `Vec<i32>`, `RoaringBitmap`, `SqlFragment`).
18/// * `E`: The **Error** type that can occur during evaluation.
19///
20/// # Optimization Note
21/// This trait uses `eval_difference` instead of a direct `not` method. This allows implementations
22/// to avoid calculating "Everything except X" (which is often expensive or infinite) and instead
23/// implicitly calculate `A AND NOT B`.
24pub trait Evaluator<T, R, E> {
25    /// Returns the Universal Set (The set of all things).
26    ///
27    /// This is used when the expression resolves to a pure negation (e.g., `!A`).
28    /// To resolve `!A`, the library calculates `Universal - A`.
29    ///
30    /// If your domain does not support a "Universal" set (e.g., an infinite number line),
31    /// you can return an error here, but be aware that top-level negations will fail.
32    fn get_universal(&mut self) -> Result<R, E>; // TODO: Might not be useful
33
34    /// Returns the Empty Set (The set of nothing).
35    fn get_empty(&mut self) -> Result<R, E>;
36
37    /// Resolves a single leaf node value into a result.
38    ///
39    /// # Example
40    /// If `T` is a User ID, this might look up that user in a database and return a `Result`
41    /// containing that user's permissions.
42    fn eval_set(&mut self, set: &T) -> Result<R, E>;
43
44    /// merges multiple results via a Union (OR) operation.
45    ///
46    /// # Arguments
47    /// * `values` - An iterator of results previously computed by `eval_set`, `eval_intersection`, etc.
48    ///
49    /// # Expected Behavior
50    /// Return a result containing items present in **at least one** of the input values.
51    fn eval_union<'a, I>(&mut self, values: I) -> Result<R, E>
52    where
53        R: 'a,
54        I: IntoIterator<Item = &'a R>,
55        I::IntoIter: ExactSizeIterator;
56
57    /// Filters multiple results via an Intersection (AND) operation.
58    ///
59    /// # Arguments
60    /// * `values` - An iterator of results previously computed by `eval_set`, `eval_intersection`, etc.
61    ///
62    /// # Expected Behavior
63    /// Return a result containing only items present in **all** of the input values.
64    fn eval_intersection<'a, I>(&mut self, values: I) -> Result<R, E>
65    where
66        R: 'a,
67        I: IntoIterator<Item = &'a R>,
68        I::IntoIter: ExactSizeIterator;
69
70    /// Calculates the difference between two results (`Include AND NOT Exclude`).
71    ///
72    /// This is used to handle negation. The expression engine transforms negations
73    /// into difference operations where possible to avoid materializing the Universal set.
74    ///
75    /// * `!A` becomes `eval_difference(Universal, A)`
76    /// * `A & !B` becomes `eval_difference(A, B)`
77    ///
78    /// # Arguments
79    /// * `include` - The base set of items.
80    /// * `exclude` - The set of items to remove from the base set.
81    fn eval_difference(&mut self, include: &R, exclude: &R) -> Result<R, E>;
82}
83
84/// A reusable memory buffer for expression evaluation.
85///
86/// When evaluating an expression multiple times (e.g., against different rows in a database),
87/// allocating new vectors for every calculation is inefficient. `EvaluatorCache` holds onto
88/// the allocated memory between runs.
89///
90/// Use this to avoid repeated allocations when evaluating the same expression multiple times.
91///
92/// # Automatic Invalidation
93/// This struct stores a version UUID of the expression it was last used with. If you pass
94/// this cache to `evaluate_with` on a modified or completely different expression, it will
95/// automatically detect the mismatch and clear itself.
96///
97/// # Memory & Performance
98/// * **Allocations:** Reuses internal vectors to minimize heap traffic.
99/// * **Cloning:** When `evaluate_with` returns, the final results for the roots are **cloned**
100///   from this cache.
101///   * If your result type `R` is large (e.g., a 10MB Bitmap or large Vector), this clone
102///     can be expensive.
103///   * **Recommendation:** Wrap large results in [`std::sync::Arc`] or [`std::rc::Rc`] so that
104///     cloning is cheap (pointer copy) rather than deep.
105///
106/// # Example
107/// ```rust
108/// use logify::{EvaluatorCache, ExpressionBuilder, Evaluator};
109/// # // Mock Setup (Hidden from docs)
110/// # struct Solver;
111/// # impl Evaluator<&str, bool, ()> for Solver {
112/// #     fn get_universal(&mut self) -> Result<bool, ()> { Ok(true) }
113/// #     fn get_empty(&mut self) -> Result<bool, ()> { Ok(false) }
114/// #     fn eval_set(&mut self, _: &&str) -> Result<bool, ()> { Ok(true) }
115/// #     fn eval_union<'a, I>(&mut self, _: I) -> Result<bool, ()> where I: IntoIterator<Item=&'a bool>, I::IntoIter: ExactSizeIterator { Ok(true) }
116/// #     fn eval_intersection<'a, I>(&mut self, _: I) -> Result<bool, ()> where I: IntoIterator<Item=&'a bool>, I::IntoIter: ExactSizeIterator { Ok(true) }
117/// #     fn eval_difference(&mut self, _: &bool, _: &bool) -> Result<bool, ()> { Ok(true) }
118/// # }
119///
120/// // Setup
121/// let mut cache = EvaluatorCache::new();
122/// let mut solver = Solver;
123///
124/// // Build a simple expression
125/// let builder = ExpressionBuilder::new();
126/// builder.add_root(builder.leaf("A"));
127/// let expr = builder.build();
128///
129/// let dataset = vec!["Row1", "Row2", "Row3"];
130///
131/// // Fast: Reuses the same vectors for every iteration
132/// for item in dataset {
133///     // In a real scenario, you would update the roots or append expressions each time.
134///     let result = expr.evaluate_with(&mut solver, &mut cache);
135/// }
136/// ```
137#[cfg_attr(feature = "fast-binary", derive(bitcode::Encode, bitcode::Decode))]
138#[derive(Serialize, Deserialize)]
139pub struct EvaluatorCache<R> {
140    pub(crate) cache: Vec<Option<R>>,
141    pub(crate) include_indices: Vec<usize>,
142    pub(crate) exclude_indices: Vec<usize>,
143    pub(crate) expr_uuid: u128, // 0 for an uninitialized cache
144}
145
146impl<R> Default for EvaluatorCache<R> {
147    fn default() -> Self {
148        Self {
149            cache: Vec::new(),
150            include_indices: Vec::new(),
151            exclude_indices: Vec::new(),
152            expr_uuid: 0,
153        }
154    }
155}
156
157impl<R> EvaluatorCache<R> {
158    /// Creates a new, empty cache.
159    pub fn new() -> Self {
160        Self::default()
161    }
162
163    /// Manually clears the internal buffers and resets the versioning.
164    ///
165    /// Usually not necessary, as `evaluate_with` handles invalidation automatically.
166    pub fn clear(&mut self) {
167        self.cache.clear();
168        self.expr_uuid = 0; // mark as uninitialized
169    }
170}
171
172impl<T> Expression<T> {
173    /// Evaluates the expression using a temporary cache.
174    ///
175    /// This is a convenience wrapper around [`evaluate_with`](Self::evaluate_with).
176    /// It creates a fresh `EvaluatorCache`, runs the evaluation, and then drops the cache.
177    ///
178    /// # Performance Note
179    /// Because this allocates memory for every call, it is not recommended for tight loops.
180    /// Use `evaluate_with` for repeated evaluations.
181    pub fn evaluate<R, E, S>(&self, solver: &mut S) -> Result<Vec<R>, E>
182    where
183        R: Clone,
184        S: Evaluator<T, R, E>,
185    {
186        let mut cache = EvaluatorCache::new();
187        self.evaluate_with(solver, &mut cache)
188    }
189
190    /// Evaluates the expression using a persistent, external cache.
191    ///
192    /// This is the most efficient way to evaluate an expression multiple times.
193    ///
194    /// # How it works
195    /// 1. **Validation:** Checks if the `cache` matches the current expression's UUID. If not, it clears the cache.
196    /// 2. **Execution:** Iterates through the node graph. Intermediate results are stored in the cache.
197    /// 3. **Reuse:** If called again, the internal vectors (`Vec<Option<R>>`) are reused, preventing heap allocation overhead.
198    ///
199    /// # Cache Invalidation
200    /// The cache is tied to the structure of the expression. Modifying the expression
201    /// (e.g., via `compress()`, `prune()`, or `optimize()`) changes the UUID, causing
202    /// the cache to reset on the next call.
203    pub fn evaluate_with<R, E, S>(
204        &self,
205        solver: &mut S,
206        cache: &mut EvaluatorCache<R>,
207    ) -> Result<Vec<R>, E>
208    where
209        R: Clone,
210        S: Evaluator<T, R, E>,
211    {
212        // cache validation
213        if cache.expr_uuid != self.uuid {
214            cache.clear();
215            cache.expr_uuid = self.uuid;
216        }
217
218        // load cache
219        let cache_vec = &mut cache.cache;
220        if cache_vec.len() < self.nodes.len() * 2 {
221            cache_vec.resize(self.nodes.len() * 2, None);
222        }
223
224        // initialize active nodes with the roots to find
225        let mut max_root = 0; // furthest root location, node 0 has no children, so safe as a flag to avoid finding children
226        let mut active = vec![false; self.nodes.len()];
227        for root in &self.roots {
228            // skip over already loaded roots
229            if cache_vec[root.idx() << 1].is_none() {
230                active[root.idx()] = true;
231                if root.idx() > max_root {
232                    max_root = root.idx();
233                }
234            }
235        }
236
237        // finds all children of uncomputed roots
238        if max_root != 0 {
239            for idx in (0..self.nodes.len()).rev() {
240                if !active[idx] {
241                    continue;
242                } // dead node
243                // activate all children
244                match &self.nodes[idx] {
245                    Node::Union(kids) | Node::Intersection(kids) => {
246                        for k in kids {
247                            active[k.idx()] = true;
248                        }
249                    }
250                    _ => {}
251                }
252            }
253        }
254
255        // evaluate each node
256        for (idx, node) in self.nodes.iter().enumerate() {
257            if idx > max_root {
258                break;
259            } // only evaluate up to the last needed root
260            if !active[idx] {
261                continue;
262            } // skips non-active nodes
263            if cache_vec[idx << 1].is_some() {
264                continue;
265            } // already evaluated
266
267            // node must be calculated
268            let result = Self::evaluate_node(
269                node,
270                solver,
271                cache_vec,
272                &mut cache.include_indices,
273                &mut cache.exclude_indices,
274            )?;
275            cache_vec[idx << 1] = Some(result);
276        }
277
278        // all root positives are now in cache
279        let mut results = Vec::with_capacity(self.roots.len());
280        for root in &self.roots {
281            if let Some(res) = &cache_vec[root.raw() as usize] {
282                results.push(res.clone());
283            } else {
284                if cache_vec[1].is_none() {
285                    cache_vec[1] = Some(solver.get_universal()?);
286                }
287                let uni = cache_vec[1].as_ref().unwrap();
288                if root.raw() == 1 {
289                    results.push(uni.clone());
290                } else {
291                    let pos = cache_vec[root.idx() << 1].as_ref().unwrap();
292                    let neg = solver.eval_difference(uni, pos)?;
293                    cache_vec[root.raw() as usize] = Some(neg.clone());
294                    results.push(neg);
295                }
296            }
297        }
298        Ok(results)
299    }
300
301    /// Evaluates the expression while aggressively freeing memory.
302    ///
303    /// Unlike standard evaluation, which keeps all intermediate results until the end,
304    /// this method calculates reference counts for every node. As soon as a node's
305    /// result is consumed by all its parents, the memory is dropped.
306    ///
307    /// # Trade-offs
308    /// * **Pros:** Significantly lower peak memory usage. Ideal for very large result types (e.g., Bitmaps, Images).
309    /// * **Cons:** Slower execution speed due to the overhead of calculating reference counts and dropping values during iteration.
310    pub fn evaluate_with_pruning<R, E, S>(&self, solver: &mut S) -> Result<Vec<R>, E>
311    where
312        R: Clone,
313        S: Evaluator<T, R, E>,
314    {
315        // create cache
316        let mut cache = vec![None; self.nodes.len() * 2];
317        let mut include_indices = Vec::new();
318        let mut exclude_indices = Vec::new();
319
320        // construct the counts
321        let mut counts = vec![0; self.nodes.len()];
322        for &root in &self.roots {
323            // retain roots until the end
324            counts[root.idx()] += 1;
325        }
326        for idx in (0..self.nodes.len()).rev() {
327            if counts[idx] == 0 {
328                continue;
329            } // dead node
330            match &self.nodes[idx] {
331                Node::Union(kids) | Node::Intersection(kids) => {
332                    for k in kids {
333                        counts[k.idx()] += 1;
334                    }
335                }
336                _ => {}
337            }
338        }
339
340        // traverse the expression linearly
341        for (idx, node) in self.nodes.iter().enumerate() {
342            if counts[idx] == 0 {
343                continue;
344            } // node isn't used
345            if cache[idx << 1].is_some() {
346                continue;
347            } // already evaluated
348
349            // node must be calculated
350            let result = Self::evaluate_node(
351                node,
352                solver,
353                &mut cache,
354                &mut include_indices,
355                &mut exclude_indices,
356            )?;
357            cache[idx << 1] = Some(result);
358
359            // decrement and remove cache if there are no more parents
360            match node {
361                Node::Union(kids) | Node::Intersection(kids) => {
362                    for k in kids {
363                        counts[k.idx()] -= 1;
364                        if counts[k.idx()] == 0 {
365                            cache[k.idx() << 1] = None;
366                            cache[(k.idx() << 1) + 1] = None;
367                        }
368                    }
369                }
370                _ => {}
371            }
372        }
373
374        // all root positives are now in cache
375        let mut results = Vec::with_capacity(self.roots.len());
376        for root in &self.roots {
377            if let Some(res) = &cache[root.raw() as usize] {
378                // root in cache
379                results.push(res.clone());
380            } else {
381                // root not in cache, must be negative and positive must be in cache
382                if cache[1].is_none() {
383                    cache[1] = Some(solver.get_universal()?);
384                }
385                let uni = cache[1].as_ref().unwrap();
386                if root.raw() == 1 {
387                    results.push(uni.clone());
388                } else {
389                    let pos = cache[root.idx() << 1].as_ref().unwrap();
390                    let neg = solver.eval_difference(uni, pos)?;
391                    cache[root.raw() as usize] = Some(neg.clone());
392                    results.push(neg);
393                }
394            }
395        }
396        Ok(results)
397    }
398
399    #[inline]
400    fn evaluate_node<R, E, S>(
401        node: &Node<T>,
402        solver: &mut S,
403        cache_vec: &mut [Option<R>],
404        include_indices: &mut Vec<usize>,
405        exclude_indices: &mut Vec<usize>,
406    ) -> Result<R, E>
407    where
408        R: Clone,
409        S: Evaluator<T, R, E>,
410    {
411        match node {
412            Node::Empty => Ok(solver.get_empty()?),
413            Node::Set(set) => Ok(solver.eval_set(set)?),
414            Node::Union(kids) => {
415                // make sure all negated terms are calculated
416                let (uni_cache, other_cache) = cache_vec.split_at_mut(2);
417                for k in kids {
418                    let idx = k.raw() as usize - 2;
419                    let pos_idx = (k.idx() << 1) - 2;
420                    if other_cache[idx].is_none() {
421                        // must be negative
422                        let uni = uni_cache[1].get_or_insert(solver.get_universal()?);
423                        let pos = other_cache[pos_idx].as_ref().unwrap();
424                        let neg = solver.eval_difference(uni, pos)?;
425                        other_cache[idx] = Some(neg); // add negative to cache
426                    }
427                }
428                // evaluate the union
429                Ok(solver.eval_union(
430                    kids.iter()
431                        .map(|k| cache_vec[k.raw() as usize].as_ref().unwrap()),
432                )?)
433            }
434            Node::Intersection(kids) => {
435                // A&B&C'&D' == (A&B)-(C|D)
436                include_indices.clear();
437                exclude_indices.clear();
438                for k in kids {
439                    if k.is_neg() {
440                        if cache_vec[k.raw() as usize].is_some() {
441                            // & is faster, so if the negative is computed, include it
442                            include_indices.push(k.raw() as usize);
443                        } else {
444                            // negative is not computed, so exclude the positive
445                            exclude_indices.push(k.idx() << 1);
446                        }
447                    } else {
448                        // k is positive, include it
449                        include_indices.push(k.raw() as usize);
450                    }
451                }
452
453                // intersections must have at least two terms
454                if exclude_indices.is_empty() {
455                    // no exclusions so use the include as the result
456                    let include = solver.eval_intersection(
457                        include_indices
458                            .iter()
459                            .map(|&i| cache_vec[i].as_ref().unwrap()),
460                    )?;
461                    Ok(include)
462                } else {
463                    // get include
464                    let include = if include_indices.is_empty() {
465                        // use universe if no inclusions are present
466                        if cache_vec[1].is_none() {
467                            cache_vec[1] = Some(solver.get_universal()?);
468                        }
469                        cache_vec[1].as_ref().unwrap()
470                    } else if include_indices.len() == 1 {
471                        cache_vec[include_indices[0]].as_ref().unwrap()
472                    } else {
473                        &solver.eval_intersection(
474                            include_indices
475                                .iter()
476                                .map(|&i| cache_vec[i].as_ref().unwrap()),
477                        )?
478                    };
479
480                    // get exclude (must be more than 1)
481                    let exclude = if exclude_indices.len() == 1 {
482                        cache_vec[exclude_indices[0]].as_ref().unwrap()
483                    } else {
484                        &solver.eval_union(
485                            exclude_indices
486                                .iter()
487                                .map(|&i| cache_vec[i].as_ref().unwrap()),
488                        )?
489                    };
490
491                    // compute difference
492                    Ok(solver.eval_difference(include, exclude)?)
493                }
494            }
495        }
496    }
497}