1use crate::planning::semantics::{
9 ArithmeticComputation, ComparisonComputation, Expression, ExpressionKind, FactPath,
10 LiteralValue, MathematicalComputation, NegationType, RulePath, SemanticConversionTarget,
11 Source,
12};
13use crate::planning::{ExecutableRule, ExecutionPlan};
14use crate::{LemmaResult, OperationResult};
15use serde::ser::{Serialize, SerializeMap, Serializer};
16use std::collections::{HashMap, HashSet, VecDeque};
17use std::sync::Arc;
18
19use super::constraint::Constraint;
20
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
23pub struct World(HashMap<RulePath, usize>);
24
25impl World {
26 pub fn new() -> Self {
28 Self(HashMap::new())
29 }
30
31 pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
33 self.0.get(rule_path)
34 }
35
36 pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
38 self.0.insert(rule_path, branch_idx)
39 }
40
41 pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
43 self.0.iter()
44 }
45}
46
47impl Serialize for World {
48 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49 where
50 S: Serializer,
51 {
52 let mut map = serializer.serialize_map(Some(self.0.len()))?;
53 for (k, v) in &self.0 {
54 map.serialize_entry(&k.to_string(), v)?;
55 }
56 map.end()
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct WorldSolution {
63 pub world: World,
65 pub constraint: Constraint,
67 pub outcome: OperationResult,
69}
70
71#[derive(Debug, Clone)]
77pub struct WorldArithmeticSolution {
78 pub world: World,
80 pub constraint: Constraint,
82 pub outcome_expression: Expression,
84}
85
86#[derive(Debug, Clone)]
88pub struct EnumerationResult {
89 pub literal_solutions: Vec<WorldSolution>,
91 pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
93}
94
95pub struct WorldEnumerator<'a> {
97 plan: &'a ExecutionPlan,
98 rules_in_order: Vec<RulePath>,
100 rule_cache: HashMap<RulePath, &'a ExecutableRule>,
102}
103
104impl<'a> WorldEnumerator<'a> {
105 pub fn new(plan: &'a ExecutionPlan, target_rule: &RulePath) -> LemmaResult<Self> {
107 let rule_map: HashMap<RulePath, &ExecutableRule> =
109 plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
110
111 let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
113
114 let rules_in_order: Vec<RulePath> = plan
116 .rules
117 .iter()
118 .filter(|r| dependent_rules.contains(&r.path))
119 .map(|r| r.path.clone())
120 .collect();
121
122 let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
124 .iter()
125 .filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
126 .collect();
127
128 Ok(Self {
129 plan,
130 rules_in_order,
131 rule_cache,
132 })
133 }
134
135 pub fn enumerate(
142 &mut self,
143 provided_facts: &HashSet<FactPath>,
144 ) -> LemmaResult<EnumerationResult> {
145 if self.rules_in_order.is_empty() {
146 return Ok(EnumerationResult {
147 literal_solutions: vec![],
148 arithmetic_solutions: vec![],
149 });
150 }
151
152 let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
154
155 for rule_path in &self.rules_in_order.clone() {
157 let rule_node = match self.rule_cache.get(rule_path) {
158 Some(node) => *node,
159 None => continue,
160 };
161
162 let mut next_worlds = Vec::new();
163
164 for (world, accumulated_constraint) in current_worlds {
165 for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
167 let mut new_world = world.clone();
169 new_world.insert(rule_path.clone(), branch_idx);
170
171 let mut branch_constraint = if let Some(ref condition) = branch.condition {
174 let substituted_condition = substitute_rules_in_expression(
176 &Arc::new(condition.clone()),
177 &new_world,
178 self.plan,
179 )?;
180 let hydrated_condition = hydrate_facts_in_expression(
181 &Arc::new(substituted_condition),
182 self.plan,
183 provided_facts,
184 )?;
185 Constraint::from_expression(&hydrated_condition)?
186 } else {
187 Constraint::True
189 };
190
191 for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
194 if let Some(ref later_condition) = later_branch.condition {
195 let substituted_later = substitute_rules_in_expression(
196 &Arc::new(later_condition.clone()),
197 &new_world,
198 self.plan,
199 )?;
200 let hydrated_later = hydrate_facts_in_expression(
201 &Arc::new(substituted_later),
202 self.plan,
203 provided_facts,
204 )?;
205 let later_constraint = Constraint::from_expression(&hydrated_later)?;
206 branch_constraint = branch_constraint.and(later_constraint.not());
208 }
209 }
210
211 let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
213
214 let simplified = combined_constraint.simplify()?;
216
217 if !simplified.is_false() {
219 next_worlds.push((new_world, simplified));
220 }
221 }
222 }
223
224 current_worlds = next_worlds;
225
226 if current_worlds.is_empty() {
228 break;
229 }
230 }
231
232 let target_rule_path = self
234 .rules_in_order
235 .last()
236 .unwrap_or_else(|| unreachable!("BUG: no rules in order for world enumeration"));
237
238 let mut literal_solutions = Vec::new();
239 let mut arithmetic_solutions = Vec::new();
240
241 for (world, constraint) in current_worlds {
242 if let Some(&branch_idx) = world.get(target_rule_path) {
244 if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
245 if branch_idx < rule_node.branches.len() {
246 let branch = &rule_node.branches[branch_idx];
247
248 let substituted_result = substitute_rules_in_expression(
250 &Arc::new(branch.result.clone()),
251 &world,
252 self.plan,
253 )?;
254
255 let hydrated_result = hydrate_facts_in_expression(
256 &Arc::new(substituted_result),
257 self.plan,
258 provided_facts,
259 )?;
260
261 let folded_result = try_constant_fold_expression(&hydrated_result)
263 .unwrap_or(hydrated_result.clone());
264
265 if let Some(outcome) = extract_outcome(&folded_result) {
267 literal_solutions.push(WorldSolution {
268 world,
269 constraint,
270 outcome,
271 });
272 } else if is_boolean_expression(&folded_result) {
273 let (true_solutions, false_solutions) =
276 create_boolean_expression_solutions(
277 world,
278 constraint,
279 &folded_result,
280 )?;
281 literal_solutions.extend(true_solutions);
282 literal_solutions.extend(false_solutions);
283 } else if is_arithmetic_expression(&folded_result) {
284 arithmetic_solutions.push(WorldArithmeticSolution {
286 world,
287 constraint,
288 outcome_expression: folded_result,
289 });
290 }
291 }
294 }
295 }
296 }
297
298 Ok(EnumerationResult {
299 literal_solutions,
300 arithmetic_solutions,
301 })
302 }
303}
304
305fn collect_transitive_dependencies(
311 target_rule: &RulePath,
312 rule_map: &HashMap<RulePath, &ExecutableRule>,
313) -> LemmaResult<HashSet<RulePath>> {
314 let mut result = HashSet::new();
315 let mut queue = VecDeque::new();
316
317 queue.push_back(target_rule.clone());
318 result.insert(target_rule.clone());
319
320 while let Some(rule_path) = queue.pop_front() {
321 if let Some(rule) = rule_map.get(&rule_path) {
322 let dependencies = extract_rule_dependencies(rule);
324 for dependency in dependencies {
325 if result.insert(dependency.clone()) {
326 queue.push_back(dependency);
327 }
328 }
329 }
330 }
331
332 Ok(result)
333}
334
335fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
337 let mut deps = HashSet::new();
338 for branch in &rule.branches {
339 if let Some(ref condition) = branch.condition {
340 extract_rule_paths_from_expression(condition, &mut deps);
341 }
342 extract_rule_paths_from_expression(&branch.result, &mut deps);
343 }
344 deps
345}
346
347fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
349 match &expr.kind {
350 ExpressionKind::RulePath(rp) => {
351 paths.insert(rp.clone());
352 }
353 ExpressionKind::Arithmetic(left, _, right)
354 | ExpressionKind::Comparison(left, _, right)
355 | ExpressionKind::LogicalAnd(left, right)
356 | ExpressionKind::LogicalOr(left, right) => {
357 extract_rule_paths_from_expression(left, paths);
358 extract_rule_paths_from_expression(right, paths);
359 }
360 ExpressionKind::LogicalNegation(inner, _)
361 | ExpressionKind::UnitConversion(inner, _)
362 | ExpressionKind::MathematicalComputation(_, inner) => {
363 extract_rule_paths_from_expression(inner, paths);
364 }
365 ExpressionKind::Literal(_) | ExpressionKind::FactPath(_) | ExpressionKind::Veto(_) => {}
366 }
367}
368
369fn substitute_rules_in_expression(
378 expr: &Arc<Expression>,
379 world: &World,
380 plan: &ExecutionPlan,
381) -> LemmaResult<Expression> {
382 enum WorkItem {
383 Process(usize),
384 BuildArithmetic(ArithmeticComputation, Option<Source>),
385 BuildComparison(ComparisonComputation, Option<Source>),
386 BuildLogicalAnd(Option<Source>),
387 BuildLogicalOr(Option<Source>),
388 BuildLogicalNegation(NegationType, Option<Source>),
389 BuildUnitConversion(SemanticConversionTarget, Option<Source>),
390 BuildMathematicalComputation(MathematicalComputation, Option<Source>),
391 PopVisitedRules,
392 }
393
394 let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
395 let mut work_stack: Vec<WorkItem> = Vec::new();
396 let mut result_pool: Vec<Expression> = Vec::new();
397 let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
398
399 let root_idx = expr_pool.len();
400 expr_pool.push(Arc::clone(expr));
401 work_stack.push(WorkItem::Process(root_idx));
402
403 while let Some(work) = work_stack.pop() {
404 match work {
405 WorkItem::Process(expr_idx) => {
406 let e = &expr_pool[expr_idx];
407 let source_loc = e.source_location.clone();
408
409 match &e.kind {
410 ExpressionKind::RulePath(rule_path) => {
411 let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
412 if visited.contains(rule_path) {
413 unreachable!(
414 "BUG: circular rule reference detected during substitution: {}",
415 rule_path
416 );
417 }
418
419 if let Some(&branch_idx) = world.get(rule_path) {
420 if let Some(rule) = plan.get_rule_by_path(rule_path) {
421 if branch_idx < rule.branches.len() {
422 let branch = &rule.branches[branch_idx];
423 let mut new_visited = visited.clone();
424 new_visited.insert(rule_path.clone());
425 visited_rules_stack.push(new_visited);
426
427 let sub_expr_idx = expr_pool.len();
428 expr_pool.push(Arc::new(branch.result.clone()));
429 work_stack.push(WorkItem::PopVisitedRules);
430 work_stack.push(WorkItem::Process(sub_expr_idx));
431 continue;
432 }
433 }
434 }
435 result_pool.push(Expression::with_source(
436 ExpressionKind::RulePath(rule_path.clone()),
437 source_loc,
438 ));
439 }
440 ExpressionKind::Arithmetic(left, op, right) => {
441 let op_clone = op.clone();
442 let left_arc = Arc::clone(left);
443 let right_arc = Arc::clone(right);
444
445 let left_idx = expr_pool.len();
446 expr_pool.push(left_arc);
447 let right_idx = expr_pool.len();
448 expr_pool.push(right_arc);
449
450 work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
451 work_stack.push(WorkItem::Process(right_idx));
452 work_stack.push(WorkItem::Process(left_idx));
453 }
454 ExpressionKind::Comparison(left, op, right) => {
455 let op_clone = op.clone();
456 let left_arc = Arc::clone(left);
457 let right_arc = Arc::clone(right);
458
459 let left_idx = expr_pool.len();
460 expr_pool.push(left_arc);
461 let right_idx = expr_pool.len();
462 expr_pool.push(right_arc);
463
464 work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
465 work_stack.push(WorkItem::Process(right_idx));
466 work_stack.push(WorkItem::Process(left_idx));
467 }
468 ExpressionKind::LogicalAnd(left, right) => {
469 let left_arc = Arc::clone(left);
470 let right_arc = Arc::clone(right);
471
472 let left_idx = expr_pool.len();
473 expr_pool.push(left_arc);
474 let right_idx = expr_pool.len();
475 expr_pool.push(right_arc);
476
477 work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
478 work_stack.push(WorkItem::Process(right_idx));
479 work_stack.push(WorkItem::Process(left_idx));
480 }
481 ExpressionKind::LogicalOr(left, right) => {
482 let left_arc = Arc::clone(left);
483 let right_arc = Arc::clone(right);
484
485 let left_idx = expr_pool.len();
486 expr_pool.push(left_arc);
487 let right_idx = expr_pool.len();
488 expr_pool.push(right_arc);
489
490 work_stack.push(WorkItem::BuildLogicalOr(source_loc));
491 work_stack.push(WorkItem::Process(right_idx));
492 work_stack.push(WorkItem::Process(left_idx));
493 }
494 ExpressionKind::LogicalNegation(inner, neg_type) => {
495 let neg_type_clone = neg_type.clone();
496 let inner_arc = Arc::clone(inner);
497 let inner_idx = expr_pool.len();
498 expr_pool.push(inner_arc);
499 work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
500 work_stack.push(WorkItem::Process(inner_idx));
501 }
502 ExpressionKind::UnitConversion(inner, unit) => {
503 let unit_clone = unit.clone();
504 let inner_arc = Arc::clone(inner);
505 let inner_idx = expr_pool.len();
506 expr_pool.push(inner_arc);
507 work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
508 work_stack.push(WorkItem::Process(inner_idx));
509 }
510 ExpressionKind::MathematicalComputation(func, inner) => {
511 let func_clone = func.clone();
512 let inner_arc = Arc::clone(inner);
513 let inner_idx = expr_pool.len();
514 expr_pool.push(inner_arc);
515 work_stack.push(WorkItem::BuildMathematicalComputation(
516 func_clone, source_loc,
517 ));
518 work_stack.push(WorkItem::Process(inner_idx));
519 }
520 ExpressionKind::Literal(lit) => {
521 result_pool.push(Expression::with_source(
522 ExpressionKind::Literal(lit.clone()),
523 source_loc,
524 ));
525 }
526 ExpressionKind::FactPath(fact_path) => {
527 result_pool.push(Expression::with_source(
528 ExpressionKind::FactPath(fact_path.clone()),
529 source_loc,
530 ));
531 }
532 ExpressionKind::Veto(veto) => {
533 result_pool.push(Expression::with_source(
534 ExpressionKind::Veto(veto.clone()),
535 source_loc,
536 ));
537 }
538 }
539 }
540 WorkItem::BuildArithmetic(op, source_loc) => {
541 let right = result_pool.pop().unwrap_or_else(|| {
542 unreachable!(
543 "BUG: missing right expression for Arithmetic during inversion hydration"
544 )
545 });
546 let left = result_pool.pop().unwrap_or_else(|| {
547 unreachable!(
548 "BUG: missing left expression for Arithmetic during inversion hydration"
549 )
550 });
551 result_pool.push(Expression::with_source(
552 ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
553 source_loc,
554 ));
555 }
556 WorkItem::BuildComparison(op, source_loc) => {
557 let right = result_pool.pop().unwrap_or_else(|| {
558 unreachable!(
559 "BUG: missing right expression for Comparison during inversion hydration"
560 )
561 });
562 let left = result_pool.pop().unwrap_or_else(|| {
563 unreachable!(
564 "BUG: missing left expression for Comparison during inversion hydration"
565 )
566 });
567 result_pool.push(Expression::with_source(
568 ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
569 source_loc,
570 ));
571 }
572 WorkItem::BuildLogicalAnd(source_loc) => {
573 let right = result_pool.pop().unwrap_or_else(|| {
574 unreachable!(
575 "BUG: missing right expression for LogicalAnd during inversion hydration"
576 )
577 });
578 let left = result_pool.pop().unwrap_or_else(|| {
579 unreachable!(
580 "BUG: missing left expression for LogicalAnd during inversion hydration"
581 )
582 });
583 result_pool.push(Expression::with_source(
584 ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
585 source_loc,
586 ));
587 }
588 WorkItem::BuildLogicalOr(source_loc) => {
589 let right = result_pool.pop().unwrap_or_else(|| {
590 unreachable!(
591 "BUG: missing right expression for LogicalOr during inversion hydration"
592 )
593 });
594 let left = result_pool.pop().unwrap_or_else(|| {
595 unreachable!(
596 "BUG: missing left expression for LogicalOr during inversion hydration"
597 )
598 });
599 result_pool.push(Expression::with_source(
600 ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
601 source_loc,
602 ));
603 }
604 WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
605 let inner = result_pool
606 .pop()
607 .expect("Internal error: missing expression for LogicalNegation");
608 result_pool.push(Expression::with_source(
609 ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
610 source_loc,
611 ));
612 }
613 WorkItem::BuildUnitConversion(unit, source_loc) => {
614 let inner = result_pool
615 .pop()
616 .expect("Internal error: missing expression for UnitConversion");
617 result_pool.push(Expression::with_source(
618 ExpressionKind::UnitConversion(Arc::new(inner), unit),
619 source_loc,
620 ));
621 }
622 WorkItem::BuildMathematicalComputation(func, source_loc) => {
623 let inner = result_pool
624 .pop()
625 .expect("Internal error: missing expression for MathematicalComputation");
626 result_pool.push(Expression::with_source(
627 ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
628 source_loc,
629 ));
630 }
631 WorkItem::PopVisitedRules => {
632 visited_rules_stack.pop();
633 }
634 }
635 }
636
637 Ok(result_pool
638 .pop()
639 .unwrap_or_else(|| unreachable!("BUG: no result from substitution")))
640}
641
642fn hydrate_facts_in_expression(
651 expr: &Arc<Expression>,
652 plan: &ExecutionPlan,
653 provided_facts: &HashSet<FactPath>,
654) -> LemmaResult<Expression> {
655 enum WorkItem {
656 Process(usize),
657 BuildArithmetic(ArithmeticComputation, Option<Source>),
658 BuildComparison(ComparisonComputation, Option<Source>),
659 BuildLogicalAnd(Option<Source>),
660 BuildLogicalOr(Option<Source>),
661 BuildLogicalNegation(NegationType, Option<Source>),
662 BuildUnitConversion(SemanticConversionTarget, Option<Source>),
663 BuildMathematicalComputation(MathematicalComputation, Option<Source>),
664 }
665
666 let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
667 let mut work_stack: Vec<WorkItem> = Vec::new();
668 let mut result_pool: Vec<Expression> = Vec::new();
669
670 let root_idx = expr_pool.len();
671 expr_pool.push(Arc::clone(expr));
672 work_stack.push(WorkItem::Process(root_idx));
673
674 while let Some(work) = work_stack.pop() {
675 match work {
676 WorkItem::Process(expr_idx) => {
677 let (source_loc, expr_kind_ref) = {
678 let e = &expr_pool[expr_idx];
679 (e.source_location.clone(), &e.kind)
680 };
681
682 match expr_kind_ref {
683 ExpressionKind::FactPath(fact_path) => {
684 if provided_facts.contains(fact_path) {
685 if let Some(lit) = plan.facts.get(fact_path).and_then(|d| d.value()) {
686 result_pool.push(Expression::with_source(
687 ExpressionKind::Literal(Box::new(lit.clone())),
688 source_loc,
689 ));
690 continue;
691 }
692 }
693 result_pool.push(Expression::with_source(
694 ExpressionKind::FactPath(fact_path.clone()),
695 source_loc,
696 ));
697 }
698 ExpressionKind::Arithmetic(left, op, right) => {
699 let op_clone = op.clone();
700 let left_arc = Arc::clone(left);
701 let right_arc = Arc::clone(right);
702
703 let left_idx = expr_pool.len();
704 expr_pool.push(left_arc);
705 let right_idx = expr_pool.len();
706 expr_pool.push(right_arc);
707
708 work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
709 work_stack.push(WorkItem::Process(right_idx));
710 work_stack.push(WorkItem::Process(left_idx));
711 }
712 ExpressionKind::Comparison(left, op, right) => {
713 let op_clone = op.clone();
714 let left_arc = Arc::clone(left);
715 let right_arc = Arc::clone(right);
716
717 let left_idx = expr_pool.len();
718 expr_pool.push(left_arc);
719 let right_idx = expr_pool.len();
720 expr_pool.push(right_arc);
721
722 work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
723 work_stack.push(WorkItem::Process(right_idx));
724 work_stack.push(WorkItem::Process(left_idx));
725 }
726 ExpressionKind::LogicalAnd(left, right) => {
727 let left_arc = Arc::clone(left);
728 let right_arc = Arc::clone(right);
729
730 let left_idx = expr_pool.len();
731 expr_pool.push(left_arc);
732 let right_idx = expr_pool.len();
733 expr_pool.push(right_arc);
734
735 work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
736 work_stack.push(WorkItem::Process(right_idx));
737 work_stack.push(WorkItem::Process(left_idx));
738 }
739 ExpressionKind::LogicalOr(left, right) => {
740 let left_arc = Arc::clone(left);
741 let right_arc = Arc::clone(right);
742
743 let left_idx = expr_pool.len();
744 expr_pool.push(left_arc);
745 let right_idx = expr_pool.len();
746 expr_pool.push(right_arc);
747
748 work_stack.push(WorkItem::BuildLogicalOr(source_loc));
749 work_stack.push(WorkItem::Process(right_idx));
750 work_stack.push(WorkItem::Process(left_idx));
751 }
752 ExpressionKind::LogicalNegation(inner, neg_type) => {
753 let neg_type_clone = neg_type.clone();
754 let inner_arc = Arc::clone(inner);
755 let inner_idx = expr_pool.len();
756 expr_pool.push(inner_arc);
757 work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
758 work_stack.push(WorkItem::Process(inner_idx));
759 }
760 ExpressionKind::UnitConversion(inner, unit) => {
761 let unit_clone = unit.clone();
762 let inner_arc = Arc::clone(inner);
763 let inner_idx = expr_pool.len();
764 expr_pool.push(inner_arc);
765 work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
766 work_stack.push(WorkItem::Process(inner_idx));
767 }
768 ExpressionKind::MathematicalComputation(func, inner) => {
769 let func_clone = func.clone();
770 let inner_arc = Arc::clone(inner);
771 let inner_idx = expr_pool.len();
772 expr_pool.push(inner_arc);
773 work_stack.push(WorkItem::BuildMathematicalComputation(
774 func_clone, source_loc,
775 ));
776 work_stack.push(WorkItem::Process(inner_idx));
777 }
778 ExpressionKind::Literal(lit) => {
779 result_pool.push(Expression::with_source(
780 ExpressionKind::Literal(lit.clone()),
781 source_loc,
782 ));
783 }
784 ExpressionKind::RulePath(rule_path) => {
785 result_pool.push(Expression::with_source(
786 ExpressionKind::RulePath(rule_path.clone()),
787 source_loc,
788 ));
789 }
790 ExpressionKind::Veto(veto) => {
791 result_pool.push(Expression::with_source(
792 ExpressionKind::Veto(veto.clone()),
793 source_loc,
794 ));
795 }
796 }
797 }
798 WorkItem::BuildArithmetic(op, source_loc) => {
799 let right = result_pool.pop().unwrap_or_else(|| {
800 unreachable!("BUG: missing right expression for Arithmetic")
801 });
802 let left = result_pool
803 .pop()
804 .unwrap_or_else(|| unreachable!("BUG: missing left expression for Arithmetic"));
805 result_pool.push(Expression::with_source(
806 ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
807 source_loc,
808 ));
809 }
810 WorkItem::BuildComparison(op, source_loc) => {
811 let right = result_pool.pop().unwrap_or_else(|| {
812 unreachable!("BUG: missing right expression for Comparison")
813 });
814 let left = result_pool
815 .pop()
816 .unwrap_or_else(|| unreachable!("BUG: missing left expression for Comparison"));
817 result_pool.push(Expression::with_source(
818 ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
819 source_loc,
820 ));
821 }
822 WorkItem::BuildLogicalAnd(source_loc) => {
823 let right = result_pool.pop().unwrap_or_else(|| {
824 unreachable!("BUG: missing right expression for LogicalAnd")
825 });
826 let left = result_pool
827 .pop()
828 .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalAnd"));
829 result_pool.push(Expression::with_source(
830 ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
831 source_loc,
832 ));
833 }
834 WorkItem::BuildLogicalOr(source_loc) => {
835 let right = result_pool
836 .pop()
837 .unwrap_or_else(|| unreachable!("BUG: missing right expression for LogicalOr"));
838 let left = result_pool
839 .pop()
840 .unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalOr"));
841 result_pool.push(Expression::with_source(
842 ExpressionKind::LogicalOr(Arc::new(left), Arc::new(right)),
843 source_loc,
844 ));
845 }
846 WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
847 let inner = result_pool
848 .pop()
849 .expect("Internal error: missing expression for LogicalNegation");
850 result_pool.push(Expression::with_source(
851 ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
852 source_loc,
853 ));
854 }
855 WorkItem::BuildUnitConversion(unit, source_loc) => {
856 let inner = result_pool
857 .pop()
858 .expect("Internal error: missing expression for UnitConversion");
859 result_pool.push(Expression::with_source(
860 ExpressionKind::UnitConversion(Arc::new(inner), unit),
861 source_loc,
862 ));
863 }
864 WorkItem::BuildMathematicalComputation(func, source_loc) => {
865 let inner = result_pool
866 .pop()
867 .expect("Internal error: missing expression for MathematicalComputation");
868 result_pool.push(Expression::with_source(
869 ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
870 source_loc,
871 ));
872 }
873 }
874 }
875
876 Ok(result_pool
877 .pop()
878 .expect("Internal error: no result from hydration"))
879}
880
881fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
887 match &expr.kind {
888 ExpressionKind::Literal(lit) => {
889 Some(OperationResult::Value(Box::new(lit.as_ref().clone())))
890 }
891 ExpressionKind::Veto(ve) => Some(OperationResult::Veto(ve.message.clone())),
892 _ => None,
893 }
894}
895
896fn is_boolean_expression(expr: &Expression) -> bool {
898 matches!(
899 &expr.kind,
900 ExpressionKind::Comparison(_, _, _)
901 | ExpressionKind::LogicalAnd(_, _)
902 | ExpressionKind::LogicalOr(_, _)
903 | ExpressionKind::LogicalNegation(_, _)
904 )
905}
906
907fn is_arithmetic_expression(expr: &Expression) -> bool {
911 match &expr.kind {
912 ExpressionKind::Arithmetic(_, _, _) => true,
913 ExpressionKind::MathematicalComputation(_, _) => true,
914 ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
915 ExpressionKind::FactPath(_) => true, _ => false,
917 }
918}
919
920fn create_boolean_expression_solutions(
925 world: World,
926 base_constraint: Constraint,
927 boolean_expr: &Expression,
928) -> LemmaResult<(Vec<WorldSolution>, Vec<WorldSolution>)> {
929 let expr_constraint = Constraint::from_expression(boolean_expr)?;
931
932 let true_constraint = base_constraint.clone().and(expr_constraint.clone());
934 let simplified_true = true_constraint.simplify()?;
935
936 let true_solutions = if !simplified_true.is_false() {
937 vec![WorldSolution {
938 world: world.clone(),
939 constraint: simplified_true,
940 outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(true))),
941 }]
942 } else {
943 vec![]
944 };
945
946 let false_constraint = base_constraint.and(expr_constraint.not());
948 let simplified_false = false_constraint.simplify()?;
949
950 let false_solutions = if !simplified_false.is_false() {
951 vec![WorldSolution {
952 world,
953 constraint: simplified_false,
954 outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(false))),
955 }]
956 } else {
957 vec![]
958 };
959
960 Ok((true_solutions, false_solutions))
961}
962
963pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
965 match &expr.kind {
966 ExpressionKind::Literal(_) => Some(expr.clone()),
967 ExpressionKind::Arithmetic(left, op, right) => {
968 let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
969 let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
970 if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
971 (&left_folded.kind, &right_folded.kind)
972 {
973 if let Some(result) = evaluate_arithmetic(left_val.as_ref(), op, right_val.as_ref())
974 {
975 return Some(Expression::with_source(
976 ExpressionKind::Literal(Box::new(result)),
977 expr.source_location.clone(),
978 ));
979 }
980 }
981 Some(Expression::with_source(
982 ExpressionKind::Arithmetic(
983 Arc::new(left_folded),
984 op.clone(),
985 Arc::new(right_folded),
986 ),
987 expr.source_location.clone(),
988 ))
989 }
990 ExpressionKind::Comparison(left, op, right) => {
991 let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
992 let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
993 if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
994 (&left_folded.kind, &right_folded.kind)
995 {
996 if let Some(result) = evaluate_comparison(left_val.as_ref(), op, right_val.as_ref())
997 {
998 return Some(Expression::with_source(
999 ExpressionKind::Literal(Box::new(LiteralValue::from_bool(result))),
1000 expr.source_location.clone(),
1001 ));
1002 }
1003 }
1004 Some(Expression::with_source(
1005 ExpressionKind::Comparison(
1006 Arc::new(left_folded),
1007 op.clone(),
1008 Arc::new(right_folded),
1009 ),
1010 expr.source_location.clone(),
1011 ))
1012 }
1013 _ => None,
1014 }
1015}
1016
1017fn evaluate_arithmetic(
1021 left: &LiteralValue,
1022 op: &ArithmeticComputation,
1023 right: &LiteralValue,
1024) -> Option<LiteralValue> {
1025 use crate::computation::arithmetic_operation;
1026
1027 match arithmetic_operation(left, op, right) {
1028 OperationResult::Value(lit) => Some(lit.as_ref().clone()),
1029 OperationResult::Veto(_) => None,
1030 }
1031}
1032
1033fn evaluate_comparison(
1037 left: &LiteralValue,
1038 op: &ComparisonComputation,
1039 right: &LiteralValue,
1040) -> Option<bool> {
1041 use crate::computation::comparison_operation;
1042 use crate::planning::semantics::ValueKind;
1043
1044 match comparison_operation(left, op, right) {
1045 OperationResult::Value(lit) => match &lit.value {
1046 ValueKind::Boolean(b) => Some(*b),
1047 _ => None,
1048 },
1049 _ => None,
1050 }
1051}
1052
1053#[cfg(test)]
1058mod tests {
1059 use super::*;
1060 use crate::planning::semantics::ValueKind;
1061 use rust_decimal::Decimal;
1062
1063 fn literal_expr(val: LiteralValue) -> Expression {
1064 Expression::with_source(ExpressionKind::Literal(Box::new(val)), None)
1065 }
1066
1067 fn fact_expr(name: &str) -> Expression {
1068 Expression::with_source(
1069 ExpressionKind::FactPath(FactPath::new(vec![], name.to_string())),
1070 None,
1071 )
1072 }
1073
1074 fn num(n: i64) -> LiteralValue {
1075 LiteralValue::number(Decimal::from(n))
1076 }
1077
1078 #[test]
1079 fn test_world_new() {
1080 let world = World::new();
1081 assert!(world.0.is_empty());
1082 }
1083
1084 #[test]
1085 fn test_world_insert_and_get() {
1086 let mut world = World::new();
1087 let rule_path = RulePath {
1088 segments: vec![],
1089 rule: "test_rule".to_string(),
1090 };
1091 world.insert(rule_path.clone(), 2);
1092 assert_eq!(world.get(&rule_path), Some(&2));
1093 }
1094
1095 fn empty_plan() -> ExecutionPlan {
1096 ExecutionPlan {
1097 doc_name: "test".to_string(),
1098 facts: HashMap::new(),
1099 rules: Vec::new(),
1100 sources: HashMap::new(),
1101 }
1102 }
1103
1104 #[test]
1105 fn test_hydrate_literal_unchanged() {
1106 let plan = empty_plan();
1107 let provided: HashSet<FactPath> = HashSet::new();
1108
1109 let expr = literal_expr(num(42));
1110 let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1111
1112 if let ExpressionKind::Literal(lit) = &result.kind {
1113 assert!(matches!(&lit.value, ValueKind::Number(_)));
1114 } else {
1115 panic!("Expected literal number");
1116 }
1117 }
1118
1119 #[test]
1120 fn test_hydrate_fact_not_provided() {
1121 let plan = empty_plan();
1122 let provided: HashSet<FactPath> = HashSet::new();
1123
1124 let expr = fact_expr("age");
1125 let result = hydrate_facts_in_expression(&Arc::new(expr), &plan, &provided).unwrap();
1126
1127 assert!(matches!(result.kind, ExpressionKind::FactPath(_)));
1128 }
1129
1130 #[test]
1131 fn test_constant_fold_arithmetic() {
1132 let left = literal_expr(num(10));
1133 let right = literal_expr(num(5));
1134 let expr = Expression::with_source(
1135 ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
1136 None,
1137 );
1138
1139 let folded = try_constant_fold_expression(&expr).unwrap();
1140
1141 if let ExpressionKind::Literal(lit) = &folded.kind {
1142 if let ValueKind::Number(n) = &lit.value {
1143 assert_eq!(*n, Decimal::from(15));
1144 } else {
1145 panic!("Expected literal number");
1146 }
1147 } else {
1148 panic!("Expected literal number");
1149 }
1150 }
1151
1152 #[test]
1153 fn test_constant_fold_comparison() {
1154 let left = literal_expr(num(10));
1155 let right = literal_expr(num(5));
1156 let expr = Expression::with_source(
1157 ExpressionKind::Comparison(
1158 Arc::new(left),
1159 ComparisonComputation::GreaterThan,
1160 Arc::new(right),
1161 ),
1162 None,
1163 );
1164
1165 let folded = try_constant_fold_expression(&expr).unwrap();
1166
1167 if let ExpressionKind::Literal(lit) = &folded.kind {
1168 if let ValueKind::Boolean(b) = &lit.value {
1169 assert!(*b);
1170 } else {
1171 panic!("Expected literal boolean");
1172 }
1173 } else {
1174 panic!("Expected literal boolean");
1175 }
1176 }
1177}