1mod constraint;
10mod derived;
11mod domain;
12mod solve;
13mod target;
14mod world;
15
16pub use derived::{DerivedExpression, DerivedExpressionKind};
17pub use domain::{extract_domains_from_constraint, Bound, Domain};
18pub use target::{Target, TargetOp};
19pub use world::World;
20
21use crate::planning::semantics::{Expression, FactPath, LiteralValue, ValueKind};
22use crate::planning::ExecutionPlan;
23use crate::{Error, OperationResult};
24use serde::ser::{Serialize, SerializeStruct, Serializer};
25use std::collections::{HashMap, HashSet};
26
27use world::{WorldEnumerator, WorldSolution};
28
29#[derive(Debug, Clone, serde::Serialize)]
38pub struct Solution {
39 pub outcome: OperationResult,
41 pub world: World,
43 #[serde(skip_serializing_if = "Option::is_none")]
46 pub shape: Option<Expression>,
47}
48
49#[derive(Debug, Clone)]
51pub struct InversionResponse {
52 pub solutions: Vec<Solution>,
54 pub domains: Vec<HashMap<FactPath, Domain>>,
56 pub undetermined_facts: Vec<FactPath>,
58 pub is_determined: bool,
60}
61
62impl InversionResponse {
63 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<FactPath, Domain>>) -> Self {
65 let undetermined_facts = compute_undetermined_facts(&domains);
66 let is_determined = compute_is_determined(&domains);
67 Self {
68 solutions,
69 domains,
70 undetermined_facts,
71 is_determined,
72 }
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.solutions.is_empty()
78 }
79
80 pub fn len(&self) -> usize {
82 self.solutions.len()
83 }
84
85 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<FactPath, Domain>)> {
87 self.solutions.iter().zip(self.domains.iter())
88 }
89}
90
91impl Serialize for InversionResponse {
92 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93 where
94 S: Serializer,
95 {
96 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
97 state.serialize_field("solutions", &self.solutions)?;
98
99 let domains_serializable: Vec<HashMap<String, String>> = self
100 .domains
101 .iter()
102 .map(|d| {
103 d.iter()
104 .map(|(k, v)| (k.to_string(), v.to_string()))
105 .collect()
106 })
107 .collect();
108 state.serialize_field("domains", &domains_serializable)?;
109
110 let undetermined_serializable: Vec<String> = self
111 .undetermined_facts
112 .iter()
113 .map(|fp| fp.to_string())
114 .collect();
115 state.serialize_field("undetermined_facts", &undetermined_serializable)?;
116 state.serialize_field("is_determined", &self.is_determined)?;
117 state.end()
118 }
119}
120
121pub fn invert(
136 rule_name: &str,
137 target: Target,
138 plan: &ExecutionPlan,
139 provided_facts: &HashSet<FactPath>,
140) -> Result<InversionResponse, Error> {
141 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
142 Error::request(
143 format!("Rule not found: {}.{}", plan.spec_name, rule_name),
144 None::<String>,
145 )
146 })?;
147
148 let rule_path = executable_rule.path.clone();
149
150 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
152 let enumeration_result = enumerator.enumerate(provided_facts)?;
153
154 let mut solutions = Vec::new();
156 let mut all_domains = Vec::new();
157
158 let filtered_literal_solutions =
160 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
161
162 for world_solution in filtered_literal_solutions {
163 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
164
165 let solution = Solution {
166 outcome: world_solution.outcome,
167 world: world_solution.world,
168 shape: None,
169 };
170
171 solutions.push(solution);
172 all_domains.push(constraint_domains);
173 }
174
175 if let Some(OperationResult::Value(target_value)) = &target.outcome {
177 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
179 let algebraic_solutions = solve::solve_arithmetic_batch(
180 enumeration_result.arithmetic_solutions.clone(),
181 target_value,
182 provided_facts,
183 );
184
185 let indices: std::collections::HashSet<usize> = algebraic_solutions
187 .iter()
188 .filter_map(|(ws, _, _)| {
189 enumeration_result
190 .arithmetic_solutions
191 .iter()
192 .position(|orig| orig.world == ws.world)
193 })
194 .collect();
195
196 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
198 let constraint_domains =
199 extract_domains_from_constraint(&world_solution.constraint)?;
200
201 let mut is_valid = true;
203 for (fact_path, solved_domain) in &solved_domains {
204 if let Some(constraint_domain) = constraint_domains.get(fact_path) {
205 if let Domain::Enumeration(values) = solved_domain {
207 for value in values.iter() {
208 if !constraint_domain.contains(value) {
209 is_valid = false;
210 break;
211 }
212 }
213 }
214 }
215 if !is_valid {
216 break;
217 }
218 }
219
220 if !is_valid {
221 continue; }
223
224 let solved_outcome_result = OperationResult::Value(Box::new(solved_outcome));
225
226 let mut combined_domains = constraint_domains;
227 for (fact_path, domain) in solved_domains {
228 combined_domains.insert(fact_path, domain);
229 }
230
231 let solution = Solution {
232 outcome: solved_outcome_result,
233 world: world_solution.world,
234 shape: None,
235 };
236
237 solutions.push(solution);
238 all_domains.push(combined_domains);
239 }
240
241 indices
242 } else {
243 std::collections::HashSet::new()
244 };
245
246 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
249 if solved_indices.contains(&idx) {
250 continue; }
252
253 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
255
256 let unknown_facts =
258 extract_fact_paths_from_expression(&arith_solution.outcome_expression);
259 for fact_path in unknown_facts {
260 if !combined_domains.contains_key(&fact_path)
262 && !provided_facts.contains(&fact_path)
263 {
264 combined_domains.insert(fact_path, Domain::Unconstrained);
265 }
266 }
267
268 let solution = Solution {
269 outcome: OperationResult::Value(Box::new(target_value.as_ref().clone())),
270 world: arith_solution.world.clone(),
271 shape: Some(arith_solution.outcome_expression.clone()),
272 };
273
274 solutions.push(solution);
275 all_domains.push(combined_domains);
276 }
277 }
278
279 Ok(InversionResponse::new(solutions, all_domains))
280}
281
282fn filter_literal_solutions_by_target(
288 solutions: Vec<WorldSolution>,
289 target: &Target,
290) -> Vec<WorldSolution> {
291 let mut filtered = Vec::new();
292
293 for solution in solutions {
294 let matches = match (&target.outcome, &solution.outcome) {
295 (None, _) => {
296 true
298 }
299 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
300 match target.op {
305 TargetOp::Eq => outcome_value.value == target_value.value,
306 TargetOp::Neq => outcome_value.value != target_value.value,
307 TargetOp::Lt => {
308 compare_values(outcome_value, target_value)
309 == Some(std::cmp::Ordering::Less)
310 }
311 TargetOp::Lte => {
312 let cmp = compare_values(outcome_value, target_value);
313 cmp == Some(std::cmp::Ordering::Less)
314 || cmp == Some(std::cmp::Ordering::Equal)
315 }
316 TargetOp::Gt => {
317 compare_values(outcome_value, target_value)
318 == Some(std::cmp::Ordering::Greater)
319 }
320 TargetOp::Gte => {
321 let cmp = compare_values(outcome_value, target_value);
322 cmp == Some(std::cmp::Ordering::Greater)
323 || cmp == Some(std::cmp::Ordering::Equal)
324 }
325 }
326 }
327 (Some(OperationResult::Veto(target_msg)), OperationResult::Veto(outcome_msg)) => {
328 match target_msg {
330 None => true, Some(t_msg) => outcome_msg.as_ref().map(|m| m == t_msg).unwrap_or(false),
332 }
333 }
334 _ => false, };
336
337 if matches {
338 filtered.push(solution);
339 }
340 }
341
342 filtered
343}
344
345fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
347 match (&a.value, &b.value) {
348 (ValueKind::Number(a_val), ValueKind::Number(b_val)) => Some(a_val.cmp(b_val)),
349 (ValueKind::Ratio(a_val, _), ValueKind::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
350 (ValueKind::Scale(a_val, _), ValueKind::Scale(b_val, _)) => Some(a_val.cmp(b_val)),
351 (ValueKind::Duration(a_val, unit_a), ValueKind::Duration(b_val, unit_b)) => {
352 if unit_a == unit_b {
353 Some(a_val.cmp(b_val))
354 } else {
355 None
356 }
357 }
358 _ => None,
359 }
360}
361
362fn extract_fact_paths_from_expression(expr: &Expression) -> Vec<FactPath> {
364 let mut set = std::collections::HashSet::new();
365 expr.collect_fact_paths(&mut set);
366 set.into_iter().collect()
367}
368
369fn compute_undetermined_facts(all_domains: &[HashMap<FactPath, Domain>]) -> Vec<FactPath> {
371 let mut undetermined: HashSet<FactPath> = HashSet::new();
372
373 for solution_domains in all_domains {
374 for (fact_path, domain) in solution_domains {
375 let is_determined = matches!(
376 domain,
377 Domain::Enumeration(values) if values.len() == 1
378 );
379 if !is_determined {
380 undetermined.insert(fact_path.clone());
381 }
382 }
383 }
384
385 let mut result: Vec<FactPath> = undetermined.into_iter().collect();
386 result.sort_by_key(|a| a.to_string());
387 result
388}
389
390fn compute_is_determined(all_domains: &[HashMap<FactPath, Domain>]) -> bool {
392 if all_domains.is_empty() {
393 return true;
394 }
395
396 for solution_domains in all_domains {
397 for domain in solution_domains.values() {
398 let is_single_value = matches!(
399 domain,
400 Domain::Enumeration(values) if values.len() == 1
401 );
402 if !is_single_value {
403 return false;
404 }
405 }
406 }
407
408 true
409}
410
411#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::parsing::ast::DateTimeValue;
419 use crate::Engine;
420 use rust_decimal::Decimal;
421 use std::collections::HashMap;
422 use std::sync::Arc;
423
424 fn add_lemma_code_blocking(
425 engine: &mut Engine,
426 code: &str,
427 source: &str,
428 ) -> Result<(), Vec<crate::Error>> {
429 let files: std::collections::HashMap<String, String> =
430 std::iter::once((source.to_string(), code.to_string())).collect();
431 engine.add_lemma_files(files)
432 }
433
434 #[test]
435 fn test_format_target_eq() {
436 let target = Target::value(LiteralValue::number(Decimal::from(42)));
437 let formatted = target.format();
438 assert_eq!(formatted, "= 42");
439 }
440
441 #[test]
442 fn test_format_target_any() {
443 let target = Target::any_value();
444 let formatted = target.format();
445 assert_eq!(formatted, "= any");
446 }
447
448 #[test]
449 fn test_compute_undetermined_facts_empty() {
450 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
451 let undetermined = compute_undetermined_facts(&domains);
452 assert!(undetermined.is_empty());
453 }
454
455 #[test]
456 fn test_compute_undetermined_facts_single_value() {
457 let mut domain_map = HashMap::new();
458 domain_map.insert(
459 FactPath::new(vec![], "age".to_string()),
460 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
461 );
462 let domains = vec![domain_map];
463 let undetermined = compute_undetermined_facts(&domains);
464 assert!(undetermined.is_empty());
465 }
466
467 #[test]
468 fn test_compute_undetermined_facts_range() {
469 let mut domain_map = HashMap::new();
470 domain_map.insert(
471 FactPath::new(vec![], "age".to_string()),
472 Domain::Range {
473 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
474 max: Bound::Unbounded,
475 },
476 );
477 let domains = vec![domain_map];
478 let undetermined = compute_undetermined_facts(&domains);
479 assert_eq!(undetermined.len(), 1);
480 }
481
482 #[test]
483 fn test_compute_is_determined_empty() {
484 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
485 assert!(compute_is_determined(&domains));
486 }
487
488 #[test]
489 fn test_compute_is_determined_true() {
490 let mut domain_map = HashMap::new();
491 domain_map.insert(
492 FactPath::new(vec![], "age".to_string()),
493 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
494 );
495 let domains = vec![domain_map];
496 assert!(compute_is_determined(&domains));
497 }
498
499 #[test]
500 fn test_compute_is_determined_false() {
501 let mut domain_map = HashMap::new();
502 domain_map.insert(
503 FactPath::new(vec![], "age".to_string()),
504 Domain::Range {
505 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
506 max: Bound::Unbounded,
507 },
508 );
509 let domains = vec![domain_map];
510 assert!(!compute_is_determined(&domains));
511 }
512
513 #[test]
514 fn test_invert_strict_rule_reference_expands_constraints() {
515 let code = r#"
518spec example
519fact x: [number]
520rule base: x
521 unless x > 3 then veto "too much"
522 unless x < 0 then veto "too little"
523
524rule another: base
525 unless x > 5 then veto "way too much"
526"#;
527
528 let mut engine = Engine::new();
529 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
530 let now = DateTimeValue::now();
531
532 let inv = engine
533 .invert(
534 "example",
535 &now,
536 "another",
537 Target::value(LiteralValue::number(3.into())),
538 HashMap::new(),
539 )
540 .expect("inversion should succeed");
541
542 assert!(!inv.is_empty(), "expected at least one solution");
543
544 let x = FactPath::new(vec![], "x".to_string());
545 let three = LiteralValue::number(3.into());
546
547 for (_solution, domains) in inv.iter() {
549 let d = domains.get(&x).expect("domain for x should exist");
550 assert!(
551 d.contains(&three),
552 "x domain should contain 3. Domain: {}",
553 d
554 );
555 }
556 }
557
558 #[test]
559 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
560 let code = r#"
561spec example
562fact x: [number]
563rule base: x
564 unless x > 3 then veto "too much"
565 unless x < 0 then veto "too little"
566
567rule another: base
568 unless x > 5 then veto "way too much"
569"#;
570
571 let mut engine = Engine::new();
572 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
573 let now = DateTimeValue::now();
574
575 let inv = engine
576 .invert(
577 "example",
578 &now,
579 "another",
580 Target::value(LiteralValue::number(7.into())),
581 HashMap::new(),
582 )
583 .expect("inversion should succeed");
584
585 assert!(
586 inv.is_empty(),
587 "Should have no solutions because another can never equal 7"
588 );
589 }
590
591 #[test]
592 fn test_invert_strict_veto_target_constrains_domain() {
593 let code = r#"
594spec example
595fact x: [number]
596rule base: x
597 unless x > 3 then veto "too much"
598 unless x < 0 then veto "too little"
599
600rule another: base
601 unless x > 5 then veto "way too much"
602"#;
603
604 let mut engine = Engine::new();
605 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
606 let now = DateTimeValue::now();
607
608 let inv = engine
609 .invert(
610 "example",
611 &now,
612 "another",
613 Target::veto(Some("way too much".to_string())),
614 HashMap::new(),
615 )
616 .expect("inversion should succeed");
617
618 assert!(!inv.is_empty(), "expected solutions for veto query");
619
620 let x = FactPath::new(vec![], "x".to_string());
621 let five = LiteralValue::number(5.into());
622 let six = LiteralValue::number(6.into());
623
624 for (solution, domains) in inv.iter() {
625 assert_eq!(
626 solution.outcome,
627 OperationResult::Veto(Some("way too much".to_string())),
628 "Expected solution outcome to be veto('way too much'), got: {:?}",
629 solution.outcome
630 );
631
632 let d = domains.get(&x).expect("domain for x should exist");
633 match d {
634 Domain::Range { min, max } => {
635 assert!(
636 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
637 "Expected min bound to be (5), got: {}",
638 d
639 );
640 assert!(
641 matches!(max, Bound::Unbounded),
642 "Expected max bound to be +inf, got: {}",
643 d
644 );
645 }
646 other => panic!("Expected range domain for x, got: {}", other),
647 }
648 assert!(
649 !d.contains(&five),
650 "x=5 should not be in veto('way too much') domain. Domain: {}",
651 d
652 );
653 assert!(
654 d.contains(&six),
655 "x=6 should be in veto('way too much') domain. Domain: {}",
656 d
657 );
658 }
659 }
660
661 #[test]
662 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
663 let code = r#"
664spec example
665fact x: [number]
666rule base: x
667 unless x > 3 then veto "too much"
668 unless x < 0 then veto "too little"
669
670rule another: base
671 unless x > 5 then veto "way too much"
672"#;
673
674 let mut engine = Engine::new();
675 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
676
677 let now = DateTimeValue::now();
678 let inv = engine
679 .invert(
680 "example",
681 &now,
682 "another",
683 Target::any_veto(),
684 HashMap::new(),
685 )
686 .expect("inversion should succeed");
687
688 assert!(!inv.is_empty(), "expected solutions for any-veto query");
689
690 let x = FactPath::new(vec![], "x".to_string());
691 let minus_one = LiteralValue::number((-1).into());
692 let zero = LiteralValue::number(0.into());
693 let two = LiteralValue::number(2.into());
694 let three = LiteralValue::number(3.into());
695 let four = LiteralValue::number(4.into());
696 let five = LiteralValue::number(5.into());
697 let six = LiteralValue::number(6.into());
698
699 let mut saw_too_little = false;
700 let mut saw_too_much = false;
701 let mut saw_way_too_much = false;
702
703 for (solution, domains) in inv.iter() {
704 let d = domains.get(&x).expect("domain for x should exist");
705 assert!(
706 !d.contains(&two),
707 "x=2 should not be in any-veto domain. Domain: {}",
708 d
709 );
710
711 match &solution.outcome {
712 OperationResult::Veto(Some(msg)) if msg == "too little" => {
713 saw_too_little = true;
714
715 match d {
716 Domain::Range { min, max } => {
717 assert!(
718 matches!(min, Bound::Unbounded),
719 "Expected min bound to be -inf for 'too little', got: {}",
720 d
721 );
722 assert!(
723 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
724 "Expected max bound to be (0) for 'too little', got: {}",
725 d
726 );
727 }
728 other => panic!("Expected range domain for x, got: {}", other),
729 }
730
731 assert!(
732 d.contains(&minus_one),
733 "x=-1 should be in veto('too little') domain. Domain: {}",
734 d
735 );
736 assert!(
737 !d.contains(&zero),
738 "x=0 should not be in veto('too little') domain. Domain: {}",
739 d
740 );
741 }
742 OperationResult::Veto(Some(msg)) if msg == "too much" => {
743 saw_too_much = true;
744
745 match d {
746 Domain::Range { min, max } => {
747 assert!(
748 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
749 "Expected min bound to be (3) for 'too much', got: {}",
750 d
751 );
752 assert!(
753 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
754 "Expected max bound to be [5] for 'too much', got: {}",
755 d
756 );
757 }
758 other => panic!("Expected range domain for x, got: {}", other),
759 }
760
761 assert!(
762 d.contains(&four),
763 "x=4 should be in veto('too much') domain. Domain: {}",
764 d
765 );
766 assert!(
767 d.contains(&five),
768 "x=5 should be in veto('too much') domain. Domain: {}",
769 d
770 );
771 assert!(
772 !d.contains(&three),
773 "x=3 should not be in veto('too much') domain. Domain: {}",
774 d
775 );
776 assert!(
777 !d.contains(&six),
778 "x=6 should not be in veto('too much') domain. Domain: {}",
779 d
780 );
781 }
782 OperationResult::Veto(Some(msg)) if msg == "way too much" => {
783 saw_way_too_much = true;
784
785 match d {
786 Domain::Range { min, max } => {
787 assert!(
788 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
789 "Expected min bound to be (5) for 'way too much', got: {}",
790 d
791 );
792 assert!(
793 matches!(max, Bound::Unbounded),
794 "Expected max bound to be +inf for 'way too much', got: {}",
795 d
796 );
797 }
798 other => panic!("Expected range domain for x, got: {}", other),
799 }
800
801 assert!(
802 d.contains(&six),
803 "x=6 should be in veto('way too much') domain. Domain: {}",
804 d
805 );
806 assert!(
807 !d.contains(&five),
808 "x=5 should not be in veto('way too much') domain. Domain: {}",
809 d
810 );
811 }
812 OperationResult::Veto(Some(other)) => {
813 panic!("Unexpected veto message in any-veto results: {:?}", other)
814 }
815 OperationResult::Veto(None) => {
816 panic!("Unexpected veto(None) in any-veto results (expected a message)")
817 }
818 OperationResult::Value(v) => {
819 panic!("Unexpected value result in any-veto results: {:?}", v)
820 }
821 }
822 }
823
824 assert!(
825 saw_too_little,
826 "Expected at least one veto('too little') solution"
827 );
828 assert!(
829 saw_too_much,
830 "Expected at least one veto('too much') solution"
831 );
832 assert!(
833 saw_way_too_much,
834 "Expected at least one veto('way too much') solution"
835 );
836 }
837}