1use crate::{
9 BooleanValue, ComparisonComputation, FactPath, LemmaResult, LiteralValue, OperationResult,
10 Value,
11};
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23 Range { min: Bound, max: Bound },
25
26 Union(Arc<Vec<Domain>>),
28
29 Enumeration(Arc<Vec<LiteralValue>>),
31
32 Complement(Box<Domain>),
34
35 Unconstrained,
37
38 Empty,
40}
41
42impl Domain {
43 pub fn is_satisfiable(&self) -> bool {
47 match self {
48 Domain::Empty => false,
49 Domain::Enumeration(values) => !values.is_empty(),
50 Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51 Domain::Range { min, max } => !bounds_contradict(min, max),
52 Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53 Domain::Unconstrained => true,
54 }
55 }
56
57 pub fn is_empty(&self) -> bool {
59 !self.is_satisfiable()
60 }
61
62 pub fn intersect(&self, other: &Domain) -> Domain {
64 domain_intersection(self.clone(), other.clone()).unwrap_or(Domain::Empty)
65 }
66
67 pub fn contains(&self, value: &LiteralValue) -> bool {
69 match self {
70 Domain::Empty => false,
71 Domain::Unconstrained => true,
72 Domain::Enumeration(values) => values.contains(value),
73 Domain::Range { min, max } => value_within(value, min, max),
74 Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
75 Domain::Complement(inner) => !inner.contains(value),
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq)]
82pub enum Bound {
83 Inclusive(Arc<LiteralValue>),
85
86 Exclusive(Arc<LiteralValue>),
88
89 Unbounded,
91}
92
93impl fmt::Display for Domain {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 Domain::Empty => write!(f, "empty"),
97 Domain::Unconstrained => write!(f, "any"),
98 Domain::Enumeration(vals) => {
99 write!(f, "{{")?;
100 for (i, v) in vals.iter().enumerate() {
101 if i > 0 {
102 write!(f, ", ")?;
103 }
104 write!(f, "{}", v)?;
105 }
106 write!(f, "}}")
107 }
108 Domain::Range { min, max } => {
109 let (l_bracket, r_bracket) = match (min, max) {
110 (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
111 (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
112 (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
113 (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
114 (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
115 (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
116 (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
117 (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
118 (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
119 };
120
121 let min_str = match min {
122 Bound::Unbounded => "-inf".to_string(),
123 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
124 };
125 let max_str = match max {
126 Bound::Unbounded => "+inf".to_string(),
127 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128 };
129 write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
130 }
131 Domain::Union(parts) => {
132 for (i, p) in parts.iter().enumerate() {
133 if i > 0 {
134 write!(f, " | ")?;
135 }
136 write!(f, "{}", p)?;
137 }
138 Ok(())
139 }
140 Domain::Complement(inner) => write!(f, "not ({})", inner),
141 }
142 }
143}
144
145impl fmt::Display for Bound {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 match self {
148 Bound::Unbounded => write!(f, "inf"),
149 Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
150 Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
151 }
152 }
153}
154
155impl Serialize for Domain {
156 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
157 where
158 S: Serializer,
159 {
160 match self {
161 Domain::Empty => {
162 let mut st = serializer.serialize_struct("domain", 1)?;
163 st.serialize_field("type", "empty")?;
164 st.end()
165 }
166 Domain::Unconstrained => {
167 let mut st = serializer.serialize_struct("domain", 1)?;
168 st.serialize_field("type", "unconstrained")?;
169 st.end()
170 }
171 Domain::Enumeration(vals) => {
172 let mut st = serializer.serialize_struct("domain", 2)?;
173 st.serialize_field("type", "enumeration")?;
174 st.serialize_field("values", vals)?;
175 st.end()
176 }
177 Domain::Range { min, max } => {
178 let mut st = serializer.serialize_struct("domain", 3)?;
179 st.serialize_field("type", "range")?;
180 st.serialize_field("min", min)?;
181 st.serialize_field("max", max)?;
182 st.end()
183 }
184 Domain::Union(parts) => {
185 let mut st = serializer.serialize_struct("domain", 2)?;
186 st.serialize_field("type", "union")?;
187 st.serialize_field("parts", parts)?;
188 st.end()
189 }
190 Domain::Complement(inner) => {
191 let mut st = serializer.serialize_struct("domain", 2)?;
192 st.serialize_field("type", "complement")?;
193 st.serialize_field("inner", inner)?;
194 st.end()
195 }
196 }
197 }
198}
199
200impl Serialize for Bound {
201 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
202 where
203 S: Serializer,
204 {
205 match self {
206 Bound::Unbounded => {
207 let mut st = serializer.serialize_struct("bound", 1)?;
208 st.serialize_field("type", "unbounded")?;
209 st.end()
210 }
211 Bound::Inclusive(v) => {
212 let mut st = serializer.serialize_struct("bound", 2)?;
213 st.serialize_field("type", "inclusive")?;
214 st.serialize_field("value", v.as_ref())?;
215 st.end()
216 }
217 Bound::Exclusive(v) => {
218 let mut st = serializer.serialize_struct("bound", 2)?;
219 st.serialize_field("type", "exclusive")?;
220 st.serialize_field("value", v.as_ref())?;
221 st.end()
222 }
223 }
224 }
225}
226
227pub fn extract_domains_from_constraint(
229 constraint: &Constraint,
230) -> LemmaResult<HashMap<FactPath, Domain>> {
231 let all_facts = constraint.collect_facts();
232 let mut domains = HashMap::new();
233
234 for fact_path in all_facts {
235 let domain =
236 extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
237 domains.insert(fact_path, domain);
238 }
239
240 Ok(domains)
241}
242
243fn extract_domain_for_fact(
244 constraint: &Constraint,
245 fact_path: &FactPath,
246) -> LemmaResult<Option<Domain>> {
247 let domain = match constraint {
248 Constraint::True => return Ok(None),
249 Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
250
251 Constraint::Comparison { fact, op, value } => {
252 if fact == fact_path {
253 Some(comparison_to_domain(op, value.as_ref())?)
254 } else {
255 None
256 }
257 }
258
259 Constraint::Fact(fp) => {
260 if fp == fact_path {
261 Some(Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
262 BooleanValue::True,
263 )])))
264 } else {
265 None
266 }
267 }
268
269 Constraint::And(left, right) => {
270 let left_domain = extract_domain_for_fact(left, fact_path)?;
271 let right_domain = extract_domain_for_fact(right, fact_path)?;
272 match (left_domain, right_domain) {
273 (None, None) => None,
274 (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
275 (Some(a), Some(b)) => match domain_intersection(a, b) {
276 Some(domain) => Some(domain),
277 None => Some(Domain::Enumeration(Arc::new(vec![]))),
278 },
279 }
280 }
281
282 Constraint::Or(left, right) => {
283 let left_domain = extract_domain_for_fact(left, fact_path)?;
284 let right_domain = extract_domain_for_fact(right, fact_path)?;
285 union_optional_domains(left_domain, right_domain)
286 }
287
288 Constraint::Not(inner) => {
289 if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
291 if fact == fact_path && op.is_equal() {
292 return Ok(Some(normalize_domain(Domain::Complement(Box::new(
293 Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
294 )))));
295 }
296 }
297
298 if let Constraint::Fact(fp) = inner.as_ref() {
300 if fp == fact_path {
301 return Ok(Some(Domain::Enumeration(Arc::new(vec![
302 LiteralValue::boolean(BooleanValue::False),
303 ]))));
304 }
305 }
306
307 extract_domain_for_fact(inner, fact_path)?
308 .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
309 }
310 };
311
312 Ok(domain.map(normalize_domain))
313}
314
315fn comparison_to_domain(op: &ComparisonComputation, value: &LiteralValue) -> LemmaResult<Domain> {
316 if op.is_equal() {
317 return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
318 }
319 if op.is_not_equal() {
320 return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
321 vec![value.clone()],
322 )))));
323 }
324 match op {
325 ComparisonComputation::LessThan => Ok(Domain::Range {
326 min: Bound::Unbounded,
327 max: Bound::Exclusive(Arc::new(value.clone())),
328 }),
329 ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
330 min: Bound::Unbounded,
331 max: Bound::Inclusive(Arc::new(value.clone())),
332 }),
333 ComparisonComputation::GreaterThan => Ok(Domain::Range {
334 min: Bound::Exclusive(Arc::new(value.clone())),
335 max: Bound::Unbounded,
336 }),
337 ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
338 min: Bound::Inclusive(Arc::new(value.clone())),
339 max: Bound::Unbounded,
340 }),
341 _ => unreachable!(
342 "BUG: unsupported comparison operator for domain extraction: {:?}",
343 op
344 ),
345 }
346}
347
348pub(crate) fn domain_for_comparison_atom(
353 op: &ComparisonComputation,
354 value: &LiteralValue,
355) -> LemmaResult<Domain> {
356 comparison_to_domain(op, value)
357}
358
359impl Domain {
360 pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
367 match (self, other) {
368 (Domain::Empty, _) => true,
369 (_, Domain::Unconstrained) => true,
370 (Domain::Unconstrained, _) => false,
371
372 (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
373 (Domain::Enumeration(vals), Domain::Range { min, max }) => {
374 vals.iter().all(|v| value_within(v, min, max))
375 }
376
377 (
378 Domain::Range {
379 min: amin,
380 max: amax,
381 },
382 Domain::Range {
383 min: bmin,
384 max: bmax,
385 },
386 ) => range_within_range(amin, amax, bmin, bmax),
387
388 (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
390 Domain::Enumeration(excluded) => {
391 excluded.iter().all(|p| !value_within(p, min, max))
392 }
393 _ => false,
394 },
395
396 (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
398 Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
399 _ => false,
400 },
401
402 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
404 match (a_inner.as_ref(), b_inner.as_ref()) {
405 (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
406 excluded_b.iter().all(|v| excluded_a.contains(v))
407 }
408 _ => false,
409 }
410 }
411
412 _ => false,
413 }
414 }
415}
416
417fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
418 lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
419}
420
421fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
422 match (a, b) {
423 (_, Bound::Unbounded) => true,
424 (Bound::Unbounded, _) => false,
425 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
426 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
427 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
428 let c = lit_cmp(av.as_ref(), bv.as_ref());
429 c >= 0
430 }
431 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
432 lit_cmp(av.as_ref(), bv.as_ref()) > 0
434 }
435 }
436}
437
438fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
439 match (a, b) {
440 (Bound::Unbounded, Bound::Unbounded) => true,
441 (_, Bound::Unbounded) => true,
442 (Bound::Unbounded, _) => false,
443 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
444 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
445 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
446 lit_cmp(av.as_ref(), bv.as_ref()) <= 0
448 }
449 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
450 lit_cmp(av.as_ref(), bv.as_ref()) < 0
452 }
453 }
454}
455
456fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
457 match (a, b) {
458 (None, None) => None,
459 (Some(d), None) | (None, Some(d)) => Some(d),
460 (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
461 }
462}
463
464fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
465 use std::cmp::Ordering;
466
467 match (&a.value, &b.value) {
468 (Value::Number(la), Value::Number(lb)) => match la.cmp(lb) {
469 Ordering::Less => -1,
470 Ordering::Equal => 0,
471 Ordering::Greater => 1,
472 },
473
474 (Value::Boolean(la), Value::Boolean(lb)) => match bool::from(la).cmp(&bool::from(lb)) {
475 Ordering::Less => -1,
476 Ordering::Equal => 0,
477 Ordering::Greater => 1,
478 },
479
480 (Value::Text(la), Value::Text(lb)) => match la.cmp(lb) {
481 Ordering::Less => -1,
482 Ordering::Equal => 0,
483 Ordering::Greater => 1,
484 },
485
486 (Value::Date(la), Value::Date(lb)) => match la.cmp(lb) {
487 Ordering::Less => -1,
488 Ordering::Equal => 0,
489 Ordering::Greater => 1,
490 },
491
492 (Value::Time(la), Value::Time(lb)) => match la.cmp(lb) {
493 Ordering::Less => -1,
494 Ordering::Equal => 0,
495 Ordering::Greater => 1,
496 },
497
498 (Value::Duration(la, lua), Value::Duration(lb, lub)) => {
499 let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
500 let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
501 match a_sec.cmp(&b_sec) {
502 Ordering::Less => -1,
503 Ordering::Equal => 0,
504 Ordering::Greater => 1,
505 }
506 }
507
508 (Value::Ratio(la, _), Value::Ratio(lb, _)) => match la.cmp(lb) {
509 Ordering::Less => -1,
510 Ordering::Equal => 0,
511 Ordering::Greater => 1,
512 },
513
514 (Value::Scale(la, lua), Value::Scale(lb, lub)) => {
515 if a.lemma_type != b.lemma_type {
516 unreachable!(
517 "BUG: lit_cmp compared different scale types ({} vs {})",
518 a.lemma_type.name(),
519 b.lemma_type.name()
520 );
521 }
522
523 match (lua, lub) {
524 (None, None) => match la.cmp(lb) {
525 Ordering::Less => -1,
526 Ordering::Equal => 0,
527 Ordering::Greater => 1,
528 },
529 (Some(lua), Some(lub)) => {
530 if lua.eq_ignore_ascii_case(lub) {
531 return match la.cmp(lb) {
532 Ordering::Less => -1,
533 Ordering::Equal => 0,
534 Ordering::Greater => 1,
535 };
536 }
537
538 let target = crate::semantic::ConversionTarget::ScaleUnit(lua.clone());
539 let converted = crate::computation::convert_unit(b, &target);
540 let converted_value = match converted {
541 OperationResult::Value(lit) => match lit.value {
542 Value::Scale(v, _) => v,
543 _ => {
544 unreachable!("BUG: scale unit conversion returned non-scale value")
545 }
546 },
547 OperationResult::Veto(msg) => unreachable!(
548 "BUG: scale unit conversion vetoed unexpectedly: {:?}",
549 msg
550 ),
551 };
552
553 match la.cmp(&converted_value) {
554 Ordering::Less => -1,
555 Ordering::Equal => 0,
556 Ordering::Greater => 1,
557 }
558 }
559 _ => unreachable!("BUG: lit_cmp saw scale value missing unit on one side"),
560 }
561 }
562
563 _ => unreachable!(
564 "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
565 a.get_type(),
566 b.get_type()
567 ),
568 }
569}
570
571fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
572 let ge_min = match min {
573 Bound::Unbounded => true,
574 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
575 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
576 };
577 let le_max = match max {
578 Bound::Unbounded => true,
579 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
580 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
581 };
582 ge_min && le_max
583}
584
585fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
586 match (min, max) {
587 (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
588 (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
589 (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
590 (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
591 (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
592 }
593}
594
595fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
596 match (min1, min2) {
597 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
598 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
599 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
600 Bound::Inclusive(v1)
601 } else {
602 Bound::Inclusive(v2)
603 }
604 }
605 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
606 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
607 Bound::Inclusive(v1)
608 } else {
609 Bound::Exclusive(v2)
610 }
611 }
612 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
613 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
614 Bound::Exclusive(v1)
615 } else {
616 Bound::Inclusive(v2)
617 }
618 }
619 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
620 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
621 Bound::Exclusive(v1)
622 } else {
623 Bound::Exclusive(v2)
624 }
625 }
626 }
627}
628
629fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
630 match (max1, max2) {
631 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
632 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
633 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
634 Bound::Inclusive(v1)
635 } else {
636 Bound::Inclusive(v2)
637 }
638 }
639 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
640 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
641 Bound::Inclusive(v1)
642 } else {
643 Bound::Exclusive(v2)
644 }
645 }
646 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
647 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
648 Bound::Exclusive(v1)
649 } else {
650 Bound::Inclusive(v2)
651 }
652 }
653 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
654 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
655 Bound::Exclusive(v1)
656 } else {
657 Bound::Exclusive(v2)
658 }
659 }
660 }
661}
662
663fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
664 let a = normalize_domain(a);
665 let b = normalize_domain(b);
666
667 let result = match (a, b) {
668 (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
669 (Domain::Empty, _) | (_, Domain::Empty) => None,
670
671 (
672 Domain::Range {
673 min: min1,
674 max: max1,
675 },
676 Domain::Range {
677 min: min2,
678 max: max2,
679 },
680 ) => {
681 let min = compute_intersection_min(min1, min2);
682 let max = compute_intersection_max(max1, max2);
683
684 if bounds_contradict(&min, &max) {
685 None
686 } else {
687 Some(Domain::Range { min, max })
688 }
689 }
690 (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
691 let filtered: Vec<LiteralValue> =
692 v1.iter().filter(|x| v2.contains(x)).cloned().collect();
693 if filtered.is_empty() {
694 None
695 } else {
696 Some(Domain::Enumeration(Arc::new(filtered)))
697 }
698 }
699 (Domain::Enumeration(vs), Domain::Range { min, max })
700 | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
701 let mut kept = Vec::new();
702 for v in vs.iter() {
703 if value_within(v, &min, &max) {
704 kept.push(v.clone());
705 }
706 }
707 if kept.is_empty() {
708 None
709 } else {
710 Some(Domain::Enumeration(Arc::new(kept)))
711 }
712 }
713 (Domain::Enumeration(vs), Domain::Complement(inner))
714 | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
715 match *inner.clone() {
716 Domain::Enumeration(excluded) => {
717 let mut kept = Vec::new();
718 for v in vs.iter() {
719 if !excluded.contains(v) {
720 kept.push(v.clone());
721 }
722 }
723 if kept.is_empty() {
724 None
725 } else {
726 Some(Domain::Enumeration(Arc::new(kept)))
727 }
728 }
729 Domain::Range { min, max } => {
730 let mut kept = Vec::new();
732 for v in vs.iter() {
733 if !value_within(v, &min, &max) {
734 kept.push(v.clone());
735 }
736 }
737 if kept.is_empty() {
738 None
739 } else {
740 Some(Domain::Enumeration(Arc::new(kept)))
741 }
742 }
743 _ => {
744 let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
746 domain_intersection(Domain::Enumeration(vs.clone()), normalized)
747 }
748 }
749 }
750 (Domain::Union(v1), Domain::Union(v2)) => {
751 let mut acc: Vec<Domain> = Vec::new();
752 for a in v1.iter() {
753 for b in v2.iter() {
754 if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
755 acc.push(ix);
756 }
757 }
758 }
759 if acc.is_empty() {
760 None
761 } else {
762 Some(Domain::Union(Arc::new(acc)))
763 }
764 }
765 (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
766 let mut acc: Vec<Domain> = Vec::new();
767 for a in vs.iter() {
768 if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
769 acc.push(ix);
770 }
771 }
772 if acc.is_empty() {
773 None
774 } else if acc.len() == 1 {
775 Some(acc.remove(0))
776 } else {
777 Some(Domain::Union(Arc::new(acc)))
778 }
779 }
780 (Domain::Range { min, max }, Domain::Complement(inner))
782 | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
783 Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
784 _ => {
785 let normalized_complement = normalize_domain(Domain::Complement(inner));
788 if matches!(&normalized_complement, Domain::Complement(_)) {
789 None
790 } else {
791 domain_intersection(Domain::Range { min, max }, normalized_complement)
792 }
793 }
794 },
795 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
796 match (a_inner.as_ref(), b_inner.as_ref()) {
797 (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
798 let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
800 excluded.extend(b_ex.iter().cloned());
801 Some(normalize_domain(Domain::Complement(Box::new(
802 Domain::Enumeration(Arc::new(excluded)),
803 ))))
804 }
805 _ => None,
806 }
807 }
808 };
809 result.map(normalize_domain)
810}
811
812fn range_minus_excluded_points(
813 min: Bound,
814 max: Bound,
815 excluded: &Arc<Vec<LiteralValue>>,
816) -> Option<Domain> {
817 let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
819
820 for p in excluded.iter() {
821 let mut next: Vec<(Bound, Bound)> = Vec::new();
822
823 for (rmin, rmax) in parts {
824 if !value_within(p, &rmin, &rmax) {
825 next.push((rmin, rmax));
826 continue;
827 }
828
829 let left_max = Bound::Exclusive(Arc::new(p.clone()));
831 if !bounds_contradict(&rmin, &left_max) {
832 next.push((rmin.clone(), left_max));
833 }
834
835 let right_min = Bound::Exclusive(Arc::new(p.clone()));
837 if !bounds_contradict(&right_min, &rmax) {
838 next.push((right_min, rmax.clone()));
839 }
840 }
841
842 parts = next;
843 if parts.is_empty() {
844 return None;
845 }
846 }
847
848 if parts.is_empty() {
849 None
850 } else if parts.len() == 1 {
851 let (min, max) = parts.remove(0);
852 Some(Domain::Range { min, max })
853 } else {
854 Some(Domain::Union(Arc::new(
855 parts
856 .into_iter()
857 .map(|(min, max)| Domain::Range { min, max })
858 .collect(),
859 )))
860 }
861}
862
863fn invert_bound(bound: Bound) -> Bound {
864 match bound {
865 Bound::Unbounded => Bound::Unbounded,
866 Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
867 Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
868 }
869}
870
871fn normalize_domain(d: Domain) -> Domain {
872 match d {
873 Domain::Complement(inner) => {
874 let normalized_inner = normalize_domain(*inner);
875 match normalized_inner {
876 Domain::Complement(double_inner) => *double_inner,
877 Domain::Range { min, max } => match (&min, &max) {
878 (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
879 (Bound::Unbounded, max) => Domain::Range {
880 min: invert_bound(max.clone()),
881 max: Bound::Unbounded,
882 },
883 (min, Bound::Unbounded) => Domain::Range {
884 min: Bound::Unbounded,
885 max: invert_bound(min.clone()),
886 },
887 (min, max) => Domain::Union(Arc::new(vec![
888 Domain::Range {
889 min: Bound::Unbounded,
890 max: invert_bound(min.clone()),
891 },
892 Domain::Range {
893 min: invert_bound(max.clone()),
894 max: Bound::Unbounded,
895 },
896 ])),
897 },
898 Domain::Enumeration(vals) => {
899 if vals.len() == 1 {
900 if let Some(lit) = vals.first() {
901 if let Value::Boolean(BooleanValue::True) = &lit.value {
902 return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
903 BooleanValue::False,
904 )]));
905 }
906 if let Value::Boolean(BooleanValue::False) = &lit.value {
907 return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
908 BooleanValue::True,
909 )]));
910 }
911 }
912 }
913 Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
914 }
915 Domain::Unconstrained => Domain::Empty,
916 Domain::Empty => Domain::Unconstrained,
917 Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
918 }
919 }
920 Domain::Empty => Domain::Empty,
921 Domain::Union(parts) => {
922 let mut flat: Vec<Domain> = Vec::new();
923 for p in parts.iter().cloned() {
924 let normalized = normalize_domain(p);
925 match normalized {
926 Domain::Union(inner) => flat.extend(inner.iter().cloned()),
927 Domain::Unconstrained => return Domain::Unconstrained,
928 Domain::Enumeration(vals) if vals.is_empty() => {}
929 other => flat.push(other),
930 }
931 }
932
933 let mut all_enum_values: Vec<LiteralValue> = Vec::new();
934 let mut ranges: Vec<Domain> = Vec::new();
935 let mut others: Vec<Domain> = Vec::new();
936
937 for domain in flat {
938 match domain {
939 Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
940 Domain::Range { .. } => ranges.push(domain),
941 other => others.push(other),
942 }
943 }
944
945 all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
946 -1 => Ordering::Less,
947 0 => Ordering::Equal,
948 _ => Ordering::Greater,
949 });
950 all_enum_values.dedup();
951
952 all_enum_values.retain(|v| {
953 !ranges.iter().any(|r| {
954 if let Domain::Range { min, max } = r {
955 value_within(v, min, max)
956 } else {
957 false
958 }
959 })
960 });
961
962 let mut result: Vec<Domain> = Vec::new();
963 result.extend(ranges);
964 result = merge_ranges(result);
965
966 if !all_enum_values.is_empty() {
967 result.push(Domain::Enumeration(Arc::new(all_enum_values)));
968 }
969 result.extend(others);
970
971 result.sort_by(|a, b| match (a, b) {
972 (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
973 (Domain::Range { .. }, _) => Ordering::Less,
974 (_, Domain::Range { .. }) => Ordering::Greater,
975 (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
976 (Domain::Enumeration(_), _) => Ordering::Less,
977 (_, Domain::Enumeration(_)) => Ordering::Greater,
978 _ => Ordering::Equal,
979 });
980
981 if result.is_empty() {
982 Domain::Enumeration(Arc::new(vec![]))
983 } else if result.len() == 1 {
984 result.remove(0)
985 } else {
986 Domain::Union(Arc::new(result))
987 }
988 }
989 Domain::Enumeration(values) => {
990 let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
991 sorted.sort_by(|a, b| match lit_cmp(a, b) {
992 -1 => Ordering::Less,
993 0 => Ordering::Equal,
994 _ => Ordering::Greater,
995 });
996 sorted.dedup();
997 Domain::Enumeration(Arc::new(sorted))
998 }
999 other => other,
1000 }
1001}
1002
1003fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1004 let mut result = Vec::new();
1005 let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1006 let mut others = Vec::new();
1007
1008 for d in domains {
1009 match d {
1010 Domain::Range { min, max } => ranges.push((min, max)),
1011 other => others.push(other),
1012 }
1013 }
1014
1015 if ranges.is_empty() {
1016 return others;
1017 }
1018
1019 ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1020
1021 let mut merged: Vec<(Bound, Bound)> = Vec::new();
1022 let mut current = ranges[0].clone();
1023
1024 for next in ranges.iter().skip(1) {
1025 if ranges_adjacent_or_overlap(¤t, next) {
1026 current = (
1027 min_bound(¤t.0, &next.0),
1028 max_bound(¤t.1, &next.1),
1029 );
1030 } else {
1031 merged.push(current);
1032 current = next.clone();
1033 }
1034 }
1035 merged.push(current);
1036
1037 for (min, max) in merged {
1038 result.push(Domain::Range { min, max });
1039 }
1040 result.extend(others);
1041
1042 result
1043}
1044
1045fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1046 match (a, b) {
1047 (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1048 (Bound::Unbounded, _) => Ordering::Less,
1049 (_, Bound::Unbounded) => Ordering::Greater,
1050 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1051 | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1052 -1 => Ordering::Less,
1053 0 => Ordering::Equal,
1054 _ => Ordering::Greater,
1055 },
1056 (Bound::Inclusive(v1), Bound::Exclusive(v2))
1057 | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1058 -1 => Ordering::Less,
1059 0 => {
1060 if matches!(a, Bound::Inclusive(_)) {
1061 Ordering::Less
1062 } else {
1063 Ordering::Greater
1064 }
1065 }
1066 _ => Ordering::Greater,
1067 },
1068 }
1069}
1070
1071fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1072 match (&r1.1, &r2.0) {
1073 (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1074 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1075 | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1076 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1077 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1078 }
1079}
1080
1081fn min_bound(a: &Bound, b: &Bound) -> Bound {
1082 match (a, b) {
1083 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1084 _ => {
1085 if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1086 a.clone()
1087 } else {
1088 b.clone()
1089 }
1090 }
1091 }
1092}
1093
1094fn max_bound(a: &Bound, b: &Bound) -> Bound {
1095 match (a, b) {
1096 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1097 _ => {
1098 if matches!(compare_bounds(a, b), Ordering::Greater) {
1099 a.clone()
1100 } else {
1101 b.clone()
1102 }
1103 }
1104 }
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109 use super::*;
1110 use rust_decimal::Decimal;
1111
1112 fn num(n: i64) -> LiteralValue {
1113 LiteralValue::number(Decimal::from(n))
1114 }
1115
1116 fn fact(name: &str) -> FactPath {
1117 FactPath::local(name.to_string())
1118 }
1119
1120 #[test]
1121 fn test_normalize_double_complement() {
1122 let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1123 let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1124 let normalized = normalize_domain(double);
1125 assert_eq!(normalized, inner);
1126 }
1127
1128 #[test]
1129 fn test_normalize_union_absorbs_unconstrained() {
1130 let union = Domain::Union(Arc::new(vec![
1131 Domain::Range {
1132 min: Bound::Inclusive(Arc::new(num(0))),
1133 max: Bound::Inclusive(Arc::new(num(10))),
1134 },
1135 Domain::Unconstrained,
1136 ]));
1137 let normalized = normalize_domain(union);
1138 assert_eq!(normalized, Domain::Unconstrained);
1139 }
1140
1141 #[test]
1142 fn test_domain_display() {
1143 let range = Domain::Range {
1144 min: Bound::Inclusive(Arc::new(num(10))),
1145 max: Bound::Exclusive(Arc::new(num(20))),
1146 };
1147 assert_eq!(format!("{}", range), "[10, 20)");
1148
1149 let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1150 assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1151 }
1152
1153 #[test]
1154 fn test_extract_domain_from_comparison() {
1155 let constraint = Constraint::Comparison {
1156 fact: fact("age"),
1157 op: ComparisonComputation::GreaterThan,
1158 value: Arc::new(num(18)),
1159 };
1160
1161 let domains = extract_domains_from_constraint(&constraint).unwrap();
1162 let age_domain = domains.get(&fact("age")).unwrap();
1163
1164 assert_eq!(
1165 *age_domain,
1166 Domain::Range {
1167 min: Bound::Exclusive(Arc::new(num(18))),
1168 max: Bound::Unbounded,
1169 }
1170 );
1171 }
1172
1173 #[test]
1174 fn test_extract_domain_from_and() {
1175 let constraint = Constraint::And(
1176 Box::new(Constraint::Comparison {
1177 fact: fact("age"),
1178 op: ComparisonComputation::GreaterThan,
1179 value: Arc::new(num(18)),
1180 }),
1181 Box::new(Constraint::Comparison {
1182 fact: fact("age"),
1183 op: ComparisonComputation::LessThan,
1184 value: Arc::new(num(65)),
1185 }),
1186 );
1187
1188 let domains = extract_domains_from_constraint(&constraint).unwrap();
1189 let age_domain = domains.get(&fact("age")).unwrap();
1190
1191 assert_eq!(
1192 *age_domain,
1193 Domain::Range {
1194 min: Bound::Exclusive(Arc::new(num(18))),
1195 max: Bound::Exclusive(Arc::new(num(65))),
1196 }
1197 );
1198 }
1199
1200 #[test]
1201 fn test_extract_domain_from_equality() {
1202 let constraint = Constraint::Comparison {
1203 fact: fact("status"),
1204 op: ComparisonComputation::Equal,
1205 value: Arc::new(LiteralValue::text("active".to_string())),
1206 };
1207
1208 let domains = extract_domains_from_constraint(&constraint).unwrap();
1209 let status_domain = domains.get(&fact("status")).unwrap();
1210
1211 assert_eq!(
1212 *status_domain,
1213 Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1214 );
1215 }
1216
1217 #[test]
1218 fn test_extract_domain_from_boolean_fact() {
1219 let constraint = Constraint::Fact(fact("is_active"));
1220
1221 let domains = extract_domains_from_constraint(&constraint).unwrap();
1222 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1223
1224 assert_eq!(
1225 *is_active_domain,
1226 Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::True)]))
1227 );
1228 }
1229
1230 #[test]
1231 fn test_extract_domain_from_not_boolean_fact() {
1232 let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1233
1234 let domains = extract_domains_from_constraint(&constraint).unwrap();
1235 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1236
1237 assert_eq!(
1238 *is_active_domain,
1239 Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::False)]))
1240 );
1241 }
1242}