Skip to main content

lemma/inversion/
world.rs

1//! World enumeration for inversion
2//!
3//! A "world" is a complete assignment of which branch is active for each rule.
4//! This module enumerates all valid worlds for a target rule.
5//!
6//! Also includes expression substitution and hydration utilities.
7
8use crate::planning::semantics::{
9    ArithmeticComputation, ComparisonComputation, Expression, ExpressionKind, FactPath,
10    LiteralValue, MathematicalComputation, NegationType, RulePath, SemanticConversionTarget,
11    Source,
12};
13use crate::planning::{ExecutableRule, ExecutionPlan};
14use crate::{LemmaResult, OperationResult};
15use serde::ser::{Serialize, SerializeMap, Serializer};
16use std::collections::{HashMap, HashSet, VecDeque};
17use std::sync::Arc;
18
19use super::constraint::Constraint;
20
21/// A world assigns each rule to one of its branch indices
22#[derive(Debug, Clone, Default, PartialEq, Eq)]
23pub struct World(HashMap<RulePath, usize>);
24
25impl World {
26    /// Create a new empty world
27    pub fn new() -> Self {
28        Self(HashMap::new())
29    }
30
31    /// Get the branch index for a rule
32    pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
33        self.0.get(rule_path)
34    }
35
36    /// Insert a branch assignment for a rule
37    pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
38        self.0.insert(rule_path, branch_idx)
39    }
40
41    /// Iterate over all branch assignments
42    pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
43        self.0.iter()
44    }
45}
46
47impl Serialize for World {
48    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49    where
50        S: Serializer,
51    {
52        let mut map = serializer.serialize_map(Some(self.0.len()))?;
53        for (k, v) in &self.0 {
54            map.serialize_entry(&k.to_string(), v)?;
55        }
56        map.end()
57    }
58}
59
60/// A solution from world enumeration with a resolved outcome
61#[derive(Debug, Clone)]
62pub struct WorldSolution {
63    /// The world (branch assignment) that produced this solution
64    pub world: World,
65    /// The constraint under which this solution applies (facts only, no rule references)
66    pub constraint: Constraint,
67    /// The outcome (value or veto)
68    pub outcome: OperationResult,
69}
70
71/// A solution from world enumeration with an arithmetic expression outcome
72///
73/// This represents cases where the outcome is a computed expression (like `price * 5`)
74/// that couldn't be evaluated to a literal because it contains unknown facts.
75/// These need algebraic solving to determine the input values.
76#[derive(Debug, Clone)]
77pub struct WorldArithmeticSolution {
78    /// The world (branch assignment) that produced this solution
79    pub world: World,
80    /// The constraint under which this solution applies (facts only, no rule references)
81    pub constraint: Constraint,
82    /// The outcome expression (contains unknown facts)
83    pub outcome_expression: Expression,
84}
85
86/// Result of world enumeration containing both literal and arithmetic solutions
87#[derive(Debug, Clone)]
88pub struct EnumerationResult {
89    /// Solutions with literal outcomes (can be directly compared to target)
90    pub literal_solutions: Vec<WorldSolution>,
91    /// Solutions with arithmetic outcomes (need algebraic solving)
92    pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
93}
94
95/// Enumerates valid worlds for a target rule
96pub struct WorldEnumerator<'a> {
97    plan: &'a ExecutionPlan,
98    /// Rules to process, in topological order (dependencies first)
99    rules_in_order: Vec<RulePath>,
100    /// Cache: rule path -> executable rule (for quick lookup)
101    rule_cache: HashMap<RulePath, &'a ExecutableRule>,
102}
103
104impl<'a> WorldEnumerator<'a> {
105    /// Create a new world enumerator for a target rule
106    pub fn new(plan: &'a ExecutionPlan, target_rule: &RulePath) -> LemmaResult<Self> {
107        // Build rule lookup from execution plan
108        let rule_map: HashMap<RulePath, &ExecutableRule> =
109            plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
110
111        // Find all rules that the target rule depends on (transitively)
112        let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
113
114        // Plan rules are already in topological order, so filter and preserve order
115        let rules_in_order: Vec<RulePath> = plan
116            .rules
117            .iter()
118            .filter(|r| dependent_rules.contains(&r.path))
119            .map(|r| r.path.clone())
120            .collect();
121
122        // Build rule cache for quick lookup (only rules we need)
123        let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
124            .iter()
125            .filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
126            .collect();
127
128        Ok(Self {
129            plan,
130            rules_in_order,
131            rule_cache,
132        })
133    }
134
135    /// Enumerate all valid worlds for the target rule
136    ///
137    /// Returns an `EnumerationResult` containing:
138    /// - `literal_solutions`: Worlds where the outcome is a concrete literal value
139    /// - `arithmetic_solutions`: Worlds where the outcome is an arithmetic expression
140    ///   containing unknown facts (needs algebraic solving)
141    pub fn enumerate(
142        &mut self,
143        provided_facts: &HashSet<FactPath>,
144    ) -> LemmaResult<EnumerationResult> {
145        if self.rules_in_order.is_empty() {
146            return Ok(EnumerationResult {
147                literal_solutions: vec![],
148                arithmetic_solutions: vec![],
149            });
150        }
151
152        // Start with a single empty world and true constraint
153        let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
154
155        // Process each rule in topological order
156        for rule_path in &self.rules_in_order.clone() {
157            let rule_node = match self.rule_cache.get(rule_path) {
158                Some(node) => *node,
159                None => continue,
160            };
161
162            let mut next_worlds = Vec::new();
163
164            for (world, accumulated_constraint) in current_worlds {
165                // For each branch in this rule
166                for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
167                    // Create new world with this branch assignment
168                    let mut new_world = world.clone();
169                    new_world.insert(rule_path.clone(), branch_idx);
170
171                    // Get branch constraint
172                    // For "last wins" semantics: all LATER branches must have FALSE conditions
173                    let mut branch_constraint = if let Some(ref condition) = branch.condition {
174                        // This branch's condition must be TRUE
175                        let substituted_condition = substitute_rules_in_expression(
176                            &Arc::new(condition.clone()),
177                            &new_world,
178                            self.plan,
179                        )?;
180                        let hydrated_condition = hydrate_facts_in_expression(
181                            &Arc::new(substituted_condition),
182                            self.plan,
183                            provided_facts,
184                        )?;
185                        Constraint::from_expression(&hydrated_condition)?
186                    } else {
187                        // Default branch has no explicit condition
188                        Constraint::True
189                    };
190
191                    // For "last wins": all LATER branches must NOT match
192                    // (their conditions must be FALSE)
193                    for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
194                        if let Some(ref later_condition) = later_branch.condition {
195                            let substituted_later = substitute_rules_in_expression(
196                                &Arc::new(later_condition.clone()),
197                                &new_world,
198                                self.plan,
199                            )?;
200                            let hydrated_later = hydrate_facts_in_expression(
201                                &Arc::new(substituted_later),
202                                self.plan,
203                                provided_facts,
204                            )?;
205                            let later_constraint = Constraint::from_expression(&hydrated_later)?;
206                            // Later branch's condition must be FALSE
207                            branch_constraint = branch_constraint.and(later_constraint.not());
208                        }
209                    }
210
211                    // Combine with accumulated constraint
212                    let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
213
214                    // Simplify and check if satisfiable
215                    let simplified = combined_constraint.simplify()?;
216
217                    // Only keep if not contradictory
218                    if !simplified.is_false() {
219                        next_worlds.push((new_world, simplified));
220                    }
221                }
222            }
223
224            current_worlds = next_worlds;
225
226            // Early exit if no valid worlds remain
227            if current_worlds.is_empty() {
228                break;
229            }
230        }
231
232        // Convert to WorldSolutions and WorldArithmeticSolutions
233        let target_rule_path = self
234            .rules_in_order
235            .last()
236            .unwrap_or_else(|| unreachable!("BUG: no rules in order for world enumeration"));
237
238        let mut literal_solutions = Vec::new();
239        let mut arithmetic_solutions = Vec::new();
240
241        for (world, constraint) in current_worlds {
242            // Get the outcome from the target rule's branch
243            if let Some(&branch_idx) = world.get(target_rule_path) {
244                if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
245                    if branch_idx < rule_node.branches.len() {
246                        let branch = &rule_node.branches[branch_idx];
247
248                        // Substitute and hydrate the result expression
249                        let substituted_result = substitute_rules_in_expression(
250                            &Arc::new(branch.result.clone()),
251                            &world,
252                            self.plan,
253                        )?;
254
255                        let hydrated_result = hydrate_facts_in_expression(
256                            &Arc::new(substituted_result),
257                            self.plan,
258                            provided_facts,
259                        )?;
260
261                        // Try to fold the result to a literal
262                        let folded_result = try_constant_fold_expression(&hydrated_result)
263                            .unwrap_or(hydrated_result.clone());
264
265                        // Try to extract a literal value directly
266                        if let Some(outcome) = extract_outcome(&folded_result) {
267                            literal_solutions.push(WorldSolution {
268                                world,
269                                constraint,
270                                outcome,
271                            });
272                        } else if is_boolean_expression(&folded_result) {
273                            // For boolean expressions (comparisons, logical ops), create two solutions:
274                            // one where the expression is true, one where it's false
275                            let (true_solutions, false_solutions) =
276                                create_boolean_expression_solutions(
277                                    world,
278                                    constraint,
279                                    &folded_result,
280                                )?;
281                            literal_solutions.extend(true_solutions);
282                            literal_solutions.extend(false_solutions);
283                        } else if is_arithmetic_expression(&folded_result) {
284                            // Arithmetic expression with unknown facts - needs algebraic solving
285                            arithmetic_solutions.push(WorldArithmeticSolution {
286                                world,
287                                constraint,
288                                outcome_expression: folded_result,
289                            });
290                        }
291                        // Other expression types (rule references, etc.) are silently skipped
292                        // as they indicate incomplete substitution
293                    }
294                }
295            }
296        }
297
298        Ok(EnumerationResult {
299            literal_solutions,
300            arithmetic_solutions,
301        })
302    }
303}
304
305// ============================================================================
306// Dependency and topological sorting
307// ============================================================================
308
309/// Collect all rules that a target rule depends on (transitively)
310fn collect_transitive_dependencies(
311    target_rule: &RulePath,
312    rule_map: &HashMap<RulePath, &ExecutableRule>,
313) -> LemmaResult<HashSet<RulePath>> {
314    let mut result = HashSet::new();
315    let mut queue = VecDeque::new();
316
317    queue.push_back(target_rule.clone());
318    result.insert(target_rule.clone());
319
320    while let Some(rule_path) = queue.pop_front() {
321        if let Some(rule) = rule_map.get(&rule_path) {
322            // Extract rule dependencies from branch expressions
323            let dependencies = extract_rule_dependencies(rule);
324            for dependency in dependencies {
325                if result.insert(dependency.clone()) {
326                    queue.push_back(dependency);
327                }
328            }
329        }
330    }
331
332    Ok(result)
333}
334
335/// Extract rule paths referenced in an executable rule's expressions
336fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
337    let mut deps = HashSet::new();
338    for branch in &rule.branches {
339        if let Some(ref condition) = branch.condition {
340            extract_rule_paths_from_expression(condition, &mut deps);
341        }
342        extract_rule_paths_from_expression(&branch.result, &mut deps);
343    }
344    deps
345}
346
347/// Recursively extract RulePath references from an expression
348fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
349    match &expr.kind {
350        ExpressionKind::RulePath(rp) => {
351            paths.insert(rp.clone());
352        }
353        ExpressionKind::Arithmetic(left, _, right)
354        | ExpressionKind::Comparison(left, _, right)
355        | ExpressionKind::LogicalAnd(left, right)
356        | ExpressionKind::LogicalOr(left, right) => {
357            extract_rule_paths_from_expression(left, paths);
358            extract_rule_paths_from_expression(right, paths);
359        }
360        ExpressionKind::LogicalNegation(inner, _)
361        | ExpressionKind::UnitConversion(inner, _)
362        | ExpressionKind::MathematicalComputation(_, inner) => {
363            extract_rule_paths_from_expression(inner, paths);
364        }
365        ExpressionKind::Literal(_) | ExpressionKind::FactPath(_) | ExpressionKind::Veto(_) => {}
366    }
367}
368
369// ============================================================================
370// Expression substitution
371// ============================================================================
372
373/// Substitute rule references in an expression with their values in a given world
374///
375/// For each RulePath in the expression, looks up which branch is active in the world
376/// and replaces the RulePath with the branch's result expression.
377fn substitute_rules_in_expression(
378    expr: &Arc<Expression>,
379    world: &World,
380    plan: &ExecutionPlan,
381) -> LemmaResult<Expression> {
382    enum WorkItem {
383        Process(usize),
384        BuildArithmetic(ArithmeticComputation, Option<Source>),
385        BuildComparison(ComparisonComputation, Option<Source>),
386        BuildLogicalAnd(Option<Source>),
387        BuildLogicalOr(Option<Source>),
388        BuildLogicalNegation(NegationType, Option<Source>),
389        BuildUnitConversion(SemanticConversionTarget, Option<Source>),
390        BuildMathematicalComputation(MathematicalComputation, Option<Source>),
391        PopVisitedRules,
392    }
393
394    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
395    let mut work_stack: Vec<WorkItem> = Vec::new();
396    let mut result_pool: Vec<Expression> = Vec::new();
397    let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
398
399    let root_idx = expr_pool.len();
400    expr_pool.push(Arc::clone(expr));
401    work_stack.push(WorkItem::Process(root_idx));
402
403    while let Some(work) = work_stack.pop() {
404        match work {
405            WorkItem::Process(expr_idx) => {
406                let e = &expr_pool[expr_idx];
407                let source_loc = e.source_location.clone();
408
409                match &e.kind {
410                    ExpressionKind::RulePath(rule_path) => {
411                        let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
412                        if visited.contains(rule_path) {
413                            unreachable!(
414                                "BUG: circular rule reference detected during substitution: {}",
415                                rule_path
416                            );
417                        }
418
419                        if let Some(&branch_idx) = world.get(rule_path) {
420                            if let Some(rule) = plan.get_rule_by_path(rule_path) {
421                                if branch_idx < rule.branches.len() {
422                                    let branch = &rule.branches[branch_idx];
423                                    let mut new_visited = visited.clone();
424                                    new_visited.insert(rule_path.clone());
425                                    visited_rules_stack.push(new_visited);
426
427                                    let sub_expr_idx = expr_pool.len();
428                                    expr_pool.push(Arc::new(branch.result.clone()));
429                                    work_stack.push(WorkItem::PopVisitedRules);
430                                    work_stack.push(WorkItem::Process(sub_expr_idx));
431                                    continue;
432                                }
433                            }
434                        }
435                        result_pool.push(Expression::with_source(
436                            ExpressionKind::RulePath(rule_path.clone()),
437                            source_loc,
438                        ));
439                    }
440                    ExpressionKind::Arithmetic(left, op, right) => {
441                        let op_clone = op.clone();
442                        let left_arc = Arc::clone(left);
443                        let right_arc = Arc::clone(right);
444
445                        let left_idx = expr_pool.len();
446                        expr_pool.push(left_arc);
447                        let right_idx = expr_pool.len();
448                        expr_pool.push(right_arc);
449
450                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
451                        work_stack.push(WorkItem::Process(right_idx));
452                        work_stack.push(WorkItem::Process(left_idx));
453                    }
454                    ExpressionKind::Comparison(left, op, right) => {
455                        let op_clone = op.clone();
456                        let left_arc = Arc::clone(left);
457                        let right_arc = Arc::clone(right);
458
459                        let left_idx = expr_pool.len();
460                        expr_pool.push(left_arc);
461                        let right_idx = expr_pool.len();
462                        expr_pool.push(right_arc);
463
464                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
465                        work_stack.push(WorkItem::Process(right_idx));
466                        work_stack.push(WorkItem::Process(left_idx));
467                    }
468                    ExpressionKind::LogicalAnd(left, right) => {
469                        let left_arc = Arc::clone(left);
470                        let right_arc = Arc::clone(right);
471
472                        let left_idx = expr_pool.len();
473                        expr_pool.push(left_arc);
474                        let right_idx = expr_pool.len();
475                        expr_pool.push(right_arc);
476
477                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
478                        work_stack.push(WorkItem::Process(right_idx));
479                        work_stack.push(WorkItem::Process(left_idx));
480                    }
481                    ExpressionKind::LogicalOr(left, right) => {
482                        let left_arc = Arc::clone(left);
483                        let right_arc = Arc::clone(right);
484
485                        let left_idx = expr_pool.len();
486                        expr_pool.push(left_arc);
487                        let right_idx = expr_pool.len();
488                        expr_pool.push(right_arc);
489
490                        work_stack.push(WorkItem::BuildLogicalOr(source_loc));
491                        work_stack.push(WorkItem::Process(right_idx));
492                        work_stack.push(WorkItem::Process(left_idx));
493                    }
494                    ExpressionKind::LogicalNegation(inner, neg_type) => {
495                        let neg_type_clone = neg_type.clone();
496                        let inner_arc = Arc::clone(inner);
497                        let inner_idx = expr_pool.len();
498                        expr_pool.push(inner_arc);
499                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
500                        work_stack.push(WorkItem::Process(inner_idx));
501                    }
502                    ExpressionKind::UnitConversion(inner, unit) => {
503                        let unit_clone = unit.clone();
504                        let inner_arc = Arc::clone(inner);
505                        let inner_idx = expr_pool.len();
506                        expr_pool.push(inner_arc);
507                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
508                        work_stack.push(WorkItem::Process(inner_idx));
509                    }
510                    ExpressionKind::MathematicalComputation(func, inner) => {
511                        let func_clone = func.clone();
512                        let inner_arc = Arc::clone(inner);
513                        let inner_idx = expr_pool.len();
514                        expr_pool.push(inner_arc);
515                        work_stack.push(WorkItem::BuildMathematicalComputation(
516                            func_clone, source_loc,
517                        ));
518                        work_stack.push(WorkItem::Process(inner_idx));
519                    }
520                    ExpressionKind::Literal(lit) => {
521                        result_pool.push(Expression::with_source(
522                            ExpressionKind::Literal(lit.clone()),
523                            source_loc,
524                        ));
525                    }
526                    ExpressionKind::FactPath(fact_path) => {
527                        result_pool.push(Expression::with_source(
528                            ExpressionKind::FactPath(fact_path.clone()),
529                            source_loc,
530                        ));
531                    }
532                    ExpressionKind::Veto(veto) => {
533                        result_pool.push(Expression::with_source(
534                            ExpressionKind::Veto(veto.clone()),
535                            source_loc,
536                        ));
537                    }
538                }
539            }
540            WorkItem::BuildArithmetic(op, source_loc) => {
541                let right = result_pool.pop().unwrap_or_else(|| {
542                    unreachable!(
543                        "BUG: missing right expression for Arithmetic during inversion hydration"
544                    )
545                });
546                let left = result_pool.pop().unwrap_or_else(|| {
547                    unreachable!(
548                        "BUG: missing left expression for Arithmetic during inversion hydration"
549                    )
550                });
551                result_pool.push(Expression::with_source(
552                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
553                    source_loc,
554                ));
555            }
556            WorkItem::BuildComparison(op, source_loc) => {
557                let right = result_pool.pop().unwrap_or_else(|| {
558                    unreachable!(
559                        "BUG: missing right expression for Comparison during inversion hydration"
560                    )
561                });
562                let left = result_pool.pop().unwrap_or_else(|| {
563                    unreachable!(
564                        "BUG: missing left expression for Comparison during inversion hydration"
565                    )
566                });
567                result_pool.push(Expression::with_source(
568                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
569                    source_loc,
570                ));
571            }
572            WorkItem::BuildLogicalAnd(source_loc) => {
573                let right = result_pool.pop().unwrap_or_else(|| {
574                    unreachable!(
575                        "BUG: missing right expression for LogicalAnd during inversion hydration"
576                    )
577                });
578                let left = result_pool.pop().unwrap_or_else(|| {
579                    unreachable!(
580                        "BUG: missing left expression for LogicalAnd during inversion hydration"
581                    )
582                });
583                result_pool.push(Expression::with_source(
584                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
585                    source_loc,
586                ));
587            }
588            WorkItem::BuildLogicalOr(source_loc) => {
589                let right = result_pool.pop().unwrap_or_else(|| {
590                    unreachable!(
591                        "BUG: missing right expression for LogicalOr during inversion hydration"
592                    )
593                });
594                let left = result_pool.pop().unwrap_or_else(|| {
595                    unreachable!(
596                        "BUG: missing left expression for LogicalOr during inversion hydration"
597                    )
598                });
599                result_pool.push(Expression::with_source(
600                    ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
601                    source_loc,
602                ));
603            }
604            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
605                let inner = result_pool
606                    .pop()
607                    .expect("Internal error: missing expression for LogicalNegation");
608                result_pool.push(Expression::with_source(
609                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
610                    source_loc,
611                ));
612            }
613            WorkItem::BuildUnitConversion(unit, source_loc) => {
614                let inner = result_pool
615                    .pop()
616                    .expect("Internal error: missing expression for UnitConversion");
617                result_pool.push(Expression::with_source(
618                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
619                    source_loc,
620                ));
621            }
622            WorkItem::BuildMathematicalComputation(func, source_loc) => {
623                let inner = result_pool
624                    .pop()
625                    .expect("Internal error: missing expression for MathematicalComputation");
626                result_pool.push(Expression::with_source(
627                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
628                    source_loc,
629                ));
630            }
631            WorkItem::PopVisitedRules => {
632                visited_rules_stack.pop();
633            }
634        }
635    }
636
637    Ok(result_pool
638        .pop()
639        .unwrap_or_else(|| unreachable!("BUG: no result from substitution")))
640}
641
642// ============================================================================
643// Fact hydration
644// ============================================================================
645
646/// Hydrate fact references in an expression with their known values
647///
648/// For each FactPath in the expression, if the fact is in provided_facts,
649/// replaces the FactPath with a Literal containing the fact's value.
650fn hydrate_facts_in_expression(
651    expr: &Arc<Expression>,
652    plan: &ExecutionPlan,
653    provided_facts: &HashSet<FactPath>,
654) -> LemmaResult<Expression> {
655    enum WorkItem {
656        Process(usize),
657        BuildArithmetic(ArithmeticComputation, Option<Source>),
658        BuildComparison(ComparisonComputation, Option<Source>),
659        BuildLogicalAnd(Option<Source>),
660        BuildLogicalOr(Option<Source>),
661        BuildLogicalNegation(NegationType, Option<Source>),
662        BuildUnitConversion(SemanticConversionTarget, Option<Source>),
663        BuildMathematicalComputation(MathematicalComputation, Option<Source>),
664    }
665
666    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
667    let mut work_stack: Vec<WorkItem> = Vec::new();
668    let mut result_pool: Vec<Expression> = Vec::new();
669
670    let root_idx = expr_pool.len();
671    expr_pool.push(Arc::clone(expr));
672    work_stack.push(WorkItem::Process(root_idx));
673
674    while let Some(work) = work_stack.pop() {
675        match work {
676            WorkItem::Process(expr_idx) => {
677                let (source_loc, expr_kind_ref) = {
678                    let e = &expr_pool[expr_idx];
679                    (e.source_location.clone(), &e.kind)
680                };
681
682                match expr_kind_ref {
683                    ExpressionKind::FactPath(fact_path) => {
684                        if provided_facts.contains(fact_path) {
685                            if let Some(lit) = plan.facts.get(fact_path).and_then(|d| d.value()) {
686                                result_pool.push(Expression::with_source(
687                                    ExpressionKind::Literal(Box::new(lit.clone())),
688                                    source_loc,
689                                ));
690                                continue;
691                            }
692                        }
693                        result_pool.push(Expression::with_source(
694                            ExpressionKind::FactPath(fact_path.clone()),
695                            source_loc,
696                        ));
697                    }
698                    ExpressionKind::Arithmetic(left, op, right) => {
699                        let op_clone = op.clone();
700                        let left_arc = Arc::clone(left);
701                        let right_arc = Arc::clone(right);
702
703                        let left_idx = expr_pool.len();
704                        expr_pool.push(left_arc);
705                        let right_idx = expr_pool.len();
706                        expr_pool.push(right_arc);
707
708                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
709                        work_stack.push(WorkItem::Process(right_idx));
710                        work_stack.push(WorkItem::Process(left_idx));
711                    }
712                    ExpressionKind::Comparison(left, op, right) => {
713                        let op_clone = op.clone();
714                        let left_arc = Arc::clone(left);
715                        let right_arc = Arc::clone(right);
716
717                        let left_idx = expr_pool.len();
718                        expr_pool.push(left_arc);
719                        let right_idx = expr_pool.len();
720                        expr_pool.push(right_arc);
721
722                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
723                        work_stack.push(WorkItem::Process(right_idx));
724                        work_stack.push(WorkItem::Process(left_idx));
725                    }
726                    ExpressionKind::LogicalAnd(left, right) => {
727                        let left_arc = Arc::clone(left);
728                        let right_arc = Arc::clone(right);
729
730                        let left_idx = expr_pool.len();
731                        expr_pool.push(left_arc);
732                        let right_idx = expr_pool.len();
733                        expr_pool.push(right_arc);
734
735                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
736                        work_stack.push(WorkItem::Process(right_idx));
737                        work_stack.push(WorkItem::Process(left_idx));
738                    }
739                    ExpressionKind::LogicalOr(left, right) => {
740                        let left_arc = Arc::clone(left);
741                        let right_arc = Arc::clone(right);
742
743                        let left_idx = expr_pool.len();
744                        expr_pool.push(left_arc);
745                        let right_idx = expr_pool.len();
746                        expr_pool.push(right_arc);
747
748                        work_stack.push(WorkItem::BuildLogicalOr(source_loc));
749                        work_stack.push(WorkItem::Process(right_idx));
750                        work_stack.push(WorkItem::Process(left_idx));
751                    }
752                    ExpressionKind::LogicalNegation(inner, neg_type) => {
753                        let neg_type_clone = neg_type.clone();
754                        let inner_arc = Arc::clone(inner);
755                        let inner_idx = expr_pool.len();
756                        expr_pool.push(inner_arc);
757                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
758                        work_stack.push(WorkItem::Process(inner_idx));
759                    }
760                    ExpressionKind::UnitConversion(inner, unit) => {
761                        let unit_clone = unit.clone();
762                        let inner_arc = Arc::clone(inner);
763                        let inner_idx = expr_pool.len();
764                        expr_pool.push(inner_arc);
765                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
766                        work_stack.push(WorkItem::Process(inner_idx));
767                    }
768                    ExpressionKind::MathematicalComputation(func, inner) => {
769                        let func_clone = func.clone();
770                        let inner_arc = Arc::clone(inner);
771                        let inner_idx = expr_pool.len();
772                        expr_pool.push(inner_arc);
773                        work_stack.push(WorkItem::BuildMathematicalComputation(
774                            func_clone, source_loc,
775                        ));
776                        work_stack.push(WorkItem::Process(inner_idx));
777                    }
778                    ExpressionKind::Literal(lit) => {
779                        result_pool.push(Expression::with_source(
780                            ExpressionKind::Literal(lit.clone()),
781                            source_loc,
782                        ));
783                    }
784                    ExpressionKind::RulePath(rule_path) => {
785                        result_pool.push(Expression::with_source(
786                            ExpressionKind::RulePath(rule_path.clone()),
787                            source_loc,
788                        ));
789                    }
790                    ExpressionKind::Veto(veto) => {
791                        result_pool.push(Expression::with_source(
792                            ExpressionKind::Veto(veto.clone()),
793                            source_loc,
794                        ));
795                    }
796                }
797            }
798            WorkItem::BuildArithmetic(op, source_loc) => {
799                let right = result_pool.pop().unwrap_or_else(|| {
800                    unreachable!("BUG: missing right expression for Arithmetic")
801                });
802                let left = result_pool
803                    .pop()
804                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Arithmetic"));
805                result_pool.push(Expression::with_source(
806                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
807                    source_loc,
808                ));
809            }
810            WorkItem::BuildComparison(op, source_loc) => {
811                let right = result_pool.pop().unwrap_or_else(|| {
812                    unreachable!("BUG: missing right expression for Comparison")
813                });
814                let left = result_pool
815                    .pop()
816                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Comparison"));
817                result_pool.push(Expression::with_source(
818                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
819                    source_loc,
820                ));
821            }
822            WorkItem::BuildLogicalAnd(source_loc) => {
823                let right = result_pool.pop().unwrap_or_else(|| {
824                    unreachable!("BUG: missing right expression for LogicalAnd")
825                });
826                let left = result_pool
827                    .pop()
828                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalAnd"));
829                result_pool.push(Expression::with_source(
830                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
831                    source_loc,
832                ));
833            }
834            WorkItem::BuildLogicalOr(source_loc) => {
835                let right = result_pool
836                    .pop()
837                    .unwrap_or_else(|| unreachable!("BUG: missing right expression for LogicalOr"));
838                let left = result_pool
839                    .pop()
840                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalOr"));
841                result_pool.push(Expression::with_source(
842                    ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
843                    source_loc,
844                ));
845            }
846            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
847                let inner = result_pool
848                    .pop()
849                    .expect("Internal error: missing expression for LogicalNegation");
850                result_pool.push(Expression::with_source(
851                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
852                    source_loc,
853                ));
854            }
855            WorkItem::BuildUnitConversion(unit, source_loc) => {
856                let inner = result_pool
857                    .pop()
858                    .expect("Internal error: missing expression for UnitConversion");
859                result_pool.push(Expression::with_source(
860                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
861                    source_loc,
862                ));
863            }
864            WorkItem::BuildMathematicalComputation(func, source_loc) => {
865                let inner = result_pool
866                    .pop()
867                    .expect("Internal error: missing expression for MathematicalComputation");
868                result_pool.push(Expression::with_source(
869                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
870                    source_loc,
871                ));
872            }
873        }
874    }
875
876    Ok(result_pool
877        .pop()
878        .expect("Internal error: no result from hydration"))
879}
880
881// ============================================================================
882// Constant folding
883// ============================================================================
884
885/// Extract an outcome (value or veto) from an expression
886fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
887    match &expr.kind {
888        ExpressionKind::Literal(lit) => {
889            Some(OperationResult::Value(Box::new(lit.as_ref().clone())))
890        }
891        ExpressionKind::Veto(ve) => Some(OperationResult::Veto(ve.message.clone())),
892        _ => None,
893    }
894}
895
896/// Check if an expression is a boolean-producing expression (comparison or logical)
897fn is_boolean_expression(expr: &Expression) -> bool {
898    matches!(
899        &expr.kind,
900        ExpressionKind::Comparison(_, _, _)
901            | ExpressionKind::LogicalAnd(_, _)
902            | ExpressionKind::LogicalOr(_, _)
903            | ExpressionKind::LogicalNegation(_, _)
904    )
905}
906
907/// Check if an expression is an arithmetic expression (contains arithmetic operations)
908///
909/// Returns true for expressions like `price * 5`, `x + y`, etc.
910fn is_arithmetic_expression(expr: &Expression) -> bool {
911    match &expr.kind {
912        ExpressionKind::Arithmetic(_, _, _) => true,
913        ExpressionKind::MathematicalComputation(_, _) => true,
914        ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
915        ExpressionKind::FactPath(_) => true, // Lone fact is also solvable
916        _ => false,
917    }
918}
919
920/// For boolean expressions that can't be evaluated to a literal (e.g., `age > 18`),
921/// create two solutions: one where the expression is true, one where it's false.
922///
923/// This allows inversion to work with rules like `rule of_age = age > 18`
924fn create_boolean_expression_solutions(
925    world: World,
926    base_constraint: Constraint,
927    boolean_expr: &Expression,
928) -> LemmaResult<(Vec<WorldSolution>, Vec<WorldSolution>)> {
929    // Convert boolean expression to constraint
930    let expr_constraint = Constraint::from_expression(boolean_expr)?;
931
932    // Solution where the boolean expression is true
933    let true_constraint = base_constraint.clone().and(expr_constraint.clone());
934    let simplified_true = true_constraint.simplify()?;
935
936    let true_solutions = if !simplified_true.is_false() {
937        vec![WorldSolution {
938            world: world.clone(),
939            constraint: simplified_true,
940            outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(true))),
941        }]
942    } else {
943        vec![]
944    };
945
946    // Solution where the boolean expression is false
947    let false_constraint = base_constraint.and(expr_constraint.not());
948    let simplified_false = false_constraint.simplify()?;
949
950    let false_solutions = if !simplified_false.is_false() {
951        vec![WorldSolution {
952            world,
953            constraint: simplified_false,
954            outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(false))),
955        }]
956    } else {
957        vec![]
958    };
959
960    Ok((true_solutions, false_solutions))
961}
962
963/// Attempt constant folding on an expression (simplified version for outcomes)
964pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
965    match &expr.kind {
966        ExpressionKind::Literal(_) => Some(expr.clone()),
967        ExpressionKind::Arithmetic(left, op, right) => {
968            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
969            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
970            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
971                (&left_folded.kind, &right_folded.kind)
972            {
973                if let Some(result) = evaluate_arithmetic(left_val.as_ref(), op, right_val.as_ref())
974                {
975                    return Some(Expression::with_source(
976                        ExpressionKind::Literal(Box::new(result)),
977                        expr.source_location.clone(),
978                    ));
979                }
980            }
981            Some(Expression::with_source(
982                ExpressionKind::Arithmetic(
983                    Arc::new(left_folded),
984                    op.clone(),
985                    Arc::new(right_folded),
986                ),
987                expr.source_location.clone(),
988            ))
989        }
990        ExpressionKind::Comparison(left, op, right) => {
991            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
992            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
993            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
994                (&left_folded.kind, &right_folded.kind)
995            {
996                if let Some(result) = evaluate_comparison(left_val.as_ref(), op, right_val.as_ref())
997                {
998                    return Some(Expression::with_source(
999                        ExpressionKind::Literal(Box::new(LiteralValue::from_bool(result))),
1000                        expr.source_location.clone(),
1001                    ));
1002                }
1003            }
1004            Some(Expression::with_source(
1005                ExpressionKind::Comparison(
1006                    Arc::new(left_folded),
1007                    op.clone(),
1008                    Arc::new(right_folded),
1009                ),
1010                expr.source_location.clone(),
1011            ))
1012        }
1013        _ => None,
1014    }
1015}
1016
1017/// Evaluate an arithmetic operation on two literals
1018///
1019/// Delegates to the computation module for consistent behavior
1020fn evaluate_arithmetic(
1021    left: &LiteralValue,
1022    op: &ArithmeticComputation,
1023    right: &LiteralValue,
1024) -> Option<LiteralValue> {
1025    use crate::computation::arithmetic_operation;
1026
1027    match arithmetic_operation(left, op, right) {
1028        OperationResult::Value(lit) => Some(lit.as_ref().clone()),
1029        OperationResult::Veto(_) => None,
1030    }
1031}
1032
1033/// Evaluate a comparison operation on two literals
1034///
1035/// Delegates to the computation module for consistent behavior
1036fn evaluate_comparison(
1037    left: &LiteralValue,
1038    op: &ComparisonComputation,
1039    right: &LiteralValue,
1040) -> Option<bool> {
1041    use crate::computation::comparison_operation;
1042    use crate::planning::semantics::ValueKind;
1043
1044    match comparison_operation(left, op, right) {
1045        OperationResult::Value(lit) => match &lit.value {
1046            ValueKind::Boolean(b) => Some(*b),
1047            _ => None,
1048        },
1049        _ => None,
1050    }
1051}
1052
1053// ============================================================================
1054// Tests
1055// ============================================================================
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060    use crate::planning::semantics::ValueKind;
1061    use rust_decimal::Decimal;
1062
1063    fn literal_expr(val: LiteralValue) -> Expression {
1064        Expression::with_source(ExpressionKind::Literal(Box::new(val)), None)
1065    }
1066
1067    fn fact_expr(name: &str) -> Expression {
1068        Expression::with_source(
1069            ExpressionKind::FactPath(FactPath::new(vec![], name.to_string())),
1070            None,
1071        )
1072    }
1073
1074    fn num(n: i64) -> LiteralValue {
1075        LiteralValue::number(Decimal::from(n))
1076    }
1077
1078    #[test]
1079    fn test_world_new() {
1080        let world = World::new();
1081        assert!(world.0.is_empty());
1082    }
1083
1084    #[test]
1085    fn test_world_insert_and_get() {
1086        let mut world = World::new();
1087        let rule_path = RulePath {
1088            segments: vec![],
1089            rule: "test_rule".to_string(),
1090        };
1091        world.insert(rule_path.clone(), 2);
1092        assert_eq!(world.get(&rule_path), Some(&2));
1093    }
1094
1095    fn empty_plan() -> ExecutionPlan {
1096        ExecutionPlan {
1097            doc_name: "test".to_string(),
1098            facts: HashMap::new(),
1099            rules: Vec::new(),
1100            sources: HashMap::new(),
1101        }
1102    }
1103
1104    #[test]
1105    fn test_hydrate_literal_unchanged() {
1106        let plan = empty_plan();
1107        let provided: HashSet<FactPath> = HashSet::new();
1108
1109        let expr = literal_expr(num(42));
1110        let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1111
1112        if let ExpressionKind::Literal(lit) = &result.kind {
1113            assert!(matches!(&lit.value, ValueKind::Number(_)));
1114        } else {
1115            panic!("Expected literal number");
1116        }
1117    }
1118
1119    #[test]
1120    fn test_hydrate_fact_not_provided() {
1121        let plan = empty_plan();
1122        let provided: HashSet<FactPath> = HashSet::new();
1123
1124        let expr = fact_expr("age");
1125        let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1126
1127        assert!(matches!(result.kind, ExpressionKind::FactPath(_)));
1128    }
1129
1130    #[test]
1131    fn test_constant_fold_arithmetic() {
1132        let left = literal_expr(num(10));
1133        let right = literal_expr(num(5));
1134        let expr = Expression::with_source(
1135            ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
1136            None,
1137        );
1138
1139        let folded = try_constant_fold_expression(&expr).unwrap();
1140
1141        if let ExpressionKind::Literal(lit) = &folded.kind {
1142            if let ValueKind::Number(n) = &lit.value {
1143                assert_eq!(*n, Decimal::from(15));
1144            } else {
1145                panic!("Expected literal number");
1146            }
1147        } else {
1148            panic!("Expected literal number");
1149        }
1150    }
1151
1152    #[test]
1153    fn test_constant_fold_comparison() {
1154        let left = literal_expr(num(10));
1155        let right = literal_expr(num(5));
1156        let expr = Expression::with_source(
1157            ExpressionKind::Comparison(
1158                Arc::new(left),
1159                ComparisonComputation::GreaterThan,
1160                Arc::new(right),
1161            ),
1162            None,
1163        );
1164
1165        let folded = try_constant_fold_expression(&expr).unwrap();
1166
1167        if let ExpressionKind::Literal(lit) = &folded.kind {
1168            if let ValueKind::Boolean(b) = &lit.value {
1169                assert!(*b);
1170            } else {
1171                panic!("Expected literal boolean");
1172            }
1173        } else {
1174            panic!("Expected literal boolean");
1175        }
1176    }
1177}