1use crate::planning::semantics::{
9 ComparisonComputation, FactPath, LiteralValue, SemanticConversionTarget, ValueKind,
10};
11use crate::{LemmaResult, OperationResult};
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23 Range { min: Bound, max: Bound },
25
26 Union(Arc<Vec<Domain>>),
28
29 Enumeration(Arc<Vec<LiteralValue>>),
31
32 Complement(Box<Domain>),
34
35 Unconstrained,
37
38 Empty,
40}
41
42impl Domain {
43 pub fn is_satisfiable(&self) -> bool {
47 match self {
48 Domain::Empty => false,
49 Domain::Enumeration(values) => !values.is_empty(),
50 Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51 Domain::Range { min, max } => !bounds_contradict(min, max),
52 Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53 Domain::Unconstrained => true,
54 }
55 }
56
57 pub fn is_empty(&self) -> bool {
59 !self.is_satisfiable()
60 }
61
62 pub fn intersect(&self, other: &Domain) -> Domain {
64 domain_intersection(self.clone(), other.clone()).unwrap_or(Domain::Empty)
65 }
66
67 pub fn contains(&self, value: &LiteralValue) -> bool {
69 match self {
70 Domain::Empty => false,
71 Domain::Unconstrained => true,
72 Domain::Enumeration(values) => values.contains(value),
73 Domain::Range { min, max } => value_within(value, min, max),
74 Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
75 Domain::Complement(inner) => !inner.contains(value),
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq)]
82pub enum Bound {
83 Inclusive(Arc<LiteralValue>),
85
86 Exclusive(Arc<LiteralValue>),
88
89 Unbounded,
91}
92
93impl fmt::Display for Domain {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 Domain::Empty => write!(f, "empty"),
97 Domain::Unconstrained => write!(f, "any"),
98 Domain::Enumeration(vals) => {
99 write!(f, "{{")?;
100 for (i, v) in vals.iter().enumerate() {
101 if i > 0 {
102 write!(f, ", ")?;
103 }
104 write!(f, "{}", v)?;
105 }
106 write!(f, "}}")
107 }
108 Domain::Range { min, max } => {
109 let (l_bracket, r_bracket) = match (min, max) {
110 (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
111 (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
112 (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
113 (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
114 (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
115 (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
116 (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
117 (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
118 (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
119 };
120
121 let min_str = match min {
122 Bound::Unbounded => "-inf".to_string(),
123 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
124 };
125 let max_str = match max {
126 Bound::Unbounded => "+inf".to_string(),
127 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128 };
129 write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
130 }
131 Domain::Union(parts) => {
132 for (i, p) in parts.iter().enumerate() {
133 if i > 0 {
134 write!(f, " | ")?;
135 }
136 write!(f, "{}", p)?;
137 }
138 Ok(())
139 }
140 Domain::Complement(inner) => write!(f, "not ({})", inner),
141 }
142 }
143}
144
145impl fmt::Display for Bound {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 match self {
148 Bound::Unbounded => write!(f, "inf"),
149 Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
150 Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
151 }
152 }
153}
154
155impl Serialize for Domain {
156 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
157 where
158 S: Serializer,
159 {
160 match self {
161 Domain::Empty => {
162 let mut st = serializer.serialize_struct("domain", 1)?;
163 st.serialize_field("type", "empty")?;
164 st.end()
165 }
166 Domain::Unconstrained => {
167 let mut st = serializer.serialize_struct("domain", 1)?;
168 st.serialize_field("type", "unconstrained")?;
169 st.end()
170 }
171 Domain::Enumeration(vals) => {
172 let mut st = serializer.serialize_struct("domain", 2)?;
173 st.serialize_field("type", "enumeration")?;
174 st.serialize_field("values", vals)?;
175 st.end()
176 }
177 Domain::Range { min, max } => {
178 let mut st = serializer.serialize_struct("domain", 3)?;
179 st.serialize_field("type", "range")?;
180 st.serialize_field("min", min)?;
181 st.serialize_field("max", max)?;
182 st.end()
183 }
184 Domain::Union(parts) => {
185 let mut st = serializer.serialize_struct("domain", 2)?;
186 st.serialize_field("type", "union")?;
187 st.serialize_field("parts", parts)?;
188 st.end()
189 }
190 Domain::Complement(inner) => {
191 let mut st = serializer.serialize_struct("domain", 2)?;
192 st.serialize_field("type", "complement")?;
193 st.serialize_field("inner", inner)?;
194 st.end()
195 }
196 }
197 }
198}
199
200impl Serialize for Bound {
201 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
202 where
203 S: Serializer,
204 {
205 match self {
206 Bound::Unbounded => {
207 let mut st = serializer.serialize_struct("bound", 1)?;
208 st.serialize_field("type", "unbounded")?;
209 st.end()
210 }
211 Bound::Inclusive(v) => {
212 let mut st = serializer.serialize_struct("bound", 2)?;
213 st.serialize_field("type", "inclusive")?;
214 st.serialize_field("value", v.as_ref())?;
215 st.end()
216 }
217 Bound::Exclusive(v) => {
218 let mut st = serializer.serialize_struct("bound", 2)?;
219 st.serialize_field("type", "exclusive")?;
220 st.serialize_field("value", v.as_ref())?;
221 st.end()
222 }
223 }
224 }
225}
226
227pub fn extract_domains_from_constraint(
229 constraint: &Constraint,
230) -> LemmaResult<HashMap<FactPath, Domain>> {
231 let all_facts = constraint.collect_facts();
232 let mut domains = HashMap::new();
233
234 for fact_path in all_facts {
235 let domain =
236 extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
237 domains.insert(fact_path, domain);
238 }
239
240 Ok(domains)
241}
242
243fn extract_domain_for_fact(
244 constraint: &Constraint,
245 fact_path: &FactPath,
246) -> LemmaResult<Option<Domain>> {
247 let domain = match constraint {
248 Constraint::True => return Ok(None),
249 Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
250
251 Constraint::Comparison { fact, op, value } => {
252 if fact == fact_path {
253 Some(comparison_to_domain(op, value.as_ref())?)
254 } else {
255 None
256 }
257 }
258
259 Constraint::Fact(fp) => {
260 if fp == fact_path {
261 Some(Domain::Enumeration(Arc::new(vec![
262 LiteralValue::from_bool(true),
263 ])))
264 } else {
265 None
266 }
267 }
268
269 Constraint::And(left, right) => {
270 let left_domain = extract_domain_for_fact(left, fact_path)?;
271 let right_domain = extract_domain_for_fact(right, fact_path)?;
272 match (left_domain, right_domain) {
273 (None, None) => None,
274 (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
275 (Some(a), Some(b)) => match domain_intersection(a, b) {
276 Some(domain) => Some(domain),
277 None => Some(Domain::Enumeration(Arc::new(vec![]))),
278 },
279 }
280 }
281
282 Constraint::Or(left, right) => {
283 let left_domain = extract_domain_for_fact(left, fact_path)?;
284 let right_domain = extract_domain_for_fact(right, fact_path)?;
285 union_optional_domains(left_domain, right_domain)
286 }
287
288 Constraint::Not(inner) => {
289 if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
291 if fact == fact_path && op.is_equal() {
292 return Ok(Some(normalize_domain(Domain::Complement(Box::new(
293 Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
294 )))));
295 }
296 }
297
298 if let Constraint::Fact(fp) = inner.as_ref() {
300 if fp == fact_path {
301 return Ok(Some(Domain::Enumeration(Arc::new(vec![
302 LiteralValue::from_bool(false),
303 ]))));
304 }
305 }
306
307 extract_domain_for_fact(inner, fact_path)?
308 .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
309 }
310 };
311
312 Ok(domain.map(normalize_domain))
313}
314
315fn comparison_to_domain(op: &ComparisonComputation, value: &LiteralValue) -> LemmaResult<Domain> {
316 if op.is_equal() {
317 return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
318 }
319 if op.is_not_equal() {
320 return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
321 vec![value.clone()],
322 )))));
323 }
324 match op {
325 ComparisonComputation::LessThan => Ok(Domain::Range {
326 min: Bound::Unbounded,
327 max: Bound::Exclusive(Arc::new(value.clone())),
328 }),
329 ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
330 min: Bound::Unbounded,
331 max: Bound::Inclusive(Arc::new(value.clone())),
332 }),
333 ComparisonComputation::GreaterThan => Ok(Domain::Range {
334 min: Bound::Exclusive(Arc::new(value.clone())),
335 max: Bound::Unbounded,
336 }),
337 ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
338 min: Bound::Inclusive(Arc::new(value.clone())),
339 max: Bound::Unbounded,
340 }),
341 _ => unreachable!(
342 "BUG: unsupported comparison operator for domain extraction: {:?}",
343 op
344 ),
345 }
346}
347
348pub(crate) fn domain_for_comparison_atom(
353 op: &ComparisonComputation,
354 value: &LiteralValue,
355) -> LemmaResult<Domain> {
356 comparison_to_domain(op, value)
357}
358
359impl Domain {
360 pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
367 match (self, other) {
368 (Domain::Empty, _) => true,
369 (_, Domain::Unconstrained) => true,
370 (Domain::Unconstrained, _) => false,
371
372 (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
373 (Domain::Enumeration(vals), Domain::Range { min, max }) => {
374 vals.iter().all(|v| value_within(v, min, max))
375 }
376
377 (
378 Domain::Range {
379 min: amin,
380 max: amax,
381 },
382 Domain::Range {
383 min: bmin,
384 max: bmax,
385 },
386 ) => range_within_range(amin, amax, bmin, bmax),
387
388 (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
390 Domain::Enumeration(excluded) => {
391 excluded.iter().all(|p| !value_within(p, min, max))
392 }
393 _ => false,
394 },
395
396 (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
398 Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
399 _ => false,
400 },
401
402 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
404 match (a_inner.as_ref(), b_inner.as_ref()) {
405 (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
406 excluded_b.iter().all(|v| excluded_a.contains(v))
407 }
408 _ => false,
409 }
410 }
411
412 _ => false,
413 }
414 }
415}
416
417fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
418 lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
419}
420
421fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
422 match (a, b) {
423 (_, Bound::Unbounded) => true,
424 (Bound::Unbounded, _) => false,
425 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
426 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
427 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
428 let c = lit_cmp(av.as_ref(), bv.as_ref());
429 c >= 0
430 }
431 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
432 lit_cmp(av.as_ref(), bv.as_ref()) > 0
434 }
435 }
436}
437
438fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
439 match (a, b) {
440 (Bound::Unbounded, Bound::Unbounded) => true,
441 (_, Bound::Unbounded) => true,
442 (Bound::Unbounded, _) => false,
443 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
444 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
445 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
446 lit_cmp(av.as_ref(), bv.as_ref()) <= 0
448 }
449 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
450 lit_cmp(av.as_ref(), bv.as_ref()) < 0
452 }
453 }
454}
455
456fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
457 match (a, b) {
458 (None, None) => None,
459 (Some(d), None) | (None, Some(d)) => Some(d),
460 (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
461 }
462}
463
464fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
465 use std::cmp::Ordering;
466
467 match (&a.value, &b.value) {
468 (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
469 Ordering::Less => -1,
470 Ordering::Equal => 0,
471 Ordering::Greater => 1,
472 },
473
474 (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
475 Ordering::Less => -1,
476 Ordering::Equal => 0,
477 Ordering::Greater => 1,
478 },
479
480 (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
481 Ordering::Less => -1,
482 Ordering::Equal => 0,
483 Ordering::Greater => 1,
484 },
485
486 (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
487 Ordering::Less => -1,
488 Ordering::Equal => 0,
489 Ordering::Greater => 1,
490 },
491
492 (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
493 Ordering::Less => -1,
494 Ordering::Equal => 0,
495 Ordering::Greater => 1,
496 },
497
498 (ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub)) => {
499 let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
500 let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
501 match a_sec.cmp(&b_sec) {
502 Ordering::Less => -1,
503 Ordering::Equal => 0,
504 Ordering::Greater => 1,
505 }
506 }
507
508 (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
509 Ordering::Less => -1,
510 Ordering::Equal => 0,
511 Ordering::Greater => 1,
512 },
513
514 (ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub)) => {
515 if a.lemma_type != b.lemma_type {
516 unreachable!(
517 "BUG: lit_cmp compared different scale types ({} vs {})",
518 a.lemma_type.name(),
519 b.lemma_type.name()
520 );
521 }
522
523 if lua.eq_ignore_ascii_case(lub) {
524 return match la.cmp(lb) {
525 Ordering::Less => -1,
526 Ordering::Equal => 0,
527 Ordering::Greater => 1,
528 };
529 }
530
531 let target = SemanticConversionTarget::ScaleUnit(lua.clone());
533 let converted = crate::computation::convert_unit(b, &target);
534 let converted_value = match converted {
535 OperationResult::Value(lit) => match lit.value {
536 ValueKind::Scale(v, _) => v,
537 _ => unreachable!("BUG: scale unit conversion returned non-scale value"),
538 },
539 OperationResult::Veto(msg) => {
540 unreachable!("BUG: scale unit conversion vetoed unexpectedly: {:?}", msg)
541 }
542 };
543
544 match la.cmp(&converted_value) {
545 Ordering::Less => -1,
546 Ordering::Equal => 0,
547 Ordering::Greater => 1,
548 }
549 }
550
551 _ => unreachable!(
552 "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
553 a.get_type(),
554 b.get_type()
555 ),
556 }
557}
558
559fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
560 let ge_min = match min {
561 Bound::Unbounded => true,
562 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
563 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
564 };
565 let le_max = match max {
566 Bound::Unbounded => true,
567 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
568 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
569 };
570 ge_min && le_max
571}
572
573fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
574 match (min, max) {
575 (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
576 (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
577 (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
578 (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
579 (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
580 }
581}
582
583fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
584 match (min1, min2) {
585 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
586 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
587 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
588 Bound::Inclusive(v1)
589 } else {
590 Bound::Inclusive(v2)
591 }
592 }
593 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
594 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
595 Bound::Inclusive(v1)
596 } else {
597 Bound::Exclusive(v2)
598 }
599 }
600 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
601 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
602 Bound::Exclusive(v1)
603 } else {
604 Bound::Inclusive(v2)
605 }
606 }
607 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
608 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
609 Bound::Exclusive(v1)
610 } else {
611 Bound::Exclusive(v2)
612 }
613 }
614 }
615}
616
617fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
618 match (max1, max2) {
619 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
620 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
621 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
622 Bound::Inclusive(v1)
623 } else {
624 Bound::Inclusive(v2)
625 }
626 }
627 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
628 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
629 Bound::Inclusive(v1)
630 } else {
631 Bound::Exclusive(v2)
632 }
633 }
634 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
635 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
636 Bound::Exclusive(v1)
637 } else {
638 Bound::Inclusive(v2)
639 }
640 }
641 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
642 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
643 Bound::Exclusive(v1)
644 } else {
645 Bound::Exclusive(v2)
646 }
647 }
648 }
649}
650
651fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
652 let a = normalize_domain(a);
653 let b = normalize_domain(b);
654
655 let result = match (a, b) {
656 (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
657 (Domain::Empty, _) | (_, Domain::Empty) => None,
658
659 (
660 Domain::Range {
661 min: min1,
662 max: max1,
663 },
664 Domain::Range {
665 min: min2,
666 max: max2,
667 },
668 ) => {
669 let min = compute_intersection_min(min1, min2);
670 let max = compute_intersection_max(max1, max2);
671
672 if bounds_contradict(&min, &max) {
673 None
674 } else {
675 Some(Domain::Range { min, max })
676 }
677 }
678 (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
679 let filtered: Vec<LiteralValue> =
680 v1.iter().filter(|x| v2.contains(x)).cloned().collect();
681 if filtered.is_empty() {
682 None
683 } else {
684 Some(Domain::Enumeration(Arc::new(filtered)))
685 }
686 }
687 (Domain::Enumeration(vs), Domain::Range { min, max })
688 | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
689 let mut kept = Vec::new();
690 for v in vs.iter() {
691 if value_within(v, &min, &max) {
692 kept.push(v.clone());
693 }
694 }
695 if kept.is_empty() {
696 None
697 } else {
698 Some(Domain::Enumeration(Arc::new(kept)))
699 }
700 }
701 (Domain::Enumeration(vs), Domain::Complement(inner))
702 | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
703 match *inner.clone() {
704 Domain::Enumeration(excluded) => {
705 let mut kept = Vec::new();
706 for v in vs.iter() {
707 if !excluded.contains(v) {
708 kept.push(v.clone());
709 }
710 }
711 if kept.is_empty() {
712 None
713 } else {
714 Some(Domain::Enumeration(Arc::new(kept)))
715 }
716 }
717 Domain::Range { min, max } => {
718 let mut kept = Vec::new();
720 for v in vs.iter() {
721 if !value_within(v, &min, &max) {
722 kept.push(v.clone());
723 }
724 }
725 if kept.is_empty() {
726 None
727 } else {
728 Some(Domain::Enumeration(Arc::new(kept)))
729 }
730 }
731 _ => {
732 let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
734 domain_intersection(Domain::Enumeration(vs.clone()), normalized)
735 }
736 }
737 }
738 (Domain::Union(v1), Domain::Union(v2)) => {
739 let mut acc: Vec<Domain> = Vec::new();
740 for a in v1.iter() {
741 for b in v2.iter() {
742 if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
743 acc.push(ix);
744 }
745 }
746 }
747 if acc.is_empty() {
748 None
749 } else {
750 Some(Domain::Union(Arc::new(acc)))
751 }
752 }
753 (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
754 let mut acc: Vec<Domain> = Vec::new();
755 for a in vs.iter() {
756 if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
757 acc.push(ix);
758 }
759 }
760 if acc.is_empty() {
761 None
762 } else if acc.len() == 1 {
763 Some(acc.remove(0))
764 } else {
765 Some(Domain::Union(Arc::new(acc)))
766 }
767 }
768 (Domain::Range { min, max }, Domain::Complement(inner))
770 | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
771 Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
772 _ => {
773 let normalized_complement = normalize_domain(Domain::Complement(inner));
776 if matches!(&normalized_complement, Domain::Complement(_)) {
777 None
778 } else {
779 domain_intersection(Domain::Range { min, max }, normalized_complement)
780 }
781 }
782 },
783 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
784 match (a_inner.as_ref(), b_inner.as_ref()) {
785 (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
786 let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
788 excluded.extend(b_ex.iter().cloned());
789 Some(normalize_domain(Domain::Complement(Box::new(
790 Domain::Enumeration(Arc::new(excluded)),
791 ))))
792 }
793 _ => None,
794 }
795 }
796 };
797 result.map(normalize_domain)
798}
799
800fn range_minus_excluded_points(
801 min: Bound,
802 max: Bound,
803 excluded: &Arc<Vec<LiteralValue>>,
804) -> Option<Domain> {
805 let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
807
808 for p in excluded.iter() {
809 let mut next: Vec<(Bound, Bound)> = Vec::new();
810
811 for (rmin, rmax) in parts {
812 if !value_within(p, &rmin, &rmax) {
813 next.push((rmin, rmax));
814 continue;
815 }
816
817 let left_max = Bound::Exclusive(Arc::new(p.clone()));
819 if !bounds_contradict(&rmin, &left_max) {
820 next.push((rmin.clone(), left_max));
821 }
822
823 let right_min = Bound::Exclusive(Arc::new(p.clone()));
825 if !bounds_contradict(&right_min, &rmax) {
826 next.push((right_min, rmax.clone()));
827 }
828 }
829
830 parts = next;
831 if parts.is_empty() {
832 return None;
833 }
834 }
835
836 if parts.is_empty() {
837 None
838 } else if parts.len() == 1 {
839 let (min, max) = parts.remove(0);
840 Some(Domain::Range { min, max })
841 } else {
842 Some(Domain::Union(Arc::new(
843 parts
844 .into_iter()
845 .map(|(min, max)| Domain::Range { min, max })
846 .collect(),
847 )))
848 }
849}
850
851fn invert_bound(bound: Bound) -> Bound {
852 match bound {
853 Bound::Unbounded => Bound::Unbounded,
854 Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
855 Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
856 }
857}
858
859fn normalize_domain(d: Domain) -> Domain {
860 match d {
861 Domain::Complement(inner) => {
862 let normalized_inner = normalize_domain(*inner);
863 match normalized_inner {
864 Domain::Complement(double_inner) => *double_inner,
865 Domain::Range { min, max } => match (&min, &max) {
866 (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
867 (Bound::Unbounded, max) => Domain::Range {
868 min: invert_bound(max.clone()),
869 max: Bound::Unbounded,
870 },
871 (min, Bound::Unbounded) => Domain::Range {
872 min: Bound::Unbounded,
873 max: invert_bound(min.clone()),
874 },
875 (min, max) => Domain::Union(Arc::new(vec![
876 Domain::Range {
877 min: Bound::Unbounded,
878 max: invert_bound(min.clone()),
879 },
880 Domain::Range {
881 min: invert_bound(max.clone()),
882 max: Bound::Unbounded,
883 },
884 ])),
885 },
886 Domain::Enumeration(vals) => {
887 if vals.len() == 1 {
888 if let Some(lit) = vals.first() {
889 if let ValueKind::Boolean(true) = &lit.value {
890 return Domain::Enumeration(Arc::new(vec![
891 LiteralValue::from_bool(false),
892 ]));
893 }
894 if let ValueKind::Boolean(false) = &lit.value {
895 return Domain::Enumeration(Arc::new(vec![
896 LiteralValue::from_bool(true),
897 ]));
898 }
899 }
900 }
901 Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
902 }
903 Domain::Unconstrained => Domain::Empty,
904 Domain::Empty => Domain::Unconstrained,
905 Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
906 }
907 }
908 Domain::Empty => Domain::Empty,
909 Domain::Union(parts) => {
910 let mut flat: Vec<Domain> = Vec::new();
911 for p in parts.iter().cloned() {
912 let normalized = normalize_domain(p);
913 match normalized {
914 Domain::Union(inner) => flat.extend(inner.iter().cloned()),
915 Domain::Unconstrained => return Domain::Unconstrained,
916 Domain::Enumeration(vals) if vals.is_empty() => {}
917 other => flat.push(other),
918 }
919 }
920
921 let mut all_enum_values: Vec<LiteralValue> = Vec::new();
922 let mut ranges: Vec<Domain> = Vec::new();
923 let mut others: Vec<Domain> = Vec::new();
924
925 for domain in flat {
926 match domain {
927 Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
928 Domain::Range { .. } => ranges.push(domain),
929 other => others.push(other),
930 }
931 }
932
933 all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
934 -1 => Ordering::Less,
935 0 => Ordering::Equal,
936 _ => Ordering::Greater,
937 });
938 all_enum_values.dedup();
939
940 all_enum_values.retain(|v| {
941 !ranges.iter().any(|r| {
942 if let Domain::Range { min, max } = r {
943 value_within(v, min, max)
944 } else {
945 false
946 }
947 })
948 });
949
950 let mut result: Vec<Domain> = Vec::new();
951 result.extend(ranges);
952 result = merge_ranges(result);
953
954 if !all_enum_values.is_empty() {
955 result.push(Domain::Enumeration(Arc::new(all_enum_values)));
956 }
957 result.extend(others);
958
959 result.sort_by(|a, b| match (a, b) {
960 (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
961 (Domain::Range { .. }, _) => Ordering::Less,
962 (_, Domain::Range { .. }) => Ordering::Greater,
963 (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
964 (Domain::Enumeration(_), _) => Ordering::Less,
965 (_, Domain::Enumeration(_)) => Ordering::Greater,
966 _ => Ordering::Equal,
967 });
968
969 if result.is_empty() {
970 Domain::Enumeration(Arc::new(vec![]))
971 } else if result.len() == 1 {
972 result.remove(0)
973 } else {
974 Domain::Union(Arc::new(result))
975 }
976 }
977 Domain::Enumeration(values) => {
978 let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
979 sorted.sort_by(|a, b| match lit_cmp(a, b) {
980 -1 => Ordering::Less,
981 0 => Ordering::Equal,
982 _ => Ordering::Greater,
983 });
984 sorted.dedup();
985 Domain::Enumeration(Arc::new(sorted))
986 }
987 other => other,
988 }
989}
990
991fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
992 let mut result = Vec::new();
993 let mut ranges: Vec<(Bound, Bound)> = Vec::new();
994 let mut others = Vec::new();
995
996 for d in domains {
997 match d {
998 Domain::Range { min, max } => ranges.push((min, max)),
999 other => others.push(other),
1000 }
1001 }
1002
1003 if ranges.is_empty() {
1004 return others;
1005 }
1006
1007 ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1008
1009 let mut merged: Vec<(Bound, Bound)> = Vec::new();
1010 let mut current = ranges[0].clone();
1011
1012 for next in ranges.iter().skip(1) {
1013 if ranges_adjacent_or_overlap(¤t, next) {
1014 current = (
1015 min_bound(¤t.0, &next.0),
1016 max_bound(¤t.1, &next.1),
1017 );
1018 } else {
1019 merged.push(current);
1020 current = next.clone();
1021 }
1022 }
1023 merged.push(current);
1024
1025 for (min, max) in merged {
1026 result.push(Domain::Range { min, max });
1027 }
1028 result.extend(others);
1029
1030 result
1031}
1032
1033fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1034 match (a, b) {
1035 (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1036 (Bound::Unbounded, _) => Ordering::Less,
1037 (_, Bound::Unbounded) => Ordering::Greater,
1038 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1039 | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1040 -1 => Ordering::Less,
1041 0 => Ordering::Equal,
1042 _ => Ordering::Greater,
1043 },
1044 (Bound::Inclusive(v1), Bound::Exclusive(v2))
1045 | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1046 -1 => Ordering::Less,
1047 0 => {
1048 if matches!(a, Bound::Inclusive(_)) {
1049 Ordering::Less
1050 } else {
1051 Ordering::Greater
1052 }
1053 }
1054 _ => Ordering::Greater,
1055 },
1056 }
1057}
1058
1059fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1060 match (&r1.1, &r2.0) {
1061 (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1062 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1063 | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1064 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1065 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1066 }
1067}
1068
1069fn min_bound(a: &Bound, b: &Bound) -> Bound {
1070 match (a, b) {
1071 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1072 _ => {
1073 if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1074 a.clone()
1075 } else {
1076 b.clone()
1077 }
1078 }
1079 }
1080}
1081
1082fn max_bound(a: &Bound, b: &Bound) -> Bound {
1083 match (a, b) {
1084 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1085 _ => {
1086 if matches!(compare_bounds(a, b), Ordering::Greater) {
1087 a.clone()
1088 } else {
1089 b.clone()
1090 }
1091 }
1092 }
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097 use super::*;
1098 use rust_decimal::Decimal;
1099
1100 fn num(n: i64) -> LiteralValue {
1101 LiteralValue::number(Decimal::from(n))
1102 }
1103
1104 fn fact(name: &str) -> FactPath {
1105 FactPath::new(vec![], name.to_string())
1106 }
1107
1108 #[test]
1109 fn test_normalize_double_complement() {
1110 let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1111 let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1112 let normalized = normalize_domain(double);
1113 assert_eq!(normalized, inner);
1114 }
1115
1116 #[test]
1117 fn test_normalize_union_absorbs_unconstrained() {
1118 let union = Domain::Union(Arc::new(vec![
1119 Domain::Range {
1120 min: Bound::Inclusive(Arc::new(num(0))),
1121 max: Bound::Inclusive(Arc::new(num(10))),
1122 },
1123 Domain::Unconstrained,
1124 ]));
1125 let normalized = normalize_domain(union);
1126 assert_eq!(normalized, Domain::Unconstrained);
1127 }
1128
1129 #[test]
1130 fn test_domain_display() {
1131 let range = Domain::Range {
1132 min: Bound::Inclusive(Arc::new(num(10))),
1133 max: Bound::Exclusive(Arc::new(num(20))),
1134 };
1135 assert_eq!(format!("{}", range), "[10, 20)");
1136
1137 let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1138 assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1139 }
1140
1141 #[test]
1142 fn test_extract_domain_from_comparison() {
1143 let constraint = Constraint::Comparison {
1144 fact: fact("age"),
1145 op: ComparisonComputation::GreaterThan,
1146 value: Arc::new(num(18)),
1147 };
1148
1149 let domains = extract_domains_from_constraint(&constraint).unwrap();
1150 let age_domain = domains.get(&fact("age")).unwrap();
1151
1152 assert_eq!(
1153 *age_domain,
1154 Domain::Range {
1155 min: Bound::Exclusive(Arc::new(num(18))),
1156 max: Bound::Unbounded,
1157 }
1158 );
1159 }
1160
1161 #[test]
1162 fn test_extract_domain_from_and() {
1163 let constraint = Constraint::And(
1164 Box::new(Constraint::Comparison {
1165 fact: fact("age"),
1166 op: ComparisonComputation::GreaterThan,
1167 value: Arc::new(num(18)),
1168 }),
1169 Box::new(Constraint::Comparison {
1170 fact: fact("age"),
1171 op: ComparisonComputation::LessThan,
1172 value: Arc::new(num(65)),
1173 }),
1174 );
1175
1176 let domains = extract_domains_from_constraint(&constraint).unwrap();
1177 let age_domain = domains.get(&fact("age")).unwrap();
1178
1179 assert_eq!(
1180 *age_domain,
1181 Domain::Range {
1182 min: Bound::Exclusive(Arc::new(num(18))),
1183 max: Bound::Exclusive(Arc::new(num(65))),
1184 }
1185 );
1186 }
1187
1188 #[test]
1189 fn test_extract_domain_from_equality() {
1190 let constraint = Constraint::Comparison {
1191 fact: fact("status"),
1192 op: ComparisonComputation::Equal,
1193 value: Arc::new(LiteralValue::text("active".to_string())),
1194 };
1195
1196 let domains = extract_domains_from_constraint(&constraint).unwrap();
1197 let status_domain = domains.get(&fact("status")).unwrap();
1198
1199 assert_eq!(
1200 *status_domain,
1201 Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1202 );
1203 }
1204
1205 #[test]
1206 fn test_extract_domain_from_boolean_fact() {
1207 let constraint = Constraint::Fact(fact("is_active"));
1208
1209 let domains = extract_domains_from_constraint(&constraint).unwrap();
1210 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1211
1212 assert_eq!(
1213 *is_active_domain,
1214 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1215 );
1216 }
1217
1218 #[test]
1219 fn test_extract_domain_from_not_boolean_fact() {
1220 let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1221
1222 let domains = extract_domains_from_constraint(&constraint).unwrap();
1223 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1224
1225 assert_eq!(
1226 *is_active_domain,
1227 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1228 );
1229 }
1230}