1mod constraint;
10mod domain;
11mod solve;
12mod target;
13mod world;
14
15pub use domain::{extract_domains_from_constraint, Bound, Domain};
16pub use target::{Target, TargetOp};
17pub use world::World;
18
19use crate::evaluation::operations::VetoType;
20use crate::planning::semantics::{DataPath, Expression, LiteralValue, ValueKind};
21use crate::planning::ExecutionPlan;
22use crate::{Error, OperationResult};
23use serde::ser::{Serialize, SerializeStruct, Serializer};
24use std::collections::{HashMap, HashSet};
25
26use world::{WorldEnumerator, WorldSolution};
27
28#[derive(Debug, Clone, serde::Serialize)]
37pub struct Solution {
38 pub outcome: OperationResult,
40 pub world: World,
42 #[serde(skip_serializing_if = "Option::is_none")]
45 pub shape: Option<Expression>,
46}
47
48#[derive(Debug, Clone)]
50pub struct InversionResponse {
51 pub solutions: Vec<Solution>,
53 pub domains: Vec<HashMap<DataPath, Domain>>,
55 pub undetermined_data: Vec<DataPath>,
57 pub is_determined: bool,
59}
60
61impl InversionResponse {
62 pub fn new(solutions: Vec<Solution>, domains: Vec<HashMap<DataPath, Domain>>) -> Self {
64 let undetermined_data = compute_undetermined_data(&domains);
65 let is_determined = compute_is_determined(&domains);
66 Self {
67 solutions,
68 domains,
69 undetermined_data,
70 is_determined,
71 }
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.solutions.is_empty()
77 }
78
79 pub fn len(&self) -> usize {
81 self.solutions.len()
82 }
83
84 pub fn iter(&self) -> impl Iterator<Item = (&Solution, &HashMap<DataPath, Domain>)> {
86 self.solutions.iter().zip(self.domains.iter())
87 }
88}
89
90impl Serialize for InversionResponse {
91 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
92 where
93 S: Serializer,
94 {
95 let mut state = serializer.serialize_struct("InversionResponse", 4)?;
96 state.serialize_field("solutions", &self.solutions)?;
97
98 let domains_serializable: Vec<HashMap<String, String>> = self
99 .domains
100 .iter()
101 .map(|d| {
102 d.iter()
103 .map(|(k, v)| (k.to_string(), v.to_string()))
104 .collect()
105 })
106 .collect();
107 state.serialize_field("domains", &domains_serializable)?;
108
109 let undetermined_serializable: Vec<String> = self
110 .undetermined_data
111 .iter()
112 .map(|fp| fp.to_string())
113 .collect();
114 state.serialize_field("undetermined_data", &undetermined_serializable)?;
115 state.serialize_field("is_determined", &self.is_determined)?;
116 state.end()
117 }
118}
119
120pub fn invert(
135 rule_name: &str,
136 target: Target,
137 plan: &ExecutionPlan,
138 provided_data: &HashSet<DataPath>,
139) -> Result<InversionResponse, Error> {
140 let executable_rule = plan.get_rule(rule_name).ok_or_else(|| {
141 Error::request(
142 format!("Rule not found: {}.{}", plan.spec_name, rule_name),
143 None::<String>,
144 )
145 })?;
146
147 let rule_path = executable_rule.path.clone();
148
149 let mut enumerator = WorldEnumerator::new(plan, &rule_path)?;
151 let enumeration_result = enumerator.enumerate(provided_data)?;
152
153 let mut solutions = Vec::new();
155 let mut all_domains = Vec::new();
156
157 let filtered_literal_solutions =
159 filter_literal_solutions_by_target(enumeration_result.literal_solutions, &target);
160
161 for world_solution in filtered_literal_solutions {
162 let constraint_domains = extract_domains_from_constraint(&world_solution.constraint)?;
163
164 let solution = Solution {
165 outcome: world_solution.outcome,
166 world: world_solution.world,
167 shape: None,
168 };
169
170 solutions.push(solution);
171 all_domains.push(constraint_domains);
172 }
173
174 if let Some(OperationResult::Value(target_value)) = &target.outcome {
176 let solved_indices: std::collections::HashSet<usize> = if target.op == TargetOp::Eq {
178 let algebraic_solutions = solve::solve_arithmetic_batch(
179 enumeration_result.arithmetic_solutions.clone(),
180 target_value,
181 provided_data,
182 );
183
184 let indices: std::collections::HashSet<usize> = algebraic_solutions
186 .iter()
187 .filter_map(|(ws, _, _)| {
188 enumeration_result
189 .arithmetic_solutions
190 .iter()
191 .position(|orig| orig.world == ws.world)
192 })
193 .collect();
194
195 for (world_solution, solved_outcome, solved_domains) in algebraic_solutions {
197 let constraint_domains =
198 extract_domains_from_constraint(&world_solution.constraint)?;
199
200 let mut is_valid = true;
202 for (data_path, solved_domain) in &solved_domains {
203 if let Some(constraint_domain) = constraint_domains.get(data_path) {
204 if let Domain::Enumeration(values) = solved_domain {
206 for value in values.iter() {
207 if !constraint_domain.contains(value) {
208 is_valid = false;
209 break;
210 }
211 }
212 }
213 }
214 if !is_valid {
215 break;
216 }
217 }
218
219 if !is_valid {
220 continue; }
222
223 let solved_outcome_result = OperationResult::Value(Box::new(solved_outcome));
224
225 let mut combined_domains = constraint_domains;
226 for (data_path, domain) in solved_domains {
227 combined_domains.insert(data_path, domain);
228 }
229
230 let solution = Solution {
231 outcome: solved_outcome_result,
232 world: world_solution.world,
233 shape: None,
234 };
235
236 solutions.push(solution);
237 all_domains.push(combined_domains);
238 }
239
240 indices
241 } else {
242 std::collections::HashSet::new()
243 };
244
245 for (idx, arith_solution) in enumeration_result.arithmetic_solutions.iter().enumerate() {
248 if solved_indices.contains(&idx) {
249 continue; }
251
252 let mut combined_domains = extract_domains_from_constraint(&arith_solution.constraint)?;
254
255 let unknown_data =
257 extract_data_paths_from_expression(&arith_solution.outcome_expression);
258 for data_path in unknown_data {
259 if !combined_domains.contains_key(&data_path) && !provided_data.contains(&data_path)
261 {
262 combined_domains.insert(data_path, Domain::Unconstrained);
263 }
264 }
265
266 let solution = Solution {
267 outcome: OperationResult::Value(Box::new(target_value.as_ref().clone())),
268 world: arith_solution.world.clone(),
269 shape: Some(arith_solution.outcome_expression.clone()),
270 };
271
272 solutions.push(solution);
273 all_domains.push(combined_domains);
274 }
275 }
276
277 Ok(InversionResponse::new(solutions, all_domains))
278}
279
280fn filter_literal_solutions_by_target(
286 solutions: Vec<WorldSolution>,
287 target: &Target,
288) -> Vec<WorldSolution> {
289 let mut filtered = Vec::new();
290
291 for solution in solutions {
292 let matches = match (&target.outcome, &solution.outcome) {
293 (None, _) => {
294 true
296 }
297 (Some(OperationResult::Value(target_value)), OperationResult::Value(outcome_value)) => {
298 match target.op {
303 TargetOp::Eq => outcome_value.value == target_value.value,
304 TargetOp::Neq => outcome_value.value != target_value.value,
305 TargetOp::Lt => {
306 compare_values(outcome_value, target_value)
307 == Some(std::cmp::Ordering::Less)
308 }
309 TargetOp::Lte => {
310 let cmp = compare_values(outcome_value, target_value);
311 cmp == Some(std::cmp::Ordering::Less)
312 || cmp == Some(std::cmp::Ordering::Equal)
313 }
314 TargetOp::Gt => {
315 compare_values(outcome_value, target_value)
316 == Some(std::cmp::Ordering::Greater)
317 }
318 TargetOp::Gte => {
319 let cmp = compare_values(outcome_value, target_value);
320 cmp == Some(std::cmp::Ordering::Greater)
321 || cmp == Some(std::cmp::Ordering::Equal)
322 }
323 }
324 }
325 (Some(OperationResult::Veto(target_reason)), OperationResult::Veto(outcome_reason)) => {
326 match target_reason {
327 VetoType::UserDefined { message: None } => true, VetoType::UserDefined {
329 message: Some(ref t_msg),
330 } => matches!(
331 outcome_reason,
332 VetoType::UserDefined {
333 message: Some(ref o_msg)
334 }
335 if o_msg == t_msg
336 ),
337 _ => false,
338 }
339 }
340 _ => false, };
342
343 if matches {
344 filtered.push(solution);
345 }
346 }
347
348 filtered
349}
350
351fn compare_values(a: &LiteralValue, b: &LiteralValue) -> Option<std::cmp::Ordering> {
353 match (&a.value, &b.value) {
354 (ValueKind::Number(a_val), ValueKind::Number(b_val)) => Some(a_val.cmp(b_val)),
355 (ValueKind::Ratio(a_val, _), ValueKind::Ratio(b_val, _)) => Some(a_val.cmp(b_val)),
356 (ValueKind::Quantity(a_val, sig_a), ValueKind::Quantity(b_val, sig_b))
357 if a.lemma_type.is_calendar_like() && b.lemma_type.is_calendar_like() =>
358 {
359 let unit_a =
360 crate::planning::semantics::semantic_calendar_unit_from_quantity_signature(sig_a);
361 let unit_b =
362 crate::planning::semantics::semantic_calendar_unit_from_quantity_signature(sig_b);
363 if unit_a == unit_b {
364 Some(a_val.cmp(b_val))
365 } else {
366 None
367 }
368 }
369 (ValueKind::Quantity(a_val, _), ValueKind::Quantity(b_val, _)) => Some(a_val.cmp(b_val)),
370 _ => None,
371 }
372}
373
374fn extract_data_paths_from_expression(expr: &Expression) -> Vec<DataPath> {
376 let mut set = std::collections::HashSet::new();
377 expr.collect_data_paths(&mut set);
378 set.into_iter().collect()
379}
380
381fn compute_undetermined_data(all_domains: &[HashMap<DataPath, Domain>]) -> Vec<DataPath> {
383 let mut undetermined: HashSet<DataPath> = HashSet::new();
384
385 for solution_domains in all_domains {
386 for (data_path, domain) in solution_domains {
387 let is_determined = matches!(
388 domain,
389 Domain::Enumeration(values) if values.len() == 1
390 );
391 if !is_determined {
392 undetermined.insert(data_path.clone());
393 }
394 }
395 }
396
397 let mut result: Vec<DataPath> = undetermined.into_iter().collect();
398 result.sort_by_key(|a| a.to_string());
399 result
400}
401
402fn compute_is_determined(all_domains: &[HashMap<DataPath, Domain>]) -> bool {
404 if all_domains.is_empty() {
405 return true;
406 }
407
408 for solution_domains in all_domains {
409 for domain in solution_domains.values() {
410 let is_single_value = matches!(
411 domain,
412 Domain::Enumeration(values) if values.len() == 1
413 );
414 if !is_single_value {
415 return false;
416 }
417 }
418 }
419
420 true
421}
422
423#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::parsing::ast::DateTimeValue;
431 use crate::Engine;
432 use rust_decimal::Decimal;
433 use std::collections::HashMap;
434 use std::sync::Arc;
435
436 #[test]
437 fn test_format_target_eq() {
438 let target = Target::value(LiteralValue::number_from_decimal(Decimal::from(42)));
439 let formatted = target.format();
440 assert_eq!(formatted, "= 42");
441 }
442
443 #[test]
444 fn test_format_target_any() {
445 let target = Target::any_value();
446 let formatted = target.format();
447 assert_eq!(formatted, "= any");
448 }
449
450 #[test]
451 fn test_compute_undetermined_data_empty() {
452 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
453 let undetermined = compute_undetermined_data(&domains);
454 assert!(undetermined.is_empty());
455 }
456
457 #[test]
458 fn test_compute_undetermined_data_single_value() {
459 let mut domain_map = HashMap::new();
460 domain_map.insert(
461 DataPath::new(vec![], "age".to_string()),
462 Domain::Enumeration(Arc::new(vec![LiteralValue::number_from_decimal(
463 Decimal::from(25),
464 )])),
465 );
466 let domains = vec![domain_map];
467 let undetermined = compute_undetermined_data(&domains);
468 assert!(undetermined.is_empty());
469 }
470
471 #[test]
472 fn test_compute_undetermined_data_range() {
473 let mut domain_map = HashMap::new();
474 domain_map.insert(
475 DataPath::new(vec![], "age".to_string()),
476 Domain::Range {
477 min: Bound::Exclusive(Arc::new(LiteralValue::number_from_decimal(Decimal::from(
478 18,
479 )))),
480 max: Bound::Unbounded,
481 },
482 );
483 let domains = vec![domain_map];
484 let undetermined = compute_undetermined_data(&domains);
485 assert_eq!(undetermined.len(), 1);
486 }
487
488 #[test]
489 fn test_compute_is_determined_empty() {
490 let domains: Vec<HashMap<DataPath, Domain>> = vec![];
491 assert!(compute_is_determined(&domains));
492 }
493
494 #[test]
495 fn test_compute_is_determined_true() {
496 let mut domain_map = HashMap::new();
497 domain_map.insert(
498 DataPath::new(vec![], "age".to_string()),
499 Domain::Enumeration(Arc::new(vec![LiteralValue::number_from_decimal(
500 Decimal::from(25),
501 )])),
502 );
503 let domains = vec![domain_map];
504 assert!(compute_is_determined(&domains));
505 }
506
507 #[test]
508 fn test_compute_is_determined_false() {
509 let mut domain_map = HashMap::new();
510 domain_map.insert(
511 DataPath::new(vec![], "age".to_string()),
512 Domain::Range {
513 min: Bound::Exclusive(Arc::new(LiteralValue::number_from_decimal(Decimal::from(
514 18,
515 )))),
516 max: Bound::Unbounded,
517 },
518 );
519 let domains = vec![domain_map];
520 assert!(!compute_is_determined(&domains));
521 }
522
523 #[test]
524 fn test_invert_strict_rule_reference_expands_constraints() {
525 let code = r#"
528spec example
529data x: number
530rule base: x
531 unless x > 3 then veto "too much"
532 unless x < 0 then veto "too little"
533
534rule another: base
535 unless x > 5 then veto "way too much"
536"#;
537
538 let mut engine = Engine::new();
539 engine
540 .load(
541 code,
542 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
543 "test.lemma",
544 ))),
545 )
546 .unwrap();
547 let now = DateTimeValue::now();
548
549 let inv = engine
550 .invert(
551 "example",
552 Some(&now),
553 "another",
554 Target::value(LiteralValue::number(3.into())),
555 HashMap::new(),
556 )
557 .expect("inversion should succeed");
558
559 assert!(!inv.is_empty(), "expected at least one solution");
560
561 let x = DataPath::new(vec![], "x".to_string());
562 let three = LiteralValue::number(3.into());
563
564 for (_solution, domains) in inv.iter() {
566 let d = domains.get(&x).expect("domain for x should exist");
567 assert!(
568 d.contains(&three),
569 "x domain should contain 3. Domain: {}",
570 d
571 );
572 }
573 }
574
575 #[test]
576 fn test_invert_strict_no_solution_when_value_is_blocked_by_veto() {
577 let code = r#"
578spec example
579data x: number
580rule base: x
581 unless x > 3 then veto "too much"
582 unless x < 0 then veto "too little"
583
584rule another: base
585 unless x > 5 then veto "way too much"
586"#;
587
588 let mut engine = Engine::new();
589 engine
590 .load(
591 code,
592 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
593 "test.lemma",
594 ))),
595 )
596 .unwrap();
597 let now = DateTimeValue::now();
598
599 let inv = engine
600 .invert(
601 "example",
602 Some(&now),
603 "another",
604 Target::value(LiteralValue::number(7.into())),
605 HashMap::new(),
606 )
607 .expect("inversion should succeed");
608
609 assert!(
610 inv.is_empty(),
611 "Should have no solutions because another can never equal 7"
612 );
613 }
614
615 #[test]
616 fn test_invert_strict_veto_target_constrains_domain() {
617 let code = r#"
618spec example
619data x: number
620rule base: x
621 unless x > 3 then veto "too much"
622 unless x < 0 then veto "too little"
623
624rule another: base
625 unless x > 5 then veto "way too much"
626"#;
627
628 let mut engine = Engine::new();
629 engine
630 .load(
631 code,
632 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
633 "test.lemma",
634 ))),
635 )
636 .unwrap();
637 let now = DateTimeValue::now();
638
639 let inv = engine
640 .invert(
641 "example",
642 Some(&now),
643 "another",
644 Target::veto(Some("way too much".to_string())),
645 HashMap::new(),
646 )
647 .expect("inversion should succeed");
648
649 assert!(!inv.is_empty(), "expected solutions for veto query");
650
651 let x = DataPath::new(vec![], "x".to_string());
652 let five = LiteralValue::number(5.into());
653 let six = LiteralValue::number(6.into());
654
655 for (solution, domains) in inv.iter() {
656 assert_eq!(
657 solution.outcome,
658 OperationResult::Veto(VetoType::UserDefined {
659 message: Some("way too much".to_string()),
660 }),
661 "Expected solution outcome to be veto('way too much'), got: {:?}",
662 solution.outcome
663 );
664
665 let d = domains.get(&x).expect("domain for x should exist");
666 match d {
667 Domain::Range { min, max } => {
668 assert!(
669 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
670 "Expected min bound to be (5), got: {}",
671 d
672 );
673 assert!(
674 matches!(max, Bound::Unbounded),
675 "Expected max bound to be +inf, got: {}",
676 d
677 );
678 }
679 other => panic!("Expected range domain for x, got: {}", other),
680 }
681 assert!(
682 !d.contains(&five),
683 "x=5 should not be in veto('way too much') domain. Domain: {}",
684 d
685 );
686 assert!(
687 d.contains(&six),
688 "x=6 should be in veto('way too much') domain. Domain: {}",
689 d
690 );
691 }
692 }
693
694 #[test]
695 fn test_invert_strict_any_veto_target_matches_all_veto_ranges() {
696 let code = r#"
697spec example
698data x: number
699rule base: x
700 unless x > 3 then veto "too much"
701 unless x < 0 then veto "too little"
702
703rule another: base
704 unless x > 5 then veto "way too much"
705"#;
706
707 let mut engine = Engine::new();
708 engine
709 .load(
710 code,
711 crate::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
712 "test.lemma",
713 ))),
714 )
715 .unwrap();
716
717 let now = DateTimeValue::now();
718 let inv = engine
719 .invert(
720 "example",
721 Some(&now),
722 "another",
723 Target::any_veto(),
724 HashMap::new(),
725 )
726 .expect("inversion should succeed");
727
728 assert!(!inv.is_empty(), "expected solutions for any-veto query");
729
730 let x = DataPath::new(vec![], "x".to_string());
731 let minus_one = LiteralValue::number((-1).into());
732 let zero = LiteralValue::number(0.into());
733 let two = LiteralValue::number(2.into());
734 let three = LiteralValue::number(3.into());
735 let four = LiteralValue::number(4.into());
736 let five = LiteralValue::number(5.into());
737 let six = LiteralValue::number(6.into());
738
739 let mut saw_too_little = false;
740 let mut saw_too_much = false;
741 let mut saw_way_too_much = false;
742
743 for (solution, domains) in inv.iter() {
744 let d = domains.get(&x).expect("domain for x should exist");
745 assert!(
746 !d.contains(&two),
747 "x=2 should not be in any-veto domain. Domain: {}",
748 d
749 );
750
751 match &solution.outcome {
752 OperationResult::Veto(VetoType::UserDefined {
753 message: Some(ref msg),
754 }) if msg == "too little" => {
755 saw_too_little = true;
756
757 match d {
758 Domain::Range { min, max } => {
759 assert!(
760 matches!(min, Bound::Unbounded),
761 "Expected min bound to be -inf for 'too little', got: {}",
762 d
763 );
764 assert!(
765 matches!(max, Bound::Exclusive(v) if v.as_ref() == &zero),
766 "Expected max bound to be (0) for 'too little', got: {}",
767 d
768 );
769 }
770 other => panic!("Expected range domain for x, got: {}", other),
771 }
772
773 assert!(
774 d.contains(&minus_one),
775 "x=-1 should be in veto('too little') domain. Domain: {}",
776 d
777 );
778 assert!(
779 !d.contains(&zero),
780 "x=0 should not be in veto('too little') domain. Domain: {}",
781 d
782 );
783 }
784 OperationResult::Veto(VetoType::UserDefined {
785 message: Some(ref msg),
786 }) if msg == "too much" => {
787 saw_too_much = true;
788
789 match d {
790 Domain::Range { min, max } => {
791 assert!(
792 matches!(min, Bound::Exclusive(v) if v.as_ref() == &three),
793 "Expected min bound to be (3) for 'too much', got: {}",
794 d
795 );
796 assert!(
797 matches!(max, Bound::Inclusive(v) if v.as_ref() == &five),
798 "Expected max bound to be [5] for 'too much', got: {}",
799 d
800 );
801 }
802 other => panic!("Expected range domain for x, got: {}", other),
803 }
804
805 assert!(
806 d.contains(&four),
807 "x=4 should be in veto('too much') domain. Domain: {}",
808 d
809 );
810 assert!(
811 d.contains(&five),
812 "x=5 should be in veto('too much') domain. Domain: {}",
813 d
814 );
815 assert!(
816 !d.contains(&three),
817 "x=3 should not be in veto('too much') domain. Domain: {}",
818 d
819 );
820 assert!(
821 !d.contains(&six),
822 "x=6 should not be in veto('too much') domain. Domain: {}",
823 d
824 );
825 }
826 OperationResult::Veto(VetoType::UserDefined {
827 message: Some(ref msg),
828 }) if msg == "way too much" => {
829 saw_way_too_much = true;
830
831 match d {
832 Domain::Range { min, max } => {
833 assert!(
834 matches!(min, Bound::Exclusive(v) if v.as_ref() == &five),
835 "Expected min bound to be (5) for 'way too much', got: {}",
836 d
837 );
838 assert!(
839 matches!(max, Bound::Unbounded),
840 "Expected max bound to be +inf for 'way too much', got: {}",
841 d
842 );
843 }
844 other => panic!("Expected range domain for x, got: {}", other),
845 }
846
847 assert!(
848 d.contains(&six),
849 "x=6 should be in veto('way too much') domain. Domain: {}",
850 d
851 );
852 assert!(
853 !d.contains(&five),
854 "x=5 should not be in veto('way too much') domain. Domain: {}",
855 d
856 );
857 }
858 OperationResult::Veto(other) => {
859 panic!("Unexpected veto in any-veto results: {:?}", other)
860 }
861 OperationResult::Value(v) => {
862 panic!("Unexpected value result in any-veto results: {:?}", v)
863 }
864 }
865 }
866
867 assert!(
868 saw_too_little,
869 "Expected at least one veto('too little') solution"
870 );
871 assert!(
872 saw_too_much,
873 "Expected at least one veto('too much') solution"
874 );
875 assert!(
876 saw_way_too_much,
877 "Expected at least one veto('way too much') solution"
878 );
879 }
880}