1mod constraint;
10mod domain;
11mod solve;
12mod target;
13mod world;
14
15pub use domain::{extract_domains_from_constraint, Bound, Domain};
16pub use target::{Target, TargetOp};
17pub use world::World;
18
19use crate::parsing::ast::Span;
20use crate::planning::ExecutionPlan;
21use crate::{
22 Expression, ExpressionKind, FactPath, LemmaError, LemmaResult, LiteralValue, OperationResult,
23 Value,
24};
25use serde::ser::{Serialize, SerializeStruct, Serializer};
26use std::collections::{HashMap, HashSet};
27
28use world::{WorldEnumerator, WorldSolution};
29
30#[derive(Debug, Clone, serde::Serialize)]
39pub struct Solution {
40 pub outcome: OperationResult,
42 pub world: World,
44 #[serde(skip_serializing_if = "Option::is_none")]
47 pub shape: Option<Expression>,
48}
49
50#[derive(Debug, Clone)]
52pub struct InversionResponse {
53 pub solutions: Vec<Solution>,
55 pub domains: Vec<HashMap<FactPath, Domain>>,
57 pub undetermined_facts: Vec<FactPath>,
59 pub is_determined: bool,
61}
62
63impl InversionResponse {
64 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<FactPath, Domain>>) -> Self {
66 let undetermined_facts = compute_undetermined_facts(&domains);
67 let is_determined = compute_is_determined(&domains);
68 Self {
69 solutions,
70 domains,
71 undetermined_facts,
72 is_determined,
73 }
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.solutions.is_empty()
79 }
80
81 pub fn len(&self) -> usize {
83 self.solutions.len()
84 }
85
86 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<FactPath, Domain>)> {
88 self.solutions.iter().zip(self.domains.iter())
89 }
90
91 pub fn free_variables(&self) -> &[FactPath] {
93 &self.undetermined_facts
94 }
95
96 pub fn is_fully_constrained(&self) -> bool {
98 self.is_determined
99 }
100}
101
102impl Serialize for InversionResponse {
103 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104 where
105 S: Serializer,
106 {
107 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
108 state.serialize_field("solutions", &self.solutions)?;
109
110 let domains_serializable: Vec<HashMap<String, String>> = self
111 .domains
112 .iter()
113 .map(|d| {
114 d.iter()
115 .map(|(k, v)| (k.to_string(), v.to_string()))
116 .collect()
117 })
118 .collect();
119 state.serialize_field("domains", &domains_serializable)?;
120
121 let undetermined_serializable: Vec<String> = self
122 .undetermined_facts
123 .iter()
124 .map(|fp| fp.to_string())
125 .collect();
126 state.serialize_field("undetermined_facts", &undetermined_serializable)?;
127 state.serialize_field("is_determined", &self.is_determined)?;
128 state.end()
129 }
130}
131
132pub fn invert(
147 rule_name: &str,
148 target: Target,
149 plan: &ExecutionPlan,
150 provided_facts: &HashSet<FactPath>,
151) -> LemmaResult<InversionResponse> {
152 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
153 LemmaError::engine(
154 format!("Rule not found: {}.{}", plan.doc_name, rule_name),
155 Span {
156 start: 0,
157 end: 0,
158 line: 1,
159 col: 0,
160 },
161 "<inversion>",
162 std::sync::Arc::from(""),
163 plan.doc_name.clone(),
164 1,
165 None::<String>,
166 )
167 })?;
168
169 let rule_path = executable_rule.path.clone();
170
171 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
173 let enumeration_result = enumerator.enumerate(provided_facts)?;
174
175 let mut solutions = Vec::new();
177 let mut all_domains = Vec::new();
178
179 let filtered_literal_solutions =
181 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
182
183 for world_solution in filtered_literal_solutions {
184 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
185
186 let solution = Solution {
187 outcome: world_solution.outcome,
188 world: world_solution.world,
189 shape: None,
190 };
191
192 solutions.push(solution);
193 all_domains.push(constraint_domains);
194 }
195
196 if let Some(OperationResult::Value(target_value)) = &target.outcome {
198 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
200 let algebraic_solutions = solve::solve_arithmetic_batch(
201 enumeration_result.arithmetic_solutions.clone(),
202 target_value,
203 provided_facts,
204 );
205
206 let indices: std::collections::HashSet<usize> = algebraic_solutions
208 .iter()
209 .map(|(ws, _, _)| {
210 enumeration_result
211 .arithmetic_solutions
212 .iter()
213 .position(|orig| orig.world == ws.world)
214 .unwrap_or(usize::MAX)
215 })
216 .collect();
217
218 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
220 let constraint_domains =
221 extract_domains_from_constraint(&world_solution.constraint)?;
222
223 let mut is_valid = true;
225 for (fact_path, solved_domain) in &solved_domains {
226 if let Some(constraint_domain) = constraint_domains.get(fact_path) {
227 if let Domain::Enumeration(values) = solved_domain {
229 for value in values.iter() {
230 if !constraint_domain.contains(value) {
231 is_valid = false;
232 break;
233 }
234 }
235 }
236 }
237 if !is_valid {
238 break;
239 }
240 }
241
242 if !is_valid {
243 continue; }
245
246 let solved_outcome_result = OperationResult::Value(solved_outcome);
247
248 let mut combined_domains = constraint_domains;
249 for (fact_path, domain) in solved_domains {
250 combined_domains.insert(fact_path, domain);
251 }
252
253 let solution = Solution {
254 outcome: solved_outcome_result,
255 world: world_solution.world,
256 shape: None,
257 };
258
259 solutions.push(solution);
260 all_domains.push(combined_domains);
261 }
262
263 indices
264 } else {
265 std::collections::HashSet::new()
266 };
267
268 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
271 if solved_indices.contains(&idx) {
272 continue; }
274
275 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
277
278 let unknown_facts =
280 extract_fact_paths_from_expression(&arith_solution.outcome_expression);
281 for fact_path in unknown_facts {
282 if !combined_domains.contains_key(&fact_path)
284 && !provided_facts.contains(&fact_path)
285 {
286 combined_domains.insert(fact_path, Domain::Unconstrained);
287 }
288 }
289
290 let solution = Solution {
291 outcome: OperationResult::Value(target_value.clone()),
292 world: arith_solution.world.clone(),
293 shape: Some(arith_solution.outcome_expression.clone()),
294 };
295
296 solutions.push(solution);
297 all_domains.push(combined_domains);
298 }
299 }
300
301 Ok(InversionResponse::new(solutions, all_domains))
302}
303
304fn filter_literal_solutions_by_target(
310 solutions: Vec<WorldSolution>,
311 target: &Target,
312) -> Vec<WorldSolution> {
313 let mut filtered = Vec::new();
314
315 for solution in solutions {
316 let matches = match (&target.outcome, &solution.outcome) {
317 (None, _) => {
318 true
320 }
321 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
322 match target.op {
324 TargetOp::Eq => outcome_value == target_value,
325 TargetOp::Neq => outcome_value != target_value,
326 TargetOp::Lt => {
327 compare_values(outcome_value, target_value)
328 == Some(std::cmp::Ordering::Less)
329 }
330 TargetOp::Lte => {
331 let cmp = compare_values(outcome_value, target_value);
332 cmp == Some(std::cmp::Ordering::Less)
333 || cmp == Some(std::cmp::Ordering::Equal)
334 }
335 TargetOp::Gt => {
336 compare_values(outcome_value, target_value)
337 == Some(std::cmp::Ordering::Greater)
338 }
339 TargetOp::Gte => {
340 let cmp = compare_values(outcome_value, target_value);
341 cmp == Some(std::cmp::Ordering::Greater)
342 || cmp == Some(std::cmp::Ordering::Equal)
343 }
344 }
345 }
346 (Some(OperationResult::Veto(target_msg)), OperationResult::Veto(outcome_msg)) => {
347 match target_msg {
349 None => true, Some(t_msg) => outcome_msg.as_ref().map(|m| m == t_msg).unwrap_or(false),
351 }
352 }
353 _ => false, };
355
356 if matches {
357 filtered.push(solution);
358 }
359 }
360
361 filtered
362}
363
364fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
366 match (&a.value, &b.value) {
367 (Value::Number(a_val), Value::Number(b_val)) => Some(a_val.cmp(b_val)),
368 (Value::Ratio(a_val, _), Value::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
369 (Value::Scale(a_val, _), Value::Scale(b_val, _)) => Some(a_val.cmp(b_val)),
370 (Value::Duration(a_val, unit_a), Value::Duration(b_val, unit_b)) => {
371 if unit_a == unit_b {
372 Some(a_val.cmp(b_val))
373 } else {
374 None
375 }
376 }
377 _ => None,
378 }
379}
380
381fn extract_fact_paths_from_expression(expr: &Expression) -> Vec<FactPath> {
383 let mut paths = Vec::new();
384 collect_fact_paths(expr, &mut paths);
385 paths
386}
387
388fn collect_fact_paths(expr: &Expression, paths: &mut Vec<FactPath>) {
389 match &expr.kind {
390 ExpressionKind::FactPath(fp) => {
391 if !paths.contains(fp) {
392 paths.push(fp.clone());
393 }
394 }
395 ExpressionKind::Arithmetic(left, _, right)
396 | ExpressionKind::Comparison(left, _, right)
397 | ExpressionKind::LogicalAnd(left, right)
398 | ExpressionKind::LogicalOr(left, right) => {
399 collect_fact_paths(left, paths);
400 collect_fact_paths(right, paths);
401 }
402 ExpressionKind::LogicalNegation(inner, _)
403 | ExpressionKind::UnitConversion(inner, _)
404 | ExpressionKind::MathematicalComputation(_, inner) => {
405 collect_fact_paths(inner, paths);
406 }
407 ExpressionKind::Literal(_)
408 | ExpressionKind::Veto(_)
409 | ExpressionKind::Reference(_)
410 | ExpressionKind::UnresolvedUnitLiteral(_, _)
411 | ExpressionKind::FactReference(_)
412 | ExpressionKind::RuleReference(_)
413 | ExpressionKind::RulePath(_) => {}
414 }
415}
416
417fn compute_undetermined_facts(all_domains: &[HashMap<FactPath, Domain>]) -> Vec<FactPath> {
419 let mut undetermined: HashSet<FactPath> = HashSet::new();
420
421 for solution_domains in all_domains {
422 for (fact_path, domain) in solution_domains {
423 let is_determined = matches!(
424 domain,
425 Domain::Enumeration(values) if values.len() == 1
426 );
427 if !is_determined {
428 undetermined.insert(fact_path.clone());
429 }
430 }
431 }
432
433 let mut result: Vec<FactPath> = undetermined.into_iter().collect();
434 result.sort_by_key(|a| a.to_string());
435 result
436}
437
438fn compute_is_determined(all_domains: &[HashMap<FactPath, Domain>]) -> bool {
440 if all_domains.is_empty() {
441 return true;
442 }
443
444 for solution_domains in all_domains {
445 for domain in solution_domains.values() {
446 let is_single_value = matches!(
447 domain,
448 Domain::Enumeration(values) if values.len() == 1
449 );
450 if !is_single_value {
451 return false;
452 }
453 }
454 }
455
456 true
457}
458
459#[cfg(test)]
464mod tests {
465 use super::*;
466 use crate::Engine;
467 use rust_decimal::Decimal;
468 use std::collections::HashMap;
469 use std::sync::Arc;
470
471 #[test]
472 fn test_format_target_eq() {
473 let target = Target::value(LiteralValue::number(Decimal::from(42)));
474 let formatted = target.format();
475 assert_eq!(formatted, "= 42");
476 }
477
478 #[test]
479 fn test_format_target_any() {
480 let target = Target::any_value();
481 let formatted = target.format();
482 assert_eq!(formatted, "= any");
483 }
484
485 #[test]
486 fn test_compute_undetermined_facts_empty() {
487 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
488 let undetermined = compute_undetermined_facts(&domains);
489 assert!(undetermined.is_empty());
490 }
491
492 #[test]
493 fn test_compute_undetermined_facts_single_value() {
494 let mut domain_map = HashMap::new();
495 domain_map.insert(
496 FactPath::local("age".to_string()),
497 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
498 );
499 let domains = vec![domain_map];
500 let undetermined = compute_undetermined_facts(&domains);
501 assert!(undetermined.is_empty());
502 }
503
504 #[test]
505 fn test_compute_undetermined_facts_range() {
506 let mut domain_map = HashMap::new();
507 domain_map.insert(
508 FactPath::local("age".to_string()),
509 Domain::Range {
510 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
511 max: Bound::Unbounded,
512 },
513 );
514 let domains = vec![domain_map];
515 let undetermined = compute_undetermined_facts(&domains);
516 assert_eq!(undetermined.len(), 1);
517 }
518
519 #[test]
520 fn test_compute_is_determined_empty() {
521 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
522 assert!(compute_is_determined(&domains));
523 }
524
525 #[test]
526 fn test_compute_is_determined_true() {
527 let mut domain_map = HashMap::new();
528 domain_map.insert(
529 FactPath::local("age".to_string()),
530 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
531 );
532 let domains = vec![domain_map];
533 assert!(compute_is_determined(&domains));
534 }
535
536 #[test]
537 fn test_compute_is_determined_false() {
538 let mut domain_map = HashMap::new();
539 domain_map.insert(
540 FactPath::local("age".to_string()),
541 Domain::Range {
542 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
543 max: Bound::Unbounded,
544 },
545 );
546 let domains = vec![domain_map];
547 assert!(!compute_is_determined(&domains));
548 }
549
550 #[test]
551 fn test_invert_strict_rule_reference_expands_constraints() {
552 let code = r#"
555doc example
556fact x = [number]
557rule base = x
558 unless x > 3 then veto "too much"
559 unless x < 0 then veto "too little"
560
561rule another = base?
562 unless x > 5 then veto "way too much"
563"#;
564
565 let mut engine = Engine::new();
566 engine.add_lemma_code(code, "test.lemma").unwrap();
567
568 let inv = engine
569 .invert(
570 "example",
571 "another",
572 Target::value(LiteralValue::number(3)),
573 HashMap::new(),
574 )
575 .expect("inversion should succeed");
576
577 assert!(!inv.is_empty(), "expected at least one solution");
578
579 let x = FactPath::local("x".to_string());
580 let three = LiteralValue::number(3);
581
582 for (_solution, domains) in inv.iter() {
584 let d = domains.get(&x).expect("domain for x should exist");
585 assert!(
586 d.contains(&three),
587 "x domain should contain 3. Domain: {}",
588 d
589 );
590 }
591 }
592
593 #[test]
594 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
595 let code = r#"
596doc example
597fact x = [number]
598rule base = x
599 unless x > 3 then veto "too much"
600 unless x < 0 then veto "too little"
601
602rule another = base?
603 unless x > 5 then veto "way too much"
604"#;
605
606 let mut engine = Engine::new();
607 engine.add_lemma_code(code, "test.lemma").unwrap();
608
609 let inv = engine
610 .invert(
611 "example",
612 "another",
613 Target::value(LiteralValue::number(7)),
614 HashMap::new(),
615 )
616 .expect("inversion should succeed");
617
618 assert!(
619 inv.is_empty(),
620 "Should have no solutions because another can never equal 7"
621 );
622 }
623
624 #[test]
625 fn test_invert_strict_veto_target_constrains_domain() {
626 let code = r#"
627doc example
628fact x = [number]
629rule base = x
630 unless x > 3 then veto "too much"
631 unless x < 0 then veto "too little"
632
633rule another = base?
634 unless x > 5 then veto "way too much"
635"#;
636
637 let mut engine = Engine::new();
638 engine.add_lemma_code(code, "test.lemma").unwrap();
639
640 let inv = engine
641 .invert(
642 "example",
643 "another",
644 Target::veto(Some("way too much".to_string())),
645 HashMap::new(),
646 )
647 .expect("inversion should succeed");
648
649 assert!(!inv.is_empty(), "expected solutions for veto query");
650
651 let x = FactPath::local("x".to_string());
652 let five = LiteralValue::number(5);
653 let six = LiteralValue::number(6);
654
655 for (solution, domains) in inv.iter() {
656 assert_eq!(
657 solution.outcome,
658 OperationResult::Veto(Some("way too much".to_string())),
659 "Expected solution outcome to be veto('way too much'), got: {:?}",
660 solution.outcome
661 );
662
663 let d = domains.get(&x).expect("domain for x should exist");
664 match d {
665 Domain::Range { min, max } => {
666 assert!(
667 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
668 "Expected min bound to be (5), got: {}",
669 d
670 );
671 assert!(
672 matches!(max, Bound::Unbounded),
673 "Expected max bound to be +inf, got: {}",
674 d
675 );
676 }
677 other => panic!("Expected range domain for x, got: {}", other),
678 }
679 assert!(
680 !d.contains(&five),
681 "x=5 should not be in veto('way too much') domain. Domain: {}",
682 d
683 );
684 assert!(
685 d.contains(&six),
686 "x=6 should be in veto('way too much') domain. Domain: {}",
687 d
688 );
689 }
690 }
691
692 #[test]
693 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
694 let code = r#"
695doc example
696fact x = [number]
697rule base = x
698 unless x > 3 then veto "too much"
699 unless x < 0 then veto "too little"
700
701rule another = base?
702 unless x > 5 then veto "way too much"
703"#;
704
705 let mut engine = Engine::new();
706 engine.add_lemma_code(code, "test.lemma").unwrap();
707
708 let inv = engine
709 .invert("example", "another", Target::any_veto(), HashMap::new())
710 .expect("inversion should succeed");
711
712 assert!(!inv.is_empty(), "expected solutions for any-veto query");
713
714 let x = FactPath::local("x".to_string());
715 let minus_one = LiteralValue::number(-1);
716 let zero = LiteralValue::number(0);
717 let two = LiteralValue::number(2);
718 let three = LiteralValue::number(3);
719 let four = LiteralValue::number(4);
720 let five = LiteralValue::number(5);
721 let six = LiteralValue::number(6);
722
723 let mut saw_too_little = false;
724 let mut saw_too_much = false;
725 let mut saw_way_too_much = false;
726
727 for (solution, domains) in inv.iter() {
728 let d = domains.get(&x).expect("domain for x should exist");
729 assert!(
730 !d.contains(&two),
731 "x=2 should not be in any-veto domain. Domain: {}",
732 d
733 );
734
735 match &solution.outcome {
736 OperationResult::Veto(Some(msg)) if msg == "too little" => {
737 saw_too_little = true;
738
739 match d {
740 Domain::Range { min, max } => {
741 assert!(
742 matches!(min, Bound::Unbounded),
743 "Expected min bound to be -inf for 'too little', got: {}",
744 d
745 );
746 assert!(
747 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
748 "Expected max bound to be (0) for 'too little', got: {}",
749 d
750 );
751 }
752 other => panic!("Expected range domain for x, got: {}", other),
753 }
754
755 assert!(
756 d.contains(&minus_one),
757 "x=-1 should be in veto('too little') domain. Domain: {}",
758 d
759 );
760 assert!(
761 !d.contains(&zero),
762 "x=0 should not be in veto('too little') domain. Domain: {}",
763 d
764 );
765 }
766 OperationResult::Veto(Some(msg)) if msg == "too much" => {
767 saw_too_much = true;
768
769 match d {
770 Domain::Range { min, max } => {
771 assert!(
772 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
773 "Expected min bound to be (3) for 'too much', got: {}",
774 d
775 );
776 assert!(
777 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
778 "Expected max bound to be [5] for 'too much', got: {}",
779 d
780 );
781 }
782 other => panic!("Expected range domain for x, got: {}", other),
783 }
784
785 assert!(
786 d.contains(&four),
787 "x=4 should be in veto('too much') domain. Domain: {}",
788 d
789 );
790 assert!(
791 d.contains(&five),
792 "x=5 should be in veto('too much') domain. Domain: {}",
793 d
794 );
795 assert!(
796 !d.contains(&three),
797 "x=3 should not be in veto('too much') domain. Domain: {}",
798 d
799 );
800 assert!(
801 !d.contains(&six),
802 "x=6 should not be in veto('too much') domain. Domain: {}",
803 d
804 );
805 }
806 OperationResult::Veto(Some(msg)) if msg == "way too much" => {
807 saw_way_too_much = true;
808
809 match d {
810 Domain::Range { min, max } => {
811 assert!(
812 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
813 "Expected min bound to be (5) for 'way too much', got: {}",
814 d
815 );
816 assert!(
817 matches!(max, Bound::Unbounded),
818 "Expected max bound to be +inf for 'way too much', got: {}",
819 d
820 );
821 }
822 other => panic!("Expected range domain for x, got: {}", other),
823 }
824
825 assert!(
826 d.contains(&six),
827 "x=6 should be in veto('way too much') domain. Domain: {}",
828 d
829 );
830 assert!(
831 !d.contains(&five),
832 "x=5 should not be in veto('way too much') domain. Domain: {}",
833 d
834 );
835 }
836 OperationResult::Veto(Some(other)) => {
837 panic!("Unexpected veto message in any-veto results: {:?}", other)
838 }
839 OperationResult::Veto(None) => {
840 panic!("Unexpected veto(None) in any-veto results (expected a message)")
841 }
842 OperationResult::Value(v) => {
843 panic!("Unexpected value result in any-veto results: {:?}", v)
844 }
845 }
846 }
847
848 assert!(
849 saw_too_little,
850 "Expected at least one veto('too little') solution"
851 );
852 assert!(
853 saw_too_much,
854 "Expected at least one veto('too much') solution"
855 );
856 assert!(
857 saw_way_too_much,
858 "Expected at least one veto('way too much') solution"
859 );
860 }
861}