1use crate::parsing::ast::Span;
9use crate::planning::{ExecutableRule, ExecutionPlan};
10use crate::{
11 ArithmeticComputation, BooleanValue, ComparisonComputation, ConversionTarget, Expression,
12 ExpressionKind, FactPath, LemmaError, LemmaResult, LiteralValue, MathematicalComputation,
13 NegationType, OperationResult, RulePath, Value,
14};
15use serde::ser::{Serialize, SerializeMap, Serializer};
16use std::collections::{HashMap, HashSet, VecDeque};
17use std::sync::Arc;
18
19use super::constraint::Constraint;
20
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
23pub struct World(HashMap<RulePath, usize>);
24
25impl World {
26 pub fn new() -> Self {
28 Self(HashMap::new())
29 }
30
31 pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
33 self.0.get(rule_path)
34 }
35
36 pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
38 self.0.insert(rule_path, branch_idx)
39 }
40
41 pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
43 self.0.iter()
44 }
45}
46
47impl Serialize for World {
48 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49 where
50 S: Serializer,
51 {
52 let mut map = serializer.serialize_map(Some(self.0.len()))?;
53 for (k, v) in &self.0 {
54 map.serialize_entry(&k.to_string(), v)?;
55 }
56 map.end()
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct WorldSolution {
63 pub world: World,
65 pub constraint: Constraint,
67 pub outcome: OperationResult,
69}
70
71#[derive(Debug, Clone)]
77pub struct WorldArithmeticSolution {
78 pub world: World,
80 pub constraint: Constraint,
82 pub outcome_expression: Expression,
84}
85
86#[derive(Debug, Clone)]
88pub struct EnumerationResult {
89 pub literal_solutions: Vec<WorldSolution>,
91 pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
93}
94
95pub struct WorldEnumerator<'a> {
97 plan: &'a ExecutionPlan,
98 rules_in_order: Vec<RulePath>,
100 rule_cache: HashMap<RulePath, &'a ExecutableRule>,
102}
103
104impl<'a> WorldEnumerator<'a> {
105 pub fn new(plan: &'a ExecutionPlan, target_rule: &RulePath) -> LemmaResult<Self> {
107 let rule_map: HashMap<RulePath, &ExecutableRule> =
109 plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
110
111 let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
113
114 let rules_in_order: Vec<RulePath> = plan
116 .rules
117 .iter()
118 .filter(|r| dependent_rules.contains(&r.path))
119 .map(|r| r.path.clone())
120 .collect();
121
122 let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
124 .iter()
125 .filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
126 .collect();
127
128 Ok(Self {
129 plan,
130 rules_in_order,
131 rule_cache,
132 })
133 }
134
135 pub fn enumerate(
142 &mut self,
143 provided_facts: &HashSet<FactPath>,
144 ) -> LemmaResult<EnumerationResult> {
145 if self.rules_in_order.is_empty() {
146 return Ok(EnumerationResult {
147 literal_solutions: vec![],
148 arithmetic_solutions: vec![],
149 });
150 }
151
152 let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
154
155 for rule_path in &self.rules_in_order.clone() {
157 let rule_node = match self.rule_cache.get(rule_path) {
158 Some(node) => *node,
159 None => continue,
160 };
161
162 let mut next_worlds = Vec::new();
163
164 for (world, accumulated_constraint) in current_worlds {
165 for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
167 let mut new_world = world.clone();
169 new_world.insert(rule_path.clone(), branch_idx);
170
171 let mut branch_constraint = if let Some(ref condition) = branch.condition {
174 let substituted_condition = substitute_rules_in_expression(
176 &Arc::new(condition.clone()),
177 &new_world,
178 self.plan,
179 )?;
180 let hydrated_condition = hydrate_facts_in_expression(
181 &Arc::new(substituted_condition),
182 self.plan,
183 provided_facts,
184 )?;
185 Constraint::from_expression(&hydrated_condition)?
186 } else {
187 Constraint::True
189 };
190
191 for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
194 if let Some(ref later_condition) = later_branch.condition {
195 let substituted_later = substitute_rules_in_expression(
196 &Arc::new(later_condition.clone()),
197 &new_world,
198 self.plan,
199 )?;
200 let hydrated_later = hydrate_facts_in_expression(
201 &Arc::new(substituted_later),
202 self.plan,
203 provided_facts,
204 )?;
205 let later_constraint = Constraint::from_expression(&hydrated_later)?;
206 branch_constraint = branch_constraint.and(later_constraint.not());
208 }
209 }
210
211 let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
213
214 let simplified = combined_constraint.simplify()?;
216
217 if !simplified.is_false() {
219 next_worlds.push((new_world, simplified));
220 }
221 }
222 }
223
224 current_worlds = next_worlds;
225
226 if current_worlds.is_empty() {
228 break;
229 }
230 }
231
232 let target_rule_path = self.rules_in_order.last().ok_or_else(|| {
234 LemmaError::engine(
235 "No rules in order for world enumeration",
236 Span {
237 start: 0,
238 end: 0,
239 line: 1,
240 col: 0,
241 },
242 "<unknown>",
243 Arc::from(""),
244 "<unknown>",
245 1,
246 None::<String>,
247 )
248 })?;
249
250 let mut literal_solutions = Vec::new();
251 let mut arithmetic_solutions = Vec::new();
252
253 for (world, constraint) in current_worlds {
254 if let Some(&branch_idx) = world.get(target_rule_path) {
256 if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
257 if branch_idx < rule_node.branches.len() {
258 let branch = &rule_node.branches[branch_idx];
259
260 let substituted_result = substitute_rules_in_expression(
262 &Arc::new(branch.result.clone()),
263 &world,
264 self.plan,
265 )?;
266
267 let hydrated_result = hydrate_facts_in_expression(
268 &Arc::new(substituted_result),
269 self.plan,
270 provided_facts,
271 )?;
272
273 let folded_result = try_constant_fold_expression(&hydrated_result)
275 .unwrap_or(hydrated_result.clone());
276
277 if let Some(outcome) = extract_outcome(&folded_result) {
279 literal_solutions.push(WorldSolution {
280 world,
281 constraint,
282 outcome,
283 });
284 } else if is_boolean_expression(&folded_result) {
285 let (true_solutions, false_solutions) =
288 create_boolean_expression_solutions(
289 world,
290 constraint,
291 &folded_result,
292 )?;
293 literal_solutions.extend(true_solutions);
294 literal_solutions.extend(false_solutions);
295 } else if is_arithmetic_expression(&folded_result) {
296 arithmetic_solutions.push(WorldArithmeticSolution {
298 world,
299 constraint,
300 outcome_expression: folded_result,
301 });
302 }
303 }
306 }
307 }
308 }
309
310 Ok(EnumerationResult {
311 literal_solutions,
312 arithmetic_solutions,
313 })
314 }
315}
316
317fn collect_transitive_dependencies(
323 target_rule: &RulePath,
324 rule_map: &HashMap<RulePath, &ExecutableRule>,
325) -> LemmaResult<HashSet<RulePath>> {
326 let mut result = HashSet::new();
327 let mut queue = VecDeque::new();
328
329 queue.push_back(target_rule.clone());
330 result.insert(target_rule.clone());
331
332 while let Some(rule_path) = queue.pop_front() {
333 if let Some(rule) = rule_map.get(&rule_path) {
334 let dependencies = extract_rule_dependencies(rule);
336 for dependency in dependencies {
337 if result.insert(dependency.clone()) {
338 queue.push_back(dependency);
339 }
340 }
341 }
342 }
343
344 Ok(result)
345}
346
347fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
349 let mut deps = HashSet::new();
350 for branch in &rule.branches {
351 if let Some(ref condition) = branch.condition {
352 extract_rule_paths_from_expression(condition, &mut deps);
353 }
354 extract_rule_paths_from_expression(&branch.result, &mut deps);
355 }
356 deps
357}
358
359fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
361 match &expr.kind {
362 ExpressionKind::RulePath(rp) => {
363 paths.insert(rp.clone());
364 }
365 ExpressionKind::Arithmetic(left, _, right)
366 | ExpressionKind::Comparison(left, _, right)
367 | ExpressionKind::LogicalAnd(left, right)
368 | ExpressionKind::LogicalOr(left, right) => {
369 extract_rule_paths_from_expression(left, paths);
370 extract_rule_paths_from_expression(right, paths);
371 }
372 ExpressionKind::LogicalNegation(inner, _)
373 | ExpressionKind::UnitConversion(inner, _)
374 | ExpressionKind::MathematicalComputation(_, inner) => {
375 extract_rule_paths_from_expression(inner, paths);
376 }
377 ExpressionKind::Literal(_)
378 | ExpressionKind::FactPath(_)
379 | ExpressionKind::Veto(_)
380 | ExpressionKind::Reference(_)
381 | ExpressionKind::UnresolvedUnitLiteral(_, _)
382 | ExpressionKind::FactReference(_)
383 | ExpressionKind::RuleReference(_) => {}
384 }
385}
386
387fn substitute_rules_in_expression(
396 expr: &Arc<Expression>,
397 world: &World,
398 plan: &ExecutionPlan,
399) -> LemmaResult<Expression> {
400 enum WorkItem {
401 Process(usize),
402 BuildArithmetic(ArithmeticComputation, Option<crate::Source>),
403 BuildComparison(ComparisonComputation, Option<crate::Source>),
404 BuildLogicalAnd(Option<crate::Source>),
405 BuildLogicalOr(Option<crate::Source>),
406 BuildLogicalNegation(NegationType, Option<crate::Source>),
407 BuildUnitConversion(ConversionTarget, Option<crate::Source>),
408 BuildMathematicalComputation(MathematicalComputation, Option<crate::Source>),
409 PopVisitedRules,
410 }
411
412 let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
413 let mut work_stack: Vec<WorkItem> = Vec::new();
414 let mut result_pool: Vec<Expression> = Vec::new();
415 let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
416
417 let root_idx = expr_pool.len();
418 expr_pool.push(Arc::clone(expr));
419 work_stack.push(WorkItem::Process(root_idx));
420
421 while let Some(work) = work_stack.pop() {
422 match work {
423 WorkItem::Process(expr_idx) => {
424 let e = &expr_pool[expr_idx];
425 let source_loc = e.source_location.clone();
426
427 match &e.kind {
428 ExpressionKind::RulePath(rule_path) => {
429 let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
430 if visited.contains(rule_path) {
431 return Err(LemmaError::engine(
432 format!(
433 "Circular rule reference detected during substitution: {}",
434 rule_path
435 ),
436 Span {
437 start: 0,
438 end: 0,
439 line: 1,
440 col: 0,
441 },
442 "<unknown>",
443 Arc::from(""),
444 "<unknown>",
445 1,
446 None::<String>,
447 ));
448 }
449
450 if let Some(&branch_idx) = world.get(rule_path) {
451 if let Some(rule) = plan.get_rule_by_path(rule_path) {
452 if branch_idx < rule.branches.len() {
453 let branch = &rule.branches[branch_idx];
454 let mut new_visited = visited.clone();
455 new_visited.insert(rule_path.clone());
456 visited_rules_stack.push(new_visited);
457
458 let sub_expr_idx = expr_pool.len();
459 expr_pool.push(Arc::new(branch.result.clone()));
460 work_stack.push(WorkItem::PopVisitedRules);
461 work_stack.push(WorkItem::Process(sub_expr_idx));
462 continue;
463 }
464 }
465 }
466 result_pool.push(Expression::new(
467 ExpressionKind::RulePath(rule_path.clone()),
468 source_loc,
469 ));
470 }
471 ExpressionKind::RuleReference(_) => {
472 return Err(LemmaError::engine(
473 "RuleReference found during substitution - should have been converted to RulePath",
474 Span { start: 0, end: 0, line: 1, col: 0 },
475 "<unknown>",
476 Arc::from(""),
477 "<unknown>",
478 1,
479 None::<String>,
480 ));
481 }
482 ExpressionKind::FactReference(_) => {
483 return Err(LemmaError::engine(
484 "FactReference found during substitution - should have been converted to FactPath",
485 Span { start: 0, end: 0, line: 1, col: 0 },
486 "<unknown>",
487 Arc::from(""),
488 "<unknown>",
489 1,
490 None::<String>,
491 ));
492 }
493 ExpressionKind::Reference(_) => {
494 return Err(LemmaError::engine(
495 "Unresolved Reference found during substitution - should have been resolved during planning",
496 Span { start: 0, end: 0, line: 1, col: 0 },
497 "<unknown>",
498 Arc::from(""),
499 "<unknown>",
500 1,
501 None::<String>,
502 ));
503 }
504 ExpressionKind::UnresolvedUnitLiteral(_, _) => {
505 unreachable!(
506 "UnresolvedUnitLiteral found during substitution - this is a bug: unresolved units should be resolved during planning"
507 );
508 }
509 ExpressionKind::Arithmetic(left, op, right) => {
510 let op_clone = op.clone();
511 let left_arc = Arc::clone(left);
512 let right_arc = Arc::clone(right);
513
514 let left_idx = expr_pool.len();
515 expr_pool.push(left_arc);
516 let right_idx = expr_pool.len();
517 expr_pool.push(right_arc);
518
519 work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
520 work_stack.push(WorkItem::Process(right_idx));
521 work_stack.push(WorkItem::Process(left_idx));
522 }
523 ExpressionKind::Comparison(left, op, right) => {
524 let op_clone = op.clone();
525 let left_arc = Arc::clone(left);
526 let right_arc = Arc::clone(right);
527
528 let left_idx = expr_pool.len();
529 expr_pool.push(left_arc);
530 let right_idx = expr_pool.len();
531 expr_pool.push(right_arc);
532
533 work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
534 work_stack.push(WorkItem::Process(right_idx));
535 work_stack.push(WorkItem::Process(left_idx));
536 }
537 ExpressionKind::LogicalAnd(left, right) => {
538 let left_arc = Arc::clone(left);
539 let right_arc = Arc::clone(right);
540
541 let left_idx = expr_pool.len();
542 expr_pool.push(left_arc);
543 let right_idx = expr_pool.len();
544 expr_pool.push(right_arc);
545
546 work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
547 work_stack.push(WorkItem::Process(right_idx));
548 work_stack.push(WorkItem::Process(left_idx));
549 }
550 ExpressionKind::LogicalOr(left, right) => {
551 let left_arc = Arc::clone(left);
552 let right_arc = Arc::clone(right);
553
554 let left_idx = expr_pool.len();
555 expr_pool.push(left_arc);
556 let right_idx = expr_pool.len();
557 expr_pool.push(right_arc);
558
559 work_stack.push(WorkItem::BuildLogicalOr(source_loc));
560 work_stack.push(WorkItem::Process(right_idx));
561 work_stack.push(WorkItem::Process(left_idx));
562 }
563 ExpressionKind::LogicalNegation(inner, neg_type) => {
564 let neg_type_clone = neg_type.clone();
565 let inner_arc = Arc::clone(inner);
566 let inner_idx = expr_pool.len();
567 expr_pool.push(inner_arc);
568 work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
569 work_stack.push(WorkItem::Process(inner_idx));
570 }
571 ExpressionKind::UnitConversion(inner, unit) => {
572 let unit_clone = unit.clone();
573 let inner_arc = Arc::clone(inner);
574 let inner_idx = expr_pool.len();
575 expr_pool.push(inner_arc);
576 work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
577 work_stack.push(WorkItem::Process(inner_idx));
578 }
579 ExpressionKind::MathematicalComputation(func, inner) => {
580 let func_clone = func.clone();
581 let inner_arc = Arc::clone(inner);
582 let inner_idx = expr_pool.len();
583 expr_pool.push(inner_arc);
584 work_stack.push(WorkItem::BuildMathematicalComputation(
585 func_clone, source_loc,
586 ));
587 work_stack.push(WorkItem::Process(inner_idx));
588 }
589 ExpressionKind::Literal(lit) => {
590 result_pool.push(Expression::new(
591 ExpressionKind::Literal(lit.clone()),
592 source_loc,
593 ));
594 }
595 ExpressionKind::FactPath(fact_path) => {
596 result_pool.push(Expression::new(
597 ExpressionKind::FactPath(fact_path.clone()),
598 source_loc,
599 ));
600 }
601 ExpressionKind::Veto(veto) => {
602 result_pool.push(Expression::new(
603 ExpressionKind::Veto(veto.clone()),
604 source_loc,
605 ));
606 }
607 }
608 }
609 WorkItem::BuildArithmetic(op, source_loc) => {
610 let right = result_pool.pop().unwrap_or_else(|| {
611 unreachable!(
612 "BUG: missing right expression for Arithmetic during inversion hydration"
613 )
614 });
615 let left = result_pool.pop().unwrap_or_else(|| {
616 unreachable!(
617 "BUG: missing left expression for Arithmetic during inversion hydration"
618 )
619 });
620 result_pool.push(Expression::new(
621 ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
622 source_loc,
623 ));
624 }
625 WorkItem::BuildComparison(op, source_loc) => {
626 let right = result_pool.pop().unwrap_or_else(|| {
627 unreachable!(
628 "BUG: missing right expression for Comparison during inversion hydration"
629 )
630 });
631 let left = result_pool.pop().unwrap_or_else(|| {
632 unreachable!(
633 "BUG: missing left expression for Comparison during inversion hydration"
634 )
635 });
636 result_pool.push(Expression::new(
637 ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
638 source_loc,
639 ));
640 }
641 WorkItem::BuildLogicalAnd(source_loc) => {
642 let right = result_pool.pop().unwrap_or_else(|| {
643 unreachable!(
644 "BUG: missing right expression for LogicalAnd during inversion hydration"
645 )
646 });
647 let left = result_pool.pop().unwrap_or_else(|| {
648 unreachable!(
649 "BUG: missing left expression for LogicalAnd during inversion hydration"
650 )
651 });
652 result_pool.push(Expression::new(
653 ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
654 source_loc,
655 ));
656 }
657 WorkItem::BuildLogicalOr(source_loc) => {
658 let right = result_pool.pop().unwrap_or_else(|| {
659 unreachable!(
660 "BUG: missing right expression for LogicalOr during inversion hydration"
661 )
662 });
663 let left = result_pool.pop().unwrap_or_else(|| {
664 unreachable!(
665 "BUG: missing left expression for LogicalOr during inversion hydration"
666 )
667 });
668 result_pool.push(Expression::new(
669 ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
670 source_loc,
671 ));
672 }
673 WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
674 let inner = result_pool
675 .pop()
676 .expect("Internal error: missing expression for LogicalNegation");
677 result_pool.push(Expression::new(
678 ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
679 source_loc,
680 ));
681 }
682 WorkItem::BuildUnitConversion(unit, source_loc) => {
683 let inner = result_pool
684 .pop()
685 .expect("Internal error: missing expression for UnitConversion");
686 result_pool.push(Expression::new(
687 ExpressionKind::UnitConversion(Arc::new(inner), unit),
688 source_loc,
689 ));
690 }
691 WorkItem::BuildMathematicalComputation(func, source_loc) => {
692 let inner = result_pool
693 .pop()
694 .expect("Internal error: missing expression for MathematicalComputation");
695 result_pool.push(Expression::new(
696 ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
697 source_loc,
698 ));
699 }
700 WorkItem::PopVisitedRules => {
701 visited_rules_stack.pop();
702 }
703 }
704 }
705
706 result_pool.pop().ok_or_else(|| {
707 LemmaError::engine(
708 "Internal error: no result from substitution",
709 Span {
710 start: 0,
711 end: 0,
712 line: 1,
713 col: 0,
714 },
715 "<unknown>",
716 Arc::from(""),
717 "<unknown>",
718 1,
719 None::<String>,
720 )
721 })
722}
723
724fn hydrate_facts_in_expression(
733 expr: &Arc<Expression>,
734 plan: &ExecutionPlan,
735 provided_facts: &HashSet<FactPath>,
736) -> LemmaResult<Expression> {
737 enum WorkItem {
738 Process(usize),
739 BuildArithmetic(ArithmeticComputation, Option<crate::Source>),
740 BuildComparison(ComparisonComputation, Option<crate::Source>),
741 BuildLogicalAnd(Option<crate::Source>),
742 BuildLogicalOr(Option<crate::Source>),
743 BuildLogicalNegation(NegationType, Option<crate::Source>),
744 BuildUnitConversion(ConversionTarget, Option<crate::Source>),
745 BuildMathematicalComputation(MathematicalComputation, Option<crate::Source>),
746 }
747
748 let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
749 let mut work_stack: Vec<WorkItem> = Vec::new();
750 let mut result_pool: Vec<Expression> = Vec::new();
751
752 let root_idx = expr_pool.len();
753 expr_pool.push(Arc::clone(expr));
754 work_stack.push(WorkItem::Process(root_idx));
755
756 while let Some(work) = work_stack.pop() {
757 match work {
758 WorkItem::Process(expr_idx) => {
759 let (source_loc, expr_kind_ref) = {
760 let e = &expr_pool[expr_idx];
761 (e.source_location.clone(), &e.kind)
762 };
763
764 match expr_kind_ref {
765 ExpressionKind::FactPath(fact_path) => {
766 if provided_facts.contains(fact_path) {
767 if let Some(lit) = plan.fact_values.get(fact_path) {
768 result_pool.push(Expression::new(
769 ExpressionKind::Literal(lit.clone()),
770 source_loc,
771 ));
772 continue;
773 }
774 }
775 result_pool.push(Expression::new(
776 ExpressionKind::FactPath(fact_path.clone()),
777 source_loc,
778 ));
779 }
780 ExpressionKind::FactReference(_) => {
781 return Err(LemmaError::engine(
782 "FactReference found during hydration - should have been converted to FactPath",
783 Span { start: 0, end: 0, line: 1, col: 0 },
784 "<unknown>",
785 Arc::from(""),
786 "<unknown>",
787 1,
788 None::<String>,
789 ));
790 }
791 ExpressionKind::RuleReference(_) => {
792 return Err(LemmaError::engine(
793 "RuleReference found during hydration - should have been converted to RulePath",
794 Span { start: 0, end: 0, line: 1, col: 0 },
795 "<unknown>",
796 Arc::from(""),
797 "<unknown>",
798 1,
799 None::<String>,
800 ));
801 }
802 ExpressionKind::Reference(_) => {
803 return Err(LemmaError::engine(
804 "Unresolved Reference found during hydration - should have been resolved during planning",
805 Span { start: 0, end: 0, line: 1, col: 0 },
806 "<unknown>",
807 Arc::from(""),
808 "<unknown>",
809 1,
810 None::<String>,
811 ));
812 }
813 ExpressionKind::UnresolvedUnitLiteral(_, _) => {
814 unreachable!(
815 "UnresolvedUnitLiteral found during hydration - this is a bug: unresolved units should be resolved during planning"
816 );
817 }
818 ExpressionKind::Arithmetic(left, op, right) => {
819 let op_clone = op.clone();
820 let left_arc = Arc::clone(left);
821 let right_arc = Arc::clone(right);
822
823 let left_idx = expr_pool.len();
824 expr_pool.push(left_arc);
825 let right_idx = expr_pool.len();
826 expr_pool.push(right_arc);
827
828 work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
829 work_stack.push(WorkItem::Process(right_idx));
830 work_stack.push(WorkItem::Process(left_idx));
831 }
832 ExpressionKind::Comparison(left, op, right) => {
833 let op_clone = op.clone();
834 let left_arc = Arc::clone(left);
835 let right_arc = Arc::clone(right);
836
837 let left_idx = expr_pool.len();
838 expr_pool.push(left_arc);
839 let right_idx = expr_pool.len();
840 expr_pool.push(right_arc);
841
842 work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
843 work_stack.push(WorkItem::Process(right_idx));
844 work_stack.push(WorkItem::Process(left_idx));
845 }
846 ExpressionKind::LogicalAnd(left, right) => {
847 let left_arc = Arc::clone(left);
848 let right_arc = Arc::clone(right);
849
850 let left_idx = expr_pool.len();
851 expr_pool.push(left_arc);
852 let right_idx = expr_pool.len();
853 expr_pool.push(right_arc);
854
855 work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
856 work_stack.push(WorkItem::Process(right_idx));
857 work_stack.push(WorkItem::Process(left_idx));
858 }
859 ExpressionKind::LogicalOr(left, right) => {
860 let left_arc = Arc::clone(left);
861 let right_arc = Arc::clone(right);
862
863 let left_idx = expr_pool.len();
864 expr_pool.push(left_arc);
865 let right_idx = expr_pool.len();
866 expr_pool.push(right_arc);
867
868 work_stack.push(WorkItem::BuildLogicalOr(source_loc));
869 work_stack.push(WorkItem::Process(right_idx));
870 work_stack.push(WorkItem::Process(left_idx));
871 }
872 ExpressionKind::LogicalNegation(inner, neg_type) => {
873 let neg_type_clone = neg_type.clone();
874 let inner_arc = Arc::clone(inner);
875 let inner_idx = expr_pool.len();
876 expr_pool.push(inner_arc);
877 work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
878 work_stack.push(WorkItem::Process(inner_idx));
879 }
880 ExpressionKind::UnitConversion(inner, unit) => {
881 let unit_clone = unit.clone();
882 let inner_arc = Arc::clone(inner);
883 let inner_idx = expr_pool.len();
884 expr_pool.push(inner_arc);
885 work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
886 work_stack.push(WorkItem::Process(inner_idx));
887 }
888 ExpressionKind::MathematicalComputation(func, inner) => {
889 let func_clone = func.clone();
890 let inner_arc = Arc::clone(inner);
891 let inner_idx = expr_pool.len();
892 expr_pool.push(inner_arc);
893 work_stack.push(WorkItem::BuildMathematicalComputation(
894 func_clone, source_loc,
895 ));
896 work_stack.push(WorkItem::Process(inner_idx));
897 }
898 ExpressionKind::Literal(lit) => {
899 result_pool.push(Expression::new(
900 ExpressionKind::Literal(lit.clone()),
901 source_loc,
902 ));
903 }
904 ExpressionKind::RulePath(rule_path) => {
905 result_pool.push(Expression::new(
906 ExpressionKind::RulePath(rule_path.clone()),
907 source_loc,
908 ));
909 }
910 ExpressionKind::Veto(veto) => {
911 result_pool.push(Expression::new(
912 ExpressionKind::Veto(veto.clone()),
913 source_loc,
914 ));
915 }
916 }
917 }
918 WorkItem::BuildArithmetic(op, source_loc) => {
919 let right = result_pool.pop().ok_or_else(|| {
920 LemmaError::engine(
921 "Internal error: missing right expression for Arithmetic",
922 Span {
923 start: 0,
924 end: 0,
925 line: 1,
926 col: 0,
927 },
928 "<unknown>",
929 Arc::from(""),
930 "<unknown>",
931 1,
932 None::<String>,
933 )
934 })?;
935 let left = result_pool.pop().ok_or_else(|| {
936 LemmaError::engine(
937 "Internal error: missing left expression for Arithmetic",
938 Span {
939 start: 0,
940 end: 0,
941 line: 1,
942 col: 0,
943 },
944 "<unknown>",
945 Arc::from(""),
946 "<unknown>",
947 1,
948 None::<String>,
949 )
950 })?;
951 result_pool.push(Expression::new(
952 ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
953 source_loc,
954 ));
955 }
956 WorkItem::BuildComparison(op, source_loc) => {
957 let right = result_pool.pop().ok_or_else(|| {
958 LemmaError::engine(
959 "Internal error: missing right expression for Comparison",
960 Span {
961 start: 0,
962 end: 0,
963 line: 1,
964 col: 0,
965 },
966 "<unknown>",
967 Arc::from(""),
968 "<unknown>",
969 1,
970 None::<String>,
971 )
972 })?;
973 let left = result_pool.pop().ok_or_else(|| {
974 LemmaError::engine(
975 "Internal error: missing left expression for Comparison",
976 Span {
977 start: 0,
978 end: 0,
979 line: 1,
980 col: 0,
981 },
982 "<unknown>",
983 Arc::from(""),
984 "<unknown>",
985 1,
986 None::<String>,
987 )
988 })?;
989 result_pool.push(Expression::new(
990 ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
991 source_loc,
992 ));
993 }
994 WorkItem::BuildLogicalAnd(source_loc) => {
995 let right = result_pool.pop().ok_or_else(|| {
996 LemmaError::engine(
997 "Internal error: missing right expression for LogicalAnd",
998 Span {
999 start: 0,
1000 end: 0,
1001 line: 1,
1002 col: 0,
1003 },
1004 "<unknown>",
1005 Arc::from(""),
1006 "<unknown>",
1007 1,
1008 None::<String>,
1009 )
1010 })?;
1011 let left = result_pool.pop().ok_or_else(|| {
1012 LemmaError::engine(
1013 "Internal error: missing left expression for LogicalAnd",
1014 Span {
1015 start: 0,
1016 end: 0,
1017 line: 1,
1018 col: 0,
1019 },
1020 "<unknown>",
1021 Arc::from(""),
1022 "<unknown>",
1023 1,
1024 None::<String>,
1025 )
1026 })?;
1027 result_pool.push(Expression::new(
1028 ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
1029 source_loc,
1030 ));
1031 }
1032 WorkItem::BuildLogicalOr(source_loc) => {
1033 let right = result_pool.pop().ok_or_else(|| {
1034 LemmaError::engine(
1035 "Internal error: missing right expression for LogicalOr",
1036 Span {
1037 start: 0,
1038 end: 0,
1039 line: 1,
1040 col: 0,
1041 },
1042 "<unknown>",
1043 Arc::from(""),
1044 "<unknown>",
1045 1,
1046 None::<String>,
1047 )
1048 })?;
1049 let left = result_pool.pop().ok_or_else(|| {
1050 LemmaError::engine(
1051 "Internal error: missing left expression for LogicalOr",
1052 Span {
1053 start: 0,
1054 end: 0,
1055 line: 1,
1056 col: 0,
1057 },
1058 "<unknown>",
1059 Arc::from(""),
1060 "<unknown>",
1061 1,
1062 None::<String>,
1063 )
1064 })?;
1065 result_pool.push(Expression::new(
1066 ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
1067 source_loc,
1068 ));
1069 }
1070 WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
1071 let inner = result_pool
1072 .pop()
1073 .expect("Internal error: missing expression for LogicalNegation");
1074 result_pool.push(Expression::new(
1075 ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
1076 source_loc,
1077 ));
1078 }
1079 WorkItem::BuildUnitConversion(unit, source_loc) => {
1080 let inner = result_pool
1081 .pop()
1082 .expect("Internal error: missing expression for UnitConversion");
1083 result_pool.push(Expression::new(
1084 ExpressionKind::UnitConversion(Arc::new(inner), unit),
1085 source_loc,
1086 ));
1087 }
1088 WorkItem::BuildMathematicalComputation(func, source_loc) => {
1089 let inner = result_pool
1090 .pop()
1091 .expect("Internal error: missing expression for MathematicalComputation");
1092 result_pool.push(Expression::new(
1093 ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
1094 source_loc,
1095 ));
1096 }
1097 }
1098 }
1099
1100 Ok(result_pool
1101 .pop()
1102 .expect("Internal error: no result from hydration"))
1103}
1104
1105fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
1111 match &expr.kind {
1112 ExpressionKind::Literal(lit) => Some(OperationResult::Value(lit.clone())),
1113 ExpressionKind::Veto(ve) => Some(OperationResult::Veto(ve.message.clone())),
1114 _ => None,
1115 }
1116}
1117
1118fn is_boolean_expression(expr: &Expression) -> bool {
1120 matches!(
1121 &expr.kind,
1122 ExpressionKind::Comparison(_, _, _)
1123 | ExpressionKind::LogicalAnd(_, _)
1124 | ExpressionKind::LogicalOr(_, _)
1125 | ExpressionKind::LogicalNegation(_, _)
1126 )
1127}
1128
1129fn is_arithmetic_expression(expr: &Expression) -> bool {
1133 match &expr.kind {
1134 ExpressionKind::Arithmetic(_, _, _) => true,
1135 ExpressionKind::MathematicalComputation(_, _) => true,
1136 ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
1137 ExpressionKind::FactPath(_) => true, _ => false,
1139 }
1140}
1141
1142fn create_boolean_expression_solutions(
1147 world: World,
1148 base_constraint: Constraint,
1149 boolean_expr: &Expression,
1150) -> LemmaResult<(Vec<WorldSolution>, Vec<WorldSolution>)> {
1151 let expr_constraint = Constraint::from_expression(boolean_expr)?;
1153
1154 let true_constraint = base_constraint.clone().and(expr_constraint.clone());
1156 let simplified_true = true_constraint.simplify()?;
1157
1158 let true_solutions = if !simplified_true.is_false() {
1159 vec![WorldSolution {
1160 world: world.clone(),
1161 constraint: simplified_true,
1162 outcome: OperationResult::Value(LiteralValue::boolean(BooleanValue::True)),
1163 }]
1164 } else {
1165 vec![]
1166 };
1167
1168 let false_constraint = base_constraint.and(expr_constraint.not());
1170 let simplified_false = false_constraint.simplify()?;
1171
1172 let false_solutions = if !simplified_false.is_false() {
1173 vec![WorldSolution {
1174 world,
1175 constraint: simplified_false,
1176 outcome: OperationResult::Value(LiteralValue::boolean(BooleanValue::False)),
1177 }]
1178 } else {
1179 vec![]
1180 };
1181
1182 Ok((true_solutions, false_solutions))
1183}
1184
1185pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
1187 match &expr.kind {
1188 ExpressionKind::Literal(_) => Some(expr.clone()),
1189 ExpressionKind::Arithmetic(left, op, right) => {
1190 let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1191 let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1192 if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1193 (&left_folded.kind, &right_folded.kind)
1194 {
1195 if let Some(result) = evaluate_arithmetic(left_val, op, right_val) {
1196 return Some(Expression::new(
1197 ExpressionKind::Literal(result),
1198 expr.source_location.clone(),
1199 ));
1200 }
1201 }
1202 Some(Expression::new(
1203 ExpressionKind::Arithmetic(
1204 Arc::new(left_folded),
1205 op.clone(),
1206 Arc::new(right_folded),
1207 ),
1208 expr.source_location.clone(),
1209 ))
1210 }
1211 ExpressionKind::Comparison(left, op, right) => {
1212 let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1213 let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1214 if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1215 (&left_folded.kind, &right_folded.kind)
1216 {
1217 if let Some(result) = evaluate_comparison(left_val, op, right_val) {
1218 return Some(Expression::new(
1219 ExpressionKind::Literal(LiteralValue::boolean(result)),
1220 expr.source_location.clone(),
1221 ));
1222 }
1223 }
1224 Some(Expression::new(
1225 ExpressionKind::Comparison(
1226 Arc::new(left_folded),
1227 op.clone(),
1228 Arc::new(right_folded),
1229 ),
1230 expr.source_location.clone(),
1231 ))
1232 }
1233 _ => None,
1234 }
1235}
1236
1237fn evaluate_arithmetic(
1241 left: &LiteralValue,
1242 op: &ArithmeticComputation,
1243 right: &LiteralValue,
1244) -> Option<LiteralValue> {
1245 use crate::computation::arithmetic_operation;
1246
1247 match arithmetic_operation(left, op, right) {
1248 OperationResult::Value(lit) => Some(lit),
1249 OperationResult::Veto(_) => None,
1250 }
1251}
1252
1253fn evaluate_comparison(
1257 left: &LiteralValue,
1258 op: &ComparisonComputation,
1259 right: &LiteralValue,
1260) -> Option<BooleanValue> {
1261 use crate::computation::comparison_operation;
1262
1263 match comparison_operation(left, op, right) {
1264 OperationResult::Value(lit) => match &lit.value {
1265 Value::Boolean(b) => Some(b.clone()),
1266 _ => None,
1267 },
1268 _ => None,
1269 }
1270}
1271
1272#[cfg(test)]
1277mod tests {
1278 use super::*;
1279 use rust_decimal::Decimal;
1280
1281 fn literal_expr(val: LiteralValue) -> Expression {
1282 Expression::new(ExpressionKind::Literal(val), None)
1283 }
1284
1285 fn fact_expr(name: &str) -> Expression {
1286 Expression::new(
1287 ExpressionKind::FactPath(FactPath::local(name.to_string())),
1288 None,
1289 )
1290 }
1291
1292 fn num(n: i64) -> LiteralValue {
1293 LiteralValue::number(Decimal::from(n))
1294 }
1295
1296 #[test]
1297 fn test_world_new() {
1298 let world = World::new();
1299 assert!(world.0.is_empty());
1300 }
1301
1302 #[test]
1303 fn test_world_insert_and_get() {
1304 let mut world = World::new();
1305 let rule_path = RulePath {
1306 segments: vec![],
1307 rule: "test_rule".to_string(),
1308 };
1309 world.insert(rule_path.clone(), 2);
1310 assert_eq!(world.get(&rule_path), Some(&2));
1311 }
1312
1313 fn empty_plan() -> ExecutionPlan {
1314 ExecutionPlan {
1315 doc_name: "test".to_string(),
1316 fact_schema: HashMap::new(),
1317 fact_values: HashMap::new(),
1318 doc_refs: HashMap::new(),
1319 fact_sources: HashMap::new(),
1320 rules: Vec::new(),
1321 sources: HashMap::new(),
1322 }
1323 }
1324
1325 #[test]
1326 fn test_hydrate_literal_unchanged() {
1327 let plan = empty_plan();
1328 let provided: HashSet<FactPath> = HashSet::new();
1329
1330 let expr = literal_expr(num(42));
1331 let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1332
1333 if let ExpressionKind::Literal(lit) = &result.kind {
1334 assert!(matches!(&lit.value, Value::Number(_)));
1335 } else {
1336 panic!("Expected literal number");
1337 }
1338 }
1339
1340 #[test]
1341 fn test_hydrate_fact_not_provided() {
1342 let plan = empty_plan();
1343 let provided: HashSet<FactPath> = HashSet::new();
1344
1345 let expr = fact_expr("age");
1346 let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1347
1348 assert!(matches!(result.kind, ExpressionKind::FactPath(_)));
1349 }
1350
1351 #[test]
1352 fn test_constant_fold_arithmetic() {
1353 let left = literal_expr(num(10));
1354 let right = literal_expr(num(5));
1355 let expr = Expression::new(
1356 ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
1357 None,
1358 );
1359
1360 let folded = try_constant_fold_expression(&expr).unwrap();
1361
1362 if let ExpressionKind::Literal(lit) = &folded.kind {
1363 if let Value::Number(n) = &lit.value {
1364 assert_eq!(*n, Decimal::from(15));
1365 } else {
1366 panic!("Expected literal number");
1367 }
1368 } else {
1369 panic!("Expected literal number");
1370 }
1371 }
1372
1373 #[test]
1374 fn test_constant_fold_comparison() {
1375 let left = literal_expr(num(10));
1376 let right = literal_expr(num(5));
1377 let expr = Expression::new(
1378 ExpressionKind::Comparison(
1379 Arc::new(left),
1380 ComparisonComputation::GreaterThan,
1381 Arc::new(right),
1382 ),
1383 None,
1384 );
1385
1386 let folded = try_constant_fold_expression(&expr).unwrap();
1387
1388 if let ExpressionKind::Literal(lit) = &folded.kind {
1389 if let Value::Boolean(b) = &lit.value {
1390 assert_eq!(*b, BooleanValue::True);
1391 } else {
1392 panic!("Expected literal boolean");
1393 }
1394 } else {
1395 panic!("Expected literal boolean");
1396 }
1397 }
1398}