1mod constraint;
10mod derived;
11mod domain;
12mod solve;
13mod target;
14mod world;
15
16pub use derived::{DerivedExpression, DerivedExpressionKind};
17pub use domain::{extract_domains_from_constraint, Bound, Domain};
18pub use target::{Target, TargetOp};
19pub use world::World;
20
21use crate::evaluation::operations::VetoType;
22use crate::planning::semantics::{DataPath, Expression, LiteralValue, ValueKind};
23use crate::planning::ExecutionPlan;
24use crate::{Error, OperationResult};
25use serde::ser::{Serialize, SerializeStruct, Serializer};
26use std::collections::{HashMap, HashSet};
27
28use world::{WorldEnumerator, WorldSolution};
29
30#[derive(Debug, Clone, serde::Serialize)]
39pub struct Solution {
40 pub outcome: OperationResult,
42 pub world: World,
44 #[serde(skip_serializing_if = "Option::is_none")]
47 pub shape: Option<Expression>,
48}
49
50#[derive(Debug, Clone)]
52pub struct InversionResponse {
53 pub solutions: Vec<Solution>,
55 pub domains: Vec<HashMap<DataPath, Domain>>,
57 pub undetermined_data: Vec<DataPath>,
59 pub is_determined: bool,
61}
62
63impl InversionResponse {
64 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<DataPath, Domain>>) -> Self {
66 let undetermined_data = compute_undetermined_data(&domains);
67 let is_determined = compute_is_determined(&domains);
68 Self {
69 solutions,
70 domains,
71 undetermined_data,
72 is_determined,
73 }
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.solutions.is_empty()
79 }
80
81 pub fn len(&self) -> usize {
83 self.solutions.len()
84 }
85
86 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<DataPath, Domain>)> {
88 self.solutions.iter().zip(self.domains.iter())
89 }
90}
91
92impl Serialize for InversionResponse {
93 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
94 where
95 S: Serializer,
96 {
97 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
98 state.serialize_field("solutions", &self.solutions)?;
99
100 let domains_serializable: Vec<HashMap<String, String>> = self
101 .domains
102 .iter()
103 .map(|d| {
104 d.iter()
105 .map(|(k, v)| (k.to_string(), v.to_string()))
106 .collect()
107 })
108 .collect();
109 state.serialize_field("domains", &domains_serializable)?;
110
111 let undetermined_serializable: Vec<String> = self
112 .undetermined_data
113 .iter()
114 .map(|fp| fp.to_string())
115 .collect();
116 state.serialize_field("undetermined_data", &undetermined_serializable)?;
117 state.serialize_field("is_determined", &self.is_determined)?;
118 state.end()
119 }
120}
121
122pub fn invert(
137 rule_name: &str,
138 target: Target,
139 plan: &ExecutionPlan,
140 provided_data: &HashSet<DataPath>,
141) -> Result<InversionResponse, Error> {
142 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
143 Error::request(
144 format!("Rule not found: {}.{}", plan.spec_name, rule_name),
145 None::<String>,
146 )
147 })?;
148
149 let rule_path = executable_rule.path.clone();
150
151 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
153 let enumeration_result = enumerator.enumerate(provided_data)?;
154
155 let mut solutions = Vec::new();
157 let mut all_domains = Vec::new();
158
159 let filtered_literal_solutions =
161 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
162
163 for world_solution in filtered_literal_solutions {
164 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
165
166 let solution = Solution {
167 outcome: world_solution.outcome,
168 world: world_solution.world,
169 shape: None,
170 };
171
172 solutions.push(solution);
173 all_domains.push(constraint_domains);
174 }
175
176 if let Some(OperationResult::Value(target_value)) = &target.outcome {
178 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
180 let algebraic_solutions = solve::solve_arithmetic_batch(
181 enumeration_result.arithmetic_solutions.clone(),
182 target_value,
183 provided_data,
184 );
185
186 let indices: std::collections::HashSet<usize> = algebraic_solutions
188 .iter()
189 .filter_map(|(ws, _, _)| {
190 enumeration_result
191 .arithmetic_solutions
192 .iter()
193 .position(|orig| orig.world == ws.world)
194 })
195 .collect();
196
197 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
199 let constraint_domains =
200 extract_domains_from_constraint(&world_solution.constraint)?;
201
202 let mut is_valid = true;
204 for (data_path, solved_domain) in &solved_domains {
205 if let Some(constraint_domain) = constraint_domains.get(data_path) {
206 if let Domain::Enumeration(values) = solved_domain {
208 for value in values.iter() {
209 if !constraint_domain.contains(value) {
210 is_valid = false;
211 break;
212 }
213 }
214 }
215 }
216 if !is_valid {
217 break;
218 }
219 }
220
221 if !is_valid {
222 continue; }
224
225 let solved_outcome_result = OperationResult::Value(Box::new(solved_outcome));
226
227 let mut combined_domains = constraint_domains;
228 for (data_path, domain) in solved_domains {
229 combined_domains.insert(data_path, domain);
230 }
231
232 let solution = Solution {
233 outcome: solved_outcome_result,
234 world: world_solution.world,
235 shape: None,
236 };
237
238 solutions.push(solution);
239 all_domains.push(combined_domains);
240 }
241
242 indices
243 } else {
244 std::collections::HashSet::new()
245 };
246
247 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
250 if solved_indices.contains(&idx) {
251 continue; }
253
254 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
256
257 let unknown_data =
259 extract_data_paths_from_expression(&arith_solution.outcome_expression);
260 for data_path in unknown_data {
261 if !combined_domains.contains_key(&data_path) && !provided_data.contains(&data_path)
263 {
264 combined_domains.insert(data_path, Domain::Unconstrained);
265 }
266 }
267
268 let solution = Solution {
269 outcome: OperationResult::Value(Box::new(target_value.as_ref().clone())),
270 world: arith_solution.world.clone(),
271 shape: Some(arith_solution.outcome_expression.clone()),
272 };
273
274 solutions.push(solution);
275 all_domains.push(combined_domains);
276 }
277 }
278
279 Ok(InversionResponse::new(solutions, all_domains))
280}
281
282fn filter_literal_solutions_by_target(
288 solutions: Vec<WorldSolution>,
289 target: &Target,
290) -> Vec<WorldSolution> {
291 let mut filtered = Vec::new();
292
293 for solution in solutions {
294 let matches = match (&target.outcome, &solution.outcome) {
295 (None, _) => {
296 true
298 }
299 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
300 match target.op {
305 TargetOp::Eq => outcome_value.value == target_value.value,
306 TargetOp::Neq => outcome_value.value != target_value.value,
307 TargetOp::Lt => {
308 compare_values(outcome_value, target_value)
309 == Some(std::cmp::Ordering::Less)
310 }
311 TargetOp::Lte => {
312 let cmp = compare_values(outcome_value, target_value);
313 cmp == Some(std::cmp::Ordering::Less)
314 || cmp == Some(std::cmp::Ordering::Equal)
315 }
316 TargetOp::Gt => {
317 compare_values(outcome_value, target_value)
318 == Some(std::cmp::Ordering::Greater)
319 }
320 TargetOp::Gte => {
321 let cmp = compare_values(outcome_value, target_value);
322 cmp == Some(std::cmp::Ordering::Greater)
323 || cmp == Some(std::cmp::Ordering::Equal)
324 }
325 }
326 }
327 (Some(OperationResult::Veto(target_reason)), OperationResult::Veto(outcome_reason)) => {
328 match target_reason {
329 VetoType::UserDefined { message: None } => true, VetoType::UserDefined {
331 message: Some(ref t_msg),
332 } => matches!(
333 outcome_reason,
334 VetoType::UserDefined {
335 message: Some(ref o_msg)
336 }
337 if o_msg == t_msg
338 ),
339 _ => false,
340 }
341 }
342 _ => false, };
344
345 if matches {
346 filtered.push(solution);
347 }
348 }
349
350 filtered
351}
352
353fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
355 match (&a.value, &b.value) {
356 (ValueKind::Number(a_val), ValueKind::Number(b_val)) => Some(a_val.cmp(b_val)),
357 (ValueKind::Ratio(a_val, _), ValueKind::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
358 (ValueKind::Scale(a_val, _), ValueKind::Scale(b_val, _)) => Some(a_val.cmp(b_val)),
359 (ValueKind::Duration(a_val, unit_a), ValueKind::Duration(b_val, unit_b)) => {
360 if unit_a == unit_b {
361 Some(a_val.cmp(b_val))
362 } else {
363 None
364 }
365 }
366 _ => None,
367 }
368}
369
370fn extract_data_paths_from_expression(expr: &Expression) -> Vec<DataPath> {
372 let mut set = std::collections::HashSet::new();
373 expr.collect_data_paths(&mut set);
374 set.into_iter().collect()
375}
376
377fn compute_undetermined_data(all_domains: &[HashMap<DataPath, Domain>]) -> Vec<DataPath> {
379 let mut undetermined: HashSet<DataPath> = HashSet::new();
380
381 for solution_domains in all_domains {
382 for (data_path, domain) in solution_domains {
383 let is_determined = matches!(
384 domain,
385 Domain::Enumeration(values) if values.len() == 1
386 );
387 if !is_determined {
388 undetermined.insert(data_path.clone());
389 }
390 }
391 }
392
393 let mut result: Vec<DataPath> = undetermined.into_iter().collect();
394 result.sort_by_key(|a| a.to_string());
395 result
396}
397
398fn compute_is_determined(all_domains: &[HashMap<DataPath, Domain>]) -> bool {
400 if all_domains.is_empty() {
401 return true;
402 }
403
404 for solution_domains in all_domains {
405 for domain in solution_domains.values() {
406 let is_single_value = matches!(
407 domain,
408 Domain::Enumeration(values) if values.len() == 1
409 );
410 if !is_single_value {
411 return false;
412 }
413 }
414 }
415
416 true
417}
418
419#[cfg(test)]
424mod tests {
425 use super::*;
426 use crate::parsing::ast::DateTimeValue;
427 use crate::Engine;
428 use rust_decimal::Decimal;
429 use std::collections::HashMap;
430 use std::sync::Arc;
431
432 #[test]
433 fn test_format_target_eq() {
434 let target = Target::value(LiteralValue::number(Decimal::from(42)));
435 let formatted = target.format();
436 assert_eq!(formatted, "= 42");
437 }
438
439 #[test]
440 fn test_format_target_any() {
441 let target = Target::any_value();
442 let formatted = target.format();
443 assert_eq!(formatted, "= any");
444 }
445
446 #[test]
447 fn test_compute_undetermined_data_empty() {
448 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
449 let undetermined = compute_undetermined_data(&domains);
450 assert!(undetermined.is_empty());
451 }
452
453 #[test]
454 fn test_compute_undetermined_data_single_value() {
455 let mut domain_map = HashMap::new();
456 domain_map.insert(
457 DataPath::new(vec![], "age".to_string()),
458 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
459 );
460 let domains = vec![domain_map];
461 let undetermined = compute_undetermined_data(&domains);
462 assert!(undetermined.is_empty());
463 }
464
465 #[test]
466 fn test_compute_undetermined_data_range() {
467 let mut domain_map = HashMap::new();
468 domain_map.insert(
469 DataPath::new(vec![], "age".to_string()),
470 Domain::Range {
471 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
472 max: Bound::Unbounded,
473 },
474 );
475 let domains = vec![domain_map];
476 let undetermined = compute_undetermined_data(&domains);
477 assert_eq!(undetermined.len(), 1);
478 }
479
480 #[test]
481 fn test_compute_is_determined_empty() {
482 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
483 assert!(compute_is_determined(&domains));
484 }
485
486 #[test]
487 fn test_compute_is_determined_true() {
488 let mut domain_map = HashMap::new();
489 domain_map.insert(
490 DataPath::new(vec![], "age".to_string()),
491 Domain::Enumeration(Arc::new(vec![LiteralValue::number(Decimal::from(25))])),
492 );
493 let domains = vec![domain_map];
494 assert!(compute_is_determined(&domains));
495 }
496
497 #[test]
498 fn test_compute_is_determined_false() {
499 let mut domain_map = HashMap::new();
500 domain_map.insert(
501 DataPath::new(vec![], "age".to_string()),
502 Domain::Range {
503 min: Bound::Exclusive(Arc::new(LiteralValue::number(Decimal::from(18)))),
504 max: Bound::Unbounded,
505 },
506 );
507 let domains = vec![domain_map];
508 assert!(!compute_is_determined(&domains));
509 }
510
511 #[test]
512 fn test_invert_strict_rule_reference_expands_constraints() {
513 let code = r#"
516spec example
517data x: number
518rule base: x
519 unless x > 3 then veto "too much"
520 unless x < 0 then veto "too little"
521
522rule another: base
523 unless x > 5 then veto "way too much"
524"#;
525
526 let mut engine = Engine::new();
527 engine
528 .load(
529 code,
530 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
531 "test.lemma",
532 ))),
533 )
534 .unwrap();
535 let now = DateTimeValue::now();
536
537 let inv = engine
538 .invert(
539 "example",
540 Some(&now),
541 "another",
542 Target::value(LiteralValue::number(3.into())),
543 HashMap::new(),
544 )
545 .expect("inversion should succeed");
546
547 assert!(!inv.is_empty(), "expected at least one solution");
548
549 let x = DataPath::new(vec![], "x".to_string());
550 let three = LiteralValue::number(3.into());
551
552 for (_solution, domains) in inv.iter() {
554 let d = domains.get(&x).expect("domain for x should exist");
555 assert!(
556 d.contains(&three),
557 "x domain should contain 3. Domain: {}",
558 d
559 );
560 }
561 }
562
563 #[test]
564 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
565 let code = r#"
566spec example
567data x: number
568rule base: x
569 unless x > 3 then veto "too much"
570 unless x < 0 then veto "too little"
571
572rule another: base
573 unless x > 5 then veto "way too much"
574"#;
575
576 let mut engine = Engine::new();
577 engine
578 .load(
579 code,
580 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
581 "test.lemma",
582 ))),
583 )
584 .unwrap();
585 let now = DateTimeValue::now();
586
587 let inv = engine
588 .invert(
589 "example",
590 Some(&now),
591 "another",
592 Target::value(LiteralValue::number(7.into())),
593 HashMap::new(),
594 )
595 .expect("inversion should succeed");
596
597 assert!(
598 inv.is_empty(),
599 "Should have no solutions because another can never equal 7"
600 );
601 }
602
603 #[test]
604 fn test_invert_strict_veto_target_constrains_domain() {
605 let code = r#"
606spec example
607data x: number
608rule base: x
609 unless x > 3 then veto "too much"
610 unless x < 0 then veto "too little"
611
612rule another: base
613 unless x > 5 then veto "way too much"
614"#;
615
616 let mut engine = Engine::new();
617 engine
618 .load(
619 code,
620 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
621 "test.lemma",
622 ))),
623 )
624 .unwrap();
625 let now = DateTimeValue::now();
626
627 let inv = engine
628 .invert(
629 "example",
630 Some(&now),
631 "another",
632 Target::veto(Some("way too much".to_string())),
633 HashMap::new(),
634 )
635 .expect("inversion should succeed");
636
637 assert!(!inv.is_empty(), "expected solutions for veto query");
638
639 let x = DataPath::new(vec![], "x".to_string());
640 let five = LiteralValue::number(5.into());
641 let six = LiteralValue::number(6.into());
642
643 for (solution, domains) in inv.iter() {
644 assert_eq!(
645 solution.outcome,
646 OperationResult::Veto(VetoType::UserDefined {
647 message: Some("way too much".to_string()),
648 }),
649 "Expected solution outcome to be veto('way too much'), got: {:?}",
650 solution.outcome
651 );
652
653 let d = domains.get(&x).expect("domain for x should exist");
654 match d {
655 Domain::Range { min, max } => {
656 assert!(
657 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
658 "Expected min bound to be (5), got: {}",
659 d
660 );
661 assert!(
662 matches!(max, Bound::Unbounded),
663 "Expected max bound to be +inf, got: {}",
664 d
665 );
666 }
667 other => panic!("Expected range domain for x, got: {}", other),
668 }
669 assert!(
670 !d.contains(&five),
671 "x=5 should not be in veto('way too much') domain. Domain: {}",
672 d
673 );
674 assert!(
675 d.contains(&six),
676 "x=6 should be in veto('way too much') domain. Domain: {}",
677 d
678 );
679 }
680 }
681
682 #[test]
683 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
684 let code = r#"
685spec example
686data x: number
687rule base: x
688 unless x > 3 then veto "too much"
689 unless x < 0 then veto "too little"
690
691rule another: base
692 unless x > 5 then veto "way too much"
693"#;
694
695 let mut engine = Engine::new();
696 engine
697 .load(
698 code,
699 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
700 "test.lemma",
701 ))),
702 )
703 .unwrap();
704
705 let now = DateTimeValue::now();
706 let inv = engine
707 .invert(
708 "example",
709 Some(&now),
710 "another",
711 Target::any_veto(),
712 HashMap::new(),
713 )
714 .expect("inversion should succeed");
715
716 assert!(!inv.is_empty(), "expected solutions for any-veto query");
717
718 let x = DataPath::new(vec![], "x".to_string());
719 let minus_one = LiteralValue::number((-1).into());
720 let zero = LiteralValue::number(0.into());
721 let two = LiteralValue::number(2.into());
722 let three = LiteralValue::number(3.into());
723 let four = LiteralValue::number(4.into());
724 let five = LiteralValue::number(5.into());
725 let six = LiteralValue::number(6.into());
726
727 let mut saw_too_little = false;
728 let mut saw_too_much = false;
729 let mut saw_way_too_much = false;
730
731 for (solution, domains) in inv.iter() {
732 let d = domains.get(&x).expect("domain for x should exist");
733 assert!(
734 !d.contains(&two),
735 "x=2 should not be in any-veto domain. Domain: {}",
736 d
737 );
738
739 match &solution.outcome {
740 OperationResult::Veto(VetoType::UserDefined {
741 message: Some(ref msg),
742 }) if msg == "too little" => {
743 saw_too_little = true;
744
745 match d {
746 Domain::Range { min, max } => {
747 assert!(
748 matches!(min, Bound::Unbounded),
749 "Expected min bound to be -inf for 'too little', got: {}",
750 d
751 );
752 assert!(
753 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
754 "Expected max bound to be (0) for 'too little', got: {}",
755 d
756 );
757 }
758 other => panic!("Expected range domain for x, got: {}", other),
759 }
760
761 assert!(
762 d.contains(&minus_one),
763 "x=-1 should be in veto('too little') domain. Domain: {}",
764 d
765 );
766 assert!(
767 !d.contains(&zero),
768 "x=0 should not be in veto('too little') domain. Domain: {}",
769 d
770 );
771 }
772 OperationResult::Veto(VetoType::UserDefined {
773 message: Some(ref msg),
774 }) if msg == "too much" => {
775 saw_too_much = true;
776
777 match d {
778 Domain::Range { min, max } => {
779 assert!(
780 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
781 "Expected min bound to be (3) for 'too much', got: {}",
782 d
783 );
784 assert!(
785 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
786 "Expected max bound to be [5] for 'too much', got: {}",
787 d
788 );
789 }
790 other => panic!("Expected range domain for x, got: {}", other),
791 }
792
793 assert!(
794 d.contains(&four),
795 "x=4 should be in veto('too much') domain. Domain: {}",
796 d
797 );
798 assert!(
799 d.contains(&five),
800 "x=5 should be in veto('too much') domain. Domain: {}",
801 d
802 );
803 assert!(
804 !d.contains(&three),
805 "x=3 should not be in veto('too much') domain. Domain: {}",
806 d
807 );
808 assert!(
809 !d.contains(&six),
810 "x=6 should not be in veto('too much') domain. Domain: {}",
811 d
812 );
813 }
814 OperationResult::Veto(VetoType::UserDefined {
815 message: Some(ref msg),
816 }) if msg == "way too much" => {
817 saw_way_too_much = true;
818
819 match d {
820 Domain::Range { min, max } => {
821 assert!(
822 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
823 "Expected min bound to be (5) for 'way too much', got: {}",
824 d
825 );
826 assert!(
827 matches!(max, Bound::Unbounded),
828 "Expected max bound to be +inf for 'way too much', got: {}",
829 d
830 );
831 }
832 other => panic!("Expected range domain for x, got: {}", other),
833 }
834
835 assert!(
836 d.contains(&six),
837 "x=6 should be in veto('way too much') domain. Domain: {}",
838 d
839 );
840 assert!(
841 !d.contains(&five),
842 "x=5 should not be in veto('way too much') domain. Domain: {}",
843 d
844 );
845 }
846 OperationResult::Veto(other) => {
847 panic!("Unexpected veto in any-veto results: {:?}", other)
848 }
849 OperationResult::Value(v) => {
850 panic!("Unexpected value result in any-veto results: {:?}", v)
851 }
852 }
853 }
854
855 assert!(
856 saw_too_little,
857 "Expected at least one veto('too little') solution"
858 );
859 assert!(
860 saw_too_much,
861 "Expected at least one veto('too much') solution"
862 );
863 assert!(
864 saw_way_too_much,
865 "Expected at least one veto('way too much') solution"
866 );
867 }
868}