espresso_logic/expression/ast.rs
1//! AST representation and tree traversal operations
2//!
3//! This module contains the AST types and fold operations for boolean expressions.
4
5use super::BoolExpr;
6use std::sync::Arc;
7
8/// Node type for expression tree folding
9///
10/// This enum represents the structure of an expression node without exposing
11/// internal Arc types. It's used with [`BoolExpr::fold`] and [`BoolExpr::fold_with_context`]
12/// to traverse and transform expression trees.
13///
14/// # Generic Parameter
15///
16/// - For [`BoolExpr::fold`]: `T` represents the accumulated result from child nodes (bottom-up)
17/// - For [`BoolExpr::fold_with_context`]: `T` is `()` since context flows top-down via closures
18///
19/// # Examples
20///
21/// See [`BoolExpr::fold`] and [`BoolExpr::fold_with_context`] for detailed usage examples.
22///
23/// [`BoolExpr::fold`]: BoolExpr::fold
24/// [`BoolExpr::fold_with_context`]: BoolExpr::fold_with_context
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum ExprNode<'a, T> {
27 /// A variable with the given name
28 Variable(&'a str),
29 /// Logical AND with results from left and right subtrees
30 And(T, T),
31 /// Logical OR with results from left and right subtrees
32 Or(T, T),
33 /// Logical NOT with result from inner subtree
34 Not(T),
35 /// A constant boolean value
36 Constant(bool),
37}
38
39/// AST representation of a boolean expression
40///
41/// Pure AST tree structure - holds Arc<BoolExprAst> children, not BoolExpr.
42/// This allows the AST to be reconstructed from BDD without circular dependencies.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub(crate) enum BoolExprAst {
45 /// A named variable
46 Variable(Arc<str>),
47 /// Logical AND of two expressions
48 And(Arc<BoolExprAst>, Arc<BoolExprAst>),
49 /// Logical OR of two expressions
50 Or(Arc<BoolExprAst>, Arc<BoolExprAst>),
51 /// Logical NOT of an expression
52 Not(Arc<BoolExprAst>),
53 /// A constant value (true or false)
54 Constant(bool),
55}
56
57impl BoolExpr {
58 /// Get or create the AST representation from the BDD
59 ///
60 /// This is called internally when AST is needed (for display, fold, etc.)
61 ///
62 /// Uses factorisation-based reconstruction for beautiful, compact expressions.
63 pub(super) fn get_or_create_ast(&self) -> Arc<BoolExprAst> {
64 // Check if we have a cached AST
65 if let Some(ast) = self.ast_cache.get() {
66 return Arc::clone(ast);
67 }
68
69 // Need to reconstruct from BDD using factorisation
70 let ast = self.to_ast_optimised();
71
72 // Try to store it (may fail if another thread beat us to it, that's fine)
73 let _ = self.ast_cache.set(Arc::clone(&ast));
74
75 ast
76 }
77
78 /// Convert BDD to optimised AST representation using factorisation
79 ///
80 /// Extracts cubes from the BDD and applies algebraic factorisation to produce
81 /// a compact, readable expression.
82 ///
83 /// Uses local caching for both DNF and AST.
84 pub(super) fn to_ast_optimised(&self) -> Arc<BoolExprAst> {
85 // Check local AST cache first
86 if let Some(ast) = self.ast_cache.get() {
87 return Arc::clone(ast);
88 }
89
90 // AST not cached, get DNF and factorise
91 let dnf = self.get_or_create_dnf();
92
93 // Convert DNF cubes to the format expected by factorisation
94 let cube_terms: Vec<(std::collections::BTreeMap<Arc<str>, bool>, bool)> = dnf
95 .cubes()
96 .iter()
97 .map(|cube| (cube.clone(), true))
98 .collect();
99
100 // Use factorisation to build a nice AST
101 let ast = crate::expression::factorization::factorise_cubes_to_ast(cube_terms);
102
103 // Cache locally
104 let _ = self.ast_cache.set(Arc::clone(&ast));
105
106 ast
107 }
108
109 /// Fold the expression tree depth-first from leaves to root
110 ///
111 /// This method traverses the expression tree recursively, calling the provided
112 /// function `f` on each node. The function receives an [`ExprNode`] containing
113 /// the node type and accumulated results from child nodes.
114 ///
115 /// This is useful for implementing custom expression transformations and analyses
116 /// without needing access to private expression internals.
117 ///
118 /// # Examples
119 ///
120 /// Count the number of operations in an expression:
121 ///
122 /// ```
123 /// use espresso_logic::{BoolExpr, ExprNode};
124 ///
125 /// let a = BoolExpr::variable("a");
126 /// let b = BoolExpr::variable("b");
127 /// let expr = a.and(&b);
128 ///
129 /// let op_count = expr.fold(|node| match node {
130 /// ExprNode::Variable(_) | ExprNode::Constant(_) => 0,
131 /// ExprNode::And(l, r) | ExprNode::Or(l, r) => l + r + 1,
132 /// ExprNode::Not(inner) => inner + 1,
133 /// });
134 ///
135 /// assert_eq!(op_count, 1); // Just AND
136 /// ```
137 pub fn fold<T, F>(&self, f: F) -> T
138 where
139 F: Fn(ExprNode<T>) -> T + Copy,
140 {
141 self.fold_impl(&f)
142 }
143
144 fn fold_impl<T, F>(&self, f: &F) -> T
145 where
146 F: Fn(ExprNode<T>) -> T,
147 {
148 let ast = self.get_or_create_ast();
149 Self::fold_ast(&ast, f)
150 }
151
152 /// Fold over an AST (helper for fold_impl)
153 fn fold_ast<T, F>(ast: &BoolExprAst, f: &F) -> T
154 where
155 F: Fn(ExprNode<T>) -> T,
156 {
157 match ast {
158 BoolExprAst::Variable(name) => f(ExprNode::Variable(name)),
159 BoolExprAst::And(left, right) => {
160 let left_result = Self::fold_ast(left, f);
161 let right_result = Self::fold_ast(right, f);
162 f(ExprNode::And(left_result, right_result))
163 }
164 BoolExprAst::Or(left, right) => {
165 let left_result = Self::fold_ast(left, f);
166 let right_result = Self::fold_ast(right, f);
167 f(ExprNode::Or(left_result, right_result))
168 }
169 BoolExprAst::Not(inner) => {
170 let inner_result = Self::fold_ast(inner, f);
171 f(ExprNode::Not(inner_result))
172 }
173 BoolExprAst::Constant(val) => f(ExprNode::Constant(*val)),
174 }
175 }
176
177 /// Fold with context parameter passed top-down through the tree
178 ///
179 /// Unlike [`fold`], which passes results bottom-up from children to parents,
180 /// this method passes a context parameter top-down from parents to children.
181 /// The function `f` receives the current node type, context from parent,
182 /// and closures to recursively process children with modified context.
183 ///
184 /// This is useful for operations like applying De Morgan's laws where negations
185 /// need to be pushed down through the tree.
186 ///
187 /// # Examples
188 ///
189 /// Count depth with context tracking current level:
190 ///
191 /// ```
192 /// use espresso_logic::{BoolExpr, ExprNode};
193 ///
194 /// let a = BoolExpr::variable("a");
195 /// let b = BoolExpr::variable("b");
196 /// let expr = a.and(&b).not();
197 ///
198 /// // Count depth with context tracking current level
199 /// let max_depth = expr.fold_with_context(0, |node, depth, recurse_left, recurse_right| {
200 /// match node {
201 /// ExprNode::Variable(_) | ExprNode::Constant(_) => depth,
202 /// ExprNode::Not(_) => recurse_left(depth + 1),
203 /// ExprNode::And(_, _) | ExprNode::Or(_, _) => {
204 /// let left_depth = recurse_left(depth + 1);
205 /// let right_depth = recurse_right(depth + 1);
206 /// left_depth.max(right_depth)
207 /// }
208 /// }
209 /// });
210 /// ```
211 ///
212 /// Apply De Morgan's laws to push negations down:
213 ///
214 /// ```
215 /// use espresso_logic::{BoolExpr, ExprNode};
216 /// use std::collections::BTreeMap;
217 /// use std::sync::Arc;
218 ///
219 /// fn to_dnf_naive(expr: &BoolExpr) -> Vec<BTreeMap<Arc<str>, bool>> {
220 /// expr.fold_with_context(false, |node, negate, recurse_left, recurse_right| {
221 /// match node {
222 /// ExprNode::Variable(name) => {
223 /// let mut cube = BTreeMap::new();
224 /// cube.insert(Arc::from(name), !negate);
225 /// vec![cube]
226 /// }
227 /// ExprNode::Not(()) => recurse_left(!negate), // Flip negation
228 /// ExprNode::And((), ()) if negate => {
229 /// // De Morgan: ~(A * B) = ~A + ~B
230 /// let mut result = recurse_left(true);
231 /// result.extend(recurse_right(true));
232 /// result
233 /// }
234 /// ExprNode::Or((), ()) if negate => {
235 /// // De Morgan: ~(A + B) = ~A * ~B (cross product)
236 /// vec![] // Simplified for example
237 /// }
238 /// _ => vec![] // Other cases omitted
239 /// }
240 /// })
241 /// }
242 /// ```
243 ///
244 /// [`fold`]: BoolExpr::fold
245 pub fn fold_with_context<C, T, F>(&self, context: C, f: F) -> T
246 where
247 C: Copy,
248 F: Fn(
249 ExprNode<()>,
250 C,
251 &dyn Fn(C) -> T, // recurse_left/inner
252 &dyn Fn(C) -> T, // recurse_right
253 ) -> T
254 + Copy,
255 {
256 self.fold_with_context_impl(context, &f)
257 }
258
259 fn fold_with_context_impl<C, T, F>(&self, context: C, f: &F) -> T
260 where
261 C: Copy,
262 F: Fn(ExprNode<()>, C, &dyn Fn(C) -> T, &dyn Fn(C) -> T) -> T,
263 {
264 let ast = self.get_or_create_ast();
265 Self::fold_with_context_ast(&ast, context, f)
266 }
267
268 /// Fold with context over an AST (helper for fold_with_context_impl)
269 fn fold_with_context_ast<C, T, F>(ast: &BoolExprAst, context: C, f: &F) -> T
270 where
271 C: Copy,
272 F: Fn(ExprNode<()>, C, &dyn Fn(C) -> T, &dyn Fn(C) -> T) -> T,
273 {
274 match ast {
275 BoolExprAst::Variable(name) => f(
276 ExprNode::Variable(name),
277 context,
278 &|_| unreachable!(),
279 &|_| unreachable!(),
280 ),
281 BoolExprAst::Constant(val) => f(
282 ExprNode::Constant(*val),
283 context,
284 &|_| unreachable!(),
285 &|_| unreachable!(),
286 ),
287 BoolExprAst::Not(inner) => {
288 let recurse = |ctx: C| Self::fold_with_context_ast(inner, ctx, f);
289 f(ExprNode::Not(()), context, &recurse, &|_| unreachable!())
290 }
291 BoolExprAst::And(left, right) => {
292 let recurse_left = |ctx: C| Self::fold_with_context_ast(left, ctx, f);
293 let recurse_right = |ctx: C| Self::fold_with_context_ast(right, ctx, f);
294 f(
295 ExprNode::And((), ()),
296 context,
297 &recurse_left,
298 &recurse_right,
299 )
300 }
301 BoolExprAst::Or(left, right) => {
302 let recurse_left = |ctx: C| Self::fold_with_context_ast(left, ctx, f);
303 let recurse_right = |ctx: C| Self::fold_with_context_ast(right, ctx, f);
304 f(ExprNode::Or((), ()), context, &recurse_left, &recurse_right)
305 }
306 }
307 }
308}