1mod constraint;
10mod derived;
11mod domain;
12mod solve;
13mod target;
14mod world;
15
16pub use derived::{DerivedExpression, DerivedExpressionKind};
17pub use domain::{extract_domains_from_constraint, Bound, Domain};
18pub use target::{Target, TargetOp};
19pub use world::World;
20
21use crate::evaluation::operations::VetoType;
22use crate::planning::semantics::{DataPath, Expression, LiteralValue, ValueKind};
23use crate::planning::ExecutionPlan;
24use crate::{Error, OperationResult};
25use serde::ser::{Serialize, SerializeStruct, Serializer};
26use std::collections::{HashMap, HashSet};
27
28use world::{WorldEnumerator, WorldSolution};
29
30#[derive(Debug, Clone, serde::Serialize)]
39pub struct Solution {
40 pub outcome: OperationResult,
42 pub world: World,
44 #[serde(skip_serializing_if = "Option::is_none")]
47 pub shape: Option<Expression>,
48}
49
50#[derive(Debug, Clone)]
52pub struct InversionResponse {
53 pub solutions: Vec<Solution>,
55 pub domains: Vec<HashMap<DataPath, Domain>>,
57 pub undetermined_data: Vec<DataPath>,
59 pub is_determined: bool,
61}
62
63impl InversionResponse {
64 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<DataPath, Domain>>) -> Self {
66 let undetermined_data = compute_undetermined_data(&domains);
67 let is_determined = compute_is_determined(&domains);
68 Self {
69 solutions,
70 domains,
71 undetermined_data,
72 is_determined,
73 }
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.solutions.is_empty()
79 }
80
81 pub fn len(&self) -> usize {
83 self.solutions.len()
84 }
85
86 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<DataPath, Domain>)> {
88 self.solutions.iter().zip(self.domains.iter())
89 }
90}
91
92impl Serialize for InversionResponse {
93 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
94 where
95 S: Serializer,
96 {
97 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
98 state.serialize_field("solutions", &self.solutions)?;
99
100 let domains_serializable: Vec<HashMap<String, String>> = self
101 .domains
102 .iter()
103 .map(|d| {
104 d.iter()
105 .map(|(k, v)| (k.to_string(), v.to_string()))
106 .collect()
107 })
108 .collect();
109 state.serialize_field("domains", &domains_serializable)?;
110
111 let undetermined_serializable: Vec<String> = self
112 .undetermined_data
113 .iter()
114 .map(|fp| fp.to_string())
115 .collect();
116 state.serialize_field("undetermined_data", &undetermined_serializable)?;
117 state.serialize_field("is_determined", &self.is_determined)?;
118 state.end()
119 }
120}
121
122pub fn invert(
137 rule_name: &str,
138 target: Target,
139 plan: &ExecutionPlan,
140 provided_data: &HashSet<DataPath>,
141) -> Result<InversionResponse, Error> {
142 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
143 Error::request(
144 format!("Rule not found: {}.{}", plan.spec_name, rule_name),
145 None::<String>,
146 )
147 })?;
148
149 let rule_path = executable_rule.path.clone();
150
151 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
153 let enumeration_result = enumerator.enumerate(provided_data)?;
154
155 let mut solutions = Vec::new();
157 let mut all_domains = Vec::new();
158
159 let filtered_literal_solutions =
161 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
162
163 for world_solution in filtered_literal_solutions {
164 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
165
166 let solution = Solution {
167 outcome: world_solution.outcome,
168 world: world_solution.world,
169 shape: None,
170 };
171
172 solutions.push(solution);
173 all_domains.push(constraint_domains);
174 }
175
176 if let Some(OperationResult::Value(target_value)) = &target.outcome {
178 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
180 let algebraic_solutions = solve::solve_arithmetic_batch(
181 enumeration_result.arithmetic_solutions.clone(),
182 target_value,
183 provided_data,
184 );
185
186 let indices: std::collections::HashSet<usize> = algebraic_solutions
188 .iter()
189 .filter_map(|(ws, _, _)| {
190 enumeration_result
191 .arithmetic_solutions
192 .iter()
193 .position(|orig| orig.world == ws.world)
194 })
195 .collect();
196
197 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
199 let constraint_domains =
200 extract_domains_from_constraint(&world_solution.constraint)?;
201
202 let mut is_valid = true;
204 for (data_path, solved_domain) in &solved_domains {
205 if let Some(constraint_domain) = constraint_domains.get(data_path) {
206 if let Domain::Enumeration(values) = solved_domain {
208 for value in values.iter() {
209 if !constraint_domain.contains(value) {
210 is_valid = false;
211 break;
212 }
213 }
214 }
215 }
216 if !is_valid {
217 break;
218 }
219 }
220
221 if !is_valid {
222 continue; }
224
225 let solved_outcome_result = OperationResult::Value(Box::new(solved_outcome));
226
227 let mut combined_domains = constraint_domains;
228 for (data_path, domain) in solved_domains {
229 combined_domains.insert(data_path, domain);
230 }
231
232 let solution = Solution {
233 outcome: solved_outcome_result,
234 world: world_solution.world,
235 shape: None,
236 };
237
238 solutions.push(solution);
239 all_domains.push(combined_domains);
240 }
241
242 indices
243 } else {
244 std::collections::HashSet::new()
245 };
246
247 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
250 if solved_indices.contains(&idx) {
251 continue; }
253
254 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
256
257 let unknown_data =
259 extract_data_paths_from_expression(&arith_solution.outcome_expression);
260 for data_path in unknown_data {
261 if !combined_domains.contains_key(&data_path) && !provided_data.contains(&data_path)
263 {
264 combined_domains.insert(data_path, Domain::Unconstrained);
265 }
266 }
267
268 let solution = Solution {
269 outcome: OperationResult::Value(Box::new(target_value.as_ref().clone())),
270 world: arith_solution.world.clone(),
271 shape: Some(arith_solution.outcome_expression.clone()),
272 };
273
274 solutions.push(solution);
275 all_domains.push(combined_domains);
276 }
277 }
278
279 Ok(InversionResponse::new(solutions, all_domains))
280}
281
282fn filter_literal_solutions_by_target(
288 solutions: Vec<WorldSolution>,
289 target: &Target,
290) -> Vec<WorldSolution> {
291 let mut filtered = Vec::new();
292
293 for solution in solutions {
294 let matches = match (&target.outcome, &solution.outcome) {
295 (None, _) => {
296 true
298 }
299 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
300 match target.op {
305 TargetOp::Eq => outcome_value.value == target_value.value,
306 TargetOp::Neq => outcome_value.value != target_value.value,
307 TargetOp::Lt => {
308 compare_values(outcome_value, target_value)
309 == Some(std::cmp::Ordering::Less)
310 }
311 TargetOp::Lte => {
312 let cmp = compare_values(outcome_value, target_value);
313 cmp == Some(std::cmp::Ordering::Less)
314 || cmp == Some(std::cmp::Ordering::Equal)
315 }
316 TargetOp::Gt => {
317 compare_values(outcome_value, target_value)
318 == Some(std::cmp::Ordering::Greater)
319 }
320 TargetOp::Gte => {
321 let cmp = compare_values(outcome_value, target_value);
322 cmp == Some(std::cmp::Ordering::Greater)
323 || cmp == Some(std::cmp::Ordering::Equal)
324 }
325 }
326 }
327 (Some(OperationResult::Veto(target_reason)), OperationResult::Veto(outcome_reason)) => {
328 match target_reason {
329 VetoType::UserDefined { message: None } => true, VetoType::UserDefined {
331 message: Some(ref t_msg),
332 } => matches!(
333 outcome_reason,
334 VetoType::UserDefined {
335 message: Some(ref o_msg)
336 }
337 if o_msg == t_msg
338 ),
339 _ => false,
340 }
341 }
342 _ => false, };
344
345 if matches {
346 filtered.push(solution);
347 }
348 }
349
350 filtered
351}
352
353fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
355 match (&a.value, &b.value) {
356 (ValueKind::Number(a_val), ValueKind::Number(b_val)) => Some(a_val.cmp(b_val)),
357 (ValueKind::Ratio(a_val, _), ValueKind::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
358 (ValueKind::Quantity(a_val, _, _), ValueKind::Quantity(b_val, _, _)) => {
359 Some(a_val.cmp(b_val))
360 }
361 (ValueKind::Calendar(a_val, unit_a), ValueKind::Calendar(b_val, unit_b)) => {
362 if unit_a == unit_b {
363 Some(a_val.cmp(b_val))
364 } else {
365 None
366 }
367 }
368 _ => None,
369 }
370}
371
372fn extract_data_paths_from_expression(expr: &Expression) -> Vec<DataPath> {
374 let mut set = std::collections::HashSet::new();
375 expr.collect_data_paths(&mut set);
376 set.into_iter().collect()
377}
378
379fn compute_undetermined_data(all_domains: &[HashMap<DataPath, Domain>]) -> Vec<DataPath> {
381 let mut undetermined: HashSet<DataPath> = HashSet::new();
382
383 for solution_domains in all_domains {
384 for (data_path, domain) in solution_domains {
385 let is_determined = matches!(
386 domain,
387 Domain::Enumeration(values) if values.len() == 1
388 );
389 if !is_determined {
390 undetermined.insert(data_path.clone());
391 }
392 }
393 }
394
395 let mut result: Vec<DataPath> = undetermined.into_iter().collect();
396 result.sort_by_key(|a| a.to_string());
397 result
398}
399
400fn compute_is_determined(all_domains: &[HashMap<DataPath, Domain>]) -> bool {
402 if all_domains.is_empty() {
403 return true;
404 }
405
406 for solution_domains in all_domains {
407 for domain in solution_domains.values() {
408 let is_single_value = matches!(
409 domain,
410 Domain::Enumeration(values) if values.len() == 1
411 );
412 if !is_single_value {
413 return false;
414 }
415 }
416 }
417
418 true
419}
420
421#[cfg(test)]
426mod tests {
427 use super::*;
428 use crate::parsing::ast::DateTimeValue;
429 use crate::Engine;
430 use rust_decimal::Decimal;
431 use std::collections::HashMap;
432 use std::sync::Arc;
433
434 #[test]
435 fn test_format_target_eq() {
436 let target = Target::value(LiteralValue::number_from_decimal(Decimal::from(42)));
437 let formatted = target.format();
438 assert_eq!(formatted, "= 42");
439 }
440
441 #[test]
442 fn test_format_target_any() {
443 let target = Target::any_value();
444 let formatted = target.format();
445 assert_eq!(formatted, "= any");
446 }
447
448 #[test]
449 fn test_compute_undetermined_data_empty() {
450 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
451 let undetermined = compute_undetermined_data(&domains);
452 assert!(undetermined.is_empty());
453 }
454
455 #[test]
456 fn test_compute_undetermined_data_single_value() {
457 let mut domain_map = HashMap::new();
458 domain_map.insert(
459 DataPath::new(vec![], "age".to_string()),
460 Domain::Enumeration(Arc::new(vec![LiteralValue::number_from_decimal(
461 Decimal::from(25),
462 )])),
463 );
464 let domains = vec![domain_map];
465 let undetermined = compute_undetermined_data(&domains);
466 assert!(undetermined.is_empty());
467 }
468
469 #[test]
470 fn test_compute_undetermined_data_range() {
471 let mut domain_map = HashMap::new();
472 domain_map.insert(
473 DataPath::new(vec![], "age".to_string()),
474 Domain::Range {
475 min: Bound::Exclusive(Arc::new(LiteralValue::number_from_decimal(Decimal::from(
476 18,
477 )))),
478 max: Bound::Unbounded,
479 },
480 );
481 let domains = vec![domain_map];
482 let undetermined = compute_undetermined_data(&domains);
483 assert_eq!(undetermined.len(), 1);
484 }
485
486 #[test]
487 fn test_compute_is_determined_empty() {
488 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
489 assert!(compute_is_determined(&domains));
490 }
491
492 #[test]
493 fn test_compute_is_determined_true() {
494 let mut domain_map = HashMap::new();
495 domain_map.insert(
496 DataPath::new(vec![], "age".to_string()),
497 Domain::Enumeration(Arc::new(vec![LiteralValue::number_from_decimal(
498 Decimal::from(25),
499 )])),
500 );
501 let domains = vec![domain_map];
502 assert!(compute_is_determined(&domains));
503 }
504
505 #[test]
506 fn test_compute_is_determined_false() {
507 let mut domain_map = HashMap::new();
508 domain_map.insert(
509 DataPath::new(vec![], "age".to_string()),
510 Domain::Range {
511 min: Bound::Exclusive(Arc::new(LiteralValue::number_from_decimal(Decimal::from(
512 18,
513 )))),
514 max: Bound::Unbounded,
515 },
516 );
517 let domains = vec![domain_map];
518 assert!(!compute_is_determined(&domains));
519 }
520
521 #[test]
522 fn test_invert_strict_rule_reference_expands_constraints() {
523 let code = r#"
526spec example
527data x: number
528rule base: x
529 unless x > 3 then veto "too much"
530 unless x < 0 then veto "too little"
531
532rule another: base
533 unless x > 5 then veto "way too much"
534"#;
535
536 let mut engine = Engine::new();
537 engine
538 .load(
539 code,
540 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
541 "test.lemma",
542 ))),
543 )
544 .unwrap();
545 let now = DateTimeValue::now();
546
547 let inv = engine
548 .invert(
549 "example",
550 Some(&now),
551 "another",
552 Target::value(LiteralValue::number(3.into())),
553 HashMap::new(),
554 )
555 .expect("inversion should succeed");
556
557 assert!(!inv.is_empty(), "expected at least one solution");
558
559 let x = DataPath::new(vec![], "x".to_string());
560 let three = LiteralValue::number(3.into());
561
562 for (_solution, domains) in inv.iter() {
564 let d = domains.get(&x).expect("domain for x should exist");
565 assert!(
566 d.contains(&three),
567 "x domain should contain 3. Domain: {}",
568 d
569 );
570 }
571 }
572
573 #[test]
574 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
575 let code = r#"
576spec example
577data x: number
578rule base: x
579 unless x > 3 then veto "too much"
580 unless x < 0 then veto "too little"
581
582rule another: base
583 unless x > 5 then veto "way too much"
584"#;
585
586 let mut engine = Engine::new();
587 engine
588 .load(
589 code,
590 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
591 "test.lemma",
592 ))),
593 )
594 .unwrap();
595 let now = DateTimeValue::now();
596
597 let inv = engine
598 .invert(
599 "example",
600 Some(&now),
601 "another",
602 Target::value(LiteralValue::number(7.into())),
603 HashMap::new(),
604 )
605 .expect("inversion should succeed");
606
607 assert!(
608 inv.is_empty(),
609 "Should have no solutions because another can never equal 7"
610 );
611 }
612
613 #[test]
614 fn test_invert_strict_veto_target_constrains_domain() {
615 let code = r#"
616spec example
617data x: number
618rule base: x
619 unless x > 3 then veto "too much"
620 unless x < 0 then veto "too little"
621
622rule another: base
623 unless x > 5 then veto "way too much"
624"#;
625
626 let mut engine = Engine::new();
627 engine
628 .load(
629 code,
630 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
631 "test.lemma",
632 ))),
633 )
634 .unwrap();
635 let now = DateTimeValue::now();
636
637 let inv = engine
638 .invert(
639 "example",
640 Some(&now),
641 "another",
642 Target::veto(Some("way too much".to_string())),
643 HashMap::new(),
644 )
645 .expect("inversion should succeed");
646
647 assert!(!inv.is_empty(), "expected solutions for veto query");
648
649 let x = DataPath::new(vec![], "x".to_string());
650 let five = LiteralValue::number(5.into());
651 let six = LiteralValue::number(6.into());
652
653 for (solution, domains) in inv.iter() {
654 assert_eq!(
655 solution.outcome,
656 OperationResult::Veto(VetoType::UserDefined {
657 message: Some("way too much".to_string()),
658 }),
659 "Expected solution outcome to be veto('way too much'), got: {:?}",
660 solution.outcome
661 );
662
663 let d = domains.get(&x).expect("domain for x should exist");
664 match d {
665 Domain::Range { min, max } => {
666 assert!(
667 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
668 "Expected min bound to be (5), got: {}",
669 d
670 );
671 assert!(
672 matches!(max, Bound::Unbounded),
673 "Expected max bound to be +inf, got: {}",
674 d
675 );
676 }
677 other => panic!("Expected range domain for x, got: {}", other),
678 }
679 assert!(
680 !d.contains(&five),
681 "x=5 should not be in veto('way too much') domain. Domain: {}",
682 d
683 );
684 assert!(
685 d.contains(&six),
686 "x=6 should be in veto('way too much') domain. Domain: {}",
687 d
688 );
689 }
690 }
691
692 #[test]
693 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
694 let code = r#"
695spec example
696data x: number
697rule base: x
698 unless x > 3 then veto "too much"
699 unless x < 0 then veto "too little"
700
701rule another: base
702 unless x > 5 then veto "way too much"
703"#;
704
705 let mut engine = Engine::new();
706 engine
707 .load(
708 code,
709 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
710 "test.lemma",
711 ))),
712 )
713 .unwrap();
714
715 let now = DateTimeValue::now();
716 let inv = engine
717 .invert(
718 "example",
719 Some(&now),
720 "another",
721 Target::any_veto(),
722 HashMap::new(),
723 )
724 .expect("inversion should succeed");
725
726 assert!(!inv.is_empty(), "expected solutions for any-veto query");
727
728 let x = DataPath::new(vec![], "x".to_string());
729 let minus_one = LiteralValue::number((-1).into());
730 let zero = LiteralValue::number(0.into());
731 let two = LiteralValue::number(2.into());
732 let three = LiteralValue::number(3.into());
733 let four = LiteralValue::number(4.into());
734 let five = LiteralValue::number(5.into());
735 let six = LiteralValue::number(6.into());
736
737 let mut saw_too_little = false;
738 let mut saw_too_much = false;
739 let mut saw_way_too_much = false;
740
741 for (solution, domains) in inv.iter() {
742 let d = domains.get(&x).expect("domain for x should exist");
743 assert!(
744 !d.contains(&two),
745 "x=2 should not be in any-veto domain. Domain: {}",
746 d
747 );
748
749 match &solution.outcome {
750 OperationResult::Veto(VetoType::UserDefined {
751 message: Some(ref msg),
752 }) if msg == "too little" => {
753 saw_too_little = true;
754
755 match d {
756 Domain::Range { min, max } => {
757 assert!(
758 matches!(min, Bound::Unbounded),
759 "Expected min bound to be -inf for 'too little', got: {}",
760 d
761 );
762 assert!(
763 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
764 "Expected max bound to be (0) for 'too little', got: {}",
765 d
766 );
767 }
768 other => panic!("Expected range domain for x, got: {}", other),
769 }
770
771 assert!(
772 d.contains(&minus_one),
773 "x=-1 should be in veto('too little') domain. Domain: {}",
774 d
775 );
776 assert!(
777 !d.contains(&zero),
778 "x=0 should not be in veto('too little') domain. Domain: {}",
779 d
780 );
781 }
782 OperationResult::Veto(VetoType::UserDefined {
783 message: Some(ref msg),
784 }) if msg == "too much" => {
785 saw_too_much = true;
786
787 match d {
788 Domain::Range { min, max } => {
789 assert!(
790 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
791 "Expected min bound to be (3) for 'too much', got: {}",
792 d
793 );
794 assert!(
795 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
796 "Expected max bound to be [5] for 'too much', got: {}",
797 d
798 );
799 }
800 other => panic!("Expected range domain for x, got: {}", other),
801 }
802
803 assert!(
804 d.contains(&four),
805 "x=4 should be in veto('too much') domain. Domain: {}",
806 d
807 );
808 assert!(
809 d.contains(&five),
810 "x=5 should be in veto('too much') domain. Domain: {}",
811 d
812 );
813 assert!(
814 !d.contains(&three),
815 "x=3 should not be in veto('too much') domain. Domain: {}",
816 d
817 );
818 assert!(
819 !d.contains(&six),
820 "x=6 should not be in veto('too much') domain. Domain: {}",
821 d
822 );
823 }
824 OperationResult::Veto(VetoType::UserDefined {
825 message: Some(ref msg),
826 }) if msg == "way too much" => {
827 saw_way_too_much = true;
828
829 match d {
830 Domain::Range { min, max } => {
831 assert!(
832 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
833 "Expected min bound to be (5) for 'way too much', got: {}",
834 d
835 );
836 assert!(
837 matches!(max, Bound::Unbounded),
838 "Expected max bound to be +inf for 'way too much', got: {}",
839 d
840 );
841 }
842 other => panic!("Expected range domain for x, got: {}", other),
843 }
844
845 assert!(
846 d.contains(&six),
847 "x=6 should be in veto('way too much') domain. Domain: {}",
848 d
849 );
850 assert!(
851 !d.contains(&five),
852 "x=5 should not be in veto('way too much') domain. Domain: {}",
853 d
854 );
855 }
856 OperationResult::Veto(other) => {
857 panic!("Unexpected veto in any-veto results: {:?}", other)
858 }
859 OperationResult::Value(v) => {
860 panic!("Unexpected value result in any-veto results: {:?}", v)
861 }
862 }
863 }
864
865 assert!(
866 saw_too_little,
867 "Expected at least one veto('too little') solution"
868 );
869 assert!(
870 saw_too_much,
871 "Expected at least one veto('too much') solution"
872 );
873 assert!(
874 saw_way_too_much,
875 "Expected at least one veto('way too much') solution"
876 );
877 }
878}