1use crate::planning::semantics::{ComparisonComputation, DataPath, LiteralValue, ValueKind};
9use serde::ser::{Serialize, SerializeStruct, Serializer};
10use std::cmp::Ordering;
11use std::collections::HashMap;
12use std::fmt;
13use std::sync::Arc;
14
15use super::constraint::Constraint;
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum Domain {
20 Range { min: Bound, max: Bound },
22
23 Union(Arc<Vec<Domain>>),
25
26 Enumeration(Arc<Vec<LiteralValue>>),
28
29 Complement(Box<Domain>),
31
32 Unconstrained,
34
35 Empty,
37}
38
39impl Domain {
40 pub fn is_satisfiable(&self) -> bool {
44 match self {
45 Domain::Empty => false,
46 Domain::Enumeration(values) => !values.is_empty(),
47 Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
48 Domain::Range { min, max } => !bounds_contradict(min, max),
49 Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
50 Domain::Unconstrained => true,
51 }
52 }
53
54 pub fn is_empty(&self) -> bool {
56 !self.is_satisfiable()
57 }
58
59 pub fn intersect(&self, other: &Domain) -> Domain {
62 match domain_intersection(self.clone(), other.clone()) {
63 Some(d) => d,
64 None => Domain::Empty,
65 }
66 }
67
68 pub fn contains(&self, value: &LiteralValue) -> bool {
70 match self {
71 Domain::Empty => false,
72 Domain::Unconstrained => true,
73 Domain::Enumeration(values) => values.contains(value),
74 Domain::Range { min, max } => value_within(value, min, max),
75 Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
76 Domain::Complement(inner) => !inner.contains(value),
77 }
78 }
79}
80
81#[derive(Debug, Clone, PartialEq)]
83pub enum Bound {
84 Inclusive(Arc<LiteralValue>),
86
87 Exclusive(Arc<LiteralValue>),
89
90 Unbounded,
92}
93
94impl fmt::Display for Domain {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 match self {
97 Domain::Empty => write!(f, "empty"),
98 Domain::Unconstrained => write!(f, "any"),
99 Domain::Enumeration(vals) => {
100 write!(f, "{{")?;
101 for (i, v) in vals.iter().enumerate() {
102 if i > 0 {
103 write!(f, ", ")?;
104 }
105 write!(f, "{}", v)?;
106 }
107 write!(f, "}}")
108 }
109 Domain::Range { min, max } => {
110 let (l_bracket, r_bracket) = match (min, max) {
111 (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
112 (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
113 (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
114 (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
115 (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
116 (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
117 (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
118 (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
119 (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
120 };
121
122 let min_str = match min {
123 Bound::Unbounded => "-inf".to_string(),
124 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
125 };
126 let max_str = match max {
127 Bound::Unbounded => "+inf".to_string(),
128 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
129 };
130 write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
131 }
132 Domain::Union(parts) => {
133 for (i, p) in parts.iter().enumerate() {
134 if i > 0 {
135 write!(f, " | ")?;
136 }
137 write!(f, "{}", p)?;
138 }
139 Ok(())
140 }
141 Domain::Complement(inner) => write!(f, "not ({})", inner),
142 }
143 }
144}
145
146impl fmt::Display for Bound {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 match self {
149 Bound::Unbounded => write!(f, "inf"),
150 Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
151 Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
152 }
153 }
154}
155
156impl Serialize for Domain {
157 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
158 where
159 S: Serializer,
160 {
161 match self {
162 Domain::Empty => {
163 let mut st = serializer.serialize_struct("domain", 1)?;
164 st.serialize_field("type", "empty")?;
165 st.end()
166 }
167 Domain::Unconstrained => {
168 let mut st = serializer.serialize_struct("domain", 1)?;
169 st.serialize_field("type", "unconstrained")?;
170 st.end()
171 }
172 Domain::Enumeration(vals) => {
173 let mut st = serializer.serialize_struct("domain", 2)?;
174 st.serialize_field("type", "enumeration")?;
175 st.serialize_field("values", vals)?;
176 st.end()
177 }
178 Domain::Range { min, max } => {
179 let mut st = serializer.serialize_struct("domain", 3)?;
180 st.serialize_field("type", "range")?;
181 st.serialize_field("min", min)?;
182 st.serialize_field("max", max)?;
183 st.end()
184 }
185 Domain::Union(parts) => {
186 let mut st = serializer.serialize_struct("domain", 2)?;
187 st.serialize_field("type", "union")?;
188 st.serialize_field("parts", parts)?;
189 st.end()
190 }
191 Domain::Complement(inner) => {
192 let mut st = serializer.serialize_struct("domain", 2)?;
193 st.serialize_field("type", "complement")?;
194 st.serialize_field("inner", inner)?;
195 st.end()
196 }
197 }
198 }
199}
200
201impl Serialize for Bound {
202 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
203 where
204 S: Serializer,
205 {
206 match self {
207 Bound::Unbounded => {
208 let mut st = serializer.serialize_struct("bound", 1)?;
209 st.serialize_field("type", "unbounded")?;
210 st.end()
211 }
212 Bound::Inclusive(v) => {
213 let mut st = serializer.serialize_struct("bound", 2)?;
214 st.serialize_field("type", "inclusive")?;
215 st.serialize_field("value", v.as_ref())?;
216 st.end()
217 }
218 Bound::Exclusive(v) => {
219 let mut st = serializer.serialize_struct("bound", 2)?;
220 st.serialize_field("type", "exclusive")?;
221 st.serialize_field("value", v.as_ref())?;
222 st.end()
223 }
224 }
225 }
226}
227
228pub fn extract_domains_from_constraint(
230 constraint: &Constraint,
231) -> Result<HashMap<DataPath, Domain>, crate::Error> {
232 let all_datas = constraint.collect_data();
233 let mut domains = HashMap::new();
234
235 for data_path in all_datas {
236 let domain =
240 extract_domain_for_data(constraint, &data_path)?.unwrap_or(Domain::Unconstrained);
241 domains.insert(data_path, domain);
242 }
243
244 Ok(domains)
245}
246
247fn extract_domain_for_data(
248 constraint: &Constraint,
249 data_path: &DataPath,
250) -> Result<Option<Domain>, crate::Error> {
251 let domain = match constraint {
252 Constraint::True => return Ok(None),
253 Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
254
255 Constraint::Comparison { data, op, value } => {
256 if data == data_path {
257 Some(comparison_to_domain(op, value.as_ref())?)
258 } else {
259 None
260 }
261 }
262
263 Constraint::Data(fp) => {
264 if fp == data_path {
265 Some(Domain::Enumeration(Arc::new(vec![
266 LiteralValue::from_bool(true),
267 ])))
268 } else {
269 None
270 }
271 }
272
273 Constraint::And(left, right) => {
274 let left_domain = extract_domain_for_data(left, data_path)?;
275 let right_domain = extract_domain_for_data(right, data_path)?;
276 match (left_domain, right_domain) {
277 (None, None) => None,
278 (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
279 (Some(a), Some(b)) => match domain_intersection(a, b) {
280 Some(domain) => Some(domain),
281 None => Some(Domain::Enumeration(Arc::new(vec![]))),
282 },
283 }
284 }
285
286 Constraint::Or(left, right) => {
287 let left_domain = extract_domain_for_data(left, data_path)?;
288 let right_domain = extract_domain_for_data(right, data_path)?;
289 union_optional_domains(left_domain, right_domain)
290 }
291
292 Constraint::Not(inner) => {
293 if let Constraint::Comparison { data, op, value } = inner.as_ref() {
295 if data == data_path && op.is_equal() {
296 return Ok(Some(normalize_domain(Domain::Complement(Box::new(
297 Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
298 )))));
299 }
300 }
301
302 if let Constraint::Data(fp) = inner.as_ref() {
304 if fp == data_path {
305 return Ok(Some(Domain::Enumeration(Arc::new(vec![
306 LiteralValue::from_bool(false),
307 ]))));
308 }
309 }
310
311 extract_domain_for_data(inner, data_path)?
312 .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
313 }
314 };
315
316 Ok(domain.map(normalize_domain))
317}
318
319fn comparison_to_domain(
320 op: &ComparisonComputation,
321 value: &LiteralValue,
322) -> Result<Domain, crate::Error> {
323 if op.is_equal() {
324 return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
325 }
326 if op.is_not_equal() {
327 return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
328 vec![value.clone()],
329 )))));
330 }
331 match op {
332 ComparisonComputation::LessThan => Ok(Domain::Range {
333 min: Bound::Unbounded,
334 max: Bound::Exclusive(Arc::new(value.clone())),
335 }),
336 ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
337 min: Bound::Unbounded,
338 max: Bound::Inclusive(Arc::new(value.clone())),
339 }),
340 ComparisonComputation::GreaterThan => Ok(Domain::Range {
341 min: Bound::Exclusive(Arc::new(value.clone())),
342 max: Bound::Unbounded,
343 }),
344 ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
345 min: Bound::Inclusive(Arc::new(value.clone())),
346 max: Bound::Unbounded,
347 }),
348 _ => unreachable!(
349 "BUG: unsupported comparison operator for domain extraction: {:?}",
350 op
351 ),
352 }
353}
354
355pub(crate) fn domain_for_comparison_atom(
360 op: &ComparisonComputation,
361 value: &LiteralValue,
362) -> Result<Domain, crate::Error> {
363 comparison_to_domain(op, value)
364}
365
366impl Domain {
367 pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
374 match (self, other) {
375 (Domain::Empty, _) => true,
376 (_, Domain::Unconstrained) => true,
377 (Domain::Unconstrained, _) => false,
378
379 (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
380 (Domain::Enumeration(vals), Domain::Range { min, max }) => {
381 vals.iter().all(|v| value_within(v, min, max))
382 }
383
384 (
385 Domain::Range {
386 min: amin,
387 max: amax,
388 },
389 Domain::Range {
390 min: bmin,
391 max: bmax,
392 },
393 ) => range_within_range(amin, amax, bmin, bmax),
394
395 (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
397 Domain::Enumeration(excluded) => {
398 excluded.iter().all(|p| !value_within(p, min, max))
399 }
400 _ => false,
401 },
402
403 (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
405 Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
406 _ => false,
407 },
408
409 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
411 match (a_inner.as_ref(), b_inner.as_ref()) {
412 (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
413 excluded_b.iter().all(|v| excluded_a.contains(v))
414 }
415 _ => false,
416 }
417 }
418
419 _ => false,
420 }
421 }
422}
423
424fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
425 lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
426}
427
428fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
429 match (a, b) {
430 (_, Bound::Unbounded) => true,
431 (Bound::Unbounded, _) => false,
432 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
433 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
434 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
435 let c = lit_cmp(av.as_ref(), bv.as_ref());
436 c >= 0
437 }
438 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
439 lit_cmp(av.as_ref(), bv.as_ref()) > 0
441 }
442 }
443}
444
445fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
446 match (a, b) {
447 (Bound::Unbounded, Bound::Unbounded) => true,
448 (_, Bound::Unbounded) => true,
449 (Bound::Unbounded, _) => false,
450 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
451 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
452 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
453 lit_cmp(av.as_ref(), bv.as_ref()) <= 0
455 }
456 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
457 lit_cmp(av.as_ref(), bv.as_ref()) < 0
459 }
460 }
461}
462
463fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
464 match (a, b) {
465 (None, None) => None,
466 (Some(d), None) | (None, Some(d)) => Some(d),
467 (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
468 }
469}
470
471fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
472 use std::cmp::Ordering;
473
474 match (&a.value, &b.value) {
475 (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
476 Ordering::Less => -1,
477 Ordering::Equal => 0,
478 Ordering::Greater => 1,
479 },
480
481 (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
482 Ordering::Less => -1,
483 Ordering::Equal => 0,
484 Ordering::Greater => 1,
485 },
486
487 (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
488 Ordering::Less => -1,
489 Ordering::Equal => 0,
490 Ordering::Greater => 1,
491 },
492
493 (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
494 Ordering::Less => -1,
495 Ordering::Equal => 0,
496 Ordering::Greater => 1,
497 },
498
499 (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
500 Ordering::Less => -1,
501 Ordering::Equal => 0,
502 Ordering::Greater => 1,
503 },
504
505 (ValueKind::Quantity(la, sig_a), ValueKind::Quantity(lb, sig_b))
506 if a.lemma_type.is_calendar_like() && b.lemma_type.is_calendar_like() =>
507 {
508 let lu =
509 crate::planning::semantics::semantic_calendar_unit_from_quantity_signature(sig_a);
510 let lu2 =
511 crate::planning::semantics::semantic_calendar_unit_from_quantity_signature(sig_b);
512 let a_months = crate::computation::units::convert_calendar_magnitude(
513 *la,
514 &lu,
515 &crate::planning::semantics::SemanticCalendarUnit::Month,
516 );
517 let b_months = crate::computation::units::convert_calendar_magnitude(
518 *lb,
519 &lu2,
520 &crate::planning::semantics::SemanticCalendarUnit::Month,
521 );
522 match a_months.cmp(&b_months) {
523 Ordering::Less => -1,
524 Ordering::Equal => 0,
525 Ordering::Greater => 1,
526 }
527 }
528
529 (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
530 Ordering::Less => -1,
531 Ordering::Equal => 0,
532 Ordering::Greater => 1,
533 },
534
535 (ValueKind::Quantity(la, _), ValueKind::Quantity(lb, _)) => {
536 let a_decomp = a
537 .lemma_type
538 .quantity_type_decomposition()
539 .expect("BUG: decomposition must be resolved after planning");
540 let b_decomp = b
541 .lemma_type
542 .quantity_type_decomposition()
543 .expect("BUG: decomposition must be resolved after planning");
544 if a_decomp != b_decomp {
545 unreachable!(
546 "BUG: lit_cmp compared quantities with different decompositions ({:?} vs {:?})",
547 a_decomp, b_decomp
548 );
549 }
550 match la.cmp(lb) {
551 Ordering::Less => -1,
552 Ordering::Equal => 0,
553 Ordering::Greater => 1,
554 }
555 }
556
557 _ => unreachable!(
558 "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
559 a.get_type(),
560 b.get_type()
561 ),
562 }
563}
564
565fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
566 let ge_min = match min {
567 Bound::Unbounded => true,
568 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
569 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
570 };
571 let le_max = match max {
572 Bound::Unbounded => true,
573 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
574 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
575 };
576 ge_min && le_max
577}
578
579fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
580 match (min, max) {
581 (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
582 (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
583 (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
584 (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
585 (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
586 }
587}
588
589fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
590 match (min1, min2) {
591 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
592 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
593 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
594 Bound::Inclusive(v1)
595 } else {
596 Bound::Inclusive(v2)
597 }
598 }
599 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
600 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
601 Bound::Inclusive(v1)
602 } else {
603 Bound::Exclusive(v2)
604 }
605 }
606 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
607 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
608 Bound::Exclusive(v1)
609 } else {
610 Bound::Inclusive(v2)
611 }
612 }
613 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
614 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
615 Bound::Exclusive(v1)
616 } else {
617 Bound::Exclusive(v2)
618 }
619 }
620 }
621}
622
623fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
624 match (max1, max2) {
625 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
626 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
627 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
628 Bound::Inclusive(v1)
629 } else {
630 Bound::Inclusive(v2)
631 }
632 }
633 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
634 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
635 Bound::Inclusive(v1)
636 } else {
637 Bound::Exclusive(v2)
638 }
639 }
640 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
641 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
642 Bound::Exclusive(v1)
643 } else {
644 Bound::Inclusive(v2)
645 }
646 }
647 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
648 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
649 Bound::Exclusive(v1)
650 } else {
651 Bound::Exclusive(v2)
652 }
653 }
654 }
655}
656
657fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
658 let a = normalize_domain(a);
659 let b = normalize_domain(b);
660
661 let result = match (a, b) {
662 (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
663 (Domain::Empty, _) | (_, Domain::Empty) => None,
664
665 (
666 Domain::Range {
667 min: min1,
668 max: max1,
669 },
670 Domain::Range {
671 min: min2,
672 max: max2,
673 },
674 ) => {
675 let min = compute_intersection_min(min1, min2);
676 let max = compute_intersection_max(max1, max2);
677
678 if bounds_contradict(&min, &max) {
679 None
680 } else {
681 Some(Domain::Range { min, max })
682 }
683 }
684 (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
685 let filtered: Vec<LiteralValue> =
686 v1.iter().filter(|x| v2.contains(x)).cloned().collect();
687 if filtered.is_empty() {
688 None
689 } else {
690 Some(Domain::Enumeration(Arc::new(filtered)))
691 }
692 }
693 (Domain::Enumeration(vs), Domain::Range { min, max })
694 | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
695 let mut kept = Vec::new();
696 for v in vs.iter() {
697 if value_within(v, &min, &max) {
698 kept.push(v.clone());
699 }
700 }
701 if kept.is_empty() {
702 None
703 } else {
704 Some(Domain::Enumeration(Arc::new(kept)))
705 }
706 }
707 (Domain::Enumeration(vs), Domain::Complement(inner))
708 | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
709 match *inner.clone() {
710 Domain::Enumeration(excluded) => {
711 let mut kept = Vec::new();
712 for v in vs.iter() {
713 if !excluded.contains(v) {
714 kept.push(v.clone());
715 }
716 }
717 if kept.is_empty() {
718 None
719 } else {
720 Some(Domain::Enumeration(Arc::new(kept)))
721 }
722 }
723 Domain::Range { min, max } => {
724 let mut kept = Vec::new();
726 for v in vs.iter() {
727 if !value_within(v, &min, &max) {
728 kept.push(v.clone());
729 }
730 }
731 if kept.is_empty() {
732 None
733 } else {
734 Some(Domain::Enumeration(Arc::new(kept)))
735 }
736 }
737 _ => {
738 let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
740 domain_intersection(Domain::Enumeration(vs.clone()), normalized)
741 }
742 }
743 }
744 (Domain::Union(v1), Domain::Union(v2)) => {
745 let mut acc: Vec<Domain> = Vec::new();
746 for a in v1.iter() {
747 for b in v2.iter() {
748 if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
749 acc.push(ix);
750 }
751 }
752 }
753 if acc.is_empty() {
754 None
755 } else {
756 Some(Domain::Union(Arc::new(acc)))
757 }
758 }
759 (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
760 let mut acc: Vec<Domain> = Vec::new();
761 for a in vs.iter() {
762 if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
763 acc.push(ix);
764 }
765 }
766 if acc.is_empty() {
767 None
768 } else if acc.len() == 1 {
769 Some(acc.remove(0))
770 } else {
771 Some(Domain::Union(Arc::new(acc)))
772 }
773 }
774 (Domain::Range { min, max }, Domain::Complement(inner))
776 | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
777 Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
778 _ => {
779 let normalized_complement = normalize_domain(Domain::Complement(inner));
782 if matches!(&normalized_complement, Domain::Complement(_)) {
783 None
784 } else {
785 domain_intersection(Domain::Range { min, max }, normalized_complement)
786 }
787 }
788 },
789 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
790 match (a_inner.as_ref(), b_inner.as_ref()) {
791 (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
792 let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
794 excluded.extend(b_ex.iter().cloned());
795 Some(normalize_domain(Domain::Complement(Box::new(
796 Domain::Enumeration(Arc::new(excluded)),
797 ))))
798 }
799 _ => None,
800 }
801 }
802 };
803 result.map(normalize_domain)
804}
805
806fn range_minus_excluded_points(
807 min: Bound,
808 max: Bound,
809 excluded: &Arc<Vec<LiteralValue>>,
810) -> Option<Domain> {
811 let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
813
814 for p in excluded.iter() {
815 let mut next: Vec<(Bound, Bound)> = Vec::new();
816
817 for (rmin, rmax) in parts {
818 if !value_within(p, &rmin, &rmax) {
819 next.push((rmin, rmax));
820 continue;
821 }
822
823 let left_max = Bound::Exclusive(Arc::new(p.clone()));
825 if !bounds_contradict(&rmin, &left_max) {
826 next.push((rmin.clone(), left_max));
827 }
828
829 let right_min = Bound::Exclusive(Arc::new(p.clone()));
831 if !bounds_contradict(&right_min, &rmax) {
832 next.push((right_min, rmax.clone()));
833 }
834 }
835
836 parts = next;
837 if parts.is_empty() {
838 return None;
839 }
840 }
841
842 if parts.is_empty() {
843 None
844 } else if parts.len() == 1 {
845 let (min, max) = parts.remove(0);
846 Some(Domain::Range { min, max })
847 } else {
848 Some(Domain::Union(Arc::new(
849 parts
850 .into_iter()
851 .map(|(min, max)| Domain::Range { min, max })
852 .collect(),
853 )))
854 }
855}
856
857fn invert_bound(bound: Bound) -> Bound {
858 match bound {
859 Bound::Unbounded => Bound::Unbounded,
860 Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
861 Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
862 }
863}
864
865fn normalize_domain(d: Domain) -> Domain {
866 match d {
867 Domain::Complement(inner) => {
868 let normalized_inner = normalize_domain(*inner);
869 match normalized_inner {
870 Domain::Complement(double_inner) => *double_inner,
871 Domain::Range { min, max } => match (&min, &max) {
872 (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
873 (Bound::Unbounded, max) => Domain::Range {
874 min: invert_bound(max.clone()),
875 max: Bound::Unbounded,
876 },
877 (min, Bound::Unbounded) => Domain::Range {
878 min: Bound::Unbounded,
879 max: invert_bound(min.clone()),
880 },
881 (min, max) => Domain::Union(Arc::new(vec![
882 Domain::Range {
883 min: Bound::Unbounded,
884 max: invert_bound(min.clone()),
885 },
886 Domain::Range {
887 min: invert_bound(max.clone()),
888 max: Bound::Unbounded,
889 },
890 ])),
891 },
892 Domain::Enumeration(vals) => {
893 if vals.len() == 1 {
894 if let Some(lit) = vals.first() {
895 if let ValueKind::Boolean(true) = &lit.value {
896 return Domain::Enumeration(Arc::new(vec![
897 LiteralValue::from_bool(false),
898 ]));
899 }
900 if let ValueKind::Boolean(false) = &lit.value {
901 return Domain::Enumeration(Arc::new(vec![
902 LiteralValue::from_bool(true),
903 ]));
904 }
905 }
906 }
907 Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
908 }
909 Domain::Unconstrained => Domain::Empty,
910 Domain::Empty => Domain::Unconstrained,
911 Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
912 }
913 }
914 Domain::Empty => Domain::Empty,
915 Domain::Union(parts) => {
916 let mut flat: Vec<Domain> = Vec::new();
917 for p in parts.iter().cloned() {
918 let normalized = normalize_domain(p);
919 match normalized {
920 Domain::Union(inner) => flat.extend(inner.iter().cloned()),
921 Domain::Unconstrained => return Domain::Unconstrained,
922 Domain::Enumeration(vals) if vals.is_empty() => {}
923 other => flat.push(other),
924 }
925 }
926
927 let mut all_enum_values: Vec<LiteralValue> = Vec::new();
928 let mut ranges: Vec<Domain> = Vec::new();
929 let mut others: Vec<Domain> = Vec::new();
930
931 for domain in flat {
932 match domain {
933 Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
934 Domain::Range { .. } => ranges.push(domain),
935 other => others.push(other),
936 }
937 }
938
939 all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
940 -1 => Ordering::Less,
941 0 => Ordering::Equal,
942 _ => Ordering::Greater,
943 });
944 all_enum_values.dedup();
945
946 all_enum_values.retain(|v| {
947 !ranges.iter().any(|r| {
948 if let Domain::Range { min, max } = r {
949 value_within(v, min, max)
950 } else {
951 false
952 }
953 })
954 });
955
956 let mut result: Vec<Domain> = Vec::new();
957 result.extend(ranges);
958 result = merge_ranges(result);
959
960 if !all_enum_values.is_empty() {
961 result.push(Domain::Enumeration(Arc::new(all_enum_values)));
962 }
963 result.extend(others);
964
965 result.sort_by(|a, b| match (a, b) {
966 (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
967 (Domain::Range { .. }, _) => Ordering::Less,
968 (_, Domain::Range { .. }) => Ordering::Greater,
969 (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
970 (Domain::Enumeration(_), _) => Ordering::Less,
971 (_, Domain::Enumeration(_)) => Ordering::Greater,
972 _ => Ordering::Equal,
973 });
974
975 if result.is_empty() {
976 Domain::Enumeration(Arc::new(vec![]))
977 } else if result.len() == 1 {
978 result.remove(0)
979 } else {
980 Domain::Union(Arc::new(result))
981 }
982 }
983 Domain::Enumeration(values) => {
984 let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
985 sorted.sort_by(|a, b| match lit_cmp(a, b) {
986 -1 => Ordering::Less,
987 0 => Ordering::Equal,
988 _ => Ordering::Greater,
989 });
990 sorted.dedup();
991 Domain::Enumeration(Arc::new(sorted))
992 }
993 other => other,
994 }
995}
996
997fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
998 let mut result = Vec::new();
999 let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1000 let mut others = Vec::new();
1001
1002 for d in domains {
1003 match d {
1004 Domain::Range { min, max } => ranges.push((min, max)),
1005 other => others.push(other),
1006 }
1007 }
1008
1009 if ranges.is_empty() {
1010 return others;
1011 }
1012
1013 ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1014
1015 let mut merged: Vec<(Bound, Bound)> = Vec::new();
1016 let mut current = ranges[0].clone();
1017
1018 for next in ranges.iter().skip(1) {
1019 if ranges_adjacent_or_overlap(¤t, next) {
1020 current = (
1021 min_bound(¤t.0, &next.0),
1022 max_bound(¤t.1, &next.1),
1023 );
1024 } else {
1025 merged.push(current);
1026 current = next.clone();
1027 }
1028 }
1029 merged.push(current);
1030
1031 for (min, max) in merged {
1032 result.push(Domain::Range { min, max });
1033 }
1034 result.extend(others);
1035
1036 result
1037}
1038
1039fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1040 match (a, b) {
1041 (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1042 (Bound::Unbounded, _) => Ordering::Less,
1043 (_, Bound::Unbounded) => Ordering::Greater,
1044 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1045 | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1046 -1 => Ordering::Less,
1047 0 => Ordering::Equal,
1048 _ => Ordering::Greater,
1049 },
1050 (Bound::Inclusive(v1), Bound::Exclusive(v2))
1051 | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1052 -1 => Ordering::Less,
1053 0 => {
1054 if matches!(a, Bound::Inclusive(_)) {
1055 Ordering::Less
1056 } else {
1057 Ordering::Greater
1058 }
1059 }
1060 _ => Ordering::Greater,
1061 },
1062 }
1063}
1064
1065fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1066 match (&r1.1, &r2.0) {
1067 (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1068 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1069 | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1070 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1071 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1072 }
1073}
1074
1075fn min_bound(a: &Bound, b: &Bound) -> Bound {
1076 match (a, b) {
1077 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1078 _ => {
1079 if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1080 a.clone()
1081 } else {
1082 b.clone()
1083 }
1084 }
1085 }
1086}
1087
1088fn max_bound(a: &Bound, b: &Bound) -> Bound {
1089 match (a, b) {
1090 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1091 _ => {
1092 if matches!(compare_bounds(a, b), Ordering::Greater) {
1093 a.clone()
1094 } else {
1095 b.clone()
1096 }
1097 }
1098 }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103 use super::*;
1104 fn num(n: i64) -> LiteralValue {
1105 LiteralValue::number(crate::computation::rational::RationalInteger::new(
1106 n as i128, 1,
1107 ))
1108 }
1109
1110 fn data(name: &str) -> DataPath {
1111 DataPath::new(vec![], name.to_string())
1112 }
1113
1114 #[test]
1115 fn test_normalize_double_complement() {
1116 let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1117 let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1118 let normalized = normalize_domain(double);
1119 assert_eq!(normalized, inner);
1120 }
1121
1122 #[test]
1123 fn test_normalize_union_absorbs_unconstrained() {
1124 let union = Domain::Union(Arc::new(vec![
1125 Domain::Range {
1126 min: Bound::Inclusive(Arc::new(num(0))),
1127 max: Bound::Inclusive(Arc::new(num(10))),
1128 },
1129 Domain::Unconstrained,
1130 ]));
1131 let normalized = normalize_domain(union);
1132 assert_eq!(normalized, Domain::Unconstrained);
1133 }
1134
1135 #[test]
1136 fn test_domain_display() {
1137 let range = Domain::Range {
1138 min: Bound::Inclusive(Arc::new(num(10))),
1139 max: Bound::Exclusive(Arc::new(num(20))),
1140 };
1141 assert_eq!(format!("{}", range), "[10, 20)");
1142
1143 let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1144 assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1145 }
1146
1147 #[test]
1148 fn test_extract_domain_from_comparison() {
1149 let constraint = Constraint::Comparison {
1150 data: data("age"),
1151 op: ComparisonComputation::GreaterThan,
1152 value: Arc::new(num(18)),
1153 };
1154
1155 let domains = extract_domains_from_constraint(&constraint).unwrap();
1156 let age_domain = domains.get(&data("age")).unwrap();
1157
1158 assert_eq!(
1159 *age_domain,
1160 Domain::Range {
1161 min: Bound::Exclusive(Arc::new(num(18))),
1162 max: Bound::Unbounded,
1163 }
1164 );
1165 }
1166
1167 #[test]
1168 fn test_extract_domain_from_and() {
1169 let constraint = Constraint::And(
1170 Box::new(Constraint::Comparison {
1171 data: data("age"),
1172 op: ComparisonComputation::GreaterThan,
1173 value: Arc::new(num(18)),
1174 }),
1175 Box::new(Constraint::Comparison {
1176 data: data("age"),
1177 op: ComparisonComputation::LessThan,
1178 value: Arc::new(num(65)),
1179 }),
1180 );
1181
1182 let domains = extract_domains_from_constraint(&constraint).unwrap();
1183 let age_domain = domains.get(&data("age")).unwrap();
1184
1185 assert_eq!(
1186 *age_domain,
1187 Domain::Range {
1188 min: Bound::Exclusive(Arc::new(num(18))),
1189 max: Bound::Exclusive(Arc::new(num(65))),
1190 }
1191 );
1192 }
1193
1194 #[test]
1195 fn test_extract_domain_from_equality() {
1196 let constraint = Constraint::Comparison {
1197 data: data("status"),
1198 op: ComparisonComputation::Is,
1199 value: Arc::new(LiteralValue::text("active".to_string())),
1200 };
1201
1202 let domains = extract_domains_from_constraint(&constraint).unwrap();
1203 let status_domain = domains.get(&data("status")).unwrap();
1204
1205 assert_eq!(
1206 *status_domain,
1207 Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1208 );
1209 }
1210
1211 #[test]
1212 fn test_extract_domain_from_boolean_data() {
1213 let constraint = Constraint::Data(data("is_active"));
1214
1215 let domains = extract_domains_from_constraint(&constraint).unwrap();
1216 let is_active_domain = domains.get(&data("is_active")).unwrap();
1217
1218 assert_eq!(
1219 *is_active_domain,
1220 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1221 );
1222 }
1223
1224 #[test]
1225 fn test_extract_domain_from_not_boolean_data() {
1226 let constraint = Constraint::Not(Box::new(Constraint::Data(data("is_active"))));
1227
1228 let domains = extract_domains_from_constraint(&constraint).unwrap();
1229 let is_active_domain = domains.get(&data("is_active")).unwrap();
1230
1231 assert_eq!(
1232 *is_active_domain,
1233 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1234 );
1235 }
1236
1237 #[test]
1248 fn inversion_lit_cmp_uses_signature_factor_for_compatible_dimensions() {
1249 use crate::engine::Engine;
1250 use crate::parsing::source::SourceType;
1251 use std::collections::HashMap;
1252 use std::path::PathBuf;
1253 use std::sync::Arc as StdArc;
1254
1255 let code = r#"spec lit_cmp_test
1260uses lemma units
1261data money: quantity
1262 -> unit eur 1
1263data rate: quantity
1264 -> unit eur_per_minute eur/minute
1265data r: 40 eur_per_minute
1266data h: 2 hour
1267rule compound_cost: (r * h) as eur
1268rule fixed_cost: 1000 eur
1269"#;
1270 let mut engine = Engine::new();
1271 engine
1272 .load(
1273 code,
1274 SourceType::Path(StdArc::new(PathBuf::from("t.lemma"))),
1275 )
1276 .expect("must load");
1277 let response = engine
1278 .run(None, "lit_cmp_test", None, HashMap::new(), true)
1279 .expect("must eval");
1280 let compound = response
1281 .results
1282 .get("compound_cost")
1283 .unwrap()
1284 .trace
1285 .as_ref()
1286 .expect("trace")
1287 .result
1288 .value()
1289 .unwrap()
1290 .clone();
1291 let fixed = response
1292 .results
1293 .get("fixed_cost")
1294 .unwrap()
1295 .trace
1296 .as_ref()
1297 .expect("trace")
1298 .result
1299 .value()
1300 .unwrap()
1301 .clone();
1302 assert!(
1304 lit_cmp(&compound, &fixed) > 0,
1305 "compound (4800 eur) must be greater than fixed (1000 eur); got {}",
1306 lit_cmp(&compound, &fixed)
1307 );
1308 }
1309}