1use crate::computation::comparison_operation;
9use crate::parsing::ast::Span;
10use crate::{
11 BooleanValue, ComparisonComputation, FactPath, LemmaError, LemmaResult, LiteralValue,
12 OperationResult, Value,
13};
14use serde::ser::{Serialize, SerializeStruct, Serializer};
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::fmt;
18use std::sync::Arc;
19
20use super::constraint::Constraint;
21
22#[derive(Debug, Clone, PartialEq)]
24pub enum Domain {
25 Range { min: Bound, max: Bound },
27
28 Union(Arc<Vec<Domain>>),
30
31 Enumeration(Arc<Vec<LiteralValue>>),
33
34 Complement(Box<Domain>),
36
37 Unconstrained,
39
40 Empty,
42}
43
44impl Domain {
45 pub fn is_satisfiable(&self) -> bool {
49 match self {
50 Domain::Empty => false,
51 Domain::Enumeration(values) => !values.is_empty(),
52 Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
53 Domain::Range { min, max } => !bounds_contradict(min, max),
54 Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
55 Domain::Unconstrained => true,
56 }
57 }
58
59 pub fn is_empty(&self) -> bool {
61 !self.is_satisfiable()
62 }
63
64 pub fn intersect(&self, other: &Domain) -> Domain {
66 domain_intersection(self.clone(), other.clone()).unwrap_or(Domain::Empty)
67 }
68
69 pub fn contains(&self, value: &LiteralValue) -> bool {
71 match self {
72 Domain::Empty => false,
73 Domain::Unconstrained => true,
74 Domain::Enumeration(values) => values.contains(value),
75 Domain::Range { min, max } => value_within(value, min, max),
76 Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
77 Domain::Complement(inner) => !inner.contains(value),
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq)]
84pub enum Bound {
85 Inclusive(Arc<LiteralValue>),
87
88 Exclusive(Arc<LiteralValue>),
90
91 Unbounded,
93}
94
95impl fmt::Display for Domain {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 match self {
98 Domain::Empty => write!(f, "empty"),
99 Domain::Unconstrained => write!(f, "any"),
100 Domain::Enumeration(vals) => {
101 write!(f, "{{")?;
102 for (i, v) in vals.iter().enumerate() {
103 if i > 0 {
104 write!(f, ", ")?;
105 }
106 write!(f, "{}", v)?;
107 }
108 write!(f, "}}")
109 }
110 Domain::Range { min, max } => {
111 let (l_bracket, r_bracket) = match (min, max) {
112 (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
113 (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
114 (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
115 (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
116 (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
117 (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
118 (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
119 (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
120 (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
121 };
122
123 let min_str = match min {
124 Bound::Unbounded => "-inf".to_string(),
125 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
126 };
127 let max_str = match max {
128 Bound::Unbounded => "+inf".to_string(),
129 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
130 };
131 write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
132 }
133 Domain::Union(parts) => {
134 for (i, p) in parts.iter().enumerate() {
135 if i > 0 {
136 write!(f, " | ")?;
137 }
138 write!(f, "{}", p)?;
139 }
140 Ok(())
141 }
142 Domain::Complement(inner) => write!(f, "not ({})", inner),
143 }
144 }
145}
146
147impl fmt::Display for Bound {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149 match self {
150 Bound::Unbounded => write!(f, "inf"),
151 Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
152 Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
153 }
154 }
155}
156
157impl Serialize for Domain {
158 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
159 where
160 S: Serializer,
161 {
162 match self {
163 Domain::Empty => {
164 let mut st = serializer.serialize_struct("domain", 1)?;
165 st.serialize_field("type", "empty")?;
166 st.end()
167 }
168 Domain::Unconstrained => {
169 let mut st = serializer.serialize_struct("domain", 1)?;
170 st.serialize_field("type", "unconstrained")?;
171 st.end()
172 }
173 Domain::Enumeration(vals) => {
174 let mut st = serializer.serialize_struct("domain", 2)?;
175 st.serialize_field("type", "enumeration")?;
176 st.serialize_field("values", vals)?;
177 st.end()
178 }
179 Domain::Range { min, max } => {
180 let mut st = serializer.serialize_struct("domain", 3)?;
181 st.serialize_field("type", "range")?;
182 st.serialize_field("min", min)?;
183 st.serialize_field("max", max)?;
184 st.end()
185 }
186 Domain::Union(parts) => {
187 let mut st = serializer.serialize_struct("domain", 2)?;
188 st.serialize_field("type", "union")?;
189 st.serialize_field("parts", parts)?;
190 st.end()
191 }
192 Domain::Complement(inner) => {
193 let mut st = serializer.serialize_struct("domain", 2)?;
194 st.serialize_field("type", "complement")?;
195 st.serialize_field("inner", inner)?;
196 st.end()
197 }
198 }
199 }
200}
201
202impl Serialize for Bound {
203 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
204 where
205 S: Serializer,
206 {
207 match self {
208 Bound::Unbounded => {
209 let mut st = serializer.serialize_struct("bound", 1)?;
210 st.serialize_field("type", "unbounded")?;
211 st.end()
212 }
213 Bound::Inclusive(v) => {
214 let mut st = serializer.serialize_struct("bound", 2)?;
215 st.serialize_field("type", "inclusive")?;
216 st.serialize_field("value", v.as_ref())?;
217 st.end()
218 }
219 Bound::Exclusive(v) => {
220 let mut st = serializer.serialize_struct("bound", 2)?;
221 st.serialize_field("type", "exclusive")?;
222 st.serialize_field("value", v.as_ref())?;
223 st.end()
224 }
225 }
226 }
227}
228
229pub fn extract_domains_from_constraint(
231 constraint: &Constraint,
232) -> LemmaResult<HashMap<FactPath, Domain>> {
233 let all_facts = constraint.collect_facts();
234 let mut domains = HashMap::new();
235
236 for fact_path in all_facts {
237 let domain =
238 extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
239 domains.insert(fact_path, domain);
240 }
241
242 Ok(domains)
243}
244
245fn extract_domain_for_fact(
246 constraint: &Constraint,
247 fact_path: &FactPath,
248) -> LemmaResult<Option<Domain>> {
249 let domain = match constraint {
250 Constraint::True => return Ok(None),
251 Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
252
253 Constraint::Comparison { fact, op, value } => {
254 if fact == fact_path {
255 Some(comparison_to_domain(op, value.as_ref())?)
256 } else {
257 None
258 }
259 }
260
261 Constraint::Fact(fp) => {
262 if fp == fact_path {
263 Some(Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
264 BooleanValue::True,
265 )])))
266 } else {
267 None
268 }
269 }
270
271 Constraint::And(left, right) => {
272 let left_domain = extract_domain_for_fact(left, fact_path)?;
273 let right_domain = extract_domain_for_fact(right, fact_path)?;
274 match (left_domain, right_domain) {
275 (None, None) => None,
276 (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
277 (Some(a), Some(b)) => match domain_intersection(a, b) {
278 Some(domain) => Some(domain),
279 None => Some(Domain::Enumeration(Arc::new(vec![]))),
280 },
281 }
282 }
283
284 Constraint::Or(left, right) => {
285 let left_domain = extract_domain_for_fact(left, fact_path)?;
286 let right_domain = extract_domain_for_fact(right, fact_path)?;
287 union_optional_domains(left_domain, right_domain)
288 }
289
290 Constraint::Not(inner) => {
291 if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
293 if fact == fact_path && op.is_equal() {
294 return Ok(Some(normalize_domain(Domain::Complement(Box::new(
295 Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
296 )))));
297 }
298 }
299
300 if let Constraint::Fact(fp) = inner.as_ref() {
302 if fp == fact_path {
303 return Ok(Some(Domain::Enumeration(Arc::new(vec![
304 LiteralValue::boolean(BooleanValue::False),
305 ]))));
306 }
307 }
308
309 extract_domain_for_fact(inner, fact_path)?
310 .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
311 }
312 };
313
314 Ok(domain.map(normalize_domain))
315}
316
317fn comparison_to_domain(op: &ComparisonComputation, value: &LiteralValue) -> LemmaResult<Domain> {
318 if op.is_equal() {
319 return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
320 }
321 if op.is_not_equal() {
322 return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
323 vec![value.clone()],
324 )))));
325 }
326 match op {
327 ComparisonComputation::LessThan => Ok(Domain::Range {
328 min: Bound::Unbounded,
329 max: Bound::Exclusive(Arc::new(value.clone())),
330 }),
331 ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
332 min: Bound::Unbounded,
333 max: Bound::Inclusive(Arc::new(value.clone())),
334 }),
335 ComparisonComputation::GreaterThan => Ok(Domain::Range {
336 min: Bound::Exclusive(Arc::new(value.clone())),
337 max: Bound::Unbounded,
338 }),
339 ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
340 min: Bound::Inclusive(Arc::new(value.clone())),
341 max: Bound::Unbounded,
342 }),
343 _ => Err(LemmaError::engine(
344 format!(
345 "Unsupported comparison operator for domain extraction: {:?}",
346 op
347 ),
348 Span {
349 start: 0,
350 end: 0,
351 line: 1,
352 col: 0,
353 },
354 "<unknown>",
355 Arc::from(""),
356 "<unknown>",
357 1,
358 None::<String>,
359 )),
360 }
361}
362
363pub(crate) fn domain_for_comparison_atom(
368 op: &ComparisonComputation,
369 value: &LiteralValue,
370) -> LemmaResult<Domain> {
371 comparison_to_domain(op, value)
372}
373
374impl Domain {
375 pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
382 match (self, other) {
383 (Domain::Empty, _) => true,
384 (_, Domain::Unconstrained) => true,
385 (Domain::Unconstrained, _) => false,
386
387 (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
388 (Domain::Enumeration(vals), Domain::Range { min, max }) => {
389 vals.iter().all(|v| value_within(v, min, max))
390 }
391
392 (
393 Domain::Range {
394 min: amin,
395 max: amax,
396 },
397 Domain::Range {
398 min: bmin,
399 max: bmax,
400 },
401 ) => range_within_range(amin, amax, bmin, bmax),
402
403 (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
405 Domain::Enumeration(excluded) => {
406 excluded.iter().all(|p| !value_within(p, min, max))
407 }
408 _ => false,
409 },
410
411 (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
413 Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
414 _ => false,
415 },
416
417 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
419 match (a_inner.as_ref(), b_inner.as_ref()) {
420 (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
421 excluded_b.iter().all(|v| excluded_a.contains(v))
422 }
423 _ => false,
424 }
425 }
426
427 _ => false,
428 }
429 }
430}
431
432fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
433 lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
434}
435
436fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
437 match (a, b) {
438 (_, Bound::Unbounded) => true,
439 (Bound::Unbounded, _) => false,
440 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
441 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
442 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
443 let c = lit_cmp(av.as_ref(), bv.as_ref());
444 c >= 0
445 }
446 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
447 lit_cmp(av.as_ref(), bv.as_ref()) > 0
449 }
450 }
451}
452
453fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
454 match (a, b) {
455 (Bound::Unbounded, Bound::Unbounded) => true,
456 (_, Bound::Unbounded) => true,
457 (Bound::Unbounded, _) => false,
458 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
459 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
460 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
461 lit_cmp(av.as_ref(), bv.as_ref()) <= 0
463 }
464 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
465 lit_cmp(av.as_ref(), bv.as_ref()) < 0
467 }
468 }
469}
470
471fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
472 match (a, b) {
473 (None, None) => None,
474 (Some(d), None) | (None, Some(d)) => Some(d),
475 (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
476 }
477}
478
479fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
480 if let OperationResult::Value(lit) =
481 comparison_operation(a, &ComparisonComputation::LessThan, b)
482 {
483 if let Value::Boolean(BooleanValue::True) = &lit.value {
484 return -1;
485 }
486 }
487 if let OperationResult::Value(lit) = comparison_operation(a, &ComparisonComputation::Equal, b) {
488 if let Value::Boolean(BooleanValue::True) = &lit.value {
489 return 0;
490 }
491 }
492 1
493}
494
495fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
496 let ge_min = match min {
497 Bound::Unbounded => true,
498 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
499 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
500 };
501 let le_max = match max {
502 Bound::Unbounded => true,
503 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
504 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
505 };
506 ge_min && le_max
507}
508
509fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
510 match (min, max) {
511 (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
512 (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
513 (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
514 (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
515 (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
516 }
517}
518
519fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
520 match (min1, min2) {
521 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
522 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
523 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
524 Bound::Inclusive(v1)
525 } else {
526 Bound::Inclusive(v2)
527 }
528 }
529 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
530 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
531 Bound::Inclusive(v1)
532 } else {
533 Bound::Exclusive(v2)
534 }
535 }
536 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
537 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
538 Bound::Exclusive(v1)
539 } else {
540 Bound::Inclusive(v2)
541 }
542 }
543 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
544 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
545 Bound::Exclusive(v1)
546 } else {
547 Bound::Exclusive(v2)
548 }
549 }
550 }
551}
552
553fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
554 match (max1, max2) {
555 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
556 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
557 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
558 Bound::Inclusive(v1)
559 } else {
560 Bound::Inclusive(v2)
561 }
562 }
563 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
564 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
565 Bound::Inclusive(v1)
566 } else {
567 Bound::Exclusive(v2)
568 }
569 }
570 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
571 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
572 Bound::Exclusive(v1)
573 } else {
574 Bound::Inclusive(v2)
575 }
576 }
577 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
578 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
579 Bound::Exclusive(v1)
580 } else {
581 Bound::Exclusive(v2)
582 }
583 }
584 }
585}
586
587fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
588 let a = normalize_domain(a);
589 let b = normalize_domain(b);
590
591 let result = match (a, b) {
592 (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
593 (Domain::Empty, _) | (_, Domain::Empty) => None,
594
595 (
596 Domain::Range {
597 min: min1,
598 max: max1,
599 },
600 Domain::Range {
601 min: min2,
602 max: max2,
603 },
604 ) => {
605 let min = compute_intersection_min(min1, min2);
606 let max = compute_intersection_max(max1, max2);
607
608 if bounds_contradict(&min, &max) {
609 None
610 } else {
611 Some(Domain::Range { min, max })
612 }
613 }
614 (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
615 let filtered: Vec<LiteralValue> =
616 v1.iter().filter(|x| v2.contains(x)).cloned().collect();
617 if filtered.is_empty() {
618 None
619 } else {
620 Some(Domain::Enumeration(Arc::new(filtered)))
621 }
622 }
623 (Domain::Enumeration(vs), Domain::Range { min, max })
624 | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
625 let mut kept = Vec::new();
626 for v in vs.iter() {
627 if value_within(v, &min, &max) {
628 kept.push(v.clone());
629 }
630 }
631 if kept.is_empty() {
632 None
633 } else {
634 Some(Domain::Enumeration(Arc::new(kept)))
635 }
636 }
637 (Domain::Enumeration(vs), Domain::Complement(inner))
638 | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
639 match *inner.clone() {
640 Domain::Enumeration(excluded) => {
641 let mut kept = Vec::new();
642 for v in vs.iter() {
643 if !excluded.contains(v) {
644 kept.push(v.clone());
645 }
646 }
647 if kept.is_empty() {
648 None
649 } else {
650 Some(Domain::Enumeration(Arc::new(kept)))
651 }
652 }
653 Domain::Range { min, max } => {
654 let mut kept = Vec::new();
656 for v in vs.iter() {
657 if !value_within(v, &min, &max) {
658 kept.push(v.clone());
659 }
660 }
661 if kept.is_empty() {
662 None
663 } else {
664 Some(Domain::Enumeration(Arc::new(kept)))
665 }
666 }
667 _ => {
668 let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
670 domain_intersection(Domain::Enumeration(vs.clone()), normalized)
671 }
672 }
673 }
674 (Domain::Union(v1), Domain::Union(v2)) => {
675 let mut acc: Vec<Domain> = Vec::new();
676 for a in v1.iter() {
677 for b in v2.iter() {
678 if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
679 acc.push(ix);
680 }
681 }
682 }
683 if acc.is_empty() {
684 None
685 } else {
686 Some(Domain::Union(Arc::new(acc)))
687 }
688 }
689 (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
690 let mut acc: Vec<Domain> = Vec::new();
691 for a in vs.iter() {
692 if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
693 acc.push(ix);
694 }
695 }
696 if acc.is_empty() {
697 None
698 } else if acc.len() == 1 {
699 Some(acc.remove(0))
700 } else {
701 Some(Domain::Union(Arc::new(acc)))
702 }
703 }
704 (Domain::Range { min, max }, Domain::Complement(inner))
706 | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
707 Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
708 _ => {
709 let normalized_complement = normalize_domain(Domain::Complement(inner));
712 if matches!(&normalized_complement, Domain::Complement(_)) {
713 None
714 } else {
715 domain_intersection(Domain::Range { min, max }, normalized_complement)
716 }
717 }
718 },
719 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
720 match (a_inner.as_ref(), b_inner.as_ref()) {
721 (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
722 let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
724 excluded.extend(b_ex.iter().cloned());
725 Some(normalize_domain(Domain::Complement(Box::new(
726 Domain::Enumeration(Arc::new(excluded)),
727 ))))
728 }
729 _ => None,
730 }
731 }
732 };
733 result.map(normalize_domain)
734}
735
736fn range_minus_excluded_points(
737 min: Bound,
738 max: Bound,
739 excluded: &Arc<Vec<LiteralValue>>,
740) -> Option<Domain> {
741 let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
743
744 for p in excluded.iter() {
745 let mut next: Vec<(Bound, Bound)> = Vec::new();
746
747 for (rmin, rmax) in parts {
748 if !value_within(p, &rmin, &rmax) {
749 next.push((rmin, rmax));
750 continue;
751 }
752
753 let left_max = Bound::Exclusive(Arc::new(p.clone()));
755 if !bounds_contradict(&rmin, &left_max) {
756 next.push((rmin.clone(), left_max));
757 }
758
759 let right_min = Bound::Exclusive(Arc::new(p.clone()));
761 if !bounds_contradict(&right_min, &rmax) {
762 next.push((right_min, rmax.clone()));
763 }
764 }
765
766 parts = next;
767 if parts.is_empty() {
768 return None;
769 }
770 }
771
772 if parts.is_empty() {
773 None
774 } else if parts.len() == 1 {
775 let (min, max) = parts.remove(0);
776 Some(Domain::Range { min, max })
777 } else {
778 Some(Domain::Union(Arc::new(
779 parts
780 .into_iter()
781 .map(|(min, max)| Domain::Range { min, max })
782 .collect(),
783 )))
784 }
785}
786
787fn invert_bound(bound: Bound) -> Bound {
788 match bound {
789 Bound::Unbounded => Bound::Unbounded,
790 Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
791 Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
792 }
793}
794
795fn normalize_domain(d: Domain) -> Domain {
796 match d {
797 Domain::Complement(inner) => {
798 let normalized_inner = normalize_domain(*inner);
799 match normalized_inner {
800 Domain::Complement(double_inner) => *double_inner,
801 Domain::Range { min, max } => match (&min, &max) {
802 (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
803 (Bound::Unbounded, max) => Domain::Range {
804 min: invert_bound(max.clone()),
805 max: Bound::Unbounded,
806 },
807 (min, Bound::Unbounded) => Domain::Range {
808 min: Bound::Unbounded,
809 max: invert_bound(min.clone()),
810 },
811 (min, max) => Domain::Union(Arc::new(vec![
812 Domain::Range {
813 min: Bound::Unbounded,
814 max: invert_bound(min.clone()),
815 },
816 Domain::Range {
817 min: invert_bound(max.clone()),
818 max: Bound::Unbounded,
819 },
820 ])),
821 },
822 Domain::Enumeration(vals) => {
823 if vals.len() == 1 {
824 if let Some(lit) = vals.first() {
825 if let Value::Boolean(BooleanValue::True) = &lit.value {
826 return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
827 BooleanValue::False,
828 )]));
829 }
830 if let Value::Boolean(BooleanValue::False) = &lit.value {
831 return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
832 BooleanValue::True,
833 )]));
834 }
835 }
836 }
837 Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
838 }
839 Domain::Unconstrained => Domain::Empty,
840 Domain::Empty => Domain::Unconstrained,
841 Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
842 }
843 }
844 Domain::Empty => Domain::Empty,
845 Domain::Union(parts) => {
846 let mut flat: Vec<Domain> = Vec::new();
847 for p in parts.iter().cloned() {
848 let normalized = normalize_domain(p);
849 match normalized {
850 Domain::Union(inner) => flat.extend(inner.iter().cloned()),
851 Domain::Unconstrained => return Domain::Unconstrained,
852 Domain::Enumeration(vals) if vals.is_empty() => {}
853 other => flat.push(other),
854 }
855 }
856
857 let mut all_enum_values: Vec<LiteralValue> = Vec::new();
858 let mut ranges: Vec<Domain> = Vec::new();
859 let mut others: Vec<Domain> = Vec::new();
860
861 for domain in flat {
862 match domain {
863 Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
864 Domain::Range { .. } => ranges.push(domain),
865 other => others.push(other),
866 }
867 }
868
869 all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
870 -1 => Ordering::Less,
871 0 => Ordering::Equal,
872 _ => Ordering::Greater,
873 });
874 all_enum_values.dedup();
875
876 all_enum_values.retain(|v| {
877 !ranges.iter().any(|r| {
878 if let Domain::Range { min, max } = r {
879 value_within(v, min, max)
880 } else {
881 false
882 }
883 })
884 });
885
886 let mut result: Vec<Domain> = Vec::new();
887 result.extend(ranges);
888 result = merge_ranges(result);
889
890 if !all_enum_values.is_empty() {
891 result.push(Domain::Enumeration(Arc::new(all_enum_values)));
892 }
893 result.extend(others);
894
895 result.sort_by(|a, b| match (a, b) {
896 (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
897 (Domain::Range { .. }, _) => Ordering::Less,
898 (_, Domain::Range { .. }) => Ordering::Greater,
899 (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
900 (Domain::Enumeration(_), _) => Ordering::Less,
901 (_, Domain::Enumeration(_)) => Ordering::Greater,
902 _ => Ordering::Equal,
903 });
904
905 if result.is_empty() {
906 Domain::Enumeration(Arc::new(vec![]))
907 } else if result.len() == 1 {
908 result.remove(0)
909 } else {
910 Domain::Union(Arc::new(result))
911 }
912 }
913 Domain::Enumeration(values) => {
914 let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
915 sorted.sort_by(|a, b| match lit_cmp(a, b) {
916 -1 => Ordering::Less,
917 0 => Ordering::Equal,
918 _ => Ordering::Greater,
919 });
920 sorted.dedup();
921 Domain::Enumeration(Arc::new(sorted))
922 }
923 other => other,
924 }
925}
926
927fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
928 let mut result = Vec::new();
929 let mut ranges: Vec<(Bound, Bound)> = Vec::new();
930 let mut others = Vec::new();
931
932 for d in domains {
933 match d {
934 Domain::Range { min, max } => ranges.push((min, max)),
935 other => others.push(other),
936 }
937 }
938
939 if ranges.is_empty() {
940 return others;
941 }
942
943 ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
944
945 let mut merged: Vec<(Bound, Bound)> = Vec::new();
946 let mut current = ranges[0].clone();
947
948 for next in ranges.iter().skip(1) {
949 if ranges_adjacent_or_overlap(¤t, next) {
950 current = (
951 min_bound(¤t.0, &next.0),
952 max_bound(¤t.1, &next.1),
953 );
954 } else {
955 merged.push(current);
956 current = next.clone();
957 }
958 }
959 merged.push(current);
960
961 for (min, max) in merged {
962 result.push(Domain::Range { min, max });
963 }
964 result.extend(others);
965
966 result
967}
968
969fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
970 match (a, b) {
971 (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
972 (Bound::Unbounded, _) => Ordering::Less,
973 (_, Bound::Unbounded) => Ordering::Greater,
974 (Bound::Inclusive(v1), Bound::Inclusive(v2))
975 | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
976 -1 => Ordering::Less,
977 0 => Ordering::Equal,
978 _ => Ordering::Greater,
979 },
980 (Bound::Inclusive(v1), Bound::Exclusive(v2))
981 | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
982 -1 => Ordering::Less,
983 0 => {
984 if matches!(a, Bound::Inclusive(_)) {
985 Ordering::Less
986 } else {
987 Ordering::Greater
988 }
989 }
990 _ => Ordering::Greater,
991 },
992 }
993}
994
995fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
996 match (&r1.1, &r2.0) {
997 (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
998 (Bound::Inclusive(v1), Bound::Inclusive(v2))
999 | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1000 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1001 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1002 }
1003}
1004
1005fn min_bound(a: &Bound, b: &Bound) -> Bound {
1006 match (a, b) {
1007 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1008 _ => {
1009 if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1010 a.clone()
1011 } else {
1012 b.clone()
1013 }
1014 }
1015 }
1016}
1017
1018fn max_bound(a: &Bound, b: &Bound) -> Bound {
1019 match (a, b) {
1020 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1021 _ => {
1022 if matches!(compare_bounds(a, b), Ordering::Greater) {
1023 a.clone()
1024 } else {
1025 b.clone()
1026 }
1027 }
1028 }
1029}
1030
1031#[cfg(test)]
1032mod tests {
1033 use super::*;
1034 use rust_decimal::Decimal;
1035
1036 fn num(n: i64) -> LiteralValue {
1037 LiteralValue::number(Decimal::from(n))
1038 }
1039
1040 fn fact(name: &str) -> FactPath {
1041 FactPath::local(name.to_string())
1042 }
1043
1044 #[test]
1045 fn test_normalize_double_complement() {
1046 let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1047 let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1048 let normalized = normalize_domain(double);
1049 assert_eq!(normalized, inner);
1050 }
1051
1052 #[test]
1053 fn test_normalize_union_absorbs_unconstrained() {
1054 let union = Domain::Union(Arc::new(vec![
1055 Domain::Range {
1056 min: Bound::Inclusive(Arc::new(num(0))),
1057 max: Bound::Inclusive(Arc::new(num(10))),
1058 },
1059 Domain::Unconstrained,
1060 ]));
1061 let normalized = normalize_domain(union);
1062 assert_eq!(normalized, Domain::Unconstrained);
1063 }
1064
1065 #[test]
1066 fn test_domain_display() {
1067 let range = Domain::Range {
1068 min: Bound::Inclusive(Arc::new(num(10))),
1069 max: Bound::Exclusive(Arc::new(num(20))),
1070 };
1071 assert_eq!(format!("{}", range), "[10, 20)");
1072
1073 let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1074 assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1075 }
1076
1077 #[test]
1078 fn test_extract_domain_from_comparison() {
1079 let constraint = Constraint::Comparison {
1080 fact: fact("age"),
1081 op: ComparisonComputation::GreaterThan,
1082 value: Arc::new(num(18)),
1083 };
1084
1085 let domains = extract_domains_from_constraint(&constraint).unwrap();
1086 let age_domain = domains.get(&fact("age")).unwrap();
1087
1088 assert_eq!(
1089 *age_domain,
1090 Domain::Range {
1091 min: Bound::Exclusive(Arc::new(num(18))),
1092 max: Bound::Unbounded,
1093 }
1094 );
1095 }
1096
1097 #[test]
1098 fn test_extract_domain_from_and() {
1099 let constraint = Constraint::And(
1100 Box::new(Constraint::Comparison {
1101 fact: fact("age"),
1102 op: ComparisonComputation::GreaterThan,
1103 value: Arc::new(num(18)),
1104 }),
1105 Box::new(Constraint::Comparison {
1106 fact: fact("age"),
1107 op: ComparisonComputation::LessThan,
1108 value: Arc::new(num(65)),
1109 }),
1110 );
1111
1112 let domains = extract_domains_from_constraint(&constraint).unwrap();
1113 let age_domain = domains.get(&fact("age")).unwrap();
1114
1115 assert_eq!(
1116 *age_domain,
1117 Domain::Range {
1118 min: Bound::Exclusive(Arc::new(num(18))),
1119 max: Bound::Exclusive(Arc::new(num(65))),
1120 }
1121 );
1122 }
1123
1124 #[test]
1125 fn test_extract_domain_from_equality() {
1126 let constraint = Constraint::Comparison {
1127 fact: fact("status"),
1128 op: ComparisonComputation::Equal,
1129 value: Arc::new(LiteralValue::text("active".to_string())),
1130 };
1131
1132 let domains = extract_domains_from_constraint(&constraint).unwrap();
1133 let status_domain = domains.get(&fact("status")).unwrap();
1134
1135 assert_eq!(
1136 *status_domain,
1137 Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1138 );
1139 }
1140
1141 #[test]
1142 fn test_extract_domain_from_boolean_fact() {
1143 let constraint = Constraint::Fact(fact("is_active"));
1144
1145 let domains = extract_domains_from_constraint(&constraint).unwrap();
1146 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1147
1148 assert_eq!(
1149 *is_active_domain,
1150 Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::True)]))
1151 );
1152 }
1153
1154 #[test]
1155 fn test_extract_domain_from_not_boolean_fact() {
1156 let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1157
1158 let domains = extract_domains_from_constraint(&constraint).unwrap();
1159 let is_active_domain = domains.get(&fact("is_active")).unwrap();
1160
1161 assert_eq!(
1162 *is_active_domain,
1163 Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::False)]))
1164 );
1165 }
1166}