Skip to main content

lemma/inversion/
world.rs

1//! World enumeration for inversion
2//!
3//! A "world" is a complete assignment of which branch is active for each rule.
4//! This module enumerates all valid worlds for a target rule.
5//!
6//! Also includes expression substitution and hydration utilities.
7
8use crate::planning::semantics::{
9    ArithmeticComputation, ComparisonComputation, DataPath, Expression, ExpressionKind,
10    LiteralValue, MathematicalComputation, NegationType, RulePath, SemanticConversionTarget,
11    Source,
12};
13use crate::planning::{ExecutableRule, ExecutionPlan};
14use crate::OperationResult;
15use serde::ser::{Serialize, SerializeMap, Serializer};
16use std::collections::{HashMap, HashSet, VecDeque};
17use std::sync::Arc;
18
19use super::constraint::Constraint;
20
21/// A world assigns each rule to one of its branch indices
22#[derive(Debug, Clone, Default, PartialEq, Eq)]
23pub struct World(HashMap<RulePath, usize>);
24
25impl World {
26    /// Create a new empty world
27    pub fn new() -> Self {
28        Self(HashMap::new())
29    }
30
31    /// Get the branch index for a rule
32    pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
33        self.0.get(rule_path)
34    }
35
36    /// Insert a branch assignment for a rule
37    pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
38        self.0.insert(rule_path, branch_idx)
39    }
40
41    /// Iterate over all branch assignments
42    pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
43        self.0.iter()
44    }
45}
46
47impl Serialize for World {
48    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49    where
50        S: Serializer,
51    {
52        let mut map = serializer.serialize_map(Some(self.0.len()))?;
53        for (k, v) in &self.0 {
54            map.serialize_entry(&k.to_string(), v)?;
55        }
56        map.end()
57    }
58}
59
60/// A solution from world enumeration with a resolved outcome
61#[derive(Debug, Clone)]
62pub(super) struct WorldSolution {
63    /// The world (branch assignment) that produced this solution
64    pub world: World,
65    /// The constraint under which this solution applies (data only, no rule references)
66    pub constraint: Constraint,
67    /// The outcome (value or veto)
68    pub outcome: OperationResult,
69}
70
71/// A solution from world enumeration with an arithmetic expression outcome
72///
73/// This represents cases where the outcome is a computed expression (like `price * 5`)
74/// that couldn't be evaluated to a literal because it contains unknown data.
75/// These need algebraic solving to determine the input values.
76#[derive(Debug, Clone)]
77pub(super) struct WorldArithmeticSolution {
78    /// The world (branch assignment) that produced this solution
79    pub world: World,
80    /// The constraint under which this solution applies (data only, no rule references)
81    pub constraint: Constraint,
82    /// The outcome expression (contains unknown data)
83    pub outcome_expression: Expression,
84}
85
86/// Result of world enumeration containing both literal and arithmetic solutions
87#[derive(Debug, Clone)]
88pub(super) struct EnumerationResult {
89    /// Solutions with literal outcomes (can be directly compared to target)
90    pub literal_solutions: Vec<WorldSolution>,
91    /// Solutions with arithmetic outcomes (need algebraic solving)
92    pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
93}
94
95/// Enumerates valid worlds for a target rule
96pub(super) struct WorldEnumerator<'a> {
97    plan: &'a ExecutionPlan,
98    /// Rules to process, in topological order (dependencies first)
99    rules_in_order: Vec<RulePath>,
100    /// Cache: rule path -> executable rule (for quick lookup)
101    rule_cache: HashMap<RulePath, &'a ExecutableRule>,
102}
103
104impl<'a> WorldEnumerator<'a> {
105    /// Create a new world enumerator for a target rule
106    pub(super) fn new(
107        plan: &'a ExecutionPlan,
108        target_rule: &RulePath,
109    ) -> Result<Self, crate::Error> {
110        // Build rule lookup from execution plan
111        let rule_map: HashMap<RulePath, &ExecutableRule> =
112            plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
113
114        // Find all rules that the target rule depends on (transitively)
115        let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
116
117        // Plan rules are already in topological order, so filter and preserve order
118        let rules_in_order: Vec<RulePath> = plan
119            .rules
120            .iter()
121            .filter(|r| dependent_rules.contains(&r.path))
122            .map(|r| r.path.clone())
123            .collect();
124
125        // Build rule cache for quick lookup (only rules we need)
126        let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
127            .iter()
128            .filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
129            .collect();
130
131        Ok(Self {
132            plan,
133            rules_in_order,
134            rule_cache,
135        })
136    }
137
138    /// Enumerate all valid worlds for the target rule
139    ///
140    /// Returns an `EnumerationResult` containing:
141    /// - `literal_solutions`: Worlds where the outcome is a concrete literal value
142    /// - `arithmetic_solutions`: Worlds where the outcome is an arithmetic expression
143    ///   containing unknown data (needs algebraic solving)
144    pub(super) fn enumerate(
145        &mut self,
146        provided_data: &HashSet<DataPath>,
147    ) -> Result<EnumerationResult, crate::Error> {
148        if self.rules_in_order.is_empty() {
149            return Ok(EnumerationResult {
150                literal_solutions: vec![],
151                arithmetic_solutions: vec![],
152            });
153        }
154
155        // Start with a single empty world and true constraint
156        let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
157
158        // Process each rule in topological order
159        for rule_path in &self.rules_in_order.clone() {
160            let rule_node = match self.rule_cache.get(rule_path) {
161                Some(node) => *node,
162                None => continue,
163            };
164
165            let mut next_worlds = Vec::new();
166
167            for (world, accumulated_constraint) in current_worlds {
168                // For each branch in this rule
169                for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
170                    // Create new world with this branch assignment
171                    let mut new_world = world.clone();
172                    new_world.insert(rule_path.clone(), branch_idx);
173
174                    // Get branch constraint
175                    // For "last wins" semantics: all LATER branches must have FALSE conditions
176                    let mut branch_constraint = if let Some(ref condition) = branch.condition {
177                        // This branch's condition must be TRUE
178                        let substituted_condition = substitute_rules_in_expression(
179                            &Arc::new(condition.clone()),
180                            &new_world,
181                            self.plan,
182                        )?;
183                        let hydrated_condition = hydrate_data_in_expression(
184                            &Arc::new(substituted_condition),
185                            self.plan,
186                            provided_data,
187                        )?;
188                        Constraint::from_expression(&hydrated_condition)?
189                    } else {
190                        // Default branch has no explicit condition
191                        Constraint::True
192                    };
193
194                    // For "last wins": all LATER branches must NOT match
195                    // (their conditions must be FALSE)
196                    for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
197                        if let Some(ref later_condition) = later_branch.condition {
198                            let substituted_later = substitute_rules_in_expression(
199                                &Arc::new(later_condition.clone()),
200                                &new_world,
201                                self.plan,
202                            )?;
203                            let hydrated_later = hydrate_data_in_expression(
204                                &Arc::new(substituted_later),
205                                self.plan,
206                                provided_data,
207                            )?;
208                            let later_constraint = Constraint::from_expression(&hydrated_later)?;
209                            // Later branch's condition must be FALSE
210                            branch_constraint = branch_constraint.and(later_constraint.not());
211                        }
212                    }
213
214                    // Combine with accumulated constraint
215                    let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
216
217                    // Simplify and check if satisfiable
218                    let simplified = combined_constraint.simplify()?;
219
220                    // Only keep if not contradictory
221                    if !simplified.is_false() {
222                        next_worlds.push((new_world, simplified));
223                    }
224                }
225            }
226
227            current_worlds = next_worlds;
228
229            // Early exit if no valid worlds remain
230            if current_worlds.is_empty() {
231                break;
232            }
233        }
234
235        // Convert to WorldSolutions and WorldArithmeticSolutions
236        let target_rule_path = self
237            .rules_in_order
238            .last()
239            .unwrap_or_else(|| unreachable!("BUG: no rules in order for world enumeration"));
240
241        let mut literal_solutions = Vec::new();
242        let mut arithmetic_solutions = Vec::new();
243
244        for (world, constraint) in current_worlds {
245            // Get the outcome from the target rule's branch
246            if let Some(&branch_idx) = world.get(target_rule_path) {
247                if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
248                    if branch_idx < rule_node.branches.len() {
249                        let branch = &rule_node.branches[branch_idx];
250
251                        // Substitute and hydrate the result expression
252                        let substituted_result = substitute_rules_in_expression(
253                            &Arc::new(branch.result.clone()),
254                            &world,
255                            self.plan,
256                        )?;
257
258                        let hydrated_result = hydrate_data_in_expression(
259                            &Arc::new(substituted_result),
260                            self.plan,
261                            provided_data,
262                        )?;
263
264                        // Try to fold the result to a literal
265                        let folded_result = try_constant_fold_expression(&hydrated_result)
266                            .unwrap_or(hydrated_result.clone());
267
268                        // Try to extract a literal value directly
269                        if let Some(outcome) = extract_outcome(&folded_result) {
270                            literal_solutions.push(WorldSolution {
271                                world,
272                                constraint,
273                                outcome,
274                            });
275                        } else if is_boolean_expression(&folded_result) {
276                            // For boolean expressions (comparisons, logical ops), create two solutions:
277                            // one where the expression is true, one where it's false
278                            let (true_solutions, false_solutions) =
279                                create_boolean_expression_solutions(
280                                    world,
281                                    constraint,
282                                    &folded_result,
283                                )?;
284                            literal_solutions.extend(true_solutions);
285                            literal_solutions.extend(false_solutions);
286                        } else if is_arithmetic_expression(&folded_result) {
287                            // Arithmetic expression with unknown data - needs algebraic solving
288                            arithmetic_solutions.push(WorldArithmeticSolution {
289                                world,
290                                constraint,
291                                outcome_expression: folded_result,
292                            });
293                        }
294                        // Other expression types (rule references, etc.) are silently skipped
295                        // as they indicate incomplete substitution
296                    }
297                }
298            }
299        }
300
301        Ok(EnumerationResult {
302            literal_solutions,
303            arithmetic_solutions,
304        })
305    }
306}
307
308// ============================================================================
309// Dependency and topological sorting
310// ============================================================================
311
312/// Collect all rules that a target rule depends on (transitively)
313fn collect_transitive_dependencies(
314    target_rule: &RulePath,
315    rule_map: &HashMap<RulePath, &ExecutableRule>,
316) -> Result<HashSet<RulePath>, crate::Error> {
317    let mut result = HashSet::new();
318    let mut queue = VecDeque::new();
319
320    queue.push_back(target_rule.clone());
321    result.insert(target_rule.clone());
322
323    while let Some(rule_path) = queue.pop_front() {
324        if let Some(rule) = rule_map.get(&rule_path) {
325            // Extract rule dependencies from branch expressions
326            let dependencies = extract_rule_dependencies(rule);
327            for dependency in dependencies {
328                if result.insert(dependency.clone()) {
329                    queue.push_back(dependency);
330                }
331            }
332        }
333    }
334
335    Ok(result)
336}
337
338/// Extract rule paths referenced in an executable rule's expressions
339fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
340    let mut deps = HashSet::new();
341    for branch in &rule.branches {
342        if let Some(ref condition) = branch.condition {
343            extract_rule_paths_from_expression(condition, &mut deps);
344        }
345        extract_rule_paths_from_expression(&branch.result, &mut deps);
346    }
347    deps
348}
349
350/// Recursively extract RulePath references from an expression
351fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
352    match &expr.kind {
353        ExpressionKind::RulePath(rp) => {
354            paths.insert(rp.clone());
355        }
356        ExpressionKind::Arithmetic(left, _, right)
357        | ExpressionKind::Comparison(left, _, right)
358        | ExpressionKind::LogicalAnd(left, right)
359        | ExpressionKind::LogicalOr(left, right)
360        | ExpressionKind::RangeLiteral(left, right)
361        | ExpressionKind::RangeContainment(left, right) => {
362            extract_rule_paths_from_expression(left, paths);
363            extract_rule_paths_from_expression(right, paths);
364        }
365        ExpressionKind::LogicalNegation(inner, _)
366        | ExpressionKind::UnitConversion(inner, _)
367        | ExpressionKind::MathematicalComputation(_, inner)
368        | ExpressionKind::ResultIsVeto(inner)
369        | ExpressionKind::PastFutureRange(_, inner) => {
370            extract_rule_paths_from_expression(inner, paths);
371        }
372        ExpressionKind::DateRelative(_, date_expr) => {
373            extract_rule_paths_from_expression(date_expr, paths);
374        }
375        ExpressionKind::DateCalendar(_, _, date_expr) => {
376            extract_rule_paths_from_expression(date_expr, paths);
377        }
378        ExpressionKind::Literal(_)
379        | ExpressionKind::DataPath(_)
380        | ExpressionKind::Veto(_)
381        | ExpressionKind::Now => {}
382    }
383}
384
385// ============================================================================
386// Expression substitution
387// ============================================================================
388
389/// Substitute rule references in an expression with their values in a given world
390///
391/// For each RulePath in the expression, looks up which branch is active in the world
392/// and replaces the RulePath with the branch's result expression.
393fn substitute_rules_in_expression(
394    expr: &Arc<Expression>,
395    world: &World,
396    plan: &ExecutionPlan,
397) -> Result<Expression, crate::Error> {
398    enum WorkItem {
399        Process(usize),
400        BuildArithmetic(ArithmeticComputation, Option<Source>),
401        BuildComparison(ComparisonComputation, Option<Source>),
402        BuildLogicalAnd(Option<Source>),
403        BuildLogicalOr(Option<Source>),
404        BuildLogicalNegation(NegationType, Option<Source>),
405        BuildResultIsVeto(Option<Source>),
406        BuildUnitConversion(SemanticConversionTarget, Option<Source>),
407        BuildMathematicalComputation(MathematicalComputation, Option<Source>),
408        PopVisitedRules,
409    }
410
411    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
412    let mut work_stack: Vec<WorkItem> = Vec::new();
413    let mut result_pool: Vec<Expression> = Vec::new();
414    let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
415
416    let root_idx = expr_pool.len();
417    expr_pool.push(Arc::clone(expr));
418    work_stack.push(WorkItem::Process(root_idx));
419
420    while let Some(work) = work_stack.pop() {
421        match work {
422            WorkItem::Process(expr_idx) => {
423                let e = &expr_pool[expr_idx];
424                let source_loc = e.source_location.clone();
425
426                match &e.kind {
427                    ExpressionKind::RulePath(rule_path) => {
428                        let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
429                        if visited.contains(rule_path) {
430                            unreachable!(
431                                "BUG: circular rule reference detected during substitution: {}",
432                                rule_path
433                            );
434                        }
435
436                        if let Some(&branch_idx) = world.get(rule_path) {
437                            if let Some(rule) = plan.get_rule_by_path(rule_path) {
438                                if branch_idx < rule.branches.len() {
439                                    let branch = &rule.branches[branch_idx];
440                                    let mut new_visited = visited.clone();
441                                    new_visited.insert(rule_path.clone());
442                                    visited_rules_stack.push(new_visited);
443
444                                    let sub_expr_idx = expr_pool.len();
445                                    expr_pool.push(Arc::new(branch.result.clone()));
446                                    work_stack.push(WorkItem::PopVisitedRules);
447                                    work_stack.push(WorkItem::Process(sub_expr_idx));
448                                    continue;
449                                }
450                            }
451                        }
452                        result_pool.push(Expression::with_source(
453                            ExpressionKind::RulePath(rule_path.clone()),
454                            source_loc,
455                        ));
456                    }
457                    ExpressionKind::Arithmetic(left, op, right) => {
458                        let op_clone = op.clone();
459                        let left_arc = Arc::clone(left);
460                        let right_arc = Arc::clone(right);
461
462                        let left_idx = expr_pool.len();
463                        expr_pool.push(left_arc);
464                        let right_idx = expr_pool.len();
465                        expr_pool.push(right_arc);
466
467                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
468                        work_stack.push(WorkItem::Process(right_idx));
469                        work_stack.push(WorkItem::Process(left_idx));
470                    }
471                    ExpressionKind::Comparison(left, op, right) => {
472                        let op_clone = op.clone();
473                        let left_arc = Arc::clone(left);
474                        let right_arc = Arc::clone(right);
475
476                        let left_idx = expr_pool.len();
477                        expr_pool.push(left_arc);
478                        let right_idx = expr_pool.len();
479                        expr_pool.push(right_arc);
480
481                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
482                        work_stack.push(WorkItem::Process(right_idx));
483                        work_stack.push(WorkItem::Process(left_idx));
484                    }
485                    ExpressionKind::LogicalAnd(left, right) => {
486                        let left_arc = Arc::clone(left);
487                        let right_arc = Arc::clone(right);
488
489                        let left_idx = expr_pool.len();
490                        expr_pool.push(left_arc);
491                        let right_idx = expr_pool.len();
492                        expr_pool.push(right_arc);
493
494                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
495                        work_stack.push(WorkItem::Process(right_idx));
496                        work_stack.push(WorkItem::Process(left_idx));
497                    }
498                    ExpressionKind::LogicalOr(left, right) => {
499                        let left_arc = Arc::clone(left);
500                        let right_arc = Arc::clone(right);
501
502                        let left_idx = expr_pool.len();
503                        expr_pool.push(left_arc);
504                        let right_idx = expr_pool.len();
505                        expr_pool.push(right_arc);
506
507                        work_stack.push(WorkItem::BuildLogicalOr(source_loc));
508                        work_stack.push(WorkItem::Process(right_idx));
509                        work_stack.push(WorkItem::Process(left_idx));
510                    }
511                    ExpressionKind::RangeLiteral(..)
512                    | ExpressionKind::PastFutureRange(..)
513                    | ExpressionKind::RangeContainment(..) => {
514                        result_pool.push(Expression::with_source(e.kind.clone(), source_loc));
515                    }
516                    ExpressionKind::LogicalNegation(inner, neg_type) => {
517                        let neg_type_clone = neg_type.clone();
518                        let inner_arc = Arc::clone(inner);
519                        let inner_idx = expr_pool.len();
520                        expr_pool.push(inner_arc);
521                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
522                        work_stack.push(WorkItem::Process(inner_idx));
523                    }
524                    ExpressionKind::UnitConversion(inner, unit) => {
525                        let unit_clone = unit.clone();
526                        let inner_arc = Arc::clone(inner);
527                        let inner_idx = expr_pool.len();
528                        expr_pool.push(inner_arc);
529                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
530                        work_stack.push(WorkItem::Process(inner_idx));
531                    }
532                    ExpressionKind::MathematicalComputation(func, inner) => {
533                        let func_clone = func.clone();
534                        let inner_arc = Arc::clone(inner);
535                        let inner_idx = expr_pool.len();
536                        expr_pool.push(inner_arc);
537                        work_stack.push(WorkItem::BuildMathematicalComputation(
538                            func_clone, source_loc,
539                        ));
540                        work_stack.push(WorkItem::Process(inner_idx));
541                    }
542                    ExpressionKind::ResultIsVeto(inner) => {
543                        let inner_arc = Arc::clone(inner);
544                        let inner_idx = expr_pool.len();
545                        expr_pool.push(inner_arc);
546                        work_stack.push(WorkItem::BuildResultIsVeto(source_loc));
547                        work_stack.push(WorkItem::Process(inner_idx));
548                    }
549                    ExpressionKind::Literal(lit) => {
550                        result_pool.push(Expression::with_source(
551                            ExpressionKind::Literal(lit.clone()),
552                            source_loc,
553                        ));
554                    }
555                    ExpressionKind::DataPath(data_path) => {
556                        result_pool.push(Expression::with_source(
557                            ExpressionKind::DataPath(data_path.clone()),
558                            source_loc,
559                        ));
560                    }
561                    ExpressionKind::Veto(veto) => {
562                        result_pool.push(Expression::with_source(
563                            ExpressionKind::Veto(veto.clone()),
564                            source_loc,
565                        ));
566                    }
567                    ExpressionKind::Now => {
568                        result_pool.push(Expression::with_source(ExpressionKind::Now, source_loc));
569                    }
570                    ExpressionKind::DateRelative(..) | ExpressionKind::DateCalendar(..) => {
571                        result_pool.push(Expression::with_source(e.kind.clone(), source_loc));
572                    }
573                }
574            }
575            WorkItem::BuildArithmetic(op, source_loc) => {
576                let right = result_pool.pop().unwrap_or_else(|| {
577                    unreachable!(
578                        "BUG: missing right expression for Arithmetic during inversion hydration"
579                    )
580                });
581                let left = result_pool.pop().unwrap_or_else(|| {
582                    unreachable!(
583                        "BUG: missing left expression for Arithmetic during inversion hydration"
584                    )
585                });
586                result_pool.push(Expression::with_source(
587                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
588                    source_loc,
589                ));
590            }
591            WorkItem::BuildComparison(op, source_loc) => {
592                let right = result_pool.pop().unwrap_or_else(|| {
593                    unreachable!(
594                        "BUG: missing right expression for Comparison during inversion hydration"
595                    )
596                });
597                let left = result_pool.pop().unwrap_or_else(|| {
598                    unreachable!(
599                        "BUG: missing left expression for Comparison during inversion hydration"
600                    )
601                });
602                result_pool.push(Expression::with_source(
603                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
604                    source_loc,
605                ));
606            }
607            WorkItem::BuildLogicalAnd(source_loc) => {
608                let right = result_pool.pop().unwrap_or_else(|| {
609                    unreachable!(
610                        "BUG: missing right expression for LogicalAnd during inversion hydration"
611                    )
612                });
613                let left = result_pool.pop().unwrap_or_else(|| {
614                    unreachable!(
615                        "BUG: missing left expression for LogicalAnd during inversion hydration"
616                    )
617                });
618                result_pool.push(Expression::with_source(
619                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
620                    source_loc,
621                ));
622            }
623            WorkItem::BuildLogicalOr(source_loc) => {
624                let right = result_pool.pop().unwrap_or_else(|| {
625                    unreachable!(
626                        "BUG: missing right expression for LogicalOr during inversion hydration"
627                    )
628                });
629                let left = result_pool.pop().unwrap_or_else(|| {
630                    unreachable!(
631                        "BUG: missing left expression for LogicalOr during inversion hydration"
632                    )
633                });
634                result_pool.push(Expression::with_source(
635                    ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
636                    source_loc,
637                ));
638            }
639            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
640                let inner = result_pool
641                    .pop()
642                    .expect("Internal error: missing expression for LogicalNegation");
643                result_pool.push(Expression::with_source(
644                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
645                    source_loc,
646                ));
647            }
648            WorkItem::BuildResultIsVeto(source_loc) => {
649                let inner = result_pool.pop().unwrap_or_else(|| {
650                    unreachable!(
651                        "BUG: missing operand expression for ResultIsVeto during inversion substitution"
652                    )
653                });
654                result_pool.push(Expression::with_source(
655                    ExpressionKind::ResultIsVeto(Arc::new(inner)),
656                    source_loc,
657                ));
658            }
659            WorkItem::BuildUnitConversion(unit, source_loc) => {
660                let inner = result_pool
661                    .pop()
662                    .expect("Internal error: missing expression for UnitConversion");
663                result_pool.push(Expression::with_source(
664                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
665                    source_loc,
666                ));
667            }
668            WorkItem::BuildMathematicalComputation(func, source_loc) => {
669                let inner = result_pool
670                    .pop()
671                    .expect("Internal error: missing expression for MathematicalComputation");
672                result_pool.push(Expression::with_source(
673                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
674                    source_loc,
675                ));
676            }
677            WorkItem::PopVisitedRules => {
678                visited_rules_stack.pop();
679            }
680        }
681    }
682
683    Ok(result_pool
684        .pop()
685        .unwrap_or_else(|| unreachable!("BUG: no result from substitution")))
686}
687
688// ============================================================================
689// Data hydration
690// ============================================================================
691
692/// Hydrate data references in an expression with their known values
693///
694/// For each DataPath in the expression, if the data is in provided_data,
695/// replaces the DataPath with a Literal containing the data's value.
696fn hydrate_data_in_expression(
697    expr: &Arc<Expression>,
698    plan: &ExecutionPlan,
699    provided_data: &HashSet<DataPath>,
700) -> Result<Expression, crate::Error> {
701    enum WorkItem {
702        Process(usize),
703        BuildArithmetic(ArithmeticComputation, Option<Source>),
704        BuildComparison(ComparisonComputation, Option<Source>),
705        BuildLogicalAnd(Option<Source>),
706        BuildLogicalOr(Option<Source>),
707        BuildLogicalNegation(NegationType, Option<Source>),
708        BuildResultIsVeto(Option<Source>),
709        BuildUnitConversion(SemanticConversionTarget, Option<Source>),
710        BuildMathematicalComputation(MathematicalComputation, Option<Source>),
711    }
712
713    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
714    let mut work_stack: Vec<WorkItem> = Vec::new();
715    let mut result_pool: Vec<Expression> = Vec::new();
716
717    let root_idx = expr_pool.len();
718    expr_pool.push(Arc::clone(expr));
719    work_stack.push(WorkItem::Process(root_idx));
720
721    while let Some(work) = work_stack.pop() {
722        match work {
723            WorkItem::Process(expr_idx) => {
724                let (source_loc, expr_kind_ref) = {
725                    let e = &expr_pool[expr_idx];
726                    (e.source_location.clone(), &e.kind)
727                };
728
729                match expr_kind_ref {
730                    ExpressionKind::DataPath(data_path) => {
731                        if provided_data.contains(data_path) {
732                            if let Some(lit) = plan.data.get(data_path).and_then(|d| d.value()) {
733                                result_pool.push(Expression::with_source(
734                                    ExpressionKind::Literal(Box::new(lit.clone())),
735                                    source_loc,
736                                ));
737                                continue;
738                            }
739                        }
740                        result_pool.push(Expression::with_source(
741                            ExpressionKind::DataPath(data_path.clone()),
742                            source_loc,
743                        ));
744                    }
745                    ExpressionKind::Arithmetic(left, op, right) => {
746                        let op_clone = op.clone();
747                        let left_arc = Arc::clone(left);
748                        let right_arc = Arc::clone(right);
749
750                        let left_idx = expr_pool.len();
751                        expr_pool.push(left_arc);
752                        let right_idx = expr_pool.len();
753                        expr_pool.push(right_arc);
754
755                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
756                        work_stack.push(WorkItem::Process(right_idx));
757                        work_stack.push(WorkItem::Process(left_idx));
758                    }
759                    ExpressionKind::Comparison(left, op, right) => {
760                        let op_clone = op.clone();
761                        let left_arc = Arc::clone(left);
762                        let right_arc = Arc::clone(right);
763
764                        let left_idx = expr_pool.len();
765                        expr_pool.push(left_arc);
766                        let right_idx = expr_pool.len();
767                        expr_pool.push(right_arc);
768
769                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
770                        work_stack.push(WorkItem::Process(right_idx));
771                        work_stack.push(WorkItem::Process(left_idx));
772                    }
773                    ExpressionKind::LogicalAnd(left, right) => {
774                        let left_arc = Arc::clone(left);
775                        let right_arc = Arc::clone(right);
776
777                        let left_idx = expr_pool.len();
778                        expr_pool.push(left_arc);
779                        let right_idx = expr_pool.len();
780                        expr_pool.push(right_arc);
781
782                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
783                        work_stack.push(WorkItem::Process(right_idx));
784                        work_stack.push(WorkItem::Process(left_idx));
785                    }
786                    ExpressionKind::LogicalOr(left, right) => {
787                        let left_arc = Arc::clone(left);
788                        let right_arc = Arc::clone(right);
789
790                        let left_idx = expr_pool.len();
791                        expr_pool.push(left_arc);
792                        let right_idx = expr_pool.len();
793                        expr_pool.push(right_arc);
794
795                        work_stack.push(WorkItem::BuildLogicalOr(source_loc));
796                        work_stack.push(WorkItem::Process(right_idx));
797                        work_stack.push(WorkItem::Process(left_idx));
798                    }
799                    ExpressionKind::RangeLiteral(..)
800                    | ExpressionKind::PastFutureRange(..)
801                    | ExpressionKind::RangeContainment(..) => {
802                        result_pool
803                            .push(Expression::with_source(expr_kind_ref.clone(), source_loc));
804                    }
805                    ExpressionKind::LogicalNegation(inner, neg_type) => {
806                        let neg_type_clone = neg_type.clone();
807                        let inner_arc = Arc::clone(inner);
808                        let inner_idx = expr_pool.len();
809                        expr_pool.push(inner_arc);
810                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
811                        work_stack.push(WorkItem::Process(inner_idx));
812                    }
813                    ExpressionKind::UnitConversion(inner, unit) => {
814                        let unit_clone = unit.clone();
815                        let inner_arc = Arc::clone(inner);
816                        let inner_idx = expr_pool.len();
817                        expr_pool.push(inner_arc);
818                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
819                        work_stack.push(WorkItem::Process(inner_idx));
820                    }
821                    ExpressionKind::MathematicalComputation(func, inner) => {
822                        let func_clone = func.clone();
823                        let inner_arc = Arc::clone(inner);
824                        let inner_idx = expr_pool.len();
825                        expr_pool.push(inner_arc);
826                        work_stack.push(WorkItem::BuildMathematicalComputation(
827                            func_clone, source_loc,
828                        ));
829                        work_stack.push(WorkItem::Process(inner_idx));
830                    }
831                    ExpressionKind::ResultIsVeto(inner) => {
832                        let inner_arc = Arc::clone(inner);
833                        let inner_idx = expr_pool.len();
834                        expr_pool.push(inner_arc);
835                        work_stack.push(WorkItem::BuildResultIsVeto(source_loc));
836                        work_stack.push(WorkItem::Process(inner_idx));
837                    }
838                    ExpressionKind::Literal(lit) => {
839                        result_pool.push(Expression::with_source(
840                            ExpressionKind::Literal(lit.clone()),
841                            source_loc,
842                        ));
843                    }
844                    ExpressionKind::RulePath(rule_path) => {
845                        result_pool.push(Expression::with_source(
846                            ExpressionKind::RulePath(rule_path.clone()),
847                            source_loc,
848                        ));
849                    }
850                    ExpressionKind::Veto(veto) => {
851                        result_pool.push(Expression::with_source(
852                            ExpressionKind::Veto(veto.clone()),
853                            source_loc,
854                        ));
855                    }
856                    ExpressionKind::Now => {
857                        result_pool.push(Expression::with_source(ExpressionKind::Now, source_loc));
858                    }
859                    ExpressionKind::DateRelative(..) | ExpressionKind::DateCalendar(..) => {
860                        result_pool
861                            .push(Expression::with_source(expr_kind_ref.clone(), source_loc));
862                    }
863                }
864            }
865            WorkItem::BuildArithmetic(op, source_loc) => {
866                let right = result_pool.pop().unwrap_or_else(|| {
867                    unreachable!("BUG: missing right expression for Arithmetic")
868                });
869                let left = result_pool
870                    .pop()
871                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Arithmetic"));
872                result_pool.push(Expression::with_source(
873                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
874                    source_loc,
875                ));
876            }
877            WorkItem::BuildComparison(op, source_loc) => {
878                let right = result_pool.pop().unwrap_or_else(|| {
879                    unreachable!("BUG: missing right expression for Comparison")
880                });
881                let left = result_pool
882                    .pop()
883                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Comparison"));
884                result_pool.push(Expression::with_source(
885                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
886                    source_loc,
887                ));
888            }
889            WorkItem::BuildLogicalAnd(source_loc) => {
890                let right = result_pool.pop().unwrap_or_else(|| {
891                    unreachable!("BUG: missing right expression for LogicalAnd")
892                });
893                let left = result_pool
894                    .pop()
895                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalAnd"));
896                result_pool.push(Expression::with_source(
897                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
898                    source_loc,
899                ));
900            }
901            WorkItem::BuildLogicalOr(source_loc) => {
902                let right = result_pool
903                    .pop()
904                    .unwrap_or_else(|| unreachable!("BUG: missing right expression for LogicalOr"));
905                let left = result_pool
906                    .pop()
907                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalOr"));
908                result_pool.push(Expression::with_source(
909                    ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
910                    source_loc,
911                ));
912            }
913            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
914                let inner = result_pool
915                    .pop()
916                    .expect("Internal error: missing expression for LogicalNegation");
917                result_pool.push(Expression::with_source(
918                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
919                    source_loc,
920                ));
921            }
922            WorkItem::BuildResultIsVeto(source_loc) => {
923                let inner = result_pool.pop().unwrap_or_else(|| {
924                    unreachable!(
925                        "BUG: missing operand expression for ResultIsVeto during inversion hydration"
926                    )
927                });
928                result_pool.push(Expression::with_source(
929                    ExpressionKind::ResultIsVeto(Arc::new(inner)),
930                    source_loc,
931                ));
932            }
933            WorkItem::BuildUnitConversion(unit, source_loc) => {
934                let inner = result_pool
935                    .pop()
936                    .expect("Internal error: missing expression for UnitConversion");
937                result_pool.push(Expression::with_source(
938                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
939                    source_loc,
940                ));
941            }
942            WorkItem::BuildMathematicalComputation(func, source_loc) => {
943                let inner = result_pool
944                    .pop()
945                    .expect("Internal error: missing expression for MathematicalComputation");
946                result_pool.push(Expression::with_source(
947                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
948                    source_loc,
949                ));
950            }
951        }
952    }
953
954    Ok(result_pool
955        .pop()
956        .expect("Internal error: no result from hydration"))
957}
958
959// ============================================================================
960// Constant folding
961// ============================================================================
962
963/// Extract an outcome (value or veto) from an expression
964fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
965    match &expr.kind {
966        ExpressionKind::Literal(lit) => {
967            Some(OperationResult::Value(Box::new(lit.as_ref().clone())))
968        }
969        ExpressionKind::Veto(ve) => Some(OperationResult::Veto(
970            crate::evaluation::operations::VetoType::UserDefined {
971                message: ve.message.clone(),
972            },
973        )),
974        _ => None,
975    }
976}
977
978/// Check if an expression is a boolean-producing expression (comparison or logical)
979fn is_boolean_expression(expr: &Expression) -> bool {
980    matches!(
981        &expr.kind,
982        ExpressionKind::Comparison(_, _, _)
983            | ExpressionKind::LogicalAnd(_, _)
984            | ExpressionKind::LogicalOr(_, _)
985            | ExpressionKind::LogicalNegation(_, _)
986            | ExpressionKind::ResultIsVeto(_)
987    )
988}
989
990/// Check if an expression is an arithmetic expression (contains arithmetic operations)
991///
992/// Returns true for expressions like `price * 5`, `x + y`, etc.
993fn is_arithmetic_expression(expr: &Expression) -> bool {
994    match &expr.kind {
995        ExpressionKind::Arithmetic(_, _, _) => true,
996        ExpressionKind::MathematicalComputation(_, _) => true,
997        ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
998        ExpressionKind::DataPath(_) => true, // Lone data is also solvable
999        _ => false,
1000    }
1001}
1002
1003/// For boolean expressions that can't be evaluated to a literal (e.g., `age > 18`),
1004/// create two solutions: one where the expression is true, one where it's false.
1005///
1006/// This allows inversion to work with rules like `rule of_age: age > 18`
1007fn create_boolean_expression_solutions(
1008    world: World,
1009    base_constraint: Constraint,
1010    boolean_expr: &Expression,
1011) -> Result<(Vec<WorldSolution>, Vec<WorldSolution>), crate::Error> {
1012    // Convert boolean expression to constraint
1013    let expr_constraint = Constraint::from_expression(boolean_expr)?;
1014
1015    // Solution where the boolean expression is true
1016    let true_constraint = base_constraint.clone().and(expr_constraint.clone());
1017    let simplified_true = true_constraint.simplify()?;
1018
1019    let true_solutions = if !simplified_true.is_false() {
1020        vec![WorldSolution {
1021            world: world.clone(),
1022            constraint: simplified_true,
1023            outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(true))),
1024        }]
1025    } else {
1026        vec![]
1027    };
1028
1029    // Solution where the boolean expression is false
1030    let false_constraint = base_constraint.and(expr_constraint.not());
1031    let simplified_false = false_constraint.simplify()?;
1032
1033    let false_solutions = if !simplified_false.is_false() {
1034        vec![WorldSolution {
1035            world,
1036            constraint: simplified_false,
1037            outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(false))),
1038        }]
1039    } else {
1040        vec![]
1041    };
1042
1043    Ok((true_solutions, false_solutions))
1044}
1045
1046/// Attempt constant folding on an expression (simplified version for outcomes)
1047pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
1048    match &expr.kind {
1049        ExpressionKind::Literal(_) => Some(expr.clone()),
1050        ExpressionKind::Arithmetic(left, op, right) => {
1051            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1052            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1053            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1054                (&left_folded.kind, &right_folded.kind)
1055            {
1056                if let Some(result) = evaluate_arithmetic(left_val.as_ref(), op, right_val.as_ref())
1057                {
1058                    return Some(Expression::with_source(
1059                        ExpressionKind::Literal(Box::new(result)),
1060                        expr.source_location.clone(),
1061                    ));
1062                }
1063            }
1064            Some(Expression::with_source(
1065                ExpressionKind::Arithmetic(
1066                    Arc::new(left_folded),
1067                    op.clone(),
1068                    Arc::new(right_folded),
1069                ),
1070                expr.source_location.clone(),
1071            ))
1072        }
1073        ExpressionKind::Comparison(left, op, right) => {
1074            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1075            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1076            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1077                (&left_folded.kind, &right_folded.kind)
1078            {
1079                if let Some(result) = evaluate_comparison(left_val.as_ref(), op, right_val.as_ref())
1080                {
1081                    return Some(Expression::with_source(
1082                        ExpressionKind::Literal(Box::new(LiteralValue::from_bool(result))),
1083                        expr.source_location.clone(),
1084                    ));
1085                }
1086            }
1087            Some(Expression::with_source(
1088                ExpressionKind::Comparison(
1089                    Arc::new(left_folded),
1090                    op.clone(),
1091                    Arc::new(right_folded),
1092                ),
1093                expr.source_location.clone(),
1094            ))
1095        }
1096        _ => None,
1097    }
1098}
1099
1100/// Evaluate an arithmetic operation on two literals
1101///
1102/// Delegates to the computation module for consistent behavior
1103fn evaluate_arithmetic(
1104    left: &LiteralValue,
1105    op: &ArithmeticComputation,
1106    right: &LiteralValue,
1107) -> Option<LiteralValue> {
1108    use crate::computation::arithmetic_operation;
1109
1110    match arithmetic_operation(left, op, right) {
1111        OperationResult::Value(lit) => Some(lit.as_ref().clone()),
1112        OperationResult::Veto(_) => None,
1113    }
1114}
1115
1116/// Evaluate a comparison operation on two literals
1117///
1118/// Delegates to the computation module for consistent behavior
1119fn evaluate_comparison(
1120    left: &LiteralValue,
1121    op: &ComparisonComputation,
1122    right: &LiteralValue,
1123) -> Option<bool> {
1124    use crate::computation::comparison_operation;
1125    use crate::planning::semantics::ValueKind;
1126
1127    match comparison_operation(
1128        left,
1129        op,
1130        right,
1131        crate::computation::UnitResolutionContext::NamedQuantityOnly,
1132    ) {
1133        OperationResult::Value(lit) => match &lit.value {
1134            ValueKind::Boolean(b) => Some(*b),
1135            _ => None,
1136        },
1137        _ => None,
1138    }
1139}
1140
1141// ============================================================================
1142// Tests
1143// ============================================================================
1144
1145#[cfg(test)]
1146mod tests {
1147    use super::*;
1148    use crate::computation::rational::RationalInteger;
1149    use crate::planning::semantics::ValueKind;
1150
1151    fn literal_expr(val: LiteralValue) -> Expression {
1152        Expression::with_source(ExpressionKind::Literal(Box::new(val)), None)
1153    }
1154
1155    fn num(n: i64) -> LiteralValue {
1156        LiteralValue::number(RationalInteger::new(n as i128, 1))
1157    }
1158
1159    #[test]
1160    fn test_world_new() {
1161        let world = World::new();
1162        assert!(world.0.is_empty());
1163    }
1164
1165    #[test]
1166    fn test_world_insert_and_get() {
1167        let mut world = World::new();
1168        let rule_path = RulePath {
1169            segments: vec![],
1170            rule: "test_rule".to_string(),
1171        };
1172        world.insert(rule_path.clone(), 2);
1173        assert_eq!(world.get(&rule_path), Some(&2));
1174    }
1175
1176    #[test]
1177    fn test_constant_fold_arithmetic() {
1178        let left = literal_expr(num(10));
1179        let right = literal_expr(num(5));
1180        let expr = Expression::with_source(
1181            ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
1182            None,
1183        );
1184
1185        let folded = try_constant_fold_expression(&expr).unwrap();
1186
1187        if let ExpressionKind::Literal(lit) = &folded.kind {
1188            if let ValueKind::Number(n) = &lit.value {
1189                assert_eq!(*n, RationalInteger::new(15, 1));
1190            } else {
1191                panic!("Expected literal number");
1192            }
1193        } else {
1194            panic!("Expected literal number");
1195        }
1196    }
1197
1198    #[test]
1199    fn test_constant_fold_comparison() {
1200        let left = literal_expr(num(10));
1201        let right = literal_expr(num(5));
1202        let expr = Expression::with_source(
1203            ExpressionKind::Comparison(
1204                Arc::new(left),
1205                ComparisonComputation::GreaterThan,
1206                Arc::new(right),
1207            ),
1208            None,
1209        );
1210
1211        let folded = try_constant_fold_expression(&expr).unwrap();
1212
1213        if let ExpressionKind::Literal(lit) = &folded.kind {
1214            if let ValueKind::Boolean(b) = &lit.value {
1215                assert!(*b);
1216            } else {
1217                panic!("Expected literal boolean");
1218            }
1219        } else {
1220            panic!("Expected literal boolean");
1221        }
1222    }
1223}