Skip to main content

lemma/inversion/
world.rs

1//! World enumeration for inversion
2//!
3//! A "world" is a complete assignment of which branch is active for each rule.
4//! This module enumerates all valid worlds for a target rule.
5//!
6//! Also includes expression substitution and hydration utilities.
7
8use crate::planning::{ExecutableRule, ExecutionPlan};
9use crate::{
10    ArithmeticComputation, BooleanValue, ComparisonComputation, ConversionTarget, Expression,
11    ExpressionKind, FactPath, LemmaResult, LiteralValue, MathematicalComputation, NegationType,
12    OperationResult, RulePath, Value,
13};
14use serde::ser::{Serialize, SerializeMap, Serializer};
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20/// A world assigns each rule to one of its branch indices
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct World(HashMap<RulePath, usize>);
23
24impl World {
25    /// Create a new empty world
26    pub fn new() -> Self {
27        Self(HashMap::new())
28    }
29
30    /// Get the branch index for a rule
31    pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
32        self.0.get(rule_path)
33    }
34
35    /// Insert a branch assignment for a rule
36    pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
37        self.0.insert(rule_path, branch_idx)
38    }
39
40    /// Iterate over all branch assignments
41    pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
42        self.0.iter()
43    }
44}
45
46impl Serialize for World {
47    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
48    where
49        S: Serializer,
50    {
51        let mut map = serializer.serialize_map(Some(self.0.len()))?;
52        for (k, v) in &self.0 {
53            map.serialize_entry(&k.to_string(), v)?;
54        }
55        map.end()
56    }
57}
58
59/// A solution from world enumeration with a resolved outcome
60#[derive(Debug, Clone)]
61pub struct WorldSolution {
62    /// The world (branch assignment) that produced this solution
63    pub world: World,
64    /// The constraint under which this solution applies (facts only, no rule references)
65    pub constraint: Constraint,
66    /// The outcome (value or veto)
67    pub outcome: OperationResult,
68}
69
70/// A solution from world enumeration with an arithmetic expression outcome
71///
72/// This represents cases where the outcome is a computed expression (like `price * 5`)
73/// that couldn't be evaluated to a literal because it contains unknown facts.
74/// These need algebraic solving to determine the input values.
75#[derive(Debug, Clone)]
76pub struct WorldArithmeticSolution {
77    /// The world (branch assignment) that produced this solution
78    pub world: World,
79    /// The constraint under which this solution applies (facts only, no rule references)
80    pub constraint: Constraint,
81    /// The outcome expression (contains unknown facts)
82    pub outcome_expression: Expression,
83}
84
85/// Result of world enumeration containing both literal and arithmetic solutions
86#[derive(Debug, Clone)]
87pub struct EnumerationResult {
88    /// Solutions with literal outcomes (can be directly compared to target)
89    pub literal_solutions: Vec<WorldSolution>,
90    /// Solutions with arithmetic outcomes (need algebraic solving)
91    pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
92}
93
94/// Enumerates valid worlds for a target rule
95pub struct WorldEnumerator<'a> {
96    plan: &'a ExecutionPlan,
97    /// Rules to process, in topological order (dependencies first)
98    rules_in_order: Vec<RulePath>,
99    /// Cache: rule path -> executable rule (for quick lookup)
100    rule_cache: HashMap<RulePath, &'a ExecutableRule>,
101}
102
103impl<'a> WorldEnumerator<'a> {
104    /// Create a new world enumerator for a target rule
105    pub fn new(plan: &'a ExecutionPlan, target_rule: &RulePath) -> LemmaResult<Self> {
106        // Build rule lookup from execution plan
107        let rule_map: HashMap<RulePath, &ExecutableRule> =
108            plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
109
110        // Find all rules that the target rule depends on (transitively)
111        let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
112
113        // Plan rules are already in topological order, so filter and preserve order
114        let rules_in_order: Vec<RulePath> = plan
115            .rules
116            .iter()
117            .filter(|r| dependent_rules.contains(&r.path))
118            .map(|r| r.path.clone())
119            .collect();
120
121        // Build rule cache for quick lookup (only rules we need)
122        let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
123            .iter()
124            .filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
125            .collect();
126
127        Ok(Self {
128            plan,
129            rules_in_order,
130            rule_cache,
131        })
132    }
133
134    /// Enumerate all valid worlds for the target rule
135    ///
136    /// Returns an `EnumerationResult` containing:
137    /// - `literal_solutions`: Worlds where the outcome is a concrete literal value
138    /// - `arithmetic_solutions`: Worlds where the outcome is an arithmetic expression
139    ///   containing unknown facts (needs algebraic solving)
140    pub fn enumerate(
141        &mut self,
142        provided_facts: &HashSet<FactPath>,
143    ) -> LemmaResult<EnumerationResult> {
144        if self.rules_in_order.is_empty() {
145            return Ok(EnumerationResult {
146                literal_solutions: vec![],
147                arithmetic_solutions: vec![],
148            });
149        }
150
151        // Start with a single empty world and true constraint
152        let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
153
154        // Process each rule in topological order
155        for rule_path in &self.rules_in_order.clone() {
156            let rule_node = match self.rule_cache.get(rule_path) {
157                Some(node) => *node,
158                None => continue,
159            };
160
161            let mut next_worlds = Vec::new();
162
163            for (world, accumulated_constraint) in current_worlds {
164                // For each branch in this rule
165                for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
166                    // Create new world with this branch assignment
167                    let mut new_world = world.clone();
168                    new_world.insert(rule_path.clone(), branch_idx);
169
170                    // Get branch constraint
171                    // For "last wins" semantics: all LATER branches must have FALSE conditions
172                    let mut branch_constraint = if let Some(ref condition) = branch.condition {
173                        // This branch's condition must be TRUE
174                        let substituted_condition = substitute_rules_in_expression(
175                            &Arc::new(condition.clone()),
176                            &new_world,
177                            self.plan,
178                        )?;
179                        let hydrated_condition = hydrate_facts_in_expression(
180                            &Arc::new(substituted_condition),
181                            self.plan,
182                            provided_facts,
183                        )?;
184                        Constraint::from_expression(&hydrated_condition)?
185                    } else {
186                        // Default branch has no explicit condition
187                        Constraint::True
188                    };
189
190                    // For "last wins": all LATER branches must NOT match
191                    // (their conditions must be FALSE)
192                    for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
193                        if let Some(ref later_condition) = later_branch.condition {
194                            let substituted_later = substitute_rules_in_expression(
195                                &Arc::new(later_condition.clone()),
196                                &new_world,
197                                self.plan,
198                            )?;
199                            let hydrated_later = hydrate_facts_in_expression(
200                                &Arc::new(substituted_later),
201                                self.plan,
202                                provided_facts,
203                            )?;
204                            let later_constraint = Constraint::from_expression(&hydrated_later)?;
205                            // Later branch's condition must be FALSE
206                            branch_constraint = branch_constraint.and(later_constraint.not());
207                        }
208                    }
209
210                    // Combine with accumulated constraint
211                    let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
212
213                    // Simplify and check if satisfiable
214                    let simplified = combined_constraint.simplify()?;
215
216                    // Only keep if not contradictory
217                    if !simplified.is_false() {
218                        next_worlds.push((new_world, simplified));
219                    }
220                }
221            }
222
223            current_worlds = next_worlds;
224
225            // Early exit if no valid worlds remain
226            if current_worlds.is_empty() {
227                break;
228            }
229        }
230
231        // Convert to WorldSolutions and WorldArithmeticSolutions
232        let target_rule_path = self
233            .rules_in_order
234            .last()
235            .unwrap_or_else(|| unreachable!("BUG: no rules in order for world enumeration"));
236
237        let mut literal_solutions = Vec::new();
238        let mut arithmetic_solutions = Vec::new();
239
240        for (world, constraint) in current_worlds {
241            // Get the outcome from the target rule's branch
242            if let Some(&branch_idx) = world.get(target_rule_path) {
243                if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
244                    if branch_idx < rule_node.branches.len() {
245                        let branch = &rule_node.branches[branch_idx];
246
247                        // Substitute and hydrate the result expression
248                        let substituted_result = substitute_rules_in_expression(
249                            &Arc::new(branch.result.clone()),
250                            &world,
251                            self.plan,
252                        )?;
253
254                        let hydrated_result = hydrate_facts_in_expression(
255                            &Arc::new(substituted_result),
256                            self.plan,
257                            provided_facts,
258                        )?;
259
260                        // Try to fold the result to a literal
261                        let folded_result = try_constant_fold_expression(&hydrated_result)
262                            .unwrap_or(hydrated_result.clone());
263
264                        // Try to extract a literal value directly
265                        if let Some(outcome) = extract_outcome(&folded_result) {
266                            literal_solutions.push(WorldSolution {
267                                world,
268                                constraint,
269                                outcome,
270                            });
271                        } else if is_boolean_expression(&folded_result) {
272                            // For boolean expressions (comparisons, logical ops), create two solutions:
273                            // one where the expression is true, one where it's false
274                            let (true_solutions, false_solutions) =
275                                create_boolean_expression_solutions(
276                                    world,
277                                    constraint,
278                                    &folded_result,
279                                )?;
280                            literal_solutions.extend(true_solutions);
281                            literal_solutions.extend(false_solutions);
282                        } else if is_arithmetic_expression(&folded_result) {
283                            // Arithmetic expression with unknown facts - needs algebraic solving
284                            arithmetic_solutions.push(WorldArithmeticSolution {
285                                world,
286                                constraint,
287                                outcome_expression: folded_result,
288                            });
289                        }
290                        // Other expression types (rule references, etc.) are silently skipped
291                        // as they indicate incomplete substitution
292                    }
293                }
294            }
295        }
296
297        Ok(EnumerationResult {
298            literal_solutions,
299            arithmetic_solutions,
300        })
301    }
302}
303
304// ============================================================================
305// Dependency and topological sorting
306// ============================================================================
307
308/// Collect all rules that a target rule depends on (transitively)
309fn collect_transitive_dependencies(
310    target_rule: &RulePath,
311    rule_map: &HashMap<RulePath, &ExecutableRule>,
312) -> LemmaResult<HashSet<RulePath>> {
313    let mut result = HashSet::new();
314    let mut queue = VecDeque::new();
315
316    queue.push_back(target_rule.clone());
317    result.insert(target_rule.clone());
318
319    while let Some(rule_path) = queue.pop_front() {
320        if let Some(rule) = rule_map.get(&rule_path) {
321            // Extract rule dependencies from branch expressions
322            let dependencies = extract_rule_dependencies(rule);
323            for dependency in dependencies {
324                if result.insert(dependency.clone()) {
325                    queue.push_back(dependency);
326                }
327            }
328        }
329    }
330
331    Ok(result)
332}
333
334/// Extract rule paths referenced in an executable rule's expressions
335fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
336    let mut deps = HashSet::new();
337    for branch in &rule.branches {
338        if let Some(ref condition) = branch.condition {
339            extract_rule_paths_from_expression(condition, &mut deps);
340        }
341        extract_rule_paths_from_expression(&branch.result, &mut deps);
342    }
343    deps
344}
345
346/// Recursively extract RulePath references from an expression
347fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
348    match &expr.kind {
349        ExpressionKind::RulePath(rp) => {
350            paths.insert(rp.clone());
351        }
352        ExpressionKind::Arithmetic(left, _, right)
353        | ExpressionKind::Comparison(left, _, right)
354        | ExpressionKind::LogicalAnd(left, right)
355        | ExpressionKind::LogicalOr(left, right) => {
356            extract_rule_paths_from_expression(left, paths);
357            extract_rule_paths_from_expression(right, paths);
358        }
359        ExpressionKind::LogicalNegation(inner, _)
360        | ExpressionKind::UnitConversion(inner, _)
361        | ExpressionKind::MathematicalComputation(_, inner) => {
362            extract_rule_paths_from_expression(inner, paths);
363        }
364        ExpressionKind::Literal(_)
365        | ExpressionKind::FactPath(_)
366        | ExpressionKind::Veto(_)
367        | ExpressionKind::Reference(_)
368        | ExpressionKind::UnresolvedUnitLiteral(_, _)
369        | ExpressionKind::FactReference(_)
370        | ExpressionKind::RuleReference(_) => {}
371    }
372}
373
374// ============================================================================
375// Expression substitution
376// ============================================================================
377
378/// Substitute rule references in an expression with their values in a given world
379///
380/// For each RulePath in the expression, looks up which branch is active in the world
381/// and replaces the RulePath with the branch's result expression.
382fn substitute_rules_in_expression(
383    expr: &Arc<Expression>,
384    world: &World,
385    plan: &ExecutionPlan,
386) -> LemmaResult<Expression> {
387    enum WorkItem {
388        Process(usize),
389        BuildArithmetic(ArithmeticComputation, Option<crate::Source>),
390        BuildComparison(ComparisonComputation, Option<crate::Source>),
391        BuildLogicalAnd(Option<crate::Source>),
392        BuildLogicalOr(Option<crate::Source>),
393        BuildLogicalNegation(NegationType, Option<crate::Source>),
394        BuildUnitConversion(ConversionTarget, Option<crate::Source>),
395        BuildMathematicalComputation(MathematicalComputation, Option<crate::Source>),
396        PopVisitedRules,
397    }
398
399    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
400    let mut work_stack: Vec<WorkItem> = Vec::new();
401    let mut result_pool: Vec<Expression> = Vec::new();
402    let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
403
404    let root_idx = expr_pool.len();
405    expr_pool.push(Arc::clone(expr));
406    work_stack.push(WorkItem::Process(root_idx));
407
408    while let Some(work) = work_stack.pop() {
409        match work {
410            WorkItem::Process(expr_idx) => {
411                let e = &expr_pool[expr_idx];
412                let source_loc = e.source_location.clone();
413
414                match &e.kind {
415                    ExpressionKind::RulePath(rule_path) => {
416                        let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
417                        if visited.contains(rule_path) {
418                            unreachable!(
419                                "BUG: circular rule reference detected during substitution: {}",
420                                rule_path
421                            );
422                        }
423
424                        if let Some(&branch_idx) = world.get(rule_path) {
425                            if let Some(rule) = plan.get_rule_by_path(rule_path) {
426                                if branch_idx < rule.branches.len() {
427                                    let branch = &rule.branches[branch_idx];
428                                    let mut new_visited = visited.clone();
429                                    new_visited.insert(rule_path.clone());
430                                    visited_rules_stack.push(new_visited);
431
432                                    let sub_expr_idx = expr_pool.len();
433                                    expr_pool.push(Arc::new(branch.result.clone()));
434                                    work_stack.push(WorkItem::PopVisitedRules);
435                                    work_stack.push(WorkItem::Process(sub_expr_idx));
436                                    continue;
437                                }
438                            }
439                        }
440                        result_pool.push(Expression::new(
441                            ExpressionKind::RulePath(rule_path.clone()),
442                            source_loc,
443                        ));
444                    }
445                    ExpressionKind::RuleReference(_) => {
446                        unreachable!(
447                            "BUG: RuleReference found during substitution (should be RulePath)"
448                        );
449                    }
450                    ExpressionKind::FactReference(_) => {
451                        unreachable!(
452                            "BUG: FactReference found during substitution (should be FactPath)"
453                        );
454                    }
455                    ExpressionKind::Reference(_) => {
456                        unreachable!(
457                            "BUG: unresolved Reference found during substitution (should be resolved during planning)"
458                        );
459                    }
460                    ExpressionKind::UnresolvedUnitLiteral(_, _) => {
461                        unreachable!(
462                            "UnresolvedUnitLiteral found during substitution - this is a bug: unresolved units should be resolved during planning"
463                        );
464                    }
465                    ExpressionKind::Arithmetic(left, op, right) => {
466                        let op_clone = op.clone();
467                        let left_arc = Arc::clone(left);
468                        let right_arc = Arc::clone(right);
469
470                        let left_idx = expr_pool.len();
471                        expr_pool.push(left_arc);
472                        let right_idx = expr_pool.len();
473                        expr_pool.push(right_arc);
474
475                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
476                        work_stack.push(WorkItem::Process(right_idx));
477                        work_stack.push(WorkItem::Process(left_idx));
478                    }
479                    ExpressionKind::Comparison(left, op, right) => {
480                        let op_clone = op.clone();
481                        let left_arc = Arc::clone(left);
482                        let right_arc = Arc::clone(right);
483
484                        let left_idx = expr_pool.len();
485                        expr_pool.push(left_arc);
486                        let right_idx = expr_pool.len();
487                        expr_pool.push(right_arc);
488
489                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
490                        work_stack.push(WorkItem::Process(right_idx));
491                        work_stack.push(WorkItem::Process(left_idx));
492                    }
493                    ExpressionKind::LogicalAnd(left, right) => {
494                        let left_arc = Arc::clone(left);
495                        let right_arc = Arc::clone(right);
496
497                        let left_idx = expr_pool.len();
498                        expr_pool.push(left_arc);
499                        let right_idx = expr_pool.len();
500                        expr_pool.push(right_arc);
501
502                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
503                        work_stack.push(WorkItem::Process(right_idx));
504                        work_stack.push(WorkItem::Process(left_idx));
505                    }
506                    ExpressionKind::LogicalOr(left, right) => {
507                        let left_arc = Arc::clone(left);
508                        let right_arc = Arc::clone(right);
509
510                        let left_idx = expr_pool.len();
511                        expr_pool.push(left_arc);
512                        let right_idx = expr_pool.len();
513                        expr_pool.push(right_arc);
514
515                        work_stack.push(WorkItem::BuildLogicalOr(source_loc));
516                        work_stack.push(WorkItem::Process(right_idx));
517                        work_stack.push(WorkItem::Process(left_idx));
518                    }
519                    ExpressionKind::LogicalNegation(inner, neg_type) => {
520                        let neg_type_clone = neg_type.clone();
521                        let inner_arc = Arc::clone(inner);
522                        let inner_idx = expr_pool.len();
523                        expr_pool.push(inner_arc);
524                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
525                        work_stack.push(WorkItem::Process(inner_idx));
526                    }
527                    ExpressionKind::UnitConversion(inner, unit) => {
528                        let unit_clone = unit.clone();
529                        let inner_arc = Arc::clone(inner);
530                        let inner_idx = expr_pool.len();
531                        expr_pool.push(inner_arc);
532                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
533                        work_stack.push(WorkItem::Process(inner_idx));
534                    }
535                    ExpressionKind::MathematicalComputation(func, inner) => {
536                        let func_clone = func.clone();
537                        let inner_arc = Arc::clone(inner);
538                        let inner_idx = expr_pool.len();
539                        expr_pool.push(inner_arc);
540                        work_stack.push(WorkItem::BuildMathematicalComputation(
541                            func_clone, source_loc,
542                        ));
543                        work_stack.push(WorkItem::Process(inner_idx));
544                    }
545                    ExpressionKind::Literal(lit) => {
546                        result_pool.push(Expression::new(
547                            ExpressionKind::Literal(lit.clone()),
548                            source_loc,
549                        ));
550                    }
551                    ExpressionKind::FactPath(fact_path) => {
552                        result_pool.push(Expression::new(
553                            ExpressionKind::FactPath(fact_path.clone()),
554                            source_loc,
555                        ));
556                    }
557                    ExpressionKind::Veto(veto) => {
558                        result_pool.push(Expression::new(
559                            ExpressionKind::Veto(veto.clone()),
560                            source_loc,
561                        ));
562                    }
563                }
564            }
565            WorkItem::BuildArithmetic(op, source_loc) => {
566                let right = result_pool.pop().unwrap_or_else(|| {
567                    unreachable!(
568                        "BUG: missing right expression for Arithmetic during inversion hydration"
569                    )
570                });
571                let left = result_pool.pop().unwrap_or_else(|| {
572                    unreachable!(
573                        "BUG: missing left expression for Arithmetic during inversion hydration"
574                    )
575                });
576                result_pool.push(Expression::new(
577                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
578                    source_loc,
579                ));
580            }
581            WorkItem::BuildComparison(op, source_loc) => {
582                let right = result_pool.pop().unwrap_or_else(|| {
583                    unreachable!(
584                        "BUG: missing right expression for Comparison during inversion hydration"
585                    )
586                });
587                let left = result_pool.pop().unwrap_or_else(|| {
588                    unreachable!(
589                        "BUG: missing left expression for Comparison during inversion hydration"
590                    )
591                });
592                result_pool.push(Expression::new(
593                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
594                    source_loc,
595                ));
596            }
597            WorkItem::BuildLogicalAnd(source_loc) => {
598                let right = result_pool.pop().unwrap_or_else(|| {
599                    unreachable!(
600                        "BUG: missing right expression for LogicalAnd during inversion hydration"
601                    )
602                });
603                let left = result_pool.pop().unwrap_or_else(|| {
604                    unreachable!(
605                        "BUG: missing left expression for LogicalAnd during inversion hydration"
606                    )
607                });
608                result_pool.push(Expression::new(
609                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
610                    source_loc,
611                ));
612            }
613            WorkItem::BuildLogicalOr(source_loc) => {
614                let right = result_pool.pop().unwrap_or_else(|| {
615                    unreachable!(
616                        "BUG: missing right expression for LogicalOr during inversion hydration"
617                    )
618                });
619                let left = result_pool.pop().unwrap_or_else(|| {
620                    unreachable!(
621                        "BUG: missing left expression for LogicalOr during inversion hydration"
622                    )
623                });
624                result_pool.push(Expression::new(
625                    ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
626                    source_loc,
627                ));
628            }
629            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
630                let inner = result_pool
631                    .pop()
632                    .expect("Internal error: missing expression for LogicalNegation");
633                result_pool.push(Expression::new(
634                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
635                    source_loc,
636                ));
637            }
638            WorkItem::BuildUnitConversion(unit, source_loc) => {
639                let inner = result_pool
640                    .pop()
641                    .expect("Internal error: missing expression for UnitConversion");
642                result_pool.push(Expression::new(
643                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
644                    source_loc,
645                ));
646            }
647            WorkItem::BuildMathematicalComputation(func, source_loc) => {
648                let inner = result_pool
649                    .pop()
650                    .expect("Internal error: missing expression for MathematicalComputation");
651                result_pool.push(Expression::new(
652                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
653                    source_loc,
654                ));
655            }
656            WorkItem::PopVisitedRules => {
657                visited_rules_stack.pop();
658            }
659        }
660    }
661
662    Ok(result_pool
663        .pop()
664        .unwrap_or_else(|| unreachable!("BUG: no result from substitution")))
665}
666
667// ============================================================================
668// Fact hydration
669// ============================================================================
670
671/// Hydrate fact references in an expression with their known values
672///
673/// For each FactPath in the expression, if the fact is in provided_facts,
674/// replaces the FactPath with a Literal containing the fact's value.
675fn hydrate_facts_in_expression(
676    expr: &Arc<Expression>,
677    plan: &ExecutionPlan,
678    provided_facts: &HashSet<FactPath>,
679) -> LemmaResult<Expression> {
680    enum WorkItem {
681        Process(usize),
682        BuildArithmetic(ArithmeticComputation, Option<crate::Source>),
683        BuildComparison(ComparisonComputation, Option<crate::Source>),
684        BuildLogicalAnd(Option<crate::Source>),
685        BuildLogicalOr(Option<crate::Source>),
686        BuildLogicalNegation(NegationType, Option<crate::Source>),
687        BuildUnitConversion(ConversionTarget, Option<crate::Source>),
688        BuildMathematicalComputation(MathematicalComputation, Option<crate::Source>),
689    }
690
691    let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
692    let mut work_stack: Vec<WorkItem> = Vec::new();
693    let mut result_pool: Vec<Expression> = Vec::new();
694
695    let root_idx = expr_pool.len();
696    expr_pool.push(Arc::clone(expr));
697    work_stack.push(WorkItem::Process(root_idx));
698
699    while let Some(work) = work_stack.pop() {
700        match work {
701            WorkItem::Process(expr_idx) => {
702                let (source_loc, expr_kind_ref) = {
703                    let e = &expr_pool[expr_idx];
704                    (e.source_location.clone(), &e.kind)
705                };
706
707                match expr_kind_ref {
708                    ExpressionKind::FactPath(fact_path) => {
709                        if provided_facts.contains(fact_path) {
710                            if let Some(lit) = plan.fact_values.get(fact_path) {
711                                result_pool.push(Expression::new(
712                                    ExpressionKind::Literal(lit.clone()),
713                                    source_loc,
714                                ));
715                                continue;
716                            }
717                        }
718                        result_pool.push(Expression::new(
719                            ExpressionKind::FactPath(fact_path.clone()),
720                            source_loc,
721                        ));
722                    }
723                    ExpressionKind::FactReference(_) => {
724                        unreachable!(
725                            "BUG: FactReference found during hydration (should be FactPath)"
726                        );
727                    }
728                    ExpressionKind::RuleReference(_) => {
729                        unreachable!(
730                            "BUG: RuleReference found during hydration (should be RulePath)"
731                        );
732                    }
733                    ExpressionKind::Reference(_) => {
734                        unreachable!(
735                            "BUG: unresolved Reference found during hydration (should be resolved during planning)"
736                        );
737                    }
738                    ExpressionKind::UnresolvedUnitLiteral(_, _) => {
739                        unreachable!(
740                            "UnresolvedUnitLiteral found during hydration - this is a bug: unresolved units should be resolved during planning"
741                        );
742                    }
743                    ExpressionKind::Arithmetic(left, op, right) => {
744                        let op_clone = op.clone();
745                        let left_arc = Arc::clone(left);
746                        let right_arc = Arc::clone(right);
747
748                        let left_idx = expr_pool.len();
749                        expr_pool.push(left_arc);
750                        let right_idx = expr_pool.len();
751                        expr_pool.push(right_arc);
752
753                        work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
754                        work_stack.push(WorkItem::Process(right_idx));
755                        work_stack.push(WorkItem::Process(left_idx));
756                    }
757                    ExpressionKind::Comparison(left, op, right) => {
758                        let op_clone = op.clone();
759                        let left_arc = Arc::clone(left);
760                        let right_arc = Arc::clone(right);
761
762                        let left_idx = expr_pool.len();
763                        expr_pool.push(left_arc);
764                        let right_idx = expr_pool.len();
765                        expr_pool.push(right_arc);
766
767                        work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
768                        work_stack.push(WorkItem::Process(right_idx));
769                        work_stack.push(WorkItem::Process(left_idx));
770                    }
771                    ExpressionKind::LogicalAnd(left, right) => {
772                        let left_arc = Arc::clone(left);
773                        let right_arc = Arc::clone(right);
774
775                        let left_idx = expr_pool.len();
776                        expr_pool.push(left_arc);
777                        let right_idx = expr_pool.len();
778                        expr_pool.push(right_arc);
779
780                        work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
781                        work_stack.push(WorkItem::Process(right_idx));
782                        work_stack.push(WorkItem::Process(left_idx));
783                    }
784                    ExpressionKind::LogicalOr(left, right) => {
785                        let left_arc = Arc::clone(left);
786                        let right_arc = Arc::clone(right);
787
788                        let left_idx = expr_pool.len();
789                        expr_pool.push(left_arc);
790                        let right_idx = expr_pool.len();
791                        expr_pool.push(right_arc);
792
793                        work_stack.push(WorkItem::BuildLogicalOr(source_loc));
794                        work_stack.push(WorkItem::Process(right_idx));
795                        work_stack.push(WorkItem::Process(left_idx));
796                    }
797                    ExpressionKind::LogicalNegation(inner, neg_type) => {
798                        let neg_type_clone = neg_type.clone();
799                        let inner_arc = Arc::clone(inner);
800                        let inner_idx = expr_pool.len();
801                        expr_pool.push(inner_arc);
802                        work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
803                        work_stack.push(WorkItem::Process(inner_idx));
804                    }
805                    ExpressionKind::UnitConversion(inner, unit) => {
806                        let unit_clone = unit.clone();
807                        let inner_arc = Arc::clone(inner);
808                        let inner_idx = expr_pool.len();
809                        expr_pool.push(inner_arc);
810                        work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
811                        work_stack.push(WorkItem::Process(inner_idx));
812                    }
813                    ExpressionKind::MathematicalComputation(func, inner) => {
814                        let func_clone = func.clone();
815                        let inner_arc = Arc::clone(inner);
816                        let inner_idx = expr_pool.len();
817                        expr_pool.push(inner_arc);
818                        work_stack.push(WorkItem::BuildMathematicalComputation(
819                            func_clone, source_loc,
820                        ));
821                        work_stack.push(WorkItem::Process(inner_idx));
822                    }
823                    ExpressionKind::Literal(lit) => {
824                        result_pool.push(Expression::new(
825                            ExpressionKind::Literal(lit.clone()),
826                            source_loc,
827                        ));
828                    }
829                    ExpressionKind::RulePath(rule_path) => {
830                        result_pool.push(Expression::new(
831                            ExpressionKind::RulePath(rule_path.clone()),
832                            source_loc,
833                        ));
834                    }
835                    ExpressionKind::Veto(veto) => {
836                        result_pool.push(Expression::new(
837                            ExpressionKind::Veto(veto.clone()),
838                            source_loc,
839                        ));
840                    }
841                }
842            }
843            WorkItem::BuildArithmetic(op, source_loc) => {
844                let right = result_pool.pop().unwrap_or_else(|| {
845                    unreachable!("BUG: missing right expression for Arithmetic")
846                });
847                let left = result_pool
848                    .pop()
849                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Arithmetic"));
850                result_pool.push(Expression::new(
851                    ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
852                    source_loc,
853                ));
854            }
855            WorkItem::BuildComparison(op, source_loc) => {
856                let right = result_pool.pop().unwrap_or_else(|| {
857                    unreachable!("BUG: missing right expression for Comparison")
858                });
859                let left = result_pool
860                    .pop()
861                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for Comparison"));
862                result_pool.push(Expression::new(
863                    ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
864                    source_loc,
865                ));
866            }
867            WorkItem::BuildLogicalAnd(source_loc) => {
868                let right = result_pool.pop().unwrap_or_else(|| {
869                    unreachable!("BUG: missing right expression for LogicalAnd")
870                });
871                let left = result_pool
872                    .pop()
873                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalAnd"));
874                result_pool.push(Expression::new(
875                    ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
876                    source_loc,
877                ));
878            }
879            WorkItem::BuildLogicalOr(source_loc) => {
880                let right = result_pool
881                    .pop()
882                    .unwrap_or_else(|| unreachable!("BUG: missing right expression for LogicalOr"));
883                let left = result_pool
884                    .pop()
885                    .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalOr"));
886                result_pool.push(Expression::new(
887                    ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
888                    source_loc,
889                ));
890            }
891            WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
892                let inner = result_pool
893                    .pop()
894                    .expect("Internal error: missing expression for LogicalNegation");
895                result_pool.push(Expression::new(
896                    ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
897                    source_loc,
898                ));
899            }
900            WorkItem::BuildUnitConversion(unit, source_loc) => {
901                let inner = result_pool
902                    .pop()
903                    .expect("Internal error: missing expression for UnitConversion");
904                result_pool.push(Expression::new(
905                    ExpressionKind::UnitConversion(Arc::new(inner), unit),
906                    source_loc,
907                ));
908            }
909            WorkItem::BuildMathematicalComputation(func, source_loc) => {
910                let inner = result_pool
911                    .pop()
912                    .expect("Internal error: missing expression for MathematicalComputation");
913                result_pool.push(Expression::new(
914                    ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
915                    source_loc,
916                ));
917            }
918        }
919    }
920
921    Ok(result_pool
922        .pop()
923        .expect("Internal error: no result from hydration"))
924}
925
926// ============================================================================
927// Constant folding
928// ============================================================================
929
930/// Extract an outcome (value or veto) from an expression
931fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
932    match &expr.kind {
933        ExpressionKind::Literal(lit) => Some(OperationResult::Value(lit.clone())),
934        ExpressionKind::Veto(ve) => Some(OperationResult::Veto(ve.message.clone())),
935        _ => None,
936    }
937}
938
939/// Check if an expression is a boolean-producing expression (comparison or logical)
940fn is_boolean_expression(expr: &Expression) -> bool {
941    matches!(
942        &expr.kind,
943        ExpressionKind::Comparison(_, _, _)
944            | ExpressionKind::LogicalAnd(_, _)
945            | ExpressionKind::LogicalOr(_, _)
946            | ExpressionKind::LogicalNegation(_, _)
947    )
948}
949
950/// Check if an expression is an arithmetic expression (contains arithmetic operations)
951///
952/// Returns true for expressions like `price * 5`, `x + y`, etc.
953fn is_arithmetic_expression(expr: &Expression) -> bool {
954    match &expr.kind {
955        ExpressionKind::Arithmetic(_, _, _) => true,
956        ExpressionKind::MathematicalComputation(_, _) => true,
957        ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
958        ExpressionKind::FactPath(_) => true, // Lone fact is also solvable
959        _ => false,
960    }
961}
962
963/// For boolean expressions that can't be evaluated to a literal (e.g., `age > 18`),
964/// create two solutions: one where the expression is true, one where it's false.
965///
966/// This allows inversion to work with rules like `rule of_age = age > 18`
967fn create_boolean_expression_solutions(
968    world: World,
969    base_constraint: Constraint,
970    boolean_expr: &Expression,
971) -> LemmaResult<(Vec<WorldSolution>, Vec<WorldSolution>)> {
972    // Convert boolean expression to constraint
973    let expr_constraint = Constraint::from_expression(boolean_expr)?;
974
975    // Solution where the boolean expression is true
976    let true_constraint = base_constraint.clone().and(expr_constraint.clone());
977    let simplified_true = true_constraint.simplify()?;
978
979    let true_solutions = if !simplified_true.is_false() {
980        vec![WorldSolution {
981            world: world.clone(),
982            constraint: simplified_true,
983            outcome: OperationResult::Value(LiteralValue::boolean(BooleanValue::True)),
984        }]
985    } else {
986        vec![]
987    };
988
989    // Solution where the boolean expression is false
990    let false_constraint = base_constraint.and(expr_constraint.not());
991    let simplified_false = false_constraint.simplify()?;
992
993    let false_solutions = if !simplified_false.is_false() {
994        vec![WorldSolution {
995            world,
996            constraint: simplified_false,
997            outcome: OperationResult::Value(LiteralValue::boolean(BooleanValue::False)),
998        }]
999    } else {
1000        vec![]
1001    };
1002
1003    Ok((true_solutions, false_solutions))
1004}
1005
1006/// Attempt constant folding on an expression (simplified version for outcomes)
1007pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
1008    match &expr.kind {
1009        ExpressionKind::Literal(_) => Some(expr.clone()),
1010        ExpressionKind::Arithmetic(left, op, right) => {
1011            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1012            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1013            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1014                (&left_folded.kind, &right_folded.kind)
1015            {
1016                if let Some(result) = evaluate_arithmetic(left_val, op, right_val) {
1017                    return Some(Expression::new(
1018                        ExpressionKind::Literal(result),
1019                        expr.source_location.clone(),
1020                    ));
1021                }
1022            }
1023            Some(Expression::new(
1024                ExpressionKind::Arithmetic(
1025                    Arc::new(left_folded),
1026                    op.clone(),
1027                    Arc::new(right_folded),
1028                ),
1029                expr.source_location.clone(),
1030            ))
1031        }
1032        ExpressionKind::Comparison(left, op, right) => {
1033            let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1034            let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1035            if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1036                (&left_folded.kind, &right_folded.kind)
1037            {
1038                if let Some(result) = evaluate_comparison(left_val, op, right_val) {
1039                    return Some(Expression::new(
1040                        ExpressionKind::Literal(LiteralValue::boolean(result)),
1041                        expr.source_location.clone(),
1042                    ));
1043                }
1044            }
1045            Some(Expression::new(
1046                ExpressionKind::Comparison(
1047                    Arc::new(left_folded),
1048                    op.clone(),
1049                    Arc::new(right_folded),
1050                ),
1051                expr.source_location.clone(),
1052            ))
1053        }
1054        _ => None,
1055    }
1056}
1057
1058/// Evaluate an arithmetic operation on two literals
1059///
1060/// Delegates to the computation module for consistent behavior
1061fn evaluate_arithmetic(
1062    left: &LiteralValue,
1063    op: &ArithmeticComputation,
1064    right: &LiteralValue,
1065) -> Option<LiteralValue> {
1066    use crate::computation::arithmetic_operation;
1067
1068    match arithmetic_operation(left, op, right) {
1069        OperationResult::Value(lit) => Some(lit),
1070        OperationResult::Veto(_) => None,
1071    }
1072}
1073
1074/// Evaluate a comparison operation on two literals
1075///
1076/// Delegates to the computation module for consistent behavior
1077fn evaluate_comparison(
1078    left: &LiteralValue,
1079    op: &ComparisonComputation,
1080    right: &LiteralValue,
1081) -> Option<BooleanValue> {
1082    use crate::computation::comparison_operation;
1083
1084    match comparison_operation(left, op, right) {
1085        OperationResult::Value(lit) => match &lit.value {
1086            Value::Boolean(b) => Some(b.clone()),
1087            _ => None,
1088        },
1089        _ => None,
1090    }
1091}
1092
1093// ============================================================================
1094// Tests
1095// ============================================================================
1096
1097#[cfg(test)]
1098mod tests {
1099    use super::*;
1100    use rust_decimal::Decimal;
1101
1102    fn literal_expr(val: LiteralValue) -> Expression {
1103        Expression::new(ExpressionKind::Literal(val), None)
1104    }
1105
1106    fn fact_expr(name: &str) -> Expression {
1107        Expression::new(
1108            ExpressionKind::FactPath(FactPath::local(name.to_string())),
1109            None,
1110        )
1111    }
1112
1113    fn num(n: i64) -> LiteralValue {
1114        LiteralValue::number(Decimal::from(n))
1115    }
1116
1117    #[test]
1118    fn test_world_new() {
1119        let world = World::new();
1120        assert!(world.0.is_empty());
1121    }
1122
1123    #[test]
1124    fn test_world_insert_and_get() {
1125        let mut world = World::new();
1126        let rule_path = RulePath {
1127            segments: vec![],
1128            rule: "test_rule".to_string(),
1129        };
1130        world.insert(rule_path.clone(), 2);
1131        assert_eq!(world.get(&rule_path), Some(&2));
1132    }
1133
1134    fn empty_plan() -> ExecutionPlan {
1135        ExecutionPlan {
1136            doc_name: "test".to_string(),
1137            fact_schema: HashMap::new(),
1138            fact_values: HashMap::new(),
1139            doc_refs: HashMap::new(),
1140            fact_sources: HashMap::new(),
1141            rules: Vec::new(),
1142            sources: HashMap::new(),
1143        }
1144    }
1145
1146    #[test]
1147    fn test_hydrate_literal_unchanged() {
1148        let plan = empty_plan();
1149        let provided: HashSet<FactPath> = HashSet::new();
1150
1151        let expr = literal_expr(num(42));
1152        let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1153
1154        if let ExpressionKind::Literal(lit) = &result.kind {
1155            assert!(matches!(&lit.value, Value::Number(_)));
1156        } else {
1157            panic!("Expected literal number");
1158        }
1159    }
1160
1161    #[test]
1162    fn test_hydrate_fact_not_provided() {
1163        let plan = empty_plan();
1164        let provided: HashSet<FactPath> = HashSet::new();
1165
1166        let expr = fact_expr("age");
1167        let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1168
1169        assert!(matches!(result.kind, ExpressionKind::FactPath(_)));
1170    }
1171
1172    #[test]
1173    fn test_constant_fold_arithmetic() {
1174        let left = literal_expr(num(10));
1175        let right = literal_expr(num(5));
1176        let expr = Expression::new(
1177            ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
1178            None,
1179        );
1180
1181        let folded = try_constant_fold_expression(&expr).unwrap();
1182
1183        if let ExpressionKind::Literal(lit) = &folded.kind {
1184            if let Value::Number(n) = &lit.value {
1185                assert_eq!(*n, Decimal::from(15));
1186            } else {
1187                panic!("Expected literal number");
1188            }
1189        } else {
1190            panic!("Expected literal number");
1191        }
1192    }
1193
1194    #[test]
1195    fn test_constant_fold_comparison() {
1196        let left = literal_expr(num(10));
1197        let right = literal_expr(num(5));
1198        let expr = Expression::new(
1199            ExpressionKind::Comparison(
1200                Arc::new(left),
1201                ComparisonComputation::GreaterThan,
1202                Arc::new(right),
1203            ),
1204            None,
1205        );
1206
1207        let folded = try_constant_fold_expression(&expr).unwrap();
1208
1209        if let ExpressionKind::Literal(lit) = &folded.kind {
1210            if let Value::Boolean(b) = &lit.value {
1211                assert_eq!(*b, BooleanValue::True);
1212            } else {
1213                panic!("Expected literal boolean");
1214            }
1215        } else {
1216            panic!("Expected literal boolean");
1217        }
1218    }
1219}