1mod constraint;
10mod derived;
11mod domain;
12mod solve;
13mod target;
14mod world;
15
16pub use derived::{DerivedExpression, DerivedExpressionKind};
17pub use domain::{extract_domains_from_constraint, Bound, Domain};
18pub use target::{Target, TargetOp};
19pub use world::World;
20
21use crate::planning::semantics::{Expression, FactPath, LiteralValue, ValueKind};
22use crate::planning::ExecutionPlan;
23use crate::{LemmaError, LemmaResult, OperationResult};
24use serde::ser::{Serialize, SerializeStruct, Serializer};
25use std::collections::{HashMap, HashSet};
26
27use world::{WorldEnumerator, WorldSolution};
28
29#[derive(Debug, Clone, serde::Serialize)]
38pub struct Solution {
39 pub outcome: OperationResult,
41 pub world: World,
43 #[serde(skip_serializing_if = "Option::is_none")]
46 pub shape: Option<Expression>,
47}
48
49#[derive(Debug, Clone)]
51pub struct InversionResponse {
52 pub solutions: Vec<Solution>,
54 pub domains: Vec<HashMap<FactPath, Domain>>,
56 pub undetermined_facts: Vec<FactPath>,
58 pub is_determined: bool,
60}
61
62impl InversionResponse {
63 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<FactPath, Domain>>) -> Self {
65 let undetermined_facts = compute_undetermined_facts(&domains);
66 let is_determined = compute_is_determined(&domains);
67 Self {
68 solutions,
69 domains,
70 undetermined_facts,
71 is_determined,
72 }
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.solutions.is_empty()
78 }
79
80 pub fn len(&self) -> usize {
82 self.solutions.len()
83 }
84
85 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<FactPath, Domain>)> {
87 self.solutions.iter().zip(self.domains.iter())
88 }
89}
90
91impl Serialize for InversionResponse {
92 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93 where
94 S: Serializer,
95 {
96 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
97 state.serialize_field("solutions", &self.solutions)?;
98
99 let domains_serializable: Vec<HashMap<String, String>> = self
100 .domains
101 .iter()
102 .map(|d| {
103 d.iter()
104 .map(|(k, v)| (k.to_string(), v.to_string()))
105 .collect()
106 })
107 .collect();
108 state.serialize_field("domains", &domains_serializable)?;
109
110 let undetermined_serializable: Vec<String> = self
111 .undetermined_facts
112 .iter()
113 .map(|fp| fp.to_string())
114 .collect();
115 state.serialize_field("undetermined_facts", &undetermined_serializable)?;
116 state.serialize_field("is_determined", &self.is_determined)?;
117 state.end()
118 }
119}
120
121pub fn invert(
136 rule_name: &str,
137 target: Target,
138 plan: &ExecutionPlan,
139 provided_facts: &HashSet<FactPath>,
140) -> LemmaResult<InversionResponse> {
141 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
142 LemmaError::engine(
143 format!("Rule not found: {}.{}", plan.doc_name, rule_name),
144 None,
145 None::<String>,
146 )
147 })?;
148
149 let rule_path = executable_rule.path.clone();
150
151 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
153 let enumeration_result = enumerator.enumerate(provided_facts)?;
154
155 let mut solutions = Vec::new();
157 let mut all_domains = Vec::new();
158
159 let filtered_literal_solutions =
161 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
162
163 for world_solution in filtered_literal_solutions {
164 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
165
166 let solution = Solution {
167 outcome: world_solution.outcome,
168 world: world_solution.world,
169 shape: None,
170 };
171
172 solutions.push(solution);
173 all_domains.push(constraint_domains);
174 }
175
176 if let Some(OperationResult::Value(target_value)) = &target.outcome {
178 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
180 let algebraic_solutions = solve::solve_arithmetic_batch(
181 enumeration_result.arithmetic_solutions.clone(),
182 target_value,
183 provided_facts,
184 );
185
186 let indices: std::collections::HashSet<usize> = algebraic_solutions
188 .iter()
189 .map(|(ws, _, _)| {
190 enumeration_result
191 .arithmetic_solutions
192 .iter()
193 .position(|orig| orig.world == ws.world)
194 .unwrap_or(usize::MAX)
195 })
196 .collect();
197
198 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
200 let constraint_domains =
201 extract_domains_from_constraint(&world_solution.constraint)?;
202
203 let mut is_valid = true;
205 for (fact_path, solved_domain) in &solved_domains {
206 if let Some(constraint_domain) = constraint_domains.get(fact_path) {
207 if let Domain::Enumeration(values) = solved_domain {
209 for value in values.iter() {
210 if !constraint_domain.contains(value) {
211 is_valid = false;
212 break;
213 }
214 }
215 }
216 }
217 if !is_valid {
218 break;
219 }
220 }
221
222 if !is_valid {
223 continue; }
225
226 let solved_outcome_result = OperationResult::Value(Box::new(solved_outcome));
227
228 let mut combined_domains = constraint_domains;
229 for (fact_path, domain) in solved_domains {
230 combined_domains.insert(fact_path, domain);
231 }
232
233 let solution = Solution {
234 outcome: solved_outcome_result,
235 world: world_solution.world,
236 shape: None,
237 };
238
239 solutions.push(solution);
240 all_domains.push(combined_domains);
241 }
242
243 indices
244 } else {
245 std::collections::HashSet::new()
246 };
247
248 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
251 if solved_indices.contains(&idx) {
252 continue; }
254
255 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
257
258 let unknown_facts =
260 extract_fact_paths_from_expression(&arith_solution.outcome_expression);
261 for fact_path in unknown_facts {
262 if !combined_domains.contains_key(&fact_path)
264 && !provided_facts.contains(&fact_path)
265 {
266 combined_domains.insert(fact_path, Domain::Unconstrained);
267 }
268 }
269
270 let solution = Solution {
271 outcome: OperationResult::Value(Box::new(target_value.as_ref().clone())),
272 world: arith_solution.world.clone(),
273 shape: Some(arith_solution.outcome_expression.clone()),
274 };
275
276 solutions.push(solution);
277 all_domains.push(combined_domains);
278 }
279 }
280
281 Ok(InversionResponse::new(solutions, all_domains))
282}
283
284fn filter_literal_solutions_by_target(
290 solutions: Vec<WorldSolution>,
291 target: &Target,
292) -> Vec<WorldSolution> {
293 let mut filtered = Vec::new();
294
295 for solution in solutions {
296 let matches = match (&target.outcome, &solution.outcome) {
297 (None, _) => {
298 true
300 }
301 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
302 match target.op {
307 TargetOp::Eq => outcome_value.value == target_value.value,
308 TargetOp::Neq => outcome_value.value != target_value.value,
309 TargetOp::Lt => {
310 compare_values(outcome_value, target_value)
311 == Some(std::cmp::Ordering::Less)
312 }
313 TargetOp::Lte => {
314 let cmp = compare_values(outcome_value, target_value);
315 cmp == Some(std::cmp::Ordering::Less)
316 || cmp == Some(std::cmp::Ordering::Equal)
317 }
318 TargetOp::Gt => {
319 compare_values(outcome_value, target_value)
320 == Some(std::cmp::Ordering::Greater)
321 }
322 TargetOp::Gte => {
323 let cmp = compare_values(outcome_value, target_value);
324 cmp == Some(std::cmp::Ordering::Greater)
325 || cmp == Some(std::cmp::Ordering::Equal)
326 }
327 }
328 }
329 (Some(OperationResult::Veto(target_msg)), OperationResult::Veto(outcome_msg)) => {
330 match target_msg {
332 None => true, Some(t_msg) => outcome_msg.as_ref().map(|m| m == t_msg).unwrap_or(false),
334 }
335 }
336 _ => false, };
338
339 if matches {
340 filtered.push(solution);
341 }
342 }
343
344 filtered
345}
346
347fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
349 match (&a.value, &b.value) {
350 (ValueKind::Number(a_val), ValueKind::Number(b_val)) => Some(a_val.cmp(b_val)),
351 (ValueKind::Ratio(a_val, _), ValueKind::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
352 (ValueKind::Scale(a_val, _), ValueKind::Scale(b_val, _)) => Some(a_val.cmp(b_val)),
353 (ValueKind::Duration(a_val, unit_a), ValueKind::Duration(b_val, unit_b)) => {
354 if unit_a == unit_b {
355 Some(a_val.cmp(b_val))
356 } else {
357 None
358 }
359 }
360 _ => None,
361 }
362}
363
364fn extract_fact_paths_from_expression(expr: &Expression) -> Vec<FactPath> {
366 let mut set = std::collections::HashSet::new();
367 expr.collect_fact_paths(&mut set);
368 set.into_iter().collect()
369}
370
371fn compute_undetermined_facts(all_domains: &[HashMap<FactPath, Domain>]) -> Vec<FactPath> {
373 let mut undetermined: HashSet<FactPath> = HashSet::new();
374
375 for solution_domains in all_domains {
376 for (fact_path, domain) in solution_domains {
377 let is_determined = matches!(
378 domain,
379 Domain::Enumeration(values) if values.len() == 1
380 );
381 if !is_determined {
382 undetermined.insert(fact_path.clone());
383 }
384 }
385 }
386
387 let mut result: Vec<FactPath> = undetermined.into_iter().collect();
388 result.sort_by_key(|a| a.to_string());
389 result
390}
391
392fn compute_is_determined(all_domains: &[HashMap<FactPath, Domain>]) -> bool {
394 if all_domains.is_empty() {
395 return true;
396 }
397
398 for solution_domains in all_domains {
399 for domain in solution_domains.values() {
400 let is_single_value = matches!(
401 domain,
402 Domain::Enumeration(values) if values.len() == 1
403 );
404 if !is_single_value {
405 return false;
406 }
407 }
408 }
409
410 true
411}
412
413#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::Engine;
421 use rust_decimal::Decimal;
422 use std::collections::HashMap;
423 use std::sync::Arc;
424
425 fn add_lemma_code_blocking(
426 engine: &mut Engine,
427 code: &str,
428 source: &str,
429 ) -> crate::LemmaResult<()> {
430 let files: std::collections::HashMap<String, String> =
431 std::iter::once((source.to_string(), code.to_string())).collect();
432 tokio::runtime::Runtime::new()
433 .expect("tokio runtime")
434 .block_on(engine.add_lemma_files(files))
435 .map_err(|errs| match errs.len() {
436 0 => unreachable!("add_lemma_files returned Err with empty error list"),
437 1 => errs.into_iter().next().unwrap(),
438 _ => crate::LemmaError::MultipleErrors(errs),
439 })
440 }
441
442 #[test]
443 fn test_format_target_eq() {
444 let target = Target::value(LiteralValue::number(Decimal::from(42)));
445 let formatted = target.format();
446 assert_eq!(formatted, "= 42");
447 }
448
449 #[test]
450 fn test_format_target_any() {
451 let target = Target::any_value();
452 let formatted = target.format();
453 assert_eq!(formatted, "= any");
454 }
455
456 #[test]
457 fn test_compute_undetermined_facts_empty() {
458 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
459 let undetermined = compute_undetermined_facts(&domains);
460 assert!(undetermined.is_empty());
461 }
462
463 #[test]
464 fn test_compute_undetermined_facts_single_value() {
465 let mut domain_map = HashMap::new();
466 domain_map.insert(
467 FactPath::new(vec![], "age".to_string()),
468 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
469 );
470 let domains = vec![domain_map];
471 let undetermined = compute_undetermined_facts(&domains);
472 assert!(undetermined.is_empty());
473 }
474
475 #[test]
476 fn test_compute_undetermined_facts_range() {
477 let mut domain_map = HashMap::new();
478 domain_map.insert(
479 FactPath::new(vec![], "age".to_string()),
480 Domain::Range {
481 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
482 max: Bound::Unbounded,
483 },
484 );
485 let domains = vec![domain_map];
486 let undetermined = compute_undetermined_facts(&domains);
487 assert_eq!(undetermined.len(), 1);
488 }
489
490 #[test]
491 fn test_compute_is_determined_empty() {
492 let domains: Vec<HashMap<FactPath, Domain>> = vec![];
493 assert!(compute_is_determined(&domains));
494 }
495
496 #[test]
497 fn test_compute_is_determined_true() {
498 let mut domain_map = HashMap::new();
499 domain_map.insert(
500 FactPath::new(vec![], "age".to_string()),
501 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
502 );
503 let domains = vec![domain_map];
504 assert!(compute_is_determined(&domains));
505 }
506
507 #[test]
508 fn test_compute_is_determined_false() {
509 let mut domain_map = HashMap::new();
510 domain_map.insert(
511 FactPath::new(vec![], "age".to_string()),
512 Domain::Range {
513 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
514 max: Bound::Unbounded,
515 },
516 );
517 let domains = vec![domain_map];
518 assert!(!compute_is_determined(&domains));
519 }
520
521 #[test]
522 fn test_invert_strict_rule_reference_expands_constraints() {
523 let code = r#"
526doc example
527fact x = [number]
528rule base = x
529 unless x > 3 then veto "too much"
530 unless x < 0 then veto "too little"
531
532rule another = base?
533 unless x > 5 then veto "way too much"
534"#;
535
536 let mut engine = Engine::new();
537 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
538
539 let inv = engine
540 .invert(
541 "example",
542 "another",
543 Target::value(LiteralValue::number(3.into())),
544 HashMap::new(),
545 )
546 .expect("inversion should succeed");
547
548 assert!(!inv.is_empty(), "expected at least one solution");
549
550 let x = FactPath::new(vec![], "x".to_string());
551 let three = LiteralValue::number(3.into());
552
553 for (_solution, domains) in inv.iter() {
555 let d = domains.get(&x).expect("domain for x should exist");
556 assert!(
557 d.contains(&three),
558 "x domain should contain 3. Domain: {}",
559 d
560 );
561 }
562 }
563
564 #[test]
565 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
566 let code = r#"
567doc example
568fact x = [number]
569rule base = x
570 unless x > 3 then veto "too much"
571 unless x < 0 then veto "too little"
572
573rule another = base?
574 unless x > 5 then veto "way too much"
575"#;
576
577 let mut engine = Engine::new();
578 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
579
580 let inv = engine
581 .invert(
582 "example",
583 "another",
584 Target::value(LiteralValue::number(7.into())),
585 HashMap::new(),
586 )
587 .expect("inversion should succeed");
588
589 assert!(
590 inv.is_empty(),
591 "Should have no solutions because another can never equal 7"
592 );
593 }
594
595 #[test]
596 fn test_invert_strict_veto_target_constrains_domain() {
597 let code = r#"
598doc example
599fact x = [number]
600rule base = x
601 unless x > 3 then veto "too much"
602 unless x < 0 then veto "too little"
603
604rule another = base?
605 unless x > 5 then veto "way too much"
606"#;
607
608 let mut engine = Engine::new();
609 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
610
611 let inv = engine
612 .invert(
613 "example",
614 "another",
615 Target::veto(Some("way too much".to_string())),
616 HashMap::new(),
617 )
618 .expect("inversion should succeed");
619
620 assert!(!inv.is_empty(), "expected solutions for veto query");
621
622 let x = FactPath::new(vec![], "x".to_string());
623 let five = LiteralValue::number(5.into());
624 let six = LiteralValue::number(6.into());
625
626 for (solution, domains) in inv.iter() {
627 assert_eq!(
628 solution.outcome,
629 OperationResult::Veto(Some("way too much".to_string())),
630 "Expected solution outcome to be veto('way too much'), got: {:?}",
631 solution.outcome
632 );
633
634 let d = domains.get(&x).expect("domain for x should exist");
635 match d {
636 Domain::Range { min, max } => {
637 assert!(
638 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
639 "Expected min bound to be (5), got: {}",
640 d
641 );
642 assert!(
643 matches!(max, Bound::Unbounded),
644 "Expected max bound to be +inf, got: {}",
645 d
646 );
647 }
648 other => panic!("Expected range domain for x, got: {}", other),
649 }
650 assert!(
651 !d.contains(&five),
652 "x=5 should not be in veto('way too much') domain. Domain: {}",
653 d
654 );
655 assert!(
656 d.contains(&six),
657 "x=6 should be in veto('way too much') domain. Domain: {}",
658 d
659 );
660 }
661 }
662
663 #[test]
664 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
665 let code = r#"
666doc example
667fact x = [number]
668rule base = x
669 unless x > 3 then veto "too much"
670 unless x < 0 then veto "too little"
671
672rule another = base?
673 unless x > 5 then veto "way too much"
674"#;
675
676 let mut engine = Engine::new();
677 add_lemma_code_blocking(&mut engine, code, "test.lemma").unwrap();
678
679 let inv = engine
680 .invert("example", "another", Target::any_veto(), HashMap::new())
681 .expect("inversion should succeed");
682
683 assert!(!inv.is_empty(), "expected solutions for any-veto query");
684
685 let x = FactPath::new(vec![], "x".to_string());
686 let minus_one = LiteralValue::number((-1).into());
687 let zero = LiteralValue::number(0.into());
688 let two = LiteralValue::number(2.into());
689 let three = LiteralValue::number(3.into());
690 let four = LiteralValue::number(4.into());
691 let five = LiteralValue::number(5.into());
692 let six = LiteralValue::number(6.into());
693
694 let mut saw_too_little = false;
695 let mut saw_too_much = false;
696 let mut saw_way_too_much = false;
697
698 for (solution, domains) in inv.iter() {
699 let d = domains.get(&x).expect("domain for x should exist");
700 assert!(
701 !d.contains(&two),
702 "x=2 should not be in any-veto domain. Domain: {}",
703 d
704 );
705
706 match &solution.outcome {
707 OperationResult::Veto(Some(msg)) if msg == "too little" => {
708 saw_too_little = true;
709
710 match d {
711 Domain::Range { min, max } => {
712 assert!(
713 matches!(min, Bound::Unbounded),
714 "Expected min bound to be -inf for 'too little', got: {}",
715 d
716 );
717 assert!(
718 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
719 "Expected max bound to be (0) for 'too little', got: {}",
720 d
721 );
722 }
723 other => panic!("Expected range domain for x, got: {}", other),
724 }
725
726 assert!(
727 d.contains(&minus_one),
728 "x=-1 should be in veto('too little') domain. Domain: {}",
729 d
730 );
731 assert!(
732 !d.contains(&zero),
733 "x=0 should not be in veto('too little') domain. Domain: {}",
734 d
735 );
736 }
737 OperationResult::Veto(Some(msg)) if msg == "too much" => {
738 saw_too_much = true;
739
740 match d {
741 Domain::Range { min, max } => {
742 assert!(
743 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
744 "Expected min bound to be (3) for 'too much', got: {}",
745 d
746 );
747 assert!(
748 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
749 "Expected max bound to be [5] for 'too much', got: {}",
750 d
751 );
752 }
753 other => panic!("Expected range domain for x, got: {}", other),
754 }
755
756 assert!(
757 d.contains(&four),
758 "x=4 should be in veto('too much') domain. Domain: {}",
759 d
760 );
761 assert!(
762 d.contains(&five),
763 "x=5 should be in veto('too much') domain. Domain: {}",
764 d
765 );
766 assert!(
767 !d.contains(&three),
768 "x=3 should not be in veto('too much') domain. Domain: {}",
769 d
770 );
771 assert!(
772 !d.contains(&six),
773 "x=6 should not be in veto('too much') domain. Domain: {}",
774 d
775 );
776 }
777 OperationResult::Veto(Some(msg)) if msg == "way too much" => {
778 saw_way_too_much = true;
779
780 match d {
781 Domain::Range { min, max } => {
782 assert!(
783 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
784 "Expected min bound to be (5) for 'way too much', got: {}",
785 d
786 );
787 assert!(
788 matches!(max, Bound::Unbounded),
789 "Expected max bound to be +inf for 'way too much', got: {}",
790 d
791 );
792 }
793 other => panic!("Expected range domain for x, got: {}", other),
794 }
795
796 assert!(
797 d.contains(&six),
798 "x=6 should be in veto('way too much') domain. Domain: {}",
799 d
800 );
801 assert!(
802 !d.contains(&five),
803 "x=5 should not be in veto('way too much') domain. Domain: {}",
804 d
805 );
806 }
807 OperationResult::Veto(Some(other)) => {
808 panic!("Unexpected veto message in any-veto results: {:?}", other)
809 }
810 OperationResult::Veto(None) => {
811 panic!("Unexpected veto(None) in any-veto results (expected a message)")
812 }
813 OperationResult::Value(v) => {
814 panic!("Unexpected value result in any-veto results: {:?}", v)
815 }
816 }
817 }
818
819 assert!(
820 saw_too_little,
821 "Expected at least one veto('too little') solution"
822 );
823 assert!(
824 saw_too_much,
825 "Expected at least one veto('too much') solution"
826 );
827 assert!(
828 saw_way_too_much,
829 "Expected at least one veto('way too much') solution"
830 );
831 }
832}