1use crate::planning::semantics::{
9 ComparisonComputation, FactPath, LiteralValue, SemanticConversionTarget, ValueKind,
10};
11use crate::OperationResult;
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23 Range { min: Bound, max: Bound },
25
26 Union(Arc<Vec<Domain>>),
28
29 Enumeration(Arc<Vec<LiteralValue>>),
31
32 Complement(Box<Domain>),
34
35 Unconstrained,
37
38 Empty,
40}
41
42impl Domain {
43 pub fn is_satisfiable(&self) -> bool {
47 match self {
48 Domain::Empty => false,
49 Domain::Enumeration(values) => !values.is_empty(),
50 Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51 Domain::Range { min, max } => !bounds_contradict(min, max),
52 Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53 Domain::Unconstrained => true,
54 }
55 }
56
57 pub fn is_empty(&self) -> bool {
59 !self.is_satisfiable()
60 }
61
62 pub fn intersect(&self, other: &Domain) -> Domain {
65 match domain_intersection(self.clone(), other.clone()) {
66 Some(d) => d,
67 None => Domain::Empty,
68 }
69 }
70
71 pub fn contains(&self, value: &LiteralValue) -> bool {
73 match self {
74 Domain::Empty => false,
75 Domain::Unconstrained => true,
76 Domain::Enumeration(values) => values.contains(value),
77 Domain::Range { min, max } => value_within(value, min, max),
78 Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
79 Domain::Complement(inner) => !inner.contains(value),
80 }
81 }
82}
83
84#[derive(Debug, Clone, PartialEq)]
86pub enum Bound {
87 Inclusive(Arc<LiteralValue>),
89
90 Exclusive(Arc<LiteralValue>),
92
93 Unbounded,
95}
96
97impl fmt::Display for Domain {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 match self {
100 Domain::Empty => write!(f, "empty"),
101 Domain::Unconstrained => write!(f, "any"),
102 Domain::Enumeration(vals) => {
103 write!(f, "{{")?;
104 for (i, v) in vals.iter().enumerate() {
105 if i > 0 {
106 write!(f, ", ")?;
107 }
108 write!(f, "{}", v)?;
109 }
110 write!(f, "}}")
111 }
112 Domain::Range { min, max } => {
113 let (l_bracket, r_bracket) = match (min, max) {
114 (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
115 (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
116 (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
117 (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
118 (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
119 (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
120 (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
121 (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
122 (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
123 };
124
125 let min_str = match min {
126 Bound::Unbounded => "-inf".to_string(),
127 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128 };
129 let max_str = match max {
130 Bound::Unbounded => "+inf".to_string(),
131 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
132 };
133 write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
134 }
135 Domain::Union(parts) => {
136 for (i, p) in parts.iter().enumerate() {
137 if i > 0 {
138 write!(f, " | ")?;
139 }
140 write!(f, "{}", p)?;
141 }
142 Ok(())
143 }
144 Domain::Complement(inner) => write!(f, "not ({})", inner),
145 }
146 }
147}
148
149impl fmt::Display for Bound {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 match self {
152 Bound::Unbounded => write!(f, "inf"),
153 Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
154 Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
155 }
156 }
157}
158
159impl Serialize for Domain {
160 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161 where
162 S: Serializer,
163 {
164 match self {
165 Domain::Empty => {
166 let mut st = serializer.serialize_struct("domain", 1)?;
167 st.serialize_field("type", "empty")?;
168 st.end()
169 }
170 Domain::Unconstrained => {
171 let mut st = serializer.serialize_struct("domain", 1)?;
172 st.serialize_field("type", "unconstrained")?;
173 st.end()
174 }
175 Domain::Enumeration(vals) => {
176 let mut st = serializer.serialize_struct("domain", 2)?;
177 st.serialize_field("type", "enumeration")?;
178 st.serialize_field("values", vals)?;
179 st.end()
180 }
181 Domain::Range { min, max } => {
182 let mut st = serializer.serialize_struct("domain", 3)?;
183 st.serialize_field("type", "range")?;
184 st.serialize_field("min", min)?;
185 st.serialize_field("max", max)?;
186 st.end()
187 }
188 Domain::Union(parts) => {
189 let mut st = serializer.serialize_struct("domain", 2)?;
190 st.serialize_field("type", "union")?;
191 st.serialize_field("parts", parts)?;
192 st.end()
193 }
194 Domain::Complement(inner) => {
195 let mut st = serializer.serialize_struct("domain", 2)?;
196 st.serialize_field("type", "complement")?;
197 st.serialize_field("inner", inner)?;
198 st.end()
199 }
200 }
201 }
202}
203
204impl Serialize for Bound {
205 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206 where
207 S: Serializer,
208 {
209 match self {
210 Bound::Unbounded => {
211 let mut st = serializer.serialize_struct("bound", 1)?;
212 st.serialize_field("type", "unbounded")?;
213 st.end()
214 }
215 Bound::Inclusive(v) => {
216 let mut st = serializer.serialize_struct("bound", 2)?;
217 st.serialize_field("type", "inclusive")?;
218 st.serialize_field("value", v.as_ref())?;
219 st.end()
220 }
221 Bound::Exclusive(v) => {
222 let mut st = serializer.serialize_struct("bound", 2)?;
223 st.serialize_field("type", "exclusive")?;
224 st.serialize_field("value", v.as_ref())?;
225 st.end()
226 }
227 }
228 }
229}
230
231pub fn extract_domains_from_constraint(
233 constraint: &Constraint,
234) -> Result<HashMap<FactPath, Domain>, crate::Error> {
235 let all_facts = constraint.collect_facts();
236 let mut domains = HashMap::new();
237
238 for fact_path in all_facts {
239 let domain =
243 extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
244 domains.insert(fact_path, domain);
245 }
246
247 Ok(domains)
248}
249
250fn extract_domain_for_fact(
251 constraint: &Constraint,
252 fact_path: &FactPath,
253) -> Result<Option<Domain>, crate::Error> {
254 let domain = match constraint {
255 Constraint::True => return Ok(None),
256 Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
257
258 Constraint::Comparison { fact, op, value } => {
259 if fact == fact_path {
260 Some(comparison_to_domain(op, value.as_ref())?)
261 } else {
262 None
263 }
264 }
265
266 Constraint::Fact(fp) => {
267 if fp == fact_path {
268 Some(Domain::Enumeration(Arc::new(vec![
269 LiteralValue::from_bool(true),
270 ])))
271 } else {
272 None
273 }
274 }
275
276 Constraint::And(left, right) => {
277 let left_domain = extract_domain_for_fact(left, fact_path)?;
278 let right_domain = extract_domain_for_fact(right, fact_path)?;
279 match (left_domain, right_domain) {
280 (None, None) => None,
281 (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
282 (Some(a), Some(b)) => match domain_intersection(a, b) {
283 Some(domain) => Some(domain),
284 None => Some(Domain::Enumeration(Arc::new(vec![]))),
285 },
286 }
287 }
288
289 Constraint::Or(left, right) => {
290 let left_domain = extract_domain_for_fact(left, fact_path)?;
291 let right_domain = extract_domain_for_fact(right, fact_path)?;
292 union_optional_domains(left_domain, right_domain)
293 }
294
295 Constraint::Not(inner) => {
296 if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
298 if fact == fact_path && op.is_equal() {
299 return Ok(Some(normalize_domain(Domain::Complement(Box::new(
300 Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
301 )))));
302 }
303 }
304
305 if let Constraint::Fact(fp) = inner.as_ref() {
307 if fp == fact_path {
308 return Ok(Some(Domain::Enumeration(Arc::new(vec![
309 LiteralValue::from_bool(false),
310 ]))));
311 }
312 }
313
314 extract_domain_for_fact(inner, fact_path)?
315 .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
316 }
317 };
318
319 Ok(domain.map(normalize_domain))
320}
321
322fn comparison_to_domain(
323 op: &ComparisonComputation,
324 value: &LiteralValue,
325) -> Result<Domain, crate::Error> {
326 if op.is_equal() {
327 return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
328 }
329 if op.is_not_equal() {
330 return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
331 vec![value.clone()],
332 )))));
333 }
334 match op {
335 ComparisonComputation::LessThan => Ok(Domain::Range {
336 min: Bound::Unbounded,
337 max: Bound::Exclusive(Arc::new(value.clone())),
338 }),
339 ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
340 min: Bound::Unbounded,
341 max: Bound::Inclusive(Arc::new(value.clone())),
342 }),
343 ComparisonComputation::GreaterThan => Ok(Domain::Range {
344 min: Bound::Exclusive(Arc::new(value.clone())),
345 max: Bound::Unbounded,
346 }),
347 ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
348 min: Bound::Inclusive(Arc::new(value.clone())),
349 max: Bound::Unbounded,
350 }),
351 _ => unreachable!(
352 "BUG: unsupported comparison operator for domain extraction: {:?}",
353 op
354 ),
355 }
356}
357
358pub(crate) fn domain_for_comparison_atom(
363 op: &ComparisonComputation,
364 value: &LiteralValue,
365) -> Result<Domain, crate::Error> {
366 comparison_to_domain(op, value)
367}
368
369impl Domain {
370 pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
377 match (self, other) {
378 (Domain::Empty, _) => true,
379 (_, Domain::Unconstrained) => true,
380 (Domain::Unconstrained, _) => false,
381
382 (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
383 (Domain::Enumeration(vals), Domain::Range { min, max }) => {
384 vals.iter().all(|v| value_within(v, min, max))
385 }
386
387 (
388 Domain::Range {
389 min: amin,
390 max: amax,
391 },
392 Domain::Range {
393 min: bmin,
394 max: bmax,
395 },
396 ) => range_within_range(amin, amax, bmin, bmax),
397
398 (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
400 Domain::Enumeration(excluded) => {
401 excluded.iter().all(|p| !value_within(p, min, max))
402 }
403 _ => false,
404 },
405
406 (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
408 Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
409 _ => false,
410 },
411
412 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
414 match (a_inner.as_ref(), b_inner.as_ref()) {
415 (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
416 excluded_b.iter().all(|v| excluded_a.contains(v))
417 }
418 _ => false,
419 }
420 }
421
422 _ => false,
423 }
424 }
425}
426
427fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
428 lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
429}
430
431fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
432 match (a, b) {
433 (_, Bound::Unbounded) => true,
434 (Bound::Unbounded, _) => false,
435 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
436 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
437 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
438 let c = lit_cmp(av.as_ref(), bv.as_ref());
439 c >= 0
440 }
441 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
442 lit_cmp(av.as_ref(), bv.as_ref()) > 0
444 }
445 }
446}
447
448fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
449 match (a, b) {
450 (Bound::Unbounded, Bound::Unbounded) => true,
451 (_, Bound::Unbounded) => true,
452 (Bound::Unbounded, _) => false,
453 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
454 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
455 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
456 lit_cmp(av.as_ref(), bv.as_ref()) <= 0
458 }
459 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
460 lit_cmp(av.as_ref(), bv.as_ref()) < 0
462 }
463 }
464}
465
466fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
467 match (a, b) {
468 (None, None) => None,
469 (Some(d), None) | (None, Some(d)) => Some(d),
470 (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
471 }
472}
473
474fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
475 use std::cmp::Ordering;
476
477 match (&a.value, &b.value) {
478 (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
479 Ordering::Less => -1,
480 Ordering::Equal => 0,
481 Ordering::Greater => 1,
482 },
483
484 (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
485 Ordering::Less => -1,
486 Ordering::Equal => 0,
487 Ordering::Greater => 1,
488 },
489
490 (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
491 Ordering::Less => -1,
492 Ordering::Equal => 0,
493 Ordering::Greater => 1,
494 },
495
496 (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
497 Ordering::Less => -1,
498 Ordering::Equal => 0,
499 Ordering::Greater => 1,
500 },
501
502 (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
503 Ordering::Less => -1,
504 Ordering::Equal => 0,
505 Ordering::Greater => 1,
506 },
507
508 (ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub)) => {
509 let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
510 let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
511 match a_sec.cmp(&b_sec) {
512 Ordering::Less => -1,
513 Ordering::Equal => 0,
514 Ordering::Greater => 1,
515 }
516 }
517
518 (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
519 Ordering::Less => -1,
520 Ordering::Equal => 0,
521 Ordering::Greater => 1,
522 },
523
524 (ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub)) => {
525 if a.lemma_type != b.lemma_type {
526 unreachable!(
527 "BUG: lit_cmp compared different scale types ({} vs {})",
528 a.lemma_type.name(),
529 b.lemma_type.name()
530 );
531 }
532
533 if lua.eq_ignore_ascii_case(lub) {
534 return match la.cmp(lb) {
535 Ordering::Less => -1,
536 Ordering::Equal => 0,
537 Ordering::Greater => 1,
538 };
539 }
540
541 let target = SemanticConversionTarget::ScaleUnit(lua.clone());
543 let converted = crate::computation::convert_unit(b, &target);
544 let converted_value = match converted {
545 OperationResult::Value(lit) => match lit.value {
546 ValueKind::Scale(v, _) => v,
547 _ => unreachable!("BUG: scale unit conversion returned non-scale value"),
548 },
549 OperationResult::Veto(msg) => {
550 unreachable!("BUG: scale unit conversion vetoed unexpectedly: {:?}", msg)
551 }
552 };
553
554 match la.cmp(&converted_value) {
555 Ordering::Less => -1,
556 Ordering::Equal => 0,
557 Ordering::Greater => 1,
558 }
559 }
560
561 _ => unreachable!(
562 "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
563 a.get_type(),
564 b.get_type()
565 ),
566 }
567}
568
569fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
570 let ge_min = match min {
571 Bound::Unbounded => true,
572 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
573 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
574 };
575 let le_max = match max {
576 Bound::Unbounded => true,
577 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
578 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
579 };
580 ge_min && le_max
581}
582
583fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
584 match (min, max) {
585 (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
586 (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
587 (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
588 (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
589 (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
590 }
591}
592
593fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
594 match (min1, min2) {
595 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
596 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
597 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
598 Bound::Inclusive(v1)
599 } else {
600 Bound::Inclusive(v2)
601 }
602 }
603 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
604 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
605 Bound::Inclusive(v1)
606 } else {
607 Bound::Exclusive(v2)
608 }
609 }
610 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
611 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
612 Bound::Exclusive(v1)
613 } else {
614 Bound::Inclusive(v2)
615 }
616 }
617 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
618 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
619 Bound::Exclusive(v1)
620 } else {
621 Bound::Exclusive(v2)
622 }
623 }
624 }
625}
626
627fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
628 match (max1, max2) {
629 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
630 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
631 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
632 Bound::Inclusive(v1)
633 } else {
634 Bound::Inclusive(v2)
635 }
636 }
637 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
638 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
639 Bound::Inclusive(v1)
640 } else {
641 Bound::Exclusive(v2)
642 }
643 }
644 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
645 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
646 Bound::Exclusive(v1)
647 } else {
648 Bound::Inclusive(v2)
649 }
650 }
651 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
652 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
653 Bound::Exclusive(v1)
654 } else {
655 Bound::Exclusive(v2)
656 }
657 }
658 }
659}
660
661fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
662 let a = normalize_domain(a);
663 let b = normalize_domain(b);
664
665 let result = match (a, b) {
666 (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
667 (Domain::Empty, _) | (_, Domain::Empty) => None,
668
669 (
670 Domain::Range {
671 min: min1,
672 max: max1,
673 },
674 Domain::Range {
675 min: min2,
676 max: max2,
677 },
678 ) => {
679 let min = compute_intersection_min(min1, min2);
680 let max = compute_intersection_max(max1, max2);
681
682 if bounds_contradict(&min, &max) {
683 None
684 } else {
685 Some(Domain::Range { min, max })
686 }
687 }
688 (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
689 let filtered: Vec<LiteralValue> =
690 v1.iter().filter(|x| v2.contains(x)).cloned().collect();
691 if filtered.is_empty() {
692 None
693 } else {
694 Some(Domain::Enumeration(Arc::new(filtered)))
695 }
696 }
697 (Domain::Enumeration(vs), Domain::Range { min, max })
698 | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
699 let mut kept = Vec::new();
700 for v in vs.iter() {
701 if value_within(v, &min, &max) {
702 kept.push(v.clone());
703 }
704 }
705 if kept.is_empty() {
706 None
707 } else {
708 Some(Domain::Enumeration(Arc::new(kept)))
709 }
710 }
711 (Domain::Enumeration(vs), Domain::Complement(inner))
712 | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
713 match *inner.clone() {
714 Domain::Enumeration(excluded) => {
715 let mut kept = Vec::new();
716 for v in vs.iter() {
717 if !excluded.contains(v) {
718 kept.push(v.clone());
719 }
720 }
721 if kept.is_empty() {
722 None
723 } else {
724 Some(Domain::Enumeration(Arc::new(kept)))
725 }
726 }
727 Domain::Range { min, max } => {
728 let mut kept = Vec::new();
730 for v in vs.iter() {
731 if !value_within(v, &min, &max) {
732 kept.push(v.clone());
733 }
734 }
735 if kept.is_empty() {
736 None
737 } else {
738 Some(Domain::Enumeration(Arc::new(kept)))
739 }
740 }
741 _ => {
742 let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
744 domain_intersection(Domain::Enumeration(vs.clone()), normalized)
745 }
746 }
747 }
748 (Domain::Union(v1), Domain::Union(v2)) => {
749 let mut acc: Vec<Domain> = Vec::new();
750 for a in v1.iter() {
751 for b in v2.iter() {
752 if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
753 acc.push(ix);
754 }
755 }
756 }
757 if acc.is_empty() {
758 None
759 } else {
760 Some(Domain::Union(Arc::new(acc)))
761 }
762 }
763 (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
764 let mut acc: Vec<Domain> = Vec::new();
765 for a in vs.iter() {
766 if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
767 acc.push(ix);
768 }
769 }
770 if acc.is_empty() {
771 None
772 } else if acc.len() == 1 {
773 Some(acc.remove(0))
774 } else {
775 Some(Domain::Union(Arc::new(acc)))
776 }
777 }
778 (Domain::Range { min, max }, Domain::Complement(inner))
780 | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
781 Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
782 _ => {
783 let normalized_complement = normalize_domain(Domain::Complement(inner));
786 if matches!(&normalized_complement, Domain::Complement(_)) {
787 None
788 } else {
789 domain_intersection(Domain::Range { min, max }, normalized_complement)
790 }
791 }
792 },
793 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
794 match (a_inner.as_ref(), b_inner.as_ref()) {
795 (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
796 let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
798 excluded.extend(b_ex.iter().cloned());
799 Some(normalize_domain(Domain::Complement(Box::new(
800 Domain::Enumeration(Arc::new(excluded)),
801 ))))
802 }
803 _ => None,
804 }
805 }
806 };
807 result.map(normalize_domain)
808}
809
810fn range_minus_excluded_points(
811 min: Bound,
812 max: Bound,
813 excluded: &Arc<Vec<LiteralValue>>,
814) -> Option<Domain> {
815 let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
817
818 for p in excluded.iter() {
819 let mut next: Vec<(Bound, Bound)> = Vec::new();
820
821 for (rmin, rmax) in parts {
822 if !value_within(p, &rmin, &rmax) {
823 next.push((rmin, rmax));
824 continue;
825 }
826
827 let left_max = Bound::Exclusive(Arc::new(p.clone()));
829 if !bounds_contradict(&rmin, &left_max) {
830 next.push((rmin.clone(), left_max));
831 }
832
833 let right_min = Bound::Exclusive(Arc::new(p.clone()));
835 if !bounds_contradict(&right_min, &rmax) {
836 next.push((right_min, rmax.clone()));
837 }
838 }
839
840 parts = next;
841 if parts.is_empty() {
842 return None;
843 }
844 }
845
846 if parts.is_empty() {
847 None
848 } else if parts.len() == 1 {
849 let (min, max) = parts.remove(0);
850 Some(Domain::Range { min, max })
851 } else {
852 Some(Domain::Union(Arc::new(
853 parts
854 .into_iter()
855 .map(|(min, max)| Domain::Range { min, max })
856 .collect(),
857 )))
858 }
859}
860
861fn invert_bound(bound: Bound) -> Bound {
862 match bound {
863 Bound::Unbounded => Bound::Unbounded,
864 Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
865 Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
866 }
867}
868
869fn normalize_domain(d: Domain) -> Domain {
870 match d {
871 Domain::Complement(inner) => {
872 let normalized_inner = normalize_domain(*inner);
873 match normalized_inner {
874 Domain::Complement(double_inner) => *double_inner,
875 Domain::Range { min, max } => match (&min, &max) {
876 (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
877 (Bound::Unbounded, max) => Domain::Range {
878 min: invert_bound(max.clone()),
879 max: Bound::Unbounded,
880 },
881 (min, Bound::Unbounded) => Domain::Range {
882 min: Bound::Unbounded,
883 max: invert_bound(min.clone()),
884 },
885 (min, max) => Domain::Union(Arc::new(vec![
886 Domain::Range {
887 min: Bound::Unbounded,
888 max: invert_bound(min.clone()),
889 },
890 Domain::Range {
891 min: invert_bound(max.clone()),
892 max: Bound::Unbounded,
893 },
894 ])),
895 },
896 Domain::Enumeration(vals) => {
897 if vals.len() == 1 {
898 if let Some(lit) = vals.first() {
899 if let ValueKind::Boolean(true) = &lit.value {
900 return Domain::Enumeration(Arc::new(vec![
901 LiteralValue::from_bool(false),
902 ]));
903 }
904 if let ValueKind::Boolean(false) = &lit.value {
905 return Domain::Enumeration(Arc::new(vec![
906 LiteralValue::from_bool(true),
907 ]));
908 }
909 }
910 }
911 Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
912 }
913 Domain::Unconstrained => Domain::Empty,
914 Domain::Empty => Domain::Unconstrained,
915 Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
916 }
917 }
918 Domain::Empty => Domain::Empty,
919 Domain::Union(parts) => {
920 let mut flat: Vec<Domain> = Vec::new();
921 for p in parts.iter().cloned() {
922 let normalized = normalize_domain(p);
923 match normalized {
924 Domain::Union(inner) => flat.extend(inner.iter().cloned()),
925 Domain::Unconstrained => return Domain::Unconstrained,
926 Domain::Enumeration(vals) if vals.is_empty() => {}
927 other => flat.push(other),
928 }
929 }
930
931 let mut all_enum_values: Vec<LiteralValue> = Vec::new();
932 let mut ranges: Vec<Domain> = Vec::new();
933 let mut others: Vec<Domain> = Vec::new();
934
935 for domain in flat {
936 match domain {
937 Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
938 Domain::Range { .. } => ranges.push(domain),
939 other => others.push(other),
940 }
941 }
942
943 all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
944 -1 => Ordering::Less,
945 0 => Ordering::Equal,
946 _ => Ordering::Greater,
947 });
948 all_enum_values.dedup();
949
950 all_enum_values.retain(|v| {
951 !ranges.iter().any(|r| {
952 if let Domain::Range { min, max } = r {
953 value_within(v, min, max)
954 } else {
955 false
956 }
957 })
958 });
959
960 let mut result: Vec<Domain> = Vec::new();
961 result.extend(ranges);
962 result = merge_ranges(result);
963
964 if !all_enum_values.is_empty() {
965 result.push(Domain::Enumeration(Arc::new(all_enum_values)));
966 }
967 result.extend(others);
968
969 result.sort_by(|a, b| match (a, b) {
970 (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
971 (Domain::Range { .. }, _) => Ordering::Less,
972 (_, Domain::Range { .. }) => Ordering::Greater,
973 (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
974 (Domain::Enumeration(_), _) => Ordering::Less,
975 (_, Domain::Enumeration(_)) => Ordering::Greater,
976 _ => Ordering::Equal,
977 });
978
979 if result.is_empty() {
980 Domain::Enumeration(Arc::new(vec![]))
981 } else if result.len() == 1 {
982 result.remove(0)
983 } else {
984 Domain::Union(Arc::new(result))
985 }
986 }
987 Domain::Enumeration(values) => {
988 let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
989 sorted.sort_by(|a, b| match lit_cmp(a, b) {
990 -1 => Ordering::Less,
991 0 => Ordering::Equal,
992 _ => Ordering::Greater,
993 });
994 sorted.dedup();
995 Domain::Enumeration(Arc::new(sorted))
996 }
997 other => other,
998 }
999}
1000
1001fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1002 let mut result = Vec::new();
1003 let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1004 let mut others = Vec::new();
1005
1006 for d in domains {
1007 match d {
1008 Domain::Range { min, max } => ranges.push((min, max)),
1009 other => others.push(other),
1010 }
1011 }
1012
1013 if ranges.is_empty() {
1014 return others;
1015 }
1016
1017 ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1018
1019 let mut merged: Vec<(Bound, Bound)> = Vec::new();
1020 let mut current = ranges[0].clone();
1021
1022 for next in ranges.iter().skip(1) {
1023 if ranges_adjacent_or_overlap(¤t, next) {
1024 current = (
1025 min_bound(¤t.0, &next.0),
1026 max_bound(¤t.1, &next.1),
1027 );
1028 } else {
1029 merged.push(current);
1030 current = next.clone();
1031 }
1032 }
1033 merged.push(current);
1034
1035 for (min, max) in merged {
1036 result.push(Domain::Range { min, max });
1037 }
1038 result.extend(others);
1039
1040 result
1041}
1042
1043fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1044 match (a, b) {
1045 (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1046 (Bound::Unbounded, _) => Ordering::Less,
1047 (_, Bound::Unbounded) => Ordering::Greater,
1048 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1049 | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1050 -1 => Ordering::Less,
1051 0 => Ordering::Equal,
1052 _ => Ordering::Greater,
1053 },
1054 (Bound::Inclusive(v1), Bound::Exclusive(v2))
1055 | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1056 -1 => Ordering::Less,
1057 0 => {
1058 if matches!(a, Bound::Inclusive(_)) {
1059 Ordering::Less
1060 } else {
1061 Ordering::Greater
1062 }
1063 }
1064 _ => Ordering::Greater,
1065 },
1066 }
1067}
1068
1069fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1070 match (&r1.1, &r2.0) {
1071 (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1072 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1073 | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1074 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1075 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1076 }
1077}
1078
1079fn min_bound(a: &Bound, b: &Bound) -> Bound {
1080 match (a, b) {
1081 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1082 _ => {
1083 if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1084 a.clone()
1085 } else {
1086 b.clone()
1087 }
1088 }
1089 }
1090}
1091
1092fn max_bound(a: &Bound, b: &Bound) -> Bound {
1093 match (a, b) {
1094 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1095 _ => {
1096 if matches!(compare_bounds(a, b), Ordering::Greater) {
1097 a.clone()
1098 } else {
1099 b.clone()
1100 }
1101 }
1102 }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use rust_decimal::Decimal;
1109
1110 fn num(n: i64) -> LiteralValue {
1111 LiteralValue::number(Decimal::from(n))
1112 }
1113
1114 fn fact(name: &str) -> FactPath {
1115 FactPath::new(vec![], name.to_string())
1116 }
1117
1118 #[test]
1119 fn test_normalize_double_complement() {
1120 let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1121 let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1122 let normalized = normalize_domain(double);
1123 assert_eq!(normalized, inner);
1124 }
1125
1126 #[test]
1127 fn test_normalize_union_absorbs_unconstrained() {
1128 let union = Domain::Union(Arc::new(vec![
1129 Domain::Range {
1130 min: Bound::Inclusive(Arc::new(num(0))),
1131 max: Bound::Inclusive(Arc::new(num(10))),
1132 },
1133 Domain::Unconstrained,
1134 ]));
1135 let normalized = normalize_domain(union);
1136 assert_eq!(normalized, Domain::Unconstrained);
1137 }
1138
1139 #[test]
1140 fn test_domain_display() {
1141 let range = Domain::Range {
1142 min: Bound::Inclusive(Arc::new(num(10))),
1143 max: Bound::Exclusive(Arc::new(num(20))),
1144 };
1145 assert_eq!(format!("{}", range), "[10, 20)");
1146
1147 let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1148 assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1149 }
1150
1151 #[test]
1152 fn test_extract_domain_from_comparison() {
1153 let constraint = Constraint::Comparison {
1154 fact: fact("age"),
1155 op: ComparisonComputation::GreaterThan,
1156 value: Arc::new(num(18)),
1157 };
1158
1159 let domains = extract_domains_from_constraint(&constraint).unwrap();
1160 let age_domain = domains.get(&fact("age")).unwrap();
1161
1162 assert_eq!(
1163 *age_domain,
1164 Domain::Range {
1165 min: Bound::Exclusive(Arc::new(num(18))),
1166 max: Bound::Unbounded,
1167 }
1168 );
1169 }
1170
1171 #[test]
1172 fn test_extract_domain_from_and() {
1173 let constraint = Constraint::And(
1174 Box::new(Constraint::Comparison {
1175 fact: fact("age"),
1176 op: ComparisonComputation::GreaterThan,
1177 value: Arc::new(num(18)),
1178 }),
1179 Box::new(Constraint::Comparison {
1180 fact: fact("age"),
1181 op: ComparisonComputation::LessThan,
1182 value: Arc::new(num(65)),
1183 }),
1184 );
1185
1186 let domains = extract_domains_from_constraint(&constraint).unwrap();
1187 let age_domain = domains.get(&fact("age")).unwrap();
1188
1189 assert_eq!(
1190 *age_domain,
1191 Domain::Range {
1192 min: Bound::Exclusive(Arc::new(num(18))),
1193 max: Bound::Exclusive(Arc::new(num(65))),
1194 }
1195 );
1196 }
1197
1198 #[test]
1199 fn test_extract_domain_from_equality() {
1200 let constraint = Constraint::Comparison {
1201 fact: fact("status"),
1202 op: ComparisonComputation::Equal,
1203 value: Arc::new(LiteralValue::text("active".to_string())),
1204 };
1205
1206 let domains = extract_domains_from_constraint(&constraint).unwrap();
1207 let status_domain = domains.get(&fact("status")).unwrap();
1208
1209 assert_eq!(
1210 *status_domain,
1211 Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1212 );
1213 }
1214
1215 #[test]
1216 fn test_extract_domain_from_boolean_fact() {
1217 let constraint = Constraint::Fact(fact("is_active"));
1218
1219 let domains = extract_domains_from_constraint(&constraint).unwrap();
1220 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1221
1222 assert_eq!(
1223 *is_active_domain,
1224 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1225 );
1226 }
1227
1228 #[test]
1229 fn test_extract_domain_from_not_boolean_fact() {
1230 let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1231
1232 let domains = extract_domains_from_constraint(&constraint).unwrap();
1233 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1234
1235 assert_eq!(
1236 *is_active_domain,
1237 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1238 );
1239 }
1240}