1mod constraint;
10mod derived;
11mod domain;
12mod solve;
13mod target;
14mod world;
15
16pub use derived::{DerivedExpression, DerivedExpressionKind};
17pub use domain::{extract_domains_from_constraint, Bound, Domain};
18pub use target::{Target, TargetOp};
19pub use world::World;
20
21use crate::evaluation::operations::VetoType;
22use crate::planning::semantics::{DataPath, Expression, LiteralValue, ValueKind};
23use crate::planning::ExecutionPlan;
24use crate::{Error, OperationResult};
25use serde::ser::{Serialize, SerializeStruct, Serializer};
26use std::collections::{HashMap, HashSet};
27
28use world::{WorldEnumerator, WorldSolution};
29
30#[derive(Debug, Clone, serde::Serialize)]
39pub struct Solution {
40 pub outcome: OperationResult,
42 pub world: World,
44 #[serde(skip_serializing_if = "Option::is_none")]
47 pub shape: Option<Expression>,
48}
49
50#[derive(Debug, Clone)]
52pub struct InversionResponse {
53 pub solutions: Vec<Solution>,
55 pub domains: Vec<HashMap<DataPath, Domain>>,
57 pub undetermined_data: Vec<DataPath>,
59 pub is_determined: bool,
61}
62
63impl InversionResponse {
64 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<DataPath, Domain>>) -> Self {
66 let undetermined_data = compute_undetermined_data(&domains);
67 let is_determined = compute_is_determined(&domains);
68 Self {
69 solutions,
70 domains,
71 undetermined_data,
72 is_determined,
73 }
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.solutions.is_empty()
79 }
80
81 pub fn len(&self) -> usize {
83 self.solutions.len()
84 }
85
86 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<DataPath, Domain>)> {
88 self.solutions.iter().zip(self.domains.iter())
89 }
90}
91
92impl Serialize for InversionResponse {
93 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
94 where
95 S: Serializer,
96 {
97 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
98 state.serialize_field("solutions", &self.solutions)?;
99
100 let domains_serializable: Vec<HashMap<String, String>> = self
101 .domains
102 .iter()
103 .map(|d| {
104 d.iter()
105 .map(|(k, v)| (k.to_string(), v.to_string()))
106 .collect()
107 })
108 .collect();
109 state.serialize_field("domains", &domains_serializable)?;
110
111 let undetermined_serializable: Vec<String> = self
112 .undetermined_data
113 .iter()
114 .map(|fp| fp.to_string())
115 .collect();
116 state.serialize_field("undetermined_data", &undetermined_serializable)?;
117 state.serialize_field("is_determined", &self.is_determined)?;
118 state.end()
119 }
120}
121
122pub fn invert(
137 rule_name: &str,
138 target: Target,
139 plan: &ExecutionPlan,
140 provided_data: &HashSet<DataPath>,
141) -> Result<InversionResponse, Error> {
142 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
143 Error::request(
144 format!("Rule not found: {}.{}", plan.spec_name, rule_name),
145 None::<String>,
146 )
147 })?;
148
149 let rule_path = executable_rule.path.clone();
150
151 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
153 let enumeration_result = enumerator.enumerate(provided_data)?;
154
155 let mut solutions = Vec::new();
157 let mut all_domains = Vec::new();
158
159 let filtered_literal_solutions =
161 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
162
163 for world_solution in filtered_literal_solutions {
164 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
165
166 let solution = Solution {
167 outcome: world_solution.outcome,
168 world: world_solution.world,
169 shape: None,
170 };
171
172 solutions.push(solution);
173 all_domains.push(constraint_domains);
174 }
175
176 if let Some(OperationResult::Value(target_value)) = &target.outcome {
178 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
180 let algebraic_solutions = solve::solve_arithmetic_batch(
181 enumeration_result.arithmetic_solutions.clone(),
182 target_value,
183 provided_data,
184 );
185
186 let indices: std::collections::HashSet<usize> = algebraic_solutions
188 .iter()
189 .filter_map(|(ws, _, _)| {
190 enumeration_result
191 .arithmetic_solutions
192 .iter()
193 .position(|orig| orig.world == ws.world)
194 })
195 .collect();
196
197 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
199 let constraint_domains =
200 extract_domains_from_constraint(&world_solution.constraint)?;
201
202 let mut is_valid = true;
204 for (data_path, solved_domain) in &solved_domains {
205 if let Some(constraint_domain) = constraint_domains.get(data_path) {
206 if let Domain::Enumeration(values) = solved_domain {
208 for value in values.iter() {
209 if !constraint_domain.contains(value) {
210 is_valid = false;
211 break;
212 }
213 }
214 }
215 }
216 if !is_valid {
217 break;
218 }
219 }
220
221 if !is_valid {
222 continue; }
224
225 let solved_outcome_result = OperationResult::Value(Box::new(solved_outcome));
226
227 let mut combined_domains = constraint_domains;
228 for (data_path, domain) in solved_domains {
229 combined_domains.insert(data_path, domain);
230 }
231
232 let solution = Solution {
233 outcome: solved_outcome_result,
234 world: world_solution.world,
235 shape: None,
236 };
237
238 solutions.push(solution);
239 all_domains.push(combined_domains);
240 }
241
242 indices
243 } else {
244 std::collections::HashSet::new()
245 };
246
247 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
250 if solved_indices.contains(&idx) {
251 continue; }
253
254 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
256
257 let unknown_data =
259 extract_data_paths_from_expression(&arith_solution.outcome_expression);
260 for data_path in unknown_data {
261 if !combined_domains.contains_key(&data_path) && !provided_data.contains(&data_path)
263 {
264 combined_domains.insert(data_path, Domain::Unconstrained);
265 }
266 }
267
268 let solution = Solution {
269 outcome: OperationResult::Value(Box::new(target_value.as_ref().clone())),
270 world: arith_solution.world.clone(),
271 shape: Some(arith_solution.outcome_expression.clone()),
272 };
273
274 solutions.push(solution);
275 all_domains.push(combined_domains);
276 }
277 }
278
279 Ok(InversionResponse::new(solutions, all_domains))
280}
281
282fn filter_literal_solutions_by_target(
288 solutions: Vec<WorldSolution>,
289 target: &Target,
290) -> Vec<WorldSolution> {
291 let mut filtered = Vec::new();
292
293 for solution in solutions {
294 let matches = match (&target.outcome, &solution.outcome) {
295 (None, _) => {
296 true
298 }
299 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
300 match target.op {
305 TargetOp::Eq => outcome_value.value == target_value.value,
306 TargetOp::Neq => outcome_value.value != target_value.value,
307 TargetOp::Lt => {
308 compare_values(outcome_value, target_value)
309 == Some(std::cmp::Ordering::Less)
310 }
311 TargetOp::Lte => {
312 let cmp = compare_values(outcome_value, target_value);
313 cmp == Some(std::cmp::Ordering::Less)
314 || cmp == Some(std::cmp::Ordering::Equal)
315 }
316 TargetOp::Gt => {
317 compare_values(outcome_value, target_value)
318 == Some(std::cmp::Ordering::Greater)
319 }
320 TargetOp::Gte => {
321 let cmp = compare_values(outcome_value, target_value);
322 cmp == Some(std::cmp::Ordering::Greater)
323 || cmp == Some(std::cmp::Ordering::Equal)
324 }
325 }
326 }
327 (Some(OperationResult::Veto(target_reason)), OperationResult::Veto(outcome_reason)) => {
328 match target_reason {
329 VetoType::UserDefined { message: None } => true, VetoType::UserDefined {
331 message: Some(ref t_msg),
332 } => matches!(
333 outcome_reason,
334 VetoType::UserDefined {
335 message: Some(ref o_msg)
336 }
337 if o_msg == t_msg
338 ),
339 _ => false,
340 }
341 }
342 _ => false, };
344
345 if matches {
346 filtered.push(solution);
347 }
348 }
349
350 filtered
351}
352
353fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
355 match (&a.value, &b.value) {
356 (ValueKind::Number(a_val), ValueKind::Number(b_val)) => Some(a_val.cmp(b_val)),
357 (ValueKind::Ratio(a_val, _), ValueKind::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
358 (ValueKind::Scale(a_val, _), ValueKind::Scale(b_val, _)) => Some(a_val.cmp(b_val)),
359 (ValueKind::Duration(a_val, unit_a), ValueKind::Duration(b_val, unit_b)) => {
360 if unit_a == unit_b {
361 Some(a_val.cmp(b_val))
362 } else {
363 None
364 }
365 }
366 _ => None,
367 }
368}
369
370fn extract_data_paths_from_expression(expr: &Expression) -> Vec<DataPath> {
372 let mut set = std::collections::HashSet::new();
373 expr.collect_data_paths(&mut set);
374 set.into_iter().collect()
375}
376
377fn compute_undetermined_data(all_domains: &[HashMap<DataPath, Domain>]) -> Vec<DataPath> {
379 let mut undetermined: HashSet<DataPath> = HashSet::new();
380
381 for solution_domains in all_domains {
382 for (data_path, domain) in solution_domains {
383 let is_determined = matches!(
384 domain,
385 Domain::Enumeration(values) if values.len() == 1
386 );
387 if !is_determined {
388 undetermined.insert(data_path.clone());
389 }
390 }
391 }
392
393 let mut result: Vec<DataPath> = undetermined.into_iter().collect();
394 result.sort_by_key(|a| a.to_string());
395 result
396}
397
398fn compute_is_determined(all_domains: &[HashMap<DataPath, Domain>]) -> bool {
400 if all_domains.is_empty() {
401 return true;
402 }
403
404 for solution_domains in all_domains {
405 for domain in solution_domains.values() {
406 let is_single_value = matches!(
407 domain,
408 Domain::Enumeration(values) if values.len() == 1
409 );
410 if !is_single_value {
411 return false;
412 }
413 }
414 }
415
416 true
417}
418
419#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::parsing::ast::DateTimeValue;
427 use crate::Engine;
428 use rust_decimal::Decimal;
429 use std::collections::HashMap;
430 use std::sync::Arc;
431
432 #[test]
433 fn test_format_target_eq() {
434 let target = Target::value(LiteralValue::number(Decimal::from(42)));
435 let formatted = target.format();
436 assert_eq!(formatted, "= 42");
437 }
438
439 #[test]
440 fn test_format_target_any() {
441 let target = Target::any_value();
442 let formatted = target.format();
443 assert_eq!(formatted, "= any");
444 }
445
446 #[test]
447 fn test_compute_undetermined_data_empty() {
448 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
449 let undetermined = compute_undetermined_data(&domains);
450 assert!(undetermined.is_empty());
451 }
452
453 #[test]
454 fn test_compute_undetermined_data_single_value() {
455 let mut domain_map = HashMap::new();
456 domain_map.insert(
457 DataPath::new(vec![], "age".to_string()),
458 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
459 );
460 let domains = vec![domain_map];
461 let undetermined = compute_undetermined_data(&domains);
462 assert!(undetermined.is_empty());
463 }
464
465 #[test]
466 fn test_compute_undetermined_data_range() {
467 let mut domain_map = HashMap::new();
468 domain_map.insert(
469 DataPath::new(vec![], "age".to_string()),
470 Domain::Range {
471 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
472 max: Bound::Unbounded,
473 },
474 );
475 let domains = vec![domain_map];
476 let undetermined = compute_undetermined_data(&domains);
477 assert_eq!(undetermined.len(), 1);
478 }
479
480 #[test]
481 fn test_compute_is_determined_empty() {
482 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
483 assert!(compute_is_determined(&domains));
484 }
485
486 #[test]
487 fn test_compute_is_determined_true() {
488 let mut domain_map = HashMap::new();
489 domain_map.insert(
490 DataPath::new(vec![], "age".to_string()),
491 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
492 );
493 let domains = vec![domain_map];
494 assert!(compute_is_determined(&domains));
495 }
496
497 #[test]
498 fn test_compute_is_determined_false() {
499 let mut domain_map = HashMap::new();
500 domain_map.insert(
501 DataPath::new(vec![], "age".to_string()),
502 Domain::Range {
503 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
504 max: Bound::Unbounded,
505 },
506 );
507 let domains = vec![domain_map];
508 assert!(!compute_is_determined(&domains));
509 }
510
511 #[test]
512 fn test_invert_strict_rule_reference_expands_constraints() {
513 let code = r#"
516spec example
517data x: number
518rule base: x
519 unless x > 3 then veto "too much"
520 unless x < 0 then veto "too little"
521
522rule another: base
523 unless x > 5 then veto "way too much"
524"#;
525
526 let mut engine = Engine::new();
527 engine
528 .load(code, crate::SourceType::Labeled("test.lemma"))
529 .unwrap();
530 let now = DateTimeValue::now();
531
532 let inv = engine
533 .invert(
534 "example",
535 Some(&now),
536 "another",
537 Target::value(LiteralValue::number(3.into())),
538 HashMap::new(),
539 )
540 .expect("inversion should succeed");
541
542 assert!(!inv.is_empty(), "expected at least one solution");
543
544 let x = DataPath::new(vec![], "x".to_string());
545 let three = LiteralValue::number(3.into());
546
547 for (_solution, domains) in inv.iter() {
549 let d = domains.get(&x).expect("domain for x should exist");
550 assert!(
551 d.contains(&three),
552 "x domain should contain 3. Domain: {}",
553 d
554 );
555 }
556 }
557
558 #[test]
559 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
560 let code = r#"
561spec example
562data x: number
563rule base: x
564 unless x > 3 then veto "too much"
565 unless x < 0 then veto "too little"
566
567rule another: base
568 unless x > 5 then veto "way too much"
569"#;
570
571 let mut engine = Engine::new();
572 engine
573 .load(code, crate::SourceType::Labeled("test.lemma"))
574 .unwrap();
575 let now = DateTimeValue::now();
576
577 let inv = engine
578 .invert(
579 "example",
580 Some(&now),
581 "another",
582 Target::value(LiteralValue::number(7.into())),
583 HashMap::new(),
584 )
585 .expect("inversion should succeed");
586
587 assert!(
588 inv.is_empty(),
589 "Should have no solutions because another can never equal 7"
590 );
591 }
592
593 #[test]
594 fn test_invert_strict_veto_target_constrains_domain() {
595 let code = r#"
596spec example
597data x: number
598rule base: x
599 unless x > 3 then veto "too much"
600 unless x < 0 then veto "too little"
601
602rule another: base
603 unless x > 5 then veto "way too much"
604"#;
605
606 let mut engine = Engine::new();
607 engine
608 .load(code, crate::SourceType::Labeled("test.lemma"))
609 .unwrap();
610 let now = DateTimeValue::now();
611
612 let inv = engine
613 .invert(
614 "example",
615 Some(&now),
616 "another",
617 Target::veto(Some("way too much".to_string())),
618 HashMap::new(),
619 )
620 .expect("inversion should succeed");
621
622 assert!(!inv.is_empty(), "expected solutions for veto query");
623
624 let x = DataPath::new(vec![], "x".to_string());
625 let five = LiteralValue::number(5.into());
626 let six = LiteralValue::number(6.into());
627
628 for (solution, domains) in inv.iter() {
629 assert_eq!(
630 solution.outcome,
631 OperationResult::Veto(VetoType::UserDefined {
632 message: Some("way too much".to_string()),
633 }),
634 "Expected solution outcome to be veto('way too much'), got: {:?}",
635 solution.outcome
636 );
637
638 let d = domains.get(&x).expect("domain for x should exist");
639 match d {
640 Domain::Range { min, max } => {
641 assert!(
642 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
643 "Expected min bound to be (5), got: {}",
644 d
645 );
646 assert!(
647 matches!(max, Bound::Unbounded),
648 "Expected max bound to be +inf, got: {}",
649 d
650 );
651 }
652 other => panic!("Expected range domain for x, got: {}", other),
653 }
654 assert!(
655 !d.contains(&five),
656 "x=5 should not be in veto('way too much') domain. Domain: {}",
657 d
658 );
659 assert!(
660 d.contains(&six),
661 "x=6 should be in veto('way too much') domain. Domain: {}",
662 d
663 );
664 }
665 }
666
667 #[test]
668 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
669 let code = r#"
670spec example
671data x: number
672rule base: x
673 unless x > 3 then veto "too much"
674 unless x < 0 then veto "too little"
675
676rule another: base
677 unless x > 5 then veto "way too much"
678"#;
679
680 let mut engine = Engine::new();
681 engine
682 .load(code, crate::SourceType::Labeled("test.lemma"))
683 .unwrap();
684
685 let now = DateTimeValue::now();
686 let inv = engine
687 .invert(
688 "example",
689 Some(&now),
690 "another",
691 Target::any_veto(),
692 HashMap::new(),
693 )
694 .expect("inversion should succeed");
695
696 assert!(!inv.is_empty(), "expected solutions for any-veto query");
697
698 let x = DataPath::new(vec![], "x".to_string());
699 let minus_one = LiteralValue::number((-1).into());
700 let zero = LiteralValue::number(0.into());
701 let two = LiteralValue::number(2.into());
702 let three = LiteralValue::number(3.into());
703 let four = LiteralValue::number(4.into());
704 let five = LiteralValue::number(5.into());
705 let six = LiteralValue::number(6.into());
706
707 let mut saw_too_little = false;
708 let mut saw_too_much = false;
709 let mut saw_way_too_much = false;
710
711 for (solution, domains) in inv.iter() {
712 let d = domains.get(&x).expect("domain for x should exist");
713 assert!(
714 !d.contains(&two),
715 "x=2 should not be in any-veto domain. Domain: {}",
716 d
717 );
718
719 match &solution.outcome {
720 OperationResult::Veto(VetoType::UserDefined {
721 message: Some(ref msg),
722 }) if msg == "too little" => {
723 saw_too_little = true;
724
725 match d {
726 Domain::Range { min, max } => {
727 assert!(
728 matches!(min, Bound::Unbounded),
729 "Expected min bound to be -inf for 'too little', got: {}",
730 d
731 );
732 assert!(
733 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
734 "Expected max bound to be (0) for 'too little', got: {}",
735 d
736 );
737 }
738 other => panic!("Expected range domain for x, got: {}", other),
739 }
740
741 assert!(
742 d.contains(&minus_one),
743 "x=-1 should be in veto('too little') domain. Domain: {}",
744 d
745 );
746 assert!(
747 !d.contains(&zero),
748 "x=0 should not be in veto('too little') domain. Domain: {}",
749 d
750 );
751 }
752 OperationResult::Veto(VetoType::UserDefined {
753 message: Some(ref msg),
754 }) if msg == "too much" => {
755 saw_too_much = true;
756
757 match d {
758 Domain::Range { min, max } => {
759 assert!(
760 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
761 "Expected min bound to be (3) for 'too much', got: {}",
762 d
763 );
764 assert!(
765 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
766 "Expected max bound to be [5] for 'too much', got: {}",
767 d
768 );
769 }
770 other => panic!("Expected range domain for x, got: {}", other),
771 }
772
773 assert!(
774 d.contains(&four),
775 "x=4 should be in veto('too much') domain. Domain: {}",
776 d
777 );
778 assert!(
779 d.contains(&five),
780 "x=5 should be in veto('too much') domain. Domain: {}",
781 d
782 );
783 assert!(
784 !d.contains(&three),
785 "x=3 should not be in veto('too much') domain. Domain: {}",
786 d
787 );
788 assert!(
789 !d.contains(&six),
790 "x=6 should not be in veto('too much') domain. Domain: {}",
791 d
792 );
793 }
794 OperationResult::Veto(VetoType::UserDefined {
795 message: Some(ref msg),
796 }) if msg == "way too much" => {
797 saw_way_too_much = true;
798
799 match d {
800 Domain::Range { min, max } => {
801 assert!(
802 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
803 "Expected min bound to be (5) for 'way too much', got: {}",
804 d
805 );
806 assert!(
807 matches!(max, Bound::Unbounded),
808 "Expected max bound to be +inf for 'way too much', got: {}",
809 d
810 );
811 }
812 other => panic!("Expected range domain for x, got: {}", other),
813 }
814
815 assert!(
816 d.contains(&six),
817 "x=6 should be in veto('way too much') domain. Domain: {}",
818 d
819 );
820 assert!(
821 !d.contains(&five),
822 "x=5 should not be in veto('way too much') domain. Domain: {}",
823 d
824 );
825 }
826 OperationResult::Veto(other) => {
827 panic!("Unexpected veto in any-veto results: {:?}", other)
828 }
829 OperationResult::Value(v) => {
830 panic!("Unexpected value result in any-veto results: {:?}", v)
831 }
832 }
833 }
834
835 assert!(
836 saw_too_little,
837 "Expected at least one veto('too little') solution"
838 );
839 assert!(
840 saw_too_much,
841 "Expected at least one veto('too much') solution"
842 );
843 assert!(
844 saw_way_too_much,
845 "Expected at least one veto('way too much') solution"
846 );
847 }
848}