Skip to main content

lemma/inversion/
world.rs

1//! World enumeration for inversion
2//!
3//! A "world" is a complete assignment of which branch is active for each rule.
4//! This module enumerates all valid worlds for a target rule.
5//!
6//! Also includes expression substitution and hydration utilities.
7
8use crate::planning::semantics::{
9    ArithmeticComputation, ComparisonComputation, Expression, ExpressionKind, FactPath,
10    LiteralValue, MathematicalComputation, NegationType, RulePath, SemanticConversionTarget,
11    Source,
12};
13use crate::planning::{ExecutableRule, ExecutionPlan};
14use crate::OperationResult;
15use serde::ser::{Serialize, SerializeMap, Serializer};
16use std::collections::{HashMap, HashSet, VecDeque};
17use std::sync::Arc;
18
19use super::constraint::Constraint;
20
21/// A world assigns each rule to one of its branch indices
22#[derive(Debug, Clone, Default, PartialEq, Eq)]
23pub struct World(HashMap<RulePath, usize>);
24
25impl World {
26    /// Create a new empty world
27    pub fn new() -> Self {
28        Self(HashMap::new())
29    }
30
31    /// Get the branch index for a rule
32    pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
33        self.0.get(rule_path)
34    }
35
36    /// Insert a branch assignment for a rule
37    pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
38        self.0.insert(rule_path, branch_idx)
39    }
40
41    /// Iterate over all branch assignments
42    pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
43        self.0.iter()
44    }
45}
46
47impl Serialize for World {
48    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49    where
50        S: Serializer,
51    {
52        let mut map = serializer.serialize_map(Some(self.0.len()))?;
53        for (k, v) in &self.0 {
54            map.serialize_entry(&k.to_string(), v)?;
55        }
56        map.end()
57    }
58}
59
60/// A solution from world enumeration with a resolved outcome
61#[derive(Debug, Clone)]
62pub(super) struct WorldSolution {
63    /// The world (branch assignment) that produced this solution
64    pub world: World,
65    /// The constraint under which this solution applies (facts only, no rule references)
66    pub constraint: Constraint,
67    /// The outcome (value or veto)
68    pub outcome: OperationResult,
69}
70
71/// A solution from world enumeration with an arithmetic expression outcome
72///
73/// This represents cases where the outcome is a computed expression (like `price * 5`)
74/// that couldn't be evaluated to a literal because it contains unknown facts.
75/// These need algebraic solving to determine the input values.
76#[derive(Debug, Clone)]
77pub(super) struct WorldArithmeticSolution {
78    /// The world (branch assignment) that produced this solution
79    pub world: World,
80    /// The constraint under which this solution applies (facts only, no rule references)
81    pub constraint: Constraint,
82    /// The outcome expression (contains unknown facts)
83    pub outcome_expression: Expression,
84}
85
86/// Result of world enumeration containing both literal and arithmetic solutions
87#[derive(Debug, Clone)]
88pub(super) struct EnumerationResult {
89    /// Solutions with literal outcomes (can be directly compared to target)
90    pub literal_solutions: Vec<WorldSolution>,
91    /// Solutions with arithmetic outcomes (need algebraic solving)
92    pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
93}
94
95/// Enumerates valid worlds for a target rule
96pub(super) struct WorldEnumerator<'a> {
97    plan: &'a ExecutionPlan,
98    /// Rules to process, in topological order (dependencies first)
99    rules_in_order: Vec<RulePath>,
100    /// Cache: rule path -> executable rule (for quick lookup)
101    rule_cache: HashMap<RulePath, &'a ExecutableRule>,
102}
103
104impl<'a> WorldEnumerator<'a> {
105    /// Create a new world enumerator for a target rule
106    pub(super) fn new(
107        plan: &'a ExecutionPlan,
108        target_rule: &RulePath,
109    ) -> Result<Self, crate::Error> {
110        // Build rule lookup from execution plan
111        let rule_map: HashMap<RulePath, &ExecutableRule> =
112            plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
113
114        // Find all rules that the target rule depends on (transitively)
115        let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
116
117        // Plan rules are already in topological order, so filter and preserve order
118        let rules_in_order: Vec<RulePath> = plan
119            .rules
120            .iter()
121            .filter(|r| dependent_rules.contains(&r.path))
122            .map(|r| r.path.clone())
123            .collect();
124
125        // Build rule cache for quick lookup (only rules we need)
126        let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
127            .iter()
128            .filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
129            .collect();
130
131        Ok(Self {
132            plan,
133            rules_in_order,
134            rule_cache,
135        })
136    }
137
138    /// Enumerate all valid worlds for the target rule
139    ///
140    /// Returns an `EnumerationResult` containing:
141    /// - `literal_solutions`: Worlds where the outcome is a concrete literal value
142    /// - `arithmetic_solutions`: Worlds where the outcome is an arithmetic expression
143    ///   containing unknown facts (needs algebraic solving)
144    pub(super) fn enumerate(
145        &mut self,
146        provided_facts: &HashSet<FactPath>,
147    ) -> Result<EnumerationResult, crate::Error> {
148        if self.rules_in_order.is_empty() {
149            return Ok(EnumerationResult {
150                literal_solutions: vec![],
151                arithmetic_solutions: vec![],
152            });
153        }
154
155        // Start with a single empty world and true constraint
156        let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
157
158        // Process each rule in topological order
159        for rule_path in &self.rules_in_order.clone() {
160            let rule_node = match self.rule_cache.get(rule_path) {
161                Some(node) => *node,
162                None => continue,
163            };
164
165            let mut next_worlds = Vec::new();
166
167            for (world, accumulated_constraint) in current_worlds {
168                // For each branch in this rule
169                for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
170                    // Create new world with this branch assignment
171                    let mut new_world = world.clone();
172                    new_world.insert(rule_path.clone(), branch_idx);
173
174                    // Get branch constraint
175                    // For "last wins" semantics: all LATER branches must have FALSE conditions
176                    let mut branch_constraint = if let Some(ref condition) = branch.condition {
177                        // This branch's condition must be TRUE
178                        let substituted_condition = substitute_rules_in_expression(
179                            &Arc::new(condition.clone()),
180                            &new_world,
181                            self.plan,
182                        )?;
183                        let hydrated_condition = hydrate_facts_in_expression(
184                            &Arc::new(substituted_condition),
185                            self.plan,
186                            provided_facts,
187                        )?;
188                        Constraint::from_expression(&hydrated_condition)?
189                    } else {
190                        // Default branch has no explicit condition
191                        Constraint::True
192                    };
193
194                    // For "last wins": all LATER branches must NOT match
195                    // (their conditions must be FALSE)
196                    for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
197                        if let Some(ref later_condition) = later_branch.condition {
198                            let substituted_later = substitute_rules_in_expression(
199                                &Arc::new(later_condition.clone()),
200                                &new_world,
201                                self.plan,
202                            )?;
203                            let hydrated_later = hydrate_facts_in_expression(
204                                &Arc::new(substituted_later),
205                                self.plan,
206                                provided_facts,
207                            )?;
208                            let later_constraint = Constraint::from_expression(&hydrated_later)?;
209                            // Later branch's condition must be FALSE
210                            branch_constraint = branch_constraint.and(later_constraint.not());
211                        }
212                    }
213
214                    // Combine with accumulated constraint
215                    let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
216
217                    // Simplify and check if satisfiable
218                    let simplified = combined_constraint.simplify()?;
219
220                    // Only keep if not contradictory
221                    if !simplified.is_false() {
222                        next_worlds.push((new_world, simplified));
223                    }
224                }
225            }
226
227            current_worlds = next_worlds;
228
229            // Early exit if no valid worlds remain
230            if current_worlds.is_empty() {
231                break;
232            }
233        }
234
235        // Convert to WorldSolutions and WorldArithmeticSolutions
236        let target_rule_path = self
237            .rules_in_order
238            .last()
239            .unwrap_or_else(|| unreachable!("BUG: no rules in order for world enumeration"));
240
241        let mut literal_solutions = Vec::new();
242        let mut arithmetic_solutions = Vec::new();
243
244        for (world, constraint) in current_worlds {
245            // Get the outcome from the target rule's branch
246            if let Some(&branch_idx) = world.get(target_rule_path) {
247                if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
248                    if branch_idx < rule_node.branches.len() {
249                        let branch = &rule_node.branches[branch_idx];
250
251                        // Substitute and hydrate the result expression
252                        let substituted_result = substitute_rules_in_expression(
253                            &Arc::new(branch.result.clone()),
254                            &world,
255                            self.plan,
256                        )?;
257
258                        let hydrated_result = hydrate_facts_in_expression(
259                            &Arc::new(substituted_result),
260                            self.plan,
261                            provided_facts,
262                        )?;
263
264                        // Try to fold the result to a literal
265                        let folded_result = try_constant_fold_expression(&hydrated_result)
266                            .unwrap_or(hydrated_result.clone());
267
268                        // Try to extract a literal value directly
269                        if let Some(outcome) = extract_outcome(&folded_result) {
270                            literal_solutions.push(WorldSolution {
271                                world,
272                                constraint,
273                                outcome,
274                            });
275                        } else if is_boolean_expression(&folded_result) {
276                            // For boolean expressions (comparisons, logical ops), create two solutions:
277                            // one where the expression is true, one where it's false
278                            let (true_solutions, false_solutions) =
279                                create_boolean_expression_solutions(
280                                    world,
281                                    constraint,
282                                    &folded_result,
283                                )?;
284                            literal_solutions.extend(true_solutions);
285                            literal_solutions.extend(false_solutions);
286                        } else if is_arithmetic_expression(&folded_result) {
287                            // Arithmetic expression with unknown facts - needs algebraic solving
288                            arithmetic_solutions.push(WorldArithmeticSolution {
289                                world,
290                                constraint,
291                                outcome_expression: folded_result,
292                            });
293                        }
294                        // Other expression types (rule references, etc.) are silently skipped
295                        // as they indicate incomplete substitution
296                    }
297                }
298            }
299        }
300
301        Ok(EnumerationResult {
302            literal_solutions,
303            arithmetic_solutions,
304        })
305    }
306}
307
308// ============================================================================
309// Dependency and topological sorting
310// ============================================================================
311
312/// Collect all rules that a target rule depends on (transitively)
313fn collect_transitive_dependencies(
314    target_rule: &RulePath,
315    rule_map: &HashMap<RulePath, &ExecutableRule>,
316) -> Result<HashSet<RulePath>, crate::Error> {
317    let mut result = HashSet::new();
318    let mut queue = VecDeque::new();
319
320    queue.push_back(target_rule.clone());
321    result.insert(target_rule.clone());
322
323    while let Some(rule_path) = queue.pop_front() {
324        if let Some(rule) = rule_map.get(&rule_path) {
325            // Extract rule dependencies from branch expressions
326            let dependencies = extract_rule_dependencies(rule);
327            for dependency in dependencies {
328                if result.insert(dependency.clone()) {
329                    queue.push_back(dependency);
330                }
331            }
332        }
333    }
334
335    Ok(result)
336}
337
338/// Extract rule paths referenced in an executable rule's expressions
339fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
340    let mut deps = HashSet::new();
341    for branch in &rule.branches {
342        if let Some(ref condition) = branch.condition {
343            extract_rule_paths_from_expression(condition, &mut deps);
344        }
345        extract_rule_paths_from_expression(&branch.result, &mut deps);
346    }
347    deps
348}
349
350/// Recursively extract RulePath references from an expression
351fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
352    match &expr.kind {
353        ExpressionKind::RulePath(rp) => {
354            paths.insert(rp.clone());
355        }
356        ExpressionKind::Arithmetic(left, _, right)
357        | ExpressionKind::Comparison(left, _, right)
358        | ExpressionKind::LogicalAnd(left, right) => {
359            extract_rule_paths_from_expression(left, paths);
360            extract_rule_paths_from_expression(right, paths);
361        }
362        ExpressionKind::LogicalNegation(inner, _)
363        | ExpressionKind::UnitConversion(inner, _)
364        | ExpressionKind::MathematicalComputation(_, inner) => {
365            extract_rule_paths_from_expression(inner, paths);
366        }
367        ExpressionKind::DateRelative(_, date_expr, tolerance) => {
368            extract_rule_paths_from_expression(date_expr, paths);
369            if let Some(tol) = tolerance {
370                extract_rule_paths_from_expression(tol, paths);
371            }
372        }
373        ExpressionKind::DateCalendar(_, _, date_expr) => {
374            extract_rule_paths_from_expression(date_expr, paths);
375        }
376        ExpressionKind::Literal(_)
377        | ExpressionKind::FactPath(_)
378        | ExpressionKind::Veto(_)
379        | ExpressionKind::Now => {}
380    }
381}
382
383// ============================================================================
384// Expression substitution
385// ============================================================================
386
387/// Substitute rule references in an expression with their values in a given world
388///
389/// For each RulePath in the expression, looks up which branch is active in the world
390/// and replaces the RulePath with the branch's result expression.
391fn substitute_rules_in_expression(
392    expr: &Arc<Expression>,
393    world: &World,
394    plan: &ExecutionPlan,
395) -> Result<Expression, crate::Error> {
396    enum WorkItem {
397        Process(usize),
398        BuildArithmetic(ArithmeticComputation, Option<Source>),
399        BuildComparison(ComparisonComputation, Option<Source>),
400        BuildLogicalAnd(Option<Source>),
401        BuildLogicalNegation(NegationType, Option<Source>),
402        BuildUnitConversion(SemanticConversionTarget, Option<Source>),
403        BuildMathematicalComputation(MathematicalComputation, Option<Source>),
404        PopVisitedRules,
405    }
406
407    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
408    let mut work_stack: Vec<WorkItem> = Vec::new();
409    let mut result_pool: Vec<Expression> = Vec::new();
410    let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
411
412    let root_idx = expr_pool.len();
413    expr_pool.push(Arc::clone(expr));
414    work_stack.push(WorkItem::Process(root_idx));
415
416    while let Some(work) = work_stack.pop() {
417        match work {
418            WorkItem::Process(expr_idx) => {
419                let e = &expr_pool[expr_idx];
420                let source_loc = e.source_location.clone();
421
422                match &e.kind {
423                    ExpressionKind::RulePath(rule_path) => {
424                        let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
425                        if visited.contains(rule_path) {
426                            unreachable!(
427                                "BUG: circular rule reference detected during substitution: {}",
428                                rule_path
429                            );
430                        }
431
432                        if let Some(&branch_idx) = world.get(rule_path) {
433                            if let Some(rule) = plan.get_rule_by_path(rule_path) {
434                                if branch_idx < rule.branches.len() {
435                                    let branch = &rule.branches[branch_idx];
436                                    let mut new_visited = visited.clone();
437                                    new_visited.insert(rule_path.clone());
438                                    visited_rules_stack.push(new_visited);
439
440                                    let sub_expr_idx = expr_pool.len();
441                                    expr_pool.push(Arc::new(branch.result.clone()));
442                                    work_stack.push(WorkItem::PopVisitedRules);
443                                    work_stack.push(WorkItem::Process(sub_expr_idx));
444                                    continue;
445                                }
446                            }
447                        }
448                        result_pool.push(Expression::with_source(
449                            ExpressionKind::RulePath(rule_path.clone()),
450                            source_loc,
451                        ));
452                    }
453                    ExpressionKind::Arithmetic(left, op, right) => {
454                        let op_clone = op.clone();
455                        let left_arc = Arc::clone(left);
456                        let right_arc = Arc::clone(right);
457
458                        let left_idx = expr_pool.len();
459                        expr_pool.push(left_arc);
460                        let right_idx = expr_pool.len();
461                        expr_pool.push(right_arc);
462
463                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
464                        work_stack.push(WorkItem::Process(right_idx));
465                        work_stack.push(WorkItem::Process(left_idx));
466                    }
467                    ExpressionKind::Comparison(left, op, right) => {
468                        let op_clone = op.clone();
469                        let left_arc = Arc::clone(left);
470                        let right_arc = Arc::clone(right);
471
472                        let left_idx = expr_pool.len();
473                        expr_pool.push(left_arc);
474                        let right_idx = expr_pool.len();
475                        expr_pool.push(right_arc);
476
477                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
478                        work_stack.push(WorkItem::Process(right_idx));
479                        work_stack.push(WorkItem::Process(left_idx));
480                    }
481                    ExpressionKind::LogicalAnd(left, right) => {
482                        let left_arc = Arc::clone(left);
483                        let right_arc = Arc::clone(right);
484
485                        let left_idx = expr_pool.len();
486                        expr_pool.push(left_arc);
487                        let right_idx = expr_pool.len();
488                        expr_pool.push(right_arc);
489
490                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
491                        work_stack.push(WorkItem::Process(right_idx));
492                        work_stack.push(WorkItem::Process(left_idx));
493                    }
494                    ExpressionKind::LogicalNegation(inner, neg_type) => {
495                        let neg_type_clone = neg_type.clone();
496                        let inner_arc = Arc::clone(inner);
497                        let inner_idx = expr_pool.len();
498                        expr_pool.push(inner_arc);
499                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
500                        work_stack.push(WorkItem::Process(inner_idx));
501                    }
502                    ExpressionKind::UnitConversion(inner, unit) => {
503                        let unit_clone = unit.clone();
504                        let inner_arc = Arc::clone(inner);
505                        let inner_idx = expr_pool.len();
506                        expr_pool.push(inner_arc);
507                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
508                        work_stack.push(WorkItem::Process(inner_idx));
509                    }
510                    ExpressionKind::MathematicalComputation(func, inner) => {
511                        let func_clone = func.clone();
512                        let inner_arc = Arc::clone(inner);
513                        let inner_idx = expr_pool.len();
514                        expr_pool.push(inner_arc);
515                        work_stack.push(WorkItem::BuildMathematicalComputation(
516                            func_clone, source_loc,
517                        ));
518                        work_stack.push(WorkItem::Process(inner_idx));
519                    }
520                    ExpressionKind::Literal(lit) => {
521                        result_pool.push(Expression::with_source(
522                            ExpressionKind::Literal(lit.clone()),
523                            source_loc,
524                        ));
525                    }
526                    ExpressionKind::FactPath(fact_path) => {
527                        result_pool.push(Expression::with_source(
528                            ExpressionKind::FactPath(fact_path.clone()),
529                            source_loc,
530                        ));
531                    }
532                    ExpressionKind::Veto(veto) => {
533                        result_pool.push(Expression::with_source(
534                            ExpressionKind::Veto(veto.clone()),
535                            source_loc,
536                        ));
537                    }
538                    ExpressionKind::Now => {
539                        result_pool.push(Expression::with_source(ExpressionKind::Now, source_loc));
540                    }
541                    ExpressionKind::DateRelative(..) | ExpressionKind::DateCalendar(..) => {
542                        result_pool.push(Expression::with_source(e.kind.clone(), source_loc));
543                    }
544                }
545            }
546            WorkItem::BuildArithmetic(op, source_loc) => {
547                let right = result_pool.pop().unwrap_or_else(|| {
548                    unreachable!(
549                        "BUG: missing right expression for Arithmetic during inversion hydration"
550                    )
551                });
552                let left = result_pool.pop().unwrap_or_else(|| {
553                    unreachable!(
554                        "BUG: missing left expression for Arithmetic during inversion hydration"
555                    )
556                });
557                result_pool.push(Expression::with_source(
558                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
559                    source_loc,
560                ));
561            }
562            WorkItem::BuildComparison(op, source_loc) => {
563                let right = result_pool.pop().unwrap_or_else(|| {
564                    unreachable!(
565                        "BUG: missing right expression for Comparison during inversion hydration"
566                    )
567                });
568                let left = result_pool.pop().unwrap_or_else(|| {
569                    unreachable!(
570                        "BUG: missing left expression for Comparison during inversion hydration"
571                    )
572                });
573                result_pool.push(Expression::with_source(
574                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
575                    source_loc,
576                ));
577            }
578            WorkItem::BuildLogicalAnd(source_loc) => {
579                let right = result_pool.pop().unwrap_or_else(|| {
580                    unreachable!(
581                        "BUG: missing right expression for LogicalAnd during inversion hydration"
582                    )
583                });
584                let left = result_pool.pop().unwrap_or_else(|| {
585                    unreachable!(
586                        "BUG: missing left expression for LogicalAnd during inversion hydration"
587                    )
588                });
589                result_pool.push(Expression::with_source(
590                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
591                    source_loc,
592                ));
593            }
594            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
595                let inner = result_pool
596                    .pop()
597                    .expect("Internal error: missing expression for LogicalNegation");
598                result_pool.push(Expression::with_source(
599                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
600                    source_loc,
601                ));
602            }
603            WorkItem::BuildUnitConversion(unit, source_loc) => {
604                let inner = result_pool
605                    .pop()
606                    .expect("Internal error: missing expression for UnitConversion");
607                result_pool.push(Expression::with_source(
608                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
609                    source_loc,
610                ));
611            }
612            WorkItem::BuildMathematicalComputation(func, source_loc) => {
613                let inner = result_pool
614                    .pop()
615                    .expect("Internal error: missing expression for MathematicalComputation");
616                result_pool.push(Expression::with_source(
617                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
618                    source_loc,
619                ));
620            }
621            WorkItem::PopVisitedRules => {
622                visited_rules_stack.pop();
623            }
624        }
625    }
626
627    Ok(result_pool
628        .pop()
629        .unwrap_or_else(|| unreachable!("BUG: no result from substitution")))
630}
631
632// ============================================================================
633// Fact hydration
634// ============================================================================
635
636/// Hydrate fact references in an expression with their known values
637///
638/// For each FactPath in the expression, if the fact is in provided_facts,
639/// replaces the FactPath with a Literal containing the fact's value.
640fn hydrate_facts_in_expression(
641    expr: &Arc<Expression>,
642    plan: &ExecutionPlan,
643    provided_facts: &HashSet<FactPath>,
644) -> Result<Expression, crate::Error> {
645    enum WorkItem {
646        Process(usize),
647        BuildArithmetic(ArithmeticComputation, Option<Source>),
648        BuildComparison(ComparisonComputation, Option<Source>),
649        BuildLogicalAnd(Option<Source>),
650        BuildLogicalNegation(NegationType, Option<Source>),
651        BuildUnitConversion(SemanticConversionTarget, Option<Source>),
652        BuildMathematicalComputation(MathematicalComputation, Option<Source>),
653    }
654
655    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
656    let mut work_stack: Vec<WorkItem> = Vec::new();
657    let mut result_pool: Vec<Expression> = Vec::new();
658
659    let root_idx = expr_pool.len();
660    expr_pool.push(Arc::clone(expr));
661    work_stack.push(WorkItem::Process(root_idx));
662
663    while let Some(work) = work_stack.pop() {
664        match work {
665            WorkItem::Process(expr_idx) => {
666                let (source_loc, expr_kind_ref) = {
667                    let e = &expr_pool[expr_idx];
668                    (e.source_location.clone(), &e.kind)
669                };
670
671                match expr_kind_ref {
672                    ExpressionKind::FactPath(fact_path) => {
673                        if provided_facts.contains(fact_path) {
674                            if let Some(lit) = plan.facts.get(fact_path).and_then(|d| d.value()) {
675                                result_pool.push(Expression::with_source(
676                                    ExpressionKind::Literal(Box::new(lit.clone())),
677                                    source_loc,
678                                ));
679                                continue;
680                            }
681                        }
682                        result_pool.push(Expression::with_source(
683                            ExpressionKind::FactPath(fact_path.clone()),
684                            source_loc,
685                        ));
686                    }
687                    ExpressionKind::Arithmetic(left, op, right) => {
688                        let op_clone = op.clone();
689                        let left_arc = Arc::clone(left);
690                        let right_arc = Arc::clone(right);
691
692                        let left_idx = expr_pool.len();
693                        expr_pool.push(left_arc);
694                        let right_idx = expr_pool.len();
695                        expr_pool.push(right_arc);
696
697                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
698                        work_stack.push(WorkItem::Process(right_idx));
699                        work_stack.push(WorkItem::Process(left_idx));
700                    }
701                    ExpressionKind::Comparison(left, op, right) => {
702                        let op_clone = op.clone();
703                        let left_arc = Arc::clone(left);
704                        let right_arc = Arc::clone(right);
705
706                        let left_idx = expr_pool.len();
707                        expr_pool.push(left_arc);
708                        let right_idx = expr_pool.len();
709                        expr_pool.push(right_arc);
710
711                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
712                        work_stack.push(WorkItem::Process(right_idx));
713                        work_stack.push(WorkItem::Process(left_idx));
714                    }
715                    ExpressionKind::LogicalAnd(left, right) => {
716                        let left_arc = Arc::clone(left);
717                        let right_arc = Arc::clone(right);
718
719                        let left_idx = expr_pool.len();
720                        expr_pool.push(left_arc);
721                        let right_idx = expr_pool.len();
722                        expr_pool.push(right_arc);
723
724                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
725                        work_stack.push(WorkItem::Process(right_idx));
726                        work_stack.push(WorkItem::Process(left_idx));
727                    }
728                    ExpressionKind::LogicalNegation(inner, neg_type) => {
729                        let neg_type_clone = neg_type.clone();
730                        let inner_arc = Arc::clone(inner);
731                        let inner_idx = expr_pool.len();
732                        expr_pool.push(inner_arc);
733                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
734                        work_stack.push(WorkItem::Process(inner_idx));
735                    }
736                    ExpressionKind::UnitConversion(inner, unit) => {
737                        let unit_clone = unit.clone();
738                        let inner_arc = Arc::clone(inner);
739                        let inner_idx = expr_pool.len();
740                        expr_pool.push(inner_arc);
741                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
742                        work_stack.push(WorkItem::Process(inner_idx));
743                    }
744                    ExpressionKind::MathematicalComputation(func, inner) => {
745                        let func_clone = func.clone();
746                        let inner_arc = Arc::clone(inner);
747                        let inner_idx = expr_pool.len();
748                        expr_pool.push(inner_arc);
749                        work_stack.push(WorkItem::BuildMathematicalComputation(
750                            func_clone, source_loc,
751                        ));
752                        work_stack.push(WorkItem::Process(inner_idx));
753                    }
754                    ExpressionKind::Literal(lit) => {
755                        result_pool.push(Expression::with_source(
756                            ExpressionKind::Literal(lit.clone()),
757                            source_loc,
758                        ));
759                    }
760                    ExpressionKind::RulePath(rule_path) => {
761                        result_pool.push(Expression::with_source(
762                            ExpressionKind::RulePath(rule_path.clone()),
763                            source_loc,
764                        ));
765                    }
766                    ExpressionKind::Veto(veto) => {
767                        result_pool.push(Expression::with_source(
768                            ExpressionKind::Veto(veto.clone()),
769                            source_loc,
770                        ));
771                    }
772                    ExpressionKind::Now => {
773                        result_pool.push(Expression::with_source(ExpressionKind::Now, source_loc));
774                    }
775                    ExpressionKind::DateRelative(..) | ExpressionKind::DateCalendar(..) => {
776                        result_pool
777                            .push(Expression::with_source(expr_kind_ref.clone(), source_loc));
778                    }
779                }
780            }
781            WorkItem::BuildArithmetic(op, source_loc) => {
782                let right = result_pool.pop().unwrap_or_else(|| {
783                    unreachable!("BUG: missing right expression for Arithmetic")
784                });
785                let left = result_pool
786                    .pop()
787                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Arithmetic"));
788                result_pool.push(Expression::with_source(
789                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
790                    source_loc,
791                ));
792            }
793            WorkItem::BuildComparison(op, source_loc) => {
794                let right = result_pool.pop().unwrap_or_else(|| {
795                    unreachable!("BUG: missing right expression for Comparison")
796                });
797                let left = result_pool
798                    .pop()
799                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Comparison"));
800                result_pool.push(Expression::with_source(
801                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
802                    source_loc,
803                ));
804            }
805            WorkItem::BuildLogicalAnd(source_loc) => {
806                let right = result_pool.pop().unwrap_or_else(|| {
807                    unreachable!("BUG: missing right expression for LogicalAnd")
808                });
809                let left = result_pool
810                    .pop()
811                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalAnd"));
812                result_pool.push(Expression::with_source(
813                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
814                    source_loc,
815                ));
816            }
817            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
818                let inner = result_pool
819                    .pop()
820                    .expect("Internal error: missing expression for LogicalNegation");
821                result_pool.push(Expression::with_source(
822                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
823                    source_loc,
824                ));
825            }
826            WorkItem::BuildUnitConversion(unit, source_loc) => {
827                let inner = result_pool
828                    .pop()
829                    .expect("Internal error: missing expression for UnitConversion");
830                result_pool.push(Expression::with_source(
831                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
832                    source_loc,
833                ));
834            }
835            WorkItem::BuildMathematicalComputation(func, source_loc) => {
836                let inner = result_pool
837                    .pop()
838                    .expect("Internal error: missing expression for MathematicalComputation");
839                result_pool.push(Expression::with_source(
840                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
841                    source_loc,
842                ));
843            }
844        }
845    }
846
847    Ok(result_pool
848        .pop()
849        .expect("Internal error: no result from hydration"))
850}
851
852// ============================================================================
853// Constant folding
854// ============================================================================
855
856/// Extract an outcome (value or veto) from an expression
857fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
858    match &expr.kind {
859        ExpressionKind::Literal(lit) => {
860            Some(OperationResult::Value(Box::new(lit.as_ref().clone())))
861        }
862        ExpressionKind::Veto(ve) => Some(OperationResult::Veto(ve.message.clone())),
863        _ => None,
864    }
865}
866
867/// Check if an expression is a boolean-producing expression (comparison or logical)
868fn is_boolean_expression(expr: &Expression) -> bool {
869    matches!(
870        &expr.kind,
871        ExpressionKind::Comparison(_, _, _)
872            | ExpressionKind::LogicalAnd(_, _)
873            | ExpressionKind::LogicalNegation(_, _)
874    )
875}
876
877/// Check if an expression is an arithmetic expression (contains arithmetic operations)
878///
879/// Returns true for expressions like `price * 5`, `x + y`, etc.
880fn is_arithmetic_expression(expr: &Expression) -> bool {
881    match &expr.kind {
882        ExpressionKind::Arithmetic(_, _, _) => true,
883        ExpressionKind::MathematicalComputation(_, _) => true,
884        ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
885        ExpressionKind::FactPath(_) => true, // Lone fact is also solvable
886        _ => false,
887    }
888}
889
890/// For boolean expressions that can't be evaluated to a literal (e.g., `age > 18`),
891/// create two solutions: one where the expression is true, one where it's false.
892///
893/// This allows inversion to work with rules like `rule of_age: age > 18`
894fn create_boolean_expression_solutions(
895    world: World,
896    base_constraint: Constraint,
897    boolean_expr: &Expression,
898) -> Result<(Vec<WorldSolution>, Vec<WorldSolution>), crate::Error> {
899    // Convert boolean expression to constraint
900    let expr_constraint = Constraint::from_expression(boolean_expr)?;
901
902    // Solution where the boolean expression is true
903    let true_constraint = base_constraint.clone().and(expr_constraint.clone());
904    let simplified_true = true_constraint.simplify()?;
905
906    let true_solutions = if !simplified_true.is_false() {
907        vec![WorldSolution {
908            world: world.clone(),
909            constraint: simplified_true,
910            outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(true))),
911        }]
912    } else {
913        vec![]
914    };
915
916    // Solution where the boolean expression is false
917    let false_constraint = base_constraint.and(expr_constraint.not());
918    let simplified_false = false_constraint.simplify()?;
919
920    let false_solutions = if !simplified_false.is_false() {
921        vec![WorldSolution {
922            world,
923            constraint: simplified_false,
924            outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(false))),
925        }]
926    } else {
927        vec![]
928    };
929
930    Ok((true_solutions, false_solutions))
931}
932
933/// Attempt constant folding on an expression (simplified version for outcomes)
934pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
935    match &expr.kind {
936        ExpressionKind::Literal(_) => Some(expr.clone()),
937        ExpressionKind::Arithmetic(left, op, right) => {
938            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
939            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
940            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
941                (&left_folded.kind, &right_folded.kind)
942            {
943                if let Some(result) = evaluate_arithmetic(left_val.as_ref(), op, right_val.as_ref())
944                {
945                    return Some(Expression::with_source(
946                        ExpressionKind::Literal(Box::new(result)),
947                        expr.source_location.clone(),
948                    ));
949                }
950            }
951            Some(Expression::with_source(
952                ExpressionKind::Arithmetic(
953                    Arc::new(left_folded),
954                    op.clone(),
955                    Arc::new(right_folded),
956                ),
957                expr.source_location.clone(),
958            ))
959        }
960        ExpressionKind::Comparison(left, op, right) => {
961            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
962            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
963            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
964                (&left_folded.kind, &right_folded.kind)
965            {
966                if let Some(result) = evaluate_comparison(left_val.as_ref(), op, right_val.as_ref())
967                {
968                    return Some(Expression::with_source(
969                        ExpressionKind::Literal(Box::new(LiteralValue::from_bool(result))),
970                        expr.source_location.clone(),
971                    ));
972                }
973            }
974            Some(Expression::with_source(
975                ExpressionKind::Comparison(
976                    Arc::new(left_folded),
977                    op.clone(),
978                    Arc::new(right_folded),
979                ),
980                expr.source_location.clone(),
981            ))
982        }
983        _ => None,
984    }
985}
986
987/// Evaluate an arithmetic operation on two literals
988///
989/// Delegates to the computation module for consistent behavior
990fn evaluate_arithmetic(
991    left: &LiteralValue,
992    op: &ArithmeticComputation,
993    right: &LiteralValue,
994) -> Option<LiteralValue> {
995    use crate::computation::arithmetic_operation;
996
997    match arithmetic_operation(left, op, right) {
998        OperationResult::Value(lit) => Some(lit.as_ref().clone()),
999        OperationResult::Veto(_) => None,
1000    }
1001}
1002
1003/// Evaluate a comparison operation on two literals
1004///
1005/// Delegates to the computation module for consistent behavior
1006fn evaluate_comparison(
1007    left: &LiteralValue,
1008    op: &ComparisonComputation,
1009    right: &LiteralValue,
1010) -> Option<bool> {
1011    use crate::computation::comparison_operation;
1012    use crate::planning::semantics::ValueKind;
1013
1014    match comparison_operation(left, op, right) {
1015        OperationResult::Value(lit) => match &lit.value {
1016            ValueKind::Boolean(b) => Some(*b),
1017            _ => None,
1018        },
1019        _ => None,
1020    }
1021}
1022
1023// ============================================================================
1024// Tests
1025// ============================================================================
1026
1027#[cfg(test)]
1028mod tests {
1029    use super::*;
1030    use crate::planning::semantics::ValueKind;
1031    use rust_decimal::Decimal;
1032
1033    fn literal_expr(val: LiteralValue) -> Expression {
1034        Expression::with_source(ExpressionKind::Literal(Box::new(val)), None)
1035    }
1036
1037    fn fact_expr(name: &str) -> Expression {
1038        Expression::with_source(
1039            ExpressionKind::FactPath(FactPath::new(vec![], name.to_string())),
1040            None,
1041        )
1042    }
1043
1044    fn num(n: i64) -> LiteralValue {
1045        LiteralValue::number(Decimal::from(n))
1046    }
1047
1048    #[test]
1049    fn test_world_new() {
1050        let world = World::new();
1051        assert!(world.0.is_empty());
1052    }
1053
1054    #[test]
1055    fn test_world_insert_and_get() {
1056        let mut world = World::new();
1057        let rule_path = RulePath {
1058            segments: vec![],
1059            rule: "test_rule".to_string(),
1060        };
1061        world.insert(rule_path.clone(), 2);
1062        assert_eq!(world.get(&rule_path), Some(&2));
1063    }
1064
1065    fn empty_plan() -> ExecutionPlan {
1066        ExecutionPlan {
1067            spec_name: "test".to_string(),
1068            facts: indexmap::IndexMap::new(),
1069            rules: Vec::new(),
1070            sources: HashMap::new(),
1071            meta: HashMap::new(),
1072            valid_from: None,
1073            valid_to: None,
1074        }
1075    }
1076
1077    #[test]
1078    fn test_hydrate_literal_unchanged() {
1079        let plan = empty_plan();
1080        let provided: HashSet<FactPath> = HashSet::new();
1081
1082        let expr = literal_expr(num(42));
1083        let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1084
1085        if let ExpressionKind::Literal(lit) = &result.kind {
1086            assert!(matches!(&lit.value, ValueKind::Number(_)));
1087        } else {
1088            panic!("Expected literal number");
1089        }
1090    }
1091
1092    #[test]
1093    fn test_hydrate_fact_not_provided() {
1094        let plan = empty_plan();
1095        let provided: HashSet<FactPath> = HashSet::new();
1096
1097        let expr = fact_expr("age");
1098        let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1099
1100        assert!(matches!(result.kind, ExpressionKind::FactPath(_)));
1101    }
1102
1103    #[test]
1104    fn test_constant_fold_arithmetic() {
1105        let left = literal_expr(num(10));
1106        let right = literal_expr(num(5));
1107        let expr = Expression::with_source(
1108            ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
1109            None,
1110        );
1111
1112        let folded = try_constant_fold_expression(&expr).unwrap();
1113
1114        if let ExpressionKind::Literal(lit) = &folded.kind {
1115            if let ValueKind::Number(n) = &lit.value {
1116                assert_eq!(*n, Decimal::from(15));
1117            } else {
1118                panic!("Expected literal number");
1119            }
1120        } else {
1121            panic!("Expected literal number");
1122        }
1123    }
1124
1125    #[test]
1126    fn test_constant_fold_comparison() {
1127        let left = literal_expr(num(10));
1128        let right = literal_expr(num(5));
1129        let expr = Expression::with_source(
1130            ExpressionKind::Comparison(
1131                Arc::new(left),
1132                ComparisonComputation::GreaterThan,
1133                Arc::new(right),
1134            ),
1135            None,
1136        );
1137
1138        let folded = try_constant_fold_expression(&expr).unwrap();
1139
1140        if let ExpressionKind::Literal(lit) = &folded.kind {
1141            if let ValueKind::Boolean(b) = &lit.value {
1142                assert!(*b);
1143            } else {
1144                panic!("Expected literal boolean");
1145            }
1146        } else {
1147            panic!("Expected literal boolean");
1148        }
1149    }
1150}