1mod constraint;
10mod derived;
11mod domain;
12mod solve;
13mod target;
14mod world;
15
16pub use derived::{DerivedExpression, DerivedExpressionKind};
17pub use domain::{extract_domains_from_constraint, Bound, Domain};
18pub use target::{Target, TargetOp};
19pub use world::World;
20
21use crate::planning::semantics::{Expression, FactPath, LiteralValue, ValueKind};
22use crate::planning::ExecutionPlan;
23use crate::{Error, OperationResult};
24use serde::ser::{Serialize, SerializeStruct, Serializer};
25use std::collections::{HashMap, HashSet};
26
27use world::{WorldEnumerator, WorldSolution};
28
29#[derive(Debug, Clone, serde::Serialize)]
38pub struct Solution {
39 pub outcome: OperationResult,
41 pub world: World,
43 #[serde(skip_serializing_if = "Option::is_none")]
46 pub shape: Option<Expression>,
47}
48
49#[derive(Debug, Clone)]
51pub struct InversionResponse {
52 pub solutions: Vec<Solution>,
54 pub domains: Vec<HashMap<FactPath, Domain>>,
56 pub undetermined_facts: Vec<FactPath>,
58 pub is_determined: bool,
60}
61
62impl InversionResponse {
63 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<FactPath, Domain>>) -> Self {
65 let undetermined_facts = compute_undetermined_facts(&domains);
66 let is_determined = compute_is_determined(&domains);
67 Self {
68 solutions,
69 domains,
70 undetermined_facts,
71 is_determined,
72 }
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.solutions.is_empty()
78 }
79
80 pub fn len(&self) -> usize {
82 self.solutions.len()
83 }
84
85 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<FactPath, Domain>)> {
87 self.solutions.iter().zip(self.domains.iter())
88 }
89}
90
91impl Serialize for InversionResponse {
92 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93 where
94 S: Serializer,
95 {
96 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
97 state.serialize_field("solutions", &self.solutions)?;
98
99 let domains_serializable: Vec<HashMap<String, String>> = self
100 .domains
101 .iter()
102 .map(|d| {
103 d.iter()
104 .map(|(k, v)| (k.to_string(), v.to_string()))
105 .collect()
106 })
107 .collect();
108 state.serialize_field("domains", &domains_serializable)?;
109
110 let undetermined_serializable: Vec<String> = self
111 .undetermined_facts
112 .iter()
113 .map(|fp| fp.to_string())
114 .collect();
115 state.serialize_field("undetermined_facts", &undetermined_serializable)?;
116 state.serialize_field("is_determined", &self.is_determined)?;
117 state.end()
118 }
119}
120
121pub fn invert(
136 rule_name: &str,
137 target: Target,
138 plan: &ExecutionPlan,
139 provided_facts: &HashSet<FactPath>,
140) -> Result<InversionResponse, Error> {
141 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
142 Error::validation(
143 format!("Rule not found: {}.{}", plan.spec_name, rule_name),
144 None,
145 None::<String>,
146 )
147 })?;
148
149 let rule_path = executable_rule.path.clone();
150
151 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
153 let enumeration_result = enumerator.enumerate(provided_facts)?;
154
155 let mut solutions = Vec::new();
157 let mut all_domains = Vec::new();
158
159 let filtered_literal_solutions =
161 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
162
163 for world_solution in filtered_literal_solutions {
164 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
165
166 let solution = Solution {
167 outcome: world_solution.outcome,
168 world: world_solution.world,
169 shape: None,
170 };
171
172 solutions.push(solution);
173 all_domains.push(constraint_domains);
174 }
175
176 if let Some(OperationResult::Value(target_value)) = &target.outcome {
178 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
180 let algebraic_solutions = solve::solve_arithmetic_batch(
181 enumeration_result.arithmetic_solutions.clone(),
182 target_value,
183 provided_facts,
184 );
185
186 let indices: std::collections::HashSet<usize> = algebraic_solutions
188 .iter()
189 .filter_map(|(ws, _, _)| {
190 enumeration_result
191 .arithmetic_solutions
192 .iter()
193 .position(|orig| orig.world == ws.world)
194 })
195 .collect();
196
197 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
199 let constraint_domains =
200 extract_domains_from_constraint(&world_solution.constraint)?;
201
202 let mut is_valid = true;
204 for (fact_path, solved_domain) in &solved_domains {
205 if let Some(constraint_domain) = constraint_domains.get(fact_path) {
206 if let Domain::Enumeration(values) = solved_domain {
208 for value in values.iter() {
209 if !constraint_domain.contains(value) {
210 is_valid = false;
211 break;
212 }
213 }
214 }
215 }
216 if !is_valid {
217 break;
218 }
219 }
220
221 if !is_valid {
222 continue; }
224
225 let solved_outcome_result = OperationResult::Value(Box::new(solved_outcome));
226
227 let mut combined_domains = constraint_domains;
228 for (fact_path, domain) in solved_domains {
229 combined_domains.insert(fact_path, domain);
230 }
231
232 let solution = Solution {
233 outcome: solved_outcome_result,
234 world: world_solution.world,
235 shape: None,
236 };
237
238 solutions.push(solution);
239 all_domains.push(combined_domains);
240 }
241
242 indices
243 } else {
244 std::collections::HashSet::new()
245 };
246
247 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
250 if solved_indices.contains(&idx) {
251 continue; }
253
254 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
256
257 let unknown_facts =
259 extract_fact_paths_from_expression(&arith_solution.outcome_expression);
260 for fact_path in unknown_facts {
261 if !combined_domains.contains_key(&fact_path)
263 && !provided_facts.contains(&fact_path)
264 {
265 combined_domains.insert(fact_path, Domain::Unconstrained);
266 }
267 }
268
269 let solution = Solution {
270 outcome: OperationResult::Value(Box::new(target_value.as_ref().clone())),
271 world: arith_solution.world.clone(),
272 shape: Some(arith_solution.outcome_expression.clone()),
273 };
274
275 solutions.push(solution);
276 all_domains.push(combined_domains);
277 }
278 }
279
280 Ok(InversionResponse::new(solutions, all_domains))
281}
282
283fn filter_literal_solutions_by_target(
289 solutions: Vec<WorldSolution>,
290 target: &Target,
291) -> Vec<WorldSolution> {
292 let mut filtered = Vec::new();
293
294 for solution in solutions {
295 let matches = match (&target.outcome, &solution.outcome) {
296 (None, _) => {
297 true
299 }
300 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
301 match target.op {
306 TargetOp::Eq => outcome_value.value == target_value.value,
307 TargetOp::Neq => outcome_value.value != target_value.value,
308 TargetOp::Lt => {
309 compare_values(outcome_value, target_value)
310 == Some(std::cmp::Ordering::Less)
311 }
312 TargetOp::Lte => {
313 let cmp = compare_values(outcome_value, target_value);
314 cmp == Some(std::cmp::Ordering::Less)
315 || cmp == Some(std::cmp::Ordering::Equal)
316 }
317 TargetOp::Gt => {
318 compare_values(outcome_value, target_value)
319 == Some(std::cmp::Ordering::Greater)
320 }
321 TargetOp::Gte => {
322 let cmp = compare_values(outcome_value, target_value);
323 cmp == Some(std::cmp::Ordering::Greater)
324 || cmp == Some(std::cmp::Ordering::Equal)
325 }
326 }
327 }
328 (Some(OperationResult::Veto(target_msg)), OperationResult::Veto(outcome_msg)) => {
329 match target_msg {
331 None => true, Some(t_msg) => outcome_msg.as_ref().map(|m| m == t_msg).unwrap_or(false),
333 }
334 }
335 _ => false, };
337
338 if matches {
339 filtered.push(solution);
340 }
341 }
342
343 filtered
344}
345
346fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
348 match (&a.value, &b.value) {
349 (ValueKind::Number(a_val), ValueKind::Number(b_val)) => Some(a_val.cmp(b_val)),
350 (ValueKind::Ratio(a_val, _), ValueKind::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
351 (ValueKind::Scale(a_val, _), ValueKind::Scale(b_val, _)) => Some(a_val.cmp(b_val)),
352 (ValueKind::Duration(a_val, unit_a), ValueKind::Duration(b_val, unit_b)) => {
353 if unit_a == unit_b {
354 Some(a_val.cmp(b_val))
355 } else {
356 None
357 }
358 }
359 _ => None,
360 }
361}
362
363fn extract_fact_paths_from_expression(expr: &Expression) -> Vec<FactPath> {
365 let mut set = std::collections::HashSet::new();
366 expr.collect_fact_paths(&mut set);
367 set.into_iter().collect()
368}
369
370fn compute_undetermined_facts(all_domains: &[HashMap<FactPath, Domain>]) -> Vec<FactPath> {
372 let mut undetermined: HashSet<FactPath> = HashSet::new();
373
374 for solution_domains in all_domains {
375 for (fact_path, domain) in solution_domains {
376 let is_determined = matches!(
377 domain,
378 Domain::Enumeration(values) if values.len() == 1
379 );
380 if !is_determined {
381 undetermined.insert(fact_path.clone());
382 }
383 }
384 }
385
386 let mut result: Vec<FactPath> = undetermined.into_iter().collect();
387 result.sort_by_key(|a| a.to_string());
388 result
389}
390
391fn compute_is_determined(all_domains: &[HashMap<FactPath, Domain>]) -> bool {
393 if all_domains.is_empty() {
394 return true;
395 }
396
397 for solution_domains in all_domains {
398 for domain in solution_domains.values() {
399 let is_single_value = matches!(
400 domain,
401 Domain::Enumeration(values) if values.len() == 1
402 );
403 if !is_single_value {
404 return false;
405 }
406 }
407 }
408
409 true
410}
411
412#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::parsing::ast::DateTimeValue;
420 use crate::Engine;
421 use rust_decimal::Decimal;
422 use std::collections::HashMap;
423 use std::sync::Arc;
424
425 fn add_lemma_code_blocking(
426 engine: &mut Engine,
427 code: &str,
428 source: &str,
429 ) -> Result<(), Vec<crate::Error>> {
430 let files: std::collections::HashMap<String, String> =
431 std::iter::once((source.to_string(), code.to_string())).collect();
432 engine.add_lemma_files(files)
433 }
434
435 #[test]
436 fn test_format_target_eq() {
437 let target = Target::value(LiteralValue::number(Decimal::from(42)));
438 let formatted = target.format();
439 assert_eq!(formatted, "= 42");
440 }
441
442 #[test]
443 fn test_format_target_any() {
444 let target = Target::any_value();
445 let formatted = target.format();
446 assert_eq!(formatted, "= any");
447 }
448
449 #[test]
450 fn test_compute_undetermined_facts_empty() {
451 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
452 let undetermined = compute_undetermined_facts(&domains);
453 assert!(undetermined.is_empty());
454 }
455
456 #[test]
457 fn test_compute_undetermined_facts_single_value() {
458 let mut domain_map = HashMap::new();
459 domain_map.insert(
460 FactPath::new(vec![], "age".to_string()),
461 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
462 );
463 let domains = vec![domain_map];
464 let undetermined = compute_undetermined_facts(&domains);
465 assert!(undetermined.is_empty());
466 }
467
468 #[test]
469 fn test_compute_undetermined_facts_range() {
470 let mut domain_map = HashMap::new();
471 domain_map.insert(
472 FactPath::new(vec![], "age".to_string()),
473 Domain::Range {
474 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
475 max: Bound::Unbounded,
476 },
477 );
478 let domains = vec![domain_map];
479 let undetermined = compute_undetermined_facts(&domains);
480 assert_eq!(undetermined.len(), 1);
481 }
482
483 #[test]
484 fn test_compute_is_determined_empty() {
485 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
486 assert!(compute_is_determined(&domains));
487 }
488
489 #[test]
490 fn test_compute_is_determined_true() {
491 let mut domain_map = HashMap::new();
492 domain_map.insert(
493 FactPath::new(vec![], "age".to_string()),
494 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
495 );
496 let domains = vec![domain_map];
497 assert!(compute_is_determined(&domains));
498 }
499
500 #[test]
501 fn test_compute_is_determined_false() {
502 let mut domain_map = HashMap::new();
503 domain_map.insert(
504 FactPath::new(vec![], "age".to_string()),
505 Domain::Range {
506 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
507 max: Bound::Unbounded,
508 },
509 );
510 let domains = vec![domain_map];
511 assert!(!compute_is_determined(&domains));
512 }
513
514 #[test]
515 fn test_invert_strict_rule_reference_expands_constraints() {
516 let code = r#"
519spec example
520fact x: [number]
521rule base: x
522 unless x > 3 then veto "too much"
523 unless x < 0 then veto "too little"
524
525rule another: base
526 unless x > 5 then veto "way too much"
527"#;
528
529 let mut engine = Engine::new();
530 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
531 let now = DateTimeValue::now();
532
533 let inv = engine
534 .invert(
535 "example",
536 &now,
537 "another",
538 Target::value(LiteralValue::number(3.into())),
539 HashMap::new(),
540 )
541 .expect("inversion should succeed");
542
543 assert!(!inv.is_empty(), "expected at least one solution");
544
545 let x = FactPath::new(vec![], "x".to_string());
546 let three = LiteralValue::number(3.into());
547
548 for (_solution, domains) in inv.iter() {
550 let d = domains.get(&x).expect("domain for x should exist");
551 assert!(
552 d.contains(&three),
553 "x domain should contain 3. Domain: {}",
554 d
555 );
556 }
557 }
558
559 #[test]
560 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
561 let code = r#"
562spec example
563fact x: [number]
564rule base: x
565 unless x > 3 then veto "too much"
566 unless x < 0 then veto "too little"
567
568rule another: base
569 unless x > 5 then veto "way too much"
570"#;
571
572 let mut engine = Engine::new();
573 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
574 let now = DateTimeValue::now();
575
576 let inv = engine
577 .invert(
578 "example",
579 &now,
580 "another",
581 Target::value(LiteralValue::number(7.into())),
582 HashMap::new(),
583 )
584 .expect("inversion should succeed");
585
586 assert!(
587 inv.is_empty(),
588 "Should have no solutions because another can never equal 7"
589 );
590 }
591
592 #[test]
593 fn test_invert_strict_veto_target_constrains_domain() {
594 let code = r#"
595spec example
596fact x: [number]
597rule base: x
598 unless x > 3 then veto "too much"
599 unless x < 0 then veto "too little"
600
601rule another: base
602 unless x > 5 then veto "way too much"
603"#;
604
605 let mut engine = Engine::new();
606 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
607 let now = DateTimeValue::now();
608
609 let inv = engine
610 .invert(
611 "example",
612 &now,
613 "another",
614 Target::veto(Some("way too much".to_string())),
615 HashMap::new(),
616 )
617 .expect("inversion should succeed");
618
619 assert!(!inv.is_empty(), "expected solutions for veto query");
620
621 let x = FactPath::new(vec![], "x".to_string());
622 let five = LiteralValue::number(5.into());
623 let six = LiteralValue::number(6.into());
624
625 for (solution, domains) in inv.iter() {
626 assert_eq!(
627 solution.outcome,
628 OperationResult::Veto(Some("way too much".to_string())),
629 "Expected solution outcome to be veto('way too much'), got: {:?}",
630 solution.outcome
631 );
632
633 let d = domains.get(&x).expect("domain for x should exist");
634 match d {
635 Domain::Range { min, max } => {
636 assert!(
637 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
638 "Expected min bound to be (5), got: {}",
639 d
640 );
641 assert!(
642 matches!(max, Bound::Unbounded),
643 "Expected max bound to be +inf, got: {}",
644 d
645 );
646 }
647 other => panic!("Expected range domain for x, got: {}", other),
648 }
649 assert!(
650 !d.contains(&five),
651 "x=5 should not be in veto('way too much') domain. Domain: {}",
652 d
653 );
654 assert!(
655 d.contains(&six),
656 "x=6 should be in veto('way too much') domain. Domain: {}",
657 d
658 );
659 }
660 }
661
662 #[test]
663 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
664 let code = r#"
665spec example
666fact x: [number]
667rule base: x
668 unless x > 3 then veto "too much"
669 unless x < 0 then veto "too little"
670
671rule another: base
672 unless x > 5 then veto "way too much"
673"#;
674
675 let mut engine = Engine::new();
676 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
677
678 let now = DateTimeValue::now();
679 let inv = engine
680 .invert(
681 "example",
682 &now,
683 "another",
684 Target::any_veto(),
685 HashMap::new(),
686 )
687 .expect("inversion should succeed");
688
689 assert!(!inv.is_empty(), "expected solutions for any-veto query");
690
691 let x = FactPath::new(vec![], "x".to_string());
692 let minus_one = LiteralValue::number((-1).into());
693 let zero = LiteralValue::number(0.into());
694 let two = LiteralValue::number(2.into());
695 let three = LiteralValue::number(3.into());
696 let four = LiteralValue::number(4.into());
697 let five = LiteralValue::number(5.into());
698 let six = LiteralValue::number(6.into());
699
700 let mut saw_too_little = false;
701 let mut saw_too_much = false;
702 let mut saw_way_too_much = false;
703
704 for (solution, domains) in inv.iter() {
705 let d = domains.get(&x).expect("domain for x should exist");
706 assert!(
707 !d.contains(&two),
708 "x=2 should not be in any-veto domain. Domain: {}",
709 d
710 );
711
712 match &solution.outcome {
713 OperationResult::Veto(Some(msg)) if msg == "too little" => {
714 saw_too_little = true;
715
716 match d {
717 Domain::Range { min, max } => {
718 assert!(
719 matches!(min, Bound::Unbounded),
720 "Expected min bound to be -inf for 'too little', got: {}",
721 d
722 );
723 assert!(
724 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
725 "Expected max bound to be (0) for 'too little', got: {}",
726 d
727 );
728 }
729 other => panic!("Expected range domain for x, got: {}", other),
730 }
731
732 assert!(
733 d.contains(&minus_one),
734 "x=-1 should be in veto('too little') domain. Domain: {}",
735 d
736 );
737 assert!(
738 !d.contains(&zero),
739 "x=0 should not be in veto('too little') domain. Domain: {}",
740 d
741 );
742 }
743 OperationResult::Veto(Some(msg)) if msg == "too much" => {
744 saw_too_much = true;
745
746 match d {
747 Domain::Range { min, max } => {
748 assert!(
749 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
750 "Expected min bound to be (3) for 'too much', got: {}",
751 d
752 );
753 assert!(
754 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
755 "Expected max bound to be [5] for 'too much', got: {}",
756 d
757 );
758 }
759 other => panic!("Expected range domain for x, got: {}", other),
760 }
761
762 assert!(
763 d.contains(&four),
764 "x=4 should be in veto('too much') domain. Domain: {}",
765 d
766 );
767 assert!(
768 d.contains(&five),
769 "x=5 should be in veto('too much') domain. Domain: {}",
770 d
771 );
772 assert!(
773 !d.contains(&three),
774 "x=3 should not be in veto('too much') domain. Domain: {}",
775 d
776 );
777 assert!(
778 !d.contains(&six),
779 "x=6 should not be in veto('too much') domain. Domain: {}",
780 d
781 );
782 }
783 OperationResult::Veto(Some(msg)) if msg == "way too much" => {
784 saw_way_too_much = true;
785
786 match d {
787 Domain::Range { min, max } => {
788 assert!(
789 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
790 "Expected min bound to be (5) for 'way too much', got: {}",
791 d
792 );
793 assert!(
794 matches!(max, Bound::Unbounded),
795 "Expected max bound to be +inf for 'way too much', got: {}",
796 d
797 );
798 }
799 other => panic!("Expected range domain for x, got: {}", other),
800 }
801
802 assert!(
803 d.contains(&six),
804 "x=6 should be in veto('way too much') domain. Domain: {}",
805 d
806 );
807 assert!(
808 !d.contains(&five),
809 "x=5 should not be in veto('way too much') domain. Domain: {}",
810 d
811 );
812 }
813 OperationResult::Veto(Some(other)) => {
814 panic!("Unexpected veto message in any-veto results: {:?}", other)
815 }
816 OperationResult::Veto(None) => {
817 panic!("Unexpected veto(None) in any-veto results (expected a message)")
818 }
819 OperationResult::Value(v) => {
820 panic!("Unexpected value result in any-veto results: {:?}", v)
821 }
822 }
823 }
824
825 assert!(
826 saw_too_little,
827 "Expected at least one veto('too little') solution"
828 );
829 assert!(
830 saw_too_much,
831 "Expected at least one veto('too much') solution"
832 );
833 assert!(
834 saw_way_too_much,
835 "Expected at least one veto('way too much') solution"
836 );
837 }
838}