1use crate::planning::{ExecutableRule, ExecutionPlan};
9use crate::{
10 ArithmeticComputation, BooleanValue, ComparisonComputation, ConversionTarget, Expression,
11 ExpressionKind, FactPath, LemmaResult, LiteralValue, MathematicalComputation, NegationType,
12 OperationResult, RulePath, Value,
13};
14use serde::ser::{Serialize, SerializeMap, Serializer};
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct World(HashMap<RulePath, usize>);
23
24impl World {
25 pub fn new() -> Self {
27 Self(HashMap::new())
28 }
29
30 pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
32 self.0.get(rule_path)
33 }
34
35 pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
37 self.0.insert(rule_path, branch_idx)
38 }
39
40 pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
42 self.0.iter()
43 }
44}
45
46impl Serialize for World {
47 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
48 where
49 S: Serializer,
50 {
51 let mut map = serializer.serialize_map(Some(self.0.len()))?;
52 for (k, v) in &self.0 {
53 map.serialize_entry(&k.to_string(), v)?;
54 }
55 map.end()
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct WorldSolution {
62 pub world: World,
64 pub constraint: Constraint,
66 pub outcome: OperationResult,
68}
69
70#[derive(Debug, Clone)]
76pub struct WorldArithmeticSolution {
77 pub world: World,
79 pub constraint: Constraint,
81 pub outcome_expression: Expression,
83}
84
85#[derive(Debug, Clone)]
87pub struct EnumerationResult {
88 pub literal_solutions: Vec<WorldSolution>,
90 pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
92}
93
94pub struct WorldEnumerator<'a> {
96 plan: &'a ExecutionPlan,
97 rules_in_order: Vec<RulePath>,
99 rule_cache: HashMap<RulePath, &'a ExecutableRule>,
101}
102
103impl<'a> WorldEnumerator<'a> {
104 pub fn new(plan: &'a ExecutionPlan, target_rule: &RulePath) -> LemmaResult<Self> {
106 let rule_map: HashMap<RulePath, &ExecutableRule> =
108 plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
109
110 let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
112
113 let rules_in_order: Vec<RulePath> = plan
115 .rules
116 .iter()
117 .filter(|r| dependent_rules.contains(&r.path))
118 .map(|r| r.path.clone())
119 .collect();
120
121 let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
123 .iter()
124 .filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
125 .collect();
126
127 Ok(Self {
128 plan,
129 rules_in_order,
130 rule_cache,
131 })
132 }
133
134 pub fn enumerate(
141 &mut self,
142 provided_facts: &HashSet<FactPath>,
143 ) -> LemmaResult<EnumerationResult> {
144 if self.rules_in_order.is_empty() {
145 return Ok(EnumerationResult {
146 literal_solutions: vec![],
147 arithmetic_solutions: vec![],
148 });
149 }
150
151 let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
153
154 for rule_path in &self.rules_in_order.clone() {
156 let rule_node = match self.rule_cache.get(rule_path) {
157 Some(node) => *node,
158 None => continue,
159 };
160
161 let mut next_worlds = Vec::new();
162
163 for (world, accumulated_constraint) in current_worlds {
164 for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
166 let mut new_world = world.clone();
168 new_world.insert(rule_path.clone(), branch_idx);
169
170 let mut branch_constraint = if let Some(ref condition) = branch.condition {
173 let substituted_condition = substitute_rules_in_expression(
175 &Arc::new(condition.clone()),
176 &new_world,
177 self.plan,
178 )?;
179 let hydrated_condition = hydrate_facts_in_expression(
180 &Arc::new(substituted_condition),
181 self.plan,
182 provided_facts,
183 )?;
184 Constraint::from_expression(&hydrated_condition)?
185 } else {
186 Constraint::True
188 };
189
190 for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
193 if let Some(ref later_condition) = later_branch.condition {
194 let substituted_later = substitute_rules_in_expression(
195 &Arc::new(later_condition.clone()),
196 &new_world,
197 self.plan,
198 )?;
199 let hydrated_later = hydrate_facts_in_expression(
200 &Arc::new(substituted_later),
201 self.plan,
202 provided_facts,
203 )?;
204 let later_constraint = Constraint::from_expression(&hydrated_later)?;
205 branch_constraint = branch_constraint.and(later_constraint.not());
207 }
208 }
209
210 let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
212
213 let simplified = combined_constraint.simplify()?;
215
216 if !simplified.is_false() {
218 next_worlds.push((new_world, simplified));
219 }
220 }
221 }
222
223 current_worlds = next_worlds;
224
225 if current_worlds.is_empty() {
227 break;
228 }
229 }
230
231 let target_rule_path = self
233 .rules_in_order
234 .last()
235 .unwrap_or_else(|| unreachable!("BUG: no rules in order for world enumeration"));
236
237 let mut literal_solutions = Vec::new();
238 let mut arithmetic_solutions = Vec::new();
239
240 for (world, constraint) in current_worlds {
241 if let Some(&branch_idx) = world.get(target_rule_path) {
243 if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
244 if branch_idx < rule_node.branches.len() {
245 let branch = &rule_node.branches[branch_idx];
246
247 let substituted_result = substitute_rules_in_expression(
249 &Arc::new(branch.result.clone()),
250 &world,
251 self.plan,
252 )?;
253
254 let hydrated_result = hydrate_facts_in_expression(
255 &Arc::new(substituted_result),
256 self.plan,
257 provided_facts,
258 )?;
259
260 let folded_result = try_constant_fold_expression(&hydrated_result)
262 .unwrap_or(hydrated_result.clone());
263
264 if let Some(outcome) = extract_outcome(&folded_result) {
266 literal_solutions.push(WorldSolution {
267 world,
268 constraint,
269 outcome,
270 });
271 } else if is_boolean_expression(&folded_result) {
272 let (true_solutions, false_solutions) =
275 create_boolean_expression_solutions(
276 world,
277 constraint,
278 &folded_result,
279 )?;
280 literal_solutions.extend(true_solutions);
281 literal_solutions.extend(false_solutions);
282 } else if is_arithmetic_expression(&folded_result) {
283 arithmetic_solutions.push(WorldArithmeticSolution {
285 world,
286 constraint,
287 outcome_expression: folded_result,
288 });
289 }
290 }
293 }
294 }
295 }
296
297 Ok(EnumerationResult {
298 literal_solutions,
299 arithmetic_solutions,
300 })
301 }
302}
303
304fn collect_transitive_dependencies(
310 target_rule: &RulePath,
311 rule_map: &HashMap<RulePath, &ExecutableRule>,
312) -> LemmaResult<HashSet<RulePath>> {
313 let mut result = HashSet::new();
314 let mut queue = VecDeque::new();
315
316 queue.push_back(target_rule.clone());
317 result.insert(target_rule.clone());
318
319 while let Some(rule_path) = queue.pop_front() {
320 if let Some(rule) = rule_map.get(&rule_path) {
321 let dependencies = extract_rule_dependencies(rule);
323 for dependency in dependencies {
324 if result.insert(dependency.clone()) {
325 queue.push_back(dependency);
326 }
327 }
328 }
329 }
330
331 Ok(result)
332}
333
334fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
336 let mut deps = HashSet::new();
337 for branch in &rule.branches {
338 if let Some(ref condition) = branch.condition {
339 extract_rule_paths_from_expression(condition, &mut deps);
340 }
341 extract_rule_paths_from_expression(&branch.result, &mut deps);
342 }
343 deps
344}
345
346fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
348 match &expr.kind {
349 ExpressionKind::RulePath(rp) => {
350 paths.insert(rp.clone());
351 }
352 ExpressionKind::Arithmetic(left, _, right)
353 | ExpressionKind::Comparison(left, _, right)
354 | ExpressionKind::LogicalAnd(left, right)
355 | ExpressionKind::LogicalOr(left, right) => {
356 extract_rule_paths_from_expression(left, paths);
357 extract_rule_paths_from_expression(right, paths);
358 }
359 ExpressionKind::LogicalNegation(inner, _)
360 | ExpressionKind::UnitConversion(inner, _)
361 | ExpressionKind::MathematicalComputation(_, inner) => {
362 extract_rule_paths_from_expression(inner, paths);
363 }
364 ExpressionKind::Literal(_)
365 | ExpressionKind::FactPath(_)
366 | ExpressionKind::Veto(_)
367 | ExpressionKind::Reference(_)
368 | ExpressionKind::UnresolvedUnitLiteral(_, _)
369 | ExpressionKind::FactReference(_)
370 | ExpressionKind::RuleReference(_) => {}
371 }
372}
373
374fn substitute_rules_in_expression(
383 expr: &Arc<Expression>,
384 world: &World,
385 plan: &ExecutionPlan,
386) -> LemmaResult<Expression> {
387 enum WorkItem {
388 Process(usize),
389 BuildArithmetic(ArithmeticComputation, Option<crate::Source>),
390 BuildComparison(ComparisonComputation, Option<crate::Source>),
391 BuildLogicalAnd(Option<crate::Source>),
392 BuildLogicalOr(Option<crate::Source>),
393 BuildLogicalNegation(NegationType, Option<crate::Source>),
394 BuildUnitConversion(ConversionTarget, Option<crate::Source>),
395 BuildMathematicalComputation(MathematicalComputation, Option<crate::Source>),
396 PopVisitedRules,
397 }
398
399 let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
400 let mut work_stack: Vec<WorkItem> = Vec::new();
401 let mut result_pool: Vec<Expression> = Vec::new();
402 let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
403
404 let root_idx = expr_pool.len();
405 expr_pool.push(Arc::clone(expr));
406 work_stack.push(WorkItem::Process(root_idx));
407
408 while let Some(work) = work_stack.pop() {
409 match work {
410 WorkItem::Process(expr_idx) => {
411 let e = &expr_pool[expr_idx];
412 let source_loc = e.source_location.clone();
413
414 match &e.kind {
415 ExpressionKind::RulePath(rule_path) => {
416 let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
417 if visited.contains(rule_path) {
418 unreachable!(
419 "BUG: circular rule reference detected during substitution: {}",
420 rule_path
421 );
422 }
423
424 if let Some(&branch_idx) = world.get(rule_path) {
425 if let Some(rule) = plan.get_rule_by_path(rule_path) {
426 if branch_idx < rule.branches.len() {
427 let branch = &rule.branches[branch_idx];
428 let mut new_visited = visited.clone();
429 new_visited.insert(rule_path.clone());
430 visited_rules_stack.push(new_visited);
431
432 let sub_expr_idx = expr_pool.len();
433 expr_pool.push(Arc::new(branch.result.clone()));
434 work_stack.push(WorkItem::PopVisitedRules);
435 work_stack.push(WorkItem::Process(sub_expr_idx));
436 continue;
437 }
438 }
439 }
440 result_pool.push(Expression::new(
441 ExpressionKind::RulePath(rule_path.clone()),
442 source_loc,
443 ));
444 }
445 ExpressionKind::RuleReference(_) => {
446 unreachable!(
447 "BUG: RuleReference found during substitution (should be RulePath)"
448 );
449 }
450 ExpressionKind::FactReference(_) => {
451 unreachable!(
452 "BUG: FactReference found during substitution (should be FactPath)"
453 );
454 }
455 ExpressionKind::Reference(_) => {
456 unreachable!(
457 "BUG: unresolved Reference found during substitution (should be resolved during planning)"
458 );
459 }
460 ExpressionKind::UnresolvedUnitLiteral(_, _) => {
461 unreachable!(
462 "UnresolvedUnitLiteral found during substitution - this is a bug: unresolved units should be resolved during planning"
463 );
464 }
465 ExpressionKind::Arithmetic(left, op, right) => {
466 let op_clone = op.clone();
467 let left_arc = Arc::clone(left);
468 let right_arc = Arc::clone(right);
469
470 let left_idx = expr_pool.len();
471 expr_pool.push(left_arc);
472 let right_idx = expr_pool.len();
473 expr_pool.push(right_arc);
474
475 work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
476 work_stack.push(WorkItem::Process(right_idx));
477 work_stack.push(WorkItem::Process(left_idx));
478 }
479 ExpressionKind::Comparison(left, op, right) => {
480 let op_clone = op.clone();
481 let left_arc = Arc::clone(left);
482 let right_arc = Arc::clone(right);
483
484 let left_idx = expr_pool.len();
485 expr_pool.push(left_arc);
486 let right_idx = expr_pool.len();
487 expr_pool.push(right_arc);
488
489 work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
490 work_stack.push(WorkItem::Process(right_idx));
491 work_stack.push(WorkItem::Process(left_idx));
492 }
493 ExpressionKind::LogicalAnd(left, right) => {
494 let left_arc = Arc::clone(left);
495 let right_arc = Arc::clone(right);
496
497 let left_idx = expr_pool.len();
498 expr_pool.push(left_arc);
499 let right_idx = expr_pool.len();
500 expr_pool.push(right_arc);
501
502 work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
503 work_stack.push(WorkItem::Process(right_idx));
504 work_stack.push(WorkItem::Process(left_idx));
505 }
506 ExpressionKind::LogicalOr(left, right) => {
507 let left_arc = Arc::clone(left);
508 let right_arc = Arc::clone(right);
509
510 let left_idx = expr_pool.len();
511 expr_pool.push(left_arc);
512 let right_idx = expr_pool.len();
513 expr_pool.push(right_arc);
514
515 work_stack.push(WorkItem::BuildLogicalOr(source_loc));
516 work_stack.push(WorkItem::Process(right_idx));
517 work_stack.push(WorkItem::Process(left_idx));
518 }
519 ExpressionKind::LogicalNegation(inner, neg_type) => {
520 let neg_type_clone = neg_type.clone();
521 let inner_arc = Arc::clone(inner);
522 let inner_idx = expr_pool.len();
523 expr_pool.push(inner_arc);
524 work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
525 work_stack.push(WorkItem::Process(inner_idx));
526 }
527 ExpressionKind::UnitConversion(inner, unit) => {
528 let unit_clone = unit.clone();
529 let inner_arc = Arc::clone(inner);
530 let inner_idx = expr_pool.len();
531 expr_pool.push(inner_arc);
532 work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
533 work_stack.push(WorkItem::Process(inner_idx));
534 }
535 ExpressionKind::MathematicalComputation(func, inner) => {
536 let func_clone = func.clone();
537 let inner_arc = Arc::clone(inner);
538 let inner_idx = expr_pool.len();
539 expr_pool.push(inner_arc);
540 work_stack.push(WorkItem::BuildMathematicalComputation(
541 func_clone, source_loc,
542 ));
543 work_stack.push(WorkItem::Process(inner_idx));
544 }
545 ExpressionKind::Literal(lit) => {
546 result_pool.push(Expression::new(
547 ExpressionKind::Literal(lit.clone()),
548 source_loc,
549 ));
550 }
551 ExpressionKind::FactPath(fact_path) => {
552 result_pool.push(Expression::new(
553 ExpressionKind::FactPath(fact_path.clone()),
554 source_loc,
555 ));
556 }
557 ExpressionKind::Veto(veto) => {
558 result_pool.push(Expression::new(
559 ExpressionKind::Veto(veto.clone()),
560 source_loc,
561 ));
562 }
563 }
564 }
565 WorkItem::BuildArithmetic(op, source_loc) => {
566 let right = result_pool.pop().unwrap_or_else(|| {
567 unreachable!(
568 "BUG: missing right expression for Arithmetic during inversion hydration"
569 )
570 });
571 let left = result_pool.pop().unwrap_or_else(|| {
572 unreachable!(
573 "BUG: missing left expression for Arithmetic during inversion hydration"
574 )
575 });
576 result_pool.push(Expression::new(
577 ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
578 source_loc,
579 ));
580 }
581 WorkItem::BuildComparison(op, source_loc) => {
582 let right = result_pool.pop().unwrap_or_else(|| {
583 unreachable!(
584 "BUG: missing right expression for Comparison during inversion hydration"
585 )
586 });
587 let left = result_pool.pop().unwrap_or_else(|| {
588 unreachable!(
589 "BUG: missing left expression for Comparison during inversion hydration"
590 )
591 });
592 result_pool.push(Expression::new(
593 ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
594 source_loc,
595 ));
596 }
597 WorkItem::BuildLogicalAnd(source_loc) => {
598 let right = result_pool.pop().unwrap_or_else(|| {
599 unreachable!(
600 "BUG: missing right expression for LogicalAnd during inversion hydration"
601 )
602 });
603 let left = result_pool.pop().unwrap_or_else(|| {
604 unreachable!(
605 "BUG: missing left expression for LogicalAnd during inversion hydration"
606 )
607 });
608 result_pool.push(Expression::new(
609 ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
610 source_loc,
611 ));
612 }
613 WorkItem::BuildLogicalOr(source_loc) => {
614 let right = result_pool.pop().unwrap_or_else(|| {
615 unreachable!(
616 "BUG: missing right expression for LogicalOr during inversion hydration"
617 )
618 });
619 let left = result_pool.pop().unwrap_or_else(|| {
620 unreachable!(
621 "BUG: missing left expression for LogicalOr during inversion hydration"
622 )
623 });
624 result_pool.push(Expression::new(
625 ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
626 source_loc,
627 ));
628 }
629 WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
630 let inner = result_pool
631 .pop()
632 .expect("Internal error: missing expression for LogicalNegation");
633 result_pool.push(Expression::new(
634 ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
635 source_loc,
636 ));
637 }
638 WorkItem::BuildUnitConversion(unit, source_loc) => {
639 let inner = result_pool
640 .pop()
641 .expect("Internal error: missing expression for UnitConversion");
642 result_pool.push(Expression::new(
643 ExpressionKind::UnitConversion(Arc::new(inner), unit),
644 source_loc,
645 ));
646 }
647 WorkItem::BuildMathematicalComputation(func, source_loc) => {
648 let inner = result_pool
649 .pop()
650 .expect("Internal error: missing expression for MathematicalComputation");
651 result_pool.push(Expression::new(
652 ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
653 source_loc,
654 ));
655 }
656 WorkItem::PopVisitedRules => {
657 visited_rules_stack.pop();
658 }
659 }
660 }
661
662 Ok(result_pool
663 .pop()
664 .unwrap_or_else(|| unreachable!("BUG: no result from substitution")))
665}
666
667fn hydrate_facts_in_expression(
676 expr: &Arc<Expression>,
677 plan: &ExecutionPlan,
678 provided_facts: &HashSet<FactPath>,
679) -> LemmaResult<Expression> {
680 enum WorkItem {
681 Process(usize),
682 BuildArithmetic(ArithmeticComputation, Option<crate::Source>),
683 BuildComparison(ComparisonComputation, Option<crate::Source>),
684 BuildLogicalAnd(Option<crate::Source>),
685 BuildLogicalOr(Option<crate::Source>),
686 BuildLogicalNegation(NegationType, Option<crate::Source>),
687 BuildUnitConversion(ConversionTarget, Option<crate::Source>),
688 BuildMathematicalComputation(MathematicalComputation, Option<crate::Source>),
689 }
690
691 let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
692 let mut work_stack: Vec<WorkItem> = Vec::new();
693 let mut result_pool: Vec<Expression> = Vec::new();
694
695 let root_idx = expr_pool.len();
696 expr_pool.push(Arc::clone(expr));
697 work_stack.push(WorkItem::Process(root_idx));
698
699 while let Some(work) = work_stack.pop() {
700 match work {
701 WorkItem::Process(expr_idx) => {
702 let (source_loc, expr_kind_ref) = {
703 let e = &expr_pool[expr_idx];
704 (e.source_location.clone(), &e.kind)
705 };
706
707 match expr_kind_ref {
708 ExpressionKind::FactPath(fact_path) => {
709 if provided_facts.contains(fact_path) {
710 if let Some(lit) = plan.fact_values.get(fact_path) {
711 result_pool.push(Expression::new(
712 ExpressionKind::Literal(lit.clone()),
713 source_loc,
714 ));
715 continue;
716 }
717 }
718 result_pool.push(Expression::new(
719 ExpressionKind::FactPath(fact_path.clone()),
720 source_loc,
721 ));
722 }
723 ExpressionKind::FactReference(_) => {
724 unreachable!(
725 "BUG: FactReference found during hydration (should be FactPath)"
726 );
727 }
728 ExpressionKind::RuleReference(_) => {
729 unreachable!(
730 "BUG: RuleReference found during hydration (should be RulePath)"
731 );
732 }
733 ExpressionKind::Reference(_) => {
734 unreachable!(
735 "BUG: unresolved Reference found during hydration (should be resolved during planning)"
736 );
737 }
738 ExpressionKind::UnresolvedUnitLiteral(_, _) => {
739 unreachable!(
740 "UnresolvedUnitLiteral found during hydration - this is a bug: unresolved units should be resolved during planning"
741 );
742 }
743 ExpressionKind::Arithmetic(left, op, right) => {
744 let op_clone = op.clone();
745 let left_arc = Arc::clone(left);
746 let right_arc = Arc::clone(right);
747
748 let left_idx = expr_pool.len();
749 expr_pool.push(left_arc);
750 let right_idx = expr_pool.len();
751 expr_pool.push(right_arc);
752
753 work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
754 work_stack.push(WorkItem::Process(right_idx));
755 work_stack.push(WorkItem::Process(left_idx));
756 }
757 ExpressionKind::Comparison(left, op, right) => {
758 let op_clone = op.clone();
759 let left_arc = Arc::clone(left);
760 let right_arc = Arc::clone(right);
761
762 let left_idx = expr_pool.len();
763 expr_pool.push(left_arc);
764 let right_idx = expr_pool.len();
765 expr_pool.push(right_arc);
766
767 work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
768 work_stack.push(WorkItem::Process(right_idx));
769 work_stack.push(WorkItem::Process(left_idx));
770 }
771 ExpressionKind::LogicalAnd(left, right) => {
772 let left_arc = Arc::clone(left);
773 let right_arc = Arc::clone(right);
774
775 let left_idx = expr_pool.len();
776 expr_pool.push(left_arc);
777 let right_idx = expr_pool.len();
778 expr_pool.push(right_arc);
779
780 work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
781 work_stack.push(WorkItem::Process(right_idx));
782 work_stack.push(WorkItem::Process(left_idx));
783 }
784 ExpressionKind::LogicalOr(left, right) => {
785 let left_arc = Arc::clone(left);
786 let right_arc = Arc::clone(right);
787
788 let left_idx = expr_pool.len();
789 expr_pool.push(left_arc);
790 let right_idx = expr_pool.len();
791 expr_pool.push(right_arc);
792
793 work_stack.push(WorkItem::BuildLogicalOr(source_loc));
794 work_stack.push(WorkItem::Process(right_idx));
795 work_stack.push(WorkItem::Process(left_idx));
796 }
797 ExpressionKind::LogicalNegation(inner, neg_type) => {
798 let neg_type_clone = neg_type.clone();
799 let inner_arc = Arc::clone(inner);
800 let inner_idx = expr_pool.len();
801 expr_pool.push(inner_arc);
802 work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
803 work_stack.push(WorkItem::Process(inner_idx));
804 }
805 ExpressionKind::UnitConversion(inner, unit) => {
806 let unit_clone = unit.clone();
807 let inner_arc = Arc::clone(inner);
808 let inner_idx = expr_pool.len();
809 expr_pool.push(inner_arc);
810 work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
811 work_stack.push(WorkItem::Process(inner_idx));
812 }
813 ExpressionKind::MathematicalComputation(func, inner) => {
814 let func_clone = func.clone();
815 let inner_arc = Arc::clone(inner);
816 let inner_idx = expr_pool.len();
817 expr_pool.push(inner_arc);
818 work_stack.push(WorkItem::BuildMathematicalComputation(
819 func_clone, source_loc,
820 ));
821 work_stack.push(WorkItem::Process(inner_idx));
822 }
823 ExpressionKind::Literal(lit) => {
824 result_pool.push(Expression::new(
825 ExpressionKind::Literal(lit.clone()),
826 source_loc,
827 ));
828 }
829 ExpressionKind::RulePath(rule_path) => {
830 result_pool.push(Expression::new(
831 ExpressionKind::RulePath(rule_path.clone()),
832 source_loc,
833 ));
834 }
835 ExpressionKind::Veto(veto) => {
836 result_pool.push(Expression::new(
837 ExpressionKind::Veto(veto.clone()),
838 source_loc,
839 ));
840 }
841 }
842 }
843 WorkItem::BuildArithmetic(op, source_loc) => {
844 let right = result_pool.pop().unwrap_or_else(|| {
845 unreachable!("BUG: missing right expression for Arithmetic")
846 });
847 let left = result_pool
848 .pop()
849 .unwrap_or_else(|| unreachable!("BUG: missing left expression for Arithmetic"));
850 result_pool.push(Expression::new(
851 ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
852 source_loc,
853 ));
854 }
855 WorkItem::BuildComparison(op, source_loc) => {
856 let right = result_pool.pop().unwrap_or_else(|| {
857 unreachable!("BUG: missing right expression for Comparison")
858 });
859 let left = result_pool
860 .pop()
861 .unwrap_or_else(|| unreachable!("BUG: missing left expression for Comparison"));
862 result_pool.push(Expression::new(
863 ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
864 source_loc,
865 ));
866 }
867 WorkItem::BuildLogicalAnd(source_loc) => {
868 let right = result_pool.pop().unwrap_or_else(|| {
869 unreachable!("BUG: missing right expression for LogicalAnd")
870 });
871 let left = result_pool
872 .pop()
873 .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalAnd"));
874 result_pool.push(Expression::new(
875 ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
876 source_loc,
877 ));
878 }
879 WorkItem::BuildLogicalOr(source_loc) => {
880 let right = result_pool
881 .pop()
882 .unwrap_or_else(|| unreachable!("BUG: missing right expression for LogicalOr"));
883 let left = result_pool
884 .pop()
885 .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalOr"));
886 result_pool.push(Expression::new(
887 ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
888 source_loc,
889 ));
890 }
891 WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
892 let inner = result_pool
893 .pop()
894 .expect("Internal error: missing expression for LogicalNegation");
895 result_pool.push(Expression::new(
896 ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
897 source_loc,
898 ));
899 }
900 WorkItem::BuildUnitConversion(unit, source_loc) => {
901 let inner = result_pool
902 .pop()
903 .expect("Internal error: missing expression for UnitConversion");
904 result_pool.push(Expression::new(
905 ExpressionKind::UnitConversion(Arc::new(inner), unit),
906 source_loc,
907 ));
908 }
909 WorkItem::BuildMathematicalComputation(func, source_loc) => {
910 let inner = result_pool
911 .pop()
912 .expect("Internal error: missing expression for MathematicalComputation");
913 result_pool.push(Expression::new(
914 ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
915 source_loc,
916 ));
917 }
918 }
919 }
920
921 Ok(result_pool
922 .pop()
923 .expect("Internal error: no result from hydration"))
924}
925
926fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
932 match &expr.kind {
933 ExpressionKind::Literal(lit) => Some(OperationResult::Value(lit.clone())),
934 ExpressionKind::Veto(ve) => Some(OperationResult::Veto(ve.message.clone())),
935 _ => None,
936 }
937}
938
939fn is_boolean_expression(expr: &Expression) -> bool {
941 matches!(
942 &expr.kind,
943 ExpressionKind::Comparison(_, _, _)
944 | ExpressionKind::LogicalAnd(_, _)
945 | ExpressionKind::LogicalOr(_, _)
946 | ExpressionKind::LogicalNegation(_, _)
947 )
948}
949
950fn is_arithmetic_expression(expr: &Expression) -> bool {
954 match &expr.kind {
955 ExpressionKind::Arithmetic(_, _, _) => true,
956 ExpressionKind::MathematicalComputation(_, _) => true,
957 ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
958 ExpressionKind::FactPath(_) => true, _ => false,
960 }
961}
962
963fn create_boolean_expression_solutions(
968 world: World,
969 base_constraint: Constraint,
970 boolean_expr: &Expression,
971) -> LemmaResult<(Vec<WorldSolution>, Vec<WorldSolution>)> {
972 let expr_constraint = Constraint::from_expression(boolean_expr)?;
974
975 let true_constraint = base_constraint.clone().and(expr_constraint.clone());
977 let simplified_true = true_constraint.simplify()?;
978
979 let true_solutions = if !simplified_true.is_false() {
980 vec![WorldSolution {
981 world: world.clone(),
982 constraint: simplified_true,
983 outcome: OperationResult::Value(LiteralValue::boolean(BooleanValue::True)),
984 }]
985 } else {
986 vec![]
987 };
988
989 let false_constraint = base_constraint.and(expr_constraint.not());
991 let simplified_false = false_constraint.simplify()?;
992
993 let false_solutions = if !simplified_false.is_false() {
994 vec![WorldSolution {
995 world,
996 constraint: simplified_false,
997 outcome: OperationResult::Value(LiteralValue::boolean(BooleanValue::False)),
998 }]
999 } else {
1000 vec![]
1001 };
1002
1003 Ok((true_solutions, false_solutions))
1004}
1005
1006pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
1008 match &expr.kind {
1009 ExpressionKind::Literal(_) => Some(expr.clone()),
1010 ExpressionKind::Arithmetic(left, op, right) => {
1011 let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1012 let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1013 if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1014 (&left_folded.kind, &right_folded.kind)
1015 {
1016 if let Some(result) = evaluate_arithmetic(left_val, op, right_val) {
1017 return Some(Expression::new(
1018 ExpressionKind::Literal(result),
1019 expr.source_location.clone(),
1020 ));
1021 }
1022 }
1023 Some(Expression::new(
1024 ExpressionKind::Arithmetic(
1025 Arc::new(left_folded),
1026 op.clone(),
1027 Arc::new(right_folded),
1028 ),
1029 expr.source_location.clone(),
1030 ))
1031 }
1032 ExpressionKind::Comparison(left, op, right) => {
1033 let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
1034 let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
1035 if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
1036 (&left_folded.kind, &right_folded.kind)
1037 {
1038 if let Some(result) = evaluate_comparison(left_val, op, right_val) {
1039 return Some(Expression::new(
1040 ExpressionKind::Literal(LiteralValue::boolean(result)),
1041 expr.source_location.clone(),
1042 ));
1043 }
1044 }
1045 Some(Expression::new(
1046 ExpressionKind::Comparison(
1047 Arc::new(left_folded),
1048 op.clone(),
1049 Arc::new(right_folded),
1050 ),
1051 expr.source_location.clone(),
1052 ))
1053 }
1054 _ => None,
1055 }
1056}
1057
1058fn evaluate_arithmetic(
1062 left: &LiteralValue,
1063 op: &ArithmeticComputation,
1064 right: &LiteralValue,
1065) -> Option<LiteralValue> {
1066 use crate::computation::arithmetic_operation;
1067
1068 match arithmetic_operation(left, op, right) {
1069 OperationResult::Value(lit) => Some(lit),
1070 OperationResult::Veto(_) => None,
1071 }
1072}
1073
1074fn evaluate_comparison(
1078 left: &LiteralValue,
1079 op: &ComparisonComputation,
1080 right: &LiteralValue,
1081) -> Option<BooleanValue> {
1082 use crate::computation::comparison_operation;
1083
1084 match comparison_operation(left, op, right) {
1085 OperationResult::Value(lit) => match &lit.value {
1086 Value::Boolean(b) => Some(b.clone()),
1087 _ => None,
1088 },
1089 _ => None,
1090 }
1091}
1092
1093#[cfg(test)]
1098mod tests {
1099 use super::*;
1100 use rust_decimal::Decimal;
1101
1102 fn literal_expr(val: LiteralValue) -> Expression {
1103 Expression::new(ExpressionKind::Literal(val), None)
1104 }
1105
1106 fn fact_expr(name: &str) -> Expression {
1107 Expression::new(
1108 ExpressionKind::FactPath(FactPath::local(name.to_string())),
1109 None,
1110 )
1111 }
1112
1113 fn num(n: i64) -> LiteralValue {
1114 LiteralValue::number(Decimal::from(n))
1115 }
1116
1117 #[test]
1118 fn test_world_new() {
1119 let world = World::new();
1120 assert!(world.0.is_empty());
1121 }
1122
1123 #[test]
1124 fn test_world_insert_and_get() {
1125 let mut world = World::new();
1126 let rule_path = RulePath {
1127 segments: vec![],
1128 rule: "test_rule".to_string(),
1129 };
1130 world.insert(rule_path.clone(), 2);
1131 assert_eq!(world.get(&rule_path), Some(&2));
1132 }
1133
1134 fn empty_plan() -> ExecutionPlan {
1135 ExecutionPlan {
1136 doc_name: "test".to_string(),
1137 fact_schema: HashMap::new(),
1138 fact_values: HashMap::new(),
1139 doc_refs: HashMap::new(),
1140 fact_sources: HashMap::new(),
1141 rules: Vec::new(),
1142 sources: HashMap::new(),
1143 }
1144 }
1145
1146 #[test]
1147 fn test_hydrate_literal_unchanged() {
1148 let plan = empty_plan();
1149 let provided: HashSet<FactPath> = HashSet::new();
1150
1151 let expr = literal_expr(num(42));
1152 let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1153
1154 if let ExpressionKind::Literal(lit) = &result.kind {
1155 assert!(matches!(&lit.value, Value::Number(_)));
1156 } else {
1157 panic!("Expected literal number");
1158 }
1159 }
1160
1161 #[test]
1162 fn test_hydrate_fact_not_provided() {
1163 let plan = empty_plan();
1164 let provided: HashSet<FactPath> = HashSet::new();
1165
1166 let expr = fact_expr("age");
1167 let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1168
1169 assert!(matches!(result.kind, ExpressionKind::FactPath(_)));
1170 }
1171
1172 #[test]
1173 fn test_constant_fold_arithmetic() {
1174 let left = literal_expr(num(10));
1175 let right = literal_expr(num(5));
1176 let expr = Expression::new(
1177 ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
1178 None,
1179 );
1180
1181 let folded = try_constant_fold_expression(&expr).unwrap();
1182
1183 if let ExpressionKind::Literal(lit) = &folded.kind {
1184 if let Value::Number(n) = &lit.value {
1185 assert_eq!(*n, Decimal::from(15));
1186 } else {
1187 panic!("Expected literal number");
1188 }
1189 } else {
1190 panic!("Expected literal number");
1191 }
1192 }
1193
1194 #[test]
1195 fn test_constant_fold_comparison() {
1196 let left = literal_expr(num(10));
1197 let right = literal_expr(num(5));
1198 let expr = Expression::new(
1199 ExpressionKind::Comparison(
1200 Arc::new(left),
1201 ComparisonComputation::GreaterThan,
1202 Arc::new(right),
1203 ),
1204 None,
1205 );
1206
1207 let folded = try_constant_fold_expression(&expr).unwrap();
1208
1209 if let ExpressionKind::Literal(lit) = &folded.kind {
1210 if let Value::Boolean(b) = &lit.value {
1211 assert_eq!(*b, BooleanValue::True);
1212 } else {
1213 panic!("Expected literal boolean");
1214 }
1215 } else {
1216 panic!("Expected literal boolean");
1217 }
1218 }
1219}