espresso_logic/expression/
ast.rs

1//! AST representation and tree traversal operations
2//!
3//! This module contains the AST types and fold operations for boolean expressions.
4
5use super::BoolExpr;
6use std::sync::Arc;
7
8/// Node type for expression tree folding
9///
10/// This enum represents the structure of an expression node without exposing
11/// internal Arc types. It's used with [`BoolExpr::fold`] and [`BoolExpr::fold_with_context`]
12/// to traverse and transform expression trees.
13///
14/// # Generic Parameter
15///
16/// - For [`BoolExpr::fold`]: `T` represents the accumulated result from child nodes (bottom-up)
17/// - For [`BoolExpr::fold_with_context`]: `T` is `()` since context flows top-down via closures
18///
19/// # Examples
20///
21/// See [`BoolExpr::fold`] and [`BoolExpr::fold_with_context`] for detailed usage examples.
22///
23/// [`BoolExpr::fold`]: BoolExpr::fold
24/// [`BoolExpr::fold_with_context`]: BoolExpr::fold_with_context
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum ExprNode<'a, T> {
27    /// A variable with the given name
28    Variable(&'a str),
29    /// Logical AND with results from left and right subtrees
30    And(T, T),
31    /// Logical OR with results from left and right subtrees
32    Or(T, T),
33    /// Logical NOT with result from inner subtree
34    Not(T),
35    /// A constant boolean value
36    Constant(bool),
37}
38
39/// AST representation of a boolean expression
40///
41/// Pure AST tree structure - holds Arc<BoolExprAst> children, not BoolExpr.
42/// This allows the AST to be reconstructed from BDD without circular dependencies.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub(crate) enum BoolExprAst {
45    /// A named variable
46    Variable(Arc<str>),
47    /// Logical AND of two expressions
48    And(Arc<BoolExprAst>, Arc<BoolExprAst>),
49    /// Logical OR of two expressions
50    Or(Arc<BoolExprAst>, Arc<BoolExprAst>),
51    /// Logical NOT of an expression
52    Not(Arc<BoolExprAst>),
53    /// A constant value (true or false)
54    Constant(bool),
55}
56
57impl BoolExpr {
58    /// Get or create the AST representation from the BDD
59    ///
60    /// This is called internally when AST is needed (for display, fold, etc.)
61    ///
62    /// Uses factorisation-based reconstruction for beautiful, compact expressions.
63    pub(super) fn get_or_create_ast(&self) -> Arc<BoolExprAst> {
64        // Check if we have a cached AST
65        if let Some(ast) = self.ast_cache.get() {
66            return Arc::clone(ast);
67        }
68
69        // Need to reconstruct from BDD using factorisation
70        let ast = self.to_ast_optimised();
71
72        // Try to store it (may fail if another thread beat us to it, that's fine)
73        let _ = self.ast_cache.set(Arc::clone(&ast));
74
75        ast
76    }
77
78    /// Convert BDD to optimised AST representation using factorisation
79    ///
80    /// Extracts cubes from the BDD and applies algebraic factorisation to produce
81    /// a compact, readable expression.
82    ///
83    /// Uses local caching for both DNF and AST.
84    pub(super) fn to_ast_optimised(&self) -> Arc<BoolExprAst> {
85        // Check local AST cache first
86        if let Some(ast) = self.ast_cache.get() {
87            return Arc::clone(ast);
88        }
89
90        // AST not cached, get DNF and factorise
91        let dnf = self.get_or_create_dnf();
92
93        // Convert DNF cubes to the format expected by factorisation
94        let cube_terms: Vec<(std::collections::BTreeMap<Arc<str>, bool>, bool)> = dnf
95            .cubes()
96            .iter()
97            .map(|cube| (cube.clone(), true))
98            .collect();
99
100        // Use factorisation to build a nice AST
101        let ast = crate::expression::factorization::factorise_cubes_to_ast(cube_terms);
102
103        // Cache locally
104        let _ = self.ast_cache.set(Arc::clone(&ast));
105
106        ast
107    }
108
109    /// Fold the expression tree depth-first from leaves to root
110    ///
111    /// This method traverses the expression tree recursively, calling the provided
112    /// function `f` on each node. The function receives an [`ExprNode`] containing
113    /// the node type and accumulated results from child nodes.
114    ///
115    /// This is useful for implementing custom expression transformations and analyses
116    /// without needing access to private expression internals.
117    ///
118    /// # Examples
119    ///
120    /// Count the number of operations in an expression:
121    ///
122    /// ```
123    /// use espresso_logic::{BoolExpr, ExprNode};
124    ///
125    /// let a = BoolExpr::variable("a");
126    /// let b = BoolExpr::variable("b");
127    /// let expr = a.and(&b);
128    ///
129    /// let op_count = expr.fold(|node| match node {
130    ///     ExprNode::Variable(_) | ExprNode::Constant(_) => 0,
131    ///     ExprNode::And(l, r) | ExprNode::Or(l, r) => l + r + 1,
132    ///     ExprNode::Not(inner) => inner + 1,
133    /// });
134    ///
135    /// assert_eq!(op_count, 1); // Just AND
136    /// ```
137    pub fn fold<T, F>(&self, f: F) -> T
138    where
139        F: Fn(ExprNode<T>) -> T + Copy,
140    {
141        self.fold_impl(&f)
142    }
143
144    fn fold_impl<T, F>(&self, f: &F) -> T
145    where
146        F: Fn(ExprNode<T>) -> T,
147    {
148        let ast = self.get_or_create_ast();
149        Self::fold_ast(&ast, f)
150    }
151
152    /// Fold over an AST (helper for fold_impl)
153    fn fold_ast<T, F>(ast: &BoolExprAst, f: &F) -> T
154    where
155        F: Fn(ExprNode<T>) -> T,
156    {
157        match ast {
158            BoolExprAst::Variable(name) => f(ExprNode::Variable(name)),
159            BoolExprAst::And(left, right) => {
160                let left_result = Self::fold_ast(left, f);
161                let right_result = Self::fold_ast(right, f);
162                f(ExprNode::And(left_result, right_result))
163            }
164            BoolExprAst::Or(left, right) => {
165                let left_result = Self::fold_ast(left, f);
166                let right_result = Self::fold_ast(right, f);
167                f(ExprNode::Or(left_result, right_result))
168            }
169            BoolExprAst::Not(inner) => {
170                let inner_result = Self::fold_ast(inner, f);
171                f(ExprNode::Not(inner_result))
172            }
173            BoolExprAst::Constant(val) => f(ExprNode::Constant(*val)),
174        }
175    }
176
177    /// Fold with context parameter passed top-down through the tree
178    ///
179    /// Unlike [`fold`], which passes results bottom-up from children to parents,
180    /// this method passes a context parameter top-down from parents to children.
181    /// The function `f` receives the current node type, context from parent,
182    /// and closures to recursively process children with modified context.
183    ///
184    /// This is useful for operations like applying De Morgan's laws where negations
185    /// need to be pushed down through the tree.
186    ///
187    /// # Examples
188    ///
189    /// Count depth with context tracking current level:
190    ///
191    /// ```
192    /// use espresso_logic::{BoolExpr, ExprNode};
193    ///
194    /// let a = BoolExpr::variable("a");
195    /// let b = BoolExpr::variable("b");
196    /// let expr = a.and(&b).not();
197    ///
198    /// // Count depth with context tracking current level
199    /// let max_depth = expr.fold_with_context(0, |node, depth, recurse_left, recurse_right| {
200    ///     match node {
201    ///         ExprNode::Variable(_) | ExprNode::Constant(_) => depth,
202    ///         ExprNode::Not(_) => recurse_left(depth + 1),
203    ///         ExprNode::And(_, _) | ExprNode::Or(_, _) => {
204    ///             let left_depth = recurse_left(depth + 1);
205    ///             let right_depth = recurse_right(depth + 1);
206    ///             left_depth.max(right_depth)
207    ///         }
208    ///     }
209    /// });
210    /// ```
211    ///
212    /// Apply De Morgan's laws to push negations down:
213    ///
214    /// ```
215    /// use espresso_logic::{BoolExpr, ExprNode};
216    /// use std::collections::BTreeMap;
217    /// use std::sync::Arc;
218    ///
219    /// fn to_dnf_naive(expr: &BoolExpr) -> Vec<BTreeMap<Arc<str>, bool>> {
220    ///     expr.fold_with_context(false, |node, negate, recurse_left, recurse_right| {
221    ///         match node {
222    ///             ExprNode::Variable(name) => {
223    ///                 let mut cube = BTreeMap::new();
224    ///                 cube.insert(Arc::from(name), !negate);
225    ///                 vec![cube]
226    ///             }
227    ///             ExprNode::Not(()) => recurse_left(!negate), // Flip negation
228    ///             ExprNode::And((), ()) if negate => {
229    ///                 // De Morgan: ~(A * B) = ~A + ~B
230    ///                 let mut result = recurse_left(true);
231    ///                 result.extend(recurse_right(true));
232    ///                 result
233    ///             }
234    ///             ExprNode::Or((), ()) if negate => {
235    ///                 // De Morgan: ~(A + B) = ~A * ~B (cross product)
236    ///                 vec![] // Simplified for example
237    ///             }
238    ///             _ => vec![] // Other cases omitted
239    ///         }
240    ///     })
241    /// }
242    /// ```
243    ///
244    /// [`fold`]: BoolExpr::fold
245    pub fn fold_with_context<C, T, F>(&self, context: C, f: F) -> T
246    where
247        C: Copy,
248        F: Fn(
249                ExprNode<()>,
250                C,
251                &dyn Fn(C) -> T, // recurse_left/inner
252                &dyn Fn(C) -> T, // recurse_right
253            ) -> T
254            + Copy,
255    {
256        self.fold_with_context_impl(context, &f)
257    }
258
259    fn fold_with_context_impl<C, T, F>(&self, context: C, f: &F) -> T
260    where
261        C: Copy,
262        F: Fn(ExprNode<()>, C, &dyn Fn(C) -> T, &dyn Fn(C) -> T) -> T,
263    {
264        let ast = self.get_or_create_ast();
265        Self::fold_with_context_ast(&ast, context, f)
266    }
267
268    /// Fold with context over an AST (helper for fold_with_context_impl)
269    fn fold_with_context_ast<C, T, F>(ast: &BoolExprAst, context: C, f: &F) -> T
270    where
271        C: Copy,
272        F: Fn(ExprNode<()>, C, &dyn Fn(C) -> T, &dyn Fn(C) -> T) -> T,
273    {
274        match ast {
275            BoolExprAst::Variable(name) => f(
276                ExprNode::Variable(name),
277                context,
278                &|_| unreachable!(),
279                &|_| unreachable!(),
280            ),
281            BoolExprAst::Constant(val) => f(
282                ExprNode::Constant(*val),
283                context,
284                &|_| unreachable!(),
285                &|_| unreachable!(),
286            ),
287            BoolExprAst::Not(inner) => {
288                let recurse = |ctx: C| Self::fold_with_context_ast(inner, ctx, f);
289                f(ExprNode::Not(()), context, &recurse, &|_| unreachable!())
290            }
291            BoolExprAst::And(left, right) => {
292                let recurse_left = |ctx: C| Self::fold_with_context_ast(left, ctx, f);
293                let recurse_right = |ctx: C| Self::fold_with_context_ast(right, ctx, f);
294                f(
295                    ExprNode::And((), ()),
296                    context,
297                    &recurse_left,
298                    &recurse_right,
299                )
300            }
301            BoolExprAst::Or(left, right) => {
302                let recurse_left = |ctx: C| Self::fold_with_context_ast(left, ctx, f);
303                let recurse_right = |ctx: C| Self::fold_with_context_ast(right, ctx, f);
304                f(ExprNode::Or((), ()), context, &recurse_left, &recurse_right)
305            }
306        }
307    }
308}