1use crate::computation::units::UnitResolutionContext;
9use crate::planning::semantics::{
10 ComparisonComputation, DataPath, LiteralValue, SemanticConversionTarget, ValueKind,
11};
12use crate::OperationResult;
13use serde::ser::{Serialize, SerializeStruct, Serializer};
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use super::constraint::Constraint;
20
21#[derive(Debug, Clone, PartialEq)]
23pub enum Domain {
24 Range { min: Bound, max: Bound },
26
27 Union(Arc<Vec<Domain>>),
29
30 Enumeration(Arc<Vec<LiteralValue>>),
32
33 Complement(Box<Domain>),
35
36 Unconstrained,
38
39 Empty,
41}
42
43impl Domain {
44 pub fn is_satisfiable(&self) -> bool {
48 match self {
49 Domain::Empty => false,
50 Domain::Enumeration(values) => !values.is_empty(),
51 Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
52 Domain::Range { min, max } => !bounds_contradict(min, max),
53 Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
54 Domain::Unconstrained => true,
55 }
56 }
57
58 pub fn is_empty(&self) -> bool {
60 !self.is_satisfiable()
61 }
62
63 pub fn intersect(&self, other: &Domain) -> Domain {
66 match domain_intersection(self.clone(), other.clone()) {
67 Some(d) => d,
68 None => Domain::Empty,
69 }
70 }
71
72 pub fn contains(&self, value: &LiteralValue) -> bool {
74 match self {
75 Domain::Empty => false,
76 Domain::Unconstrained => true,
77 Domain::Enumeration(values) => values.contains(value),
78 Domain::Range { min, max } => value_within(value, min, max),
79 Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
80 Domain::Complement(inner) => !inner.contains(value),
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq)]
87pub enum Bound {
88 Inclusive(Arc<LiteralValue>),
90
91 Exclusive(Arc<LiteralValue>),
93
94 Unbounded,
96}
97
98impl fmt::Display for Domain {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 Domain::Empty => write!(f, "empty"),
102 Domain::Unconstrained => write!(f, "any"),
103 Domain::Enumeration(vals) => {
104 write!(f, "{{")?;
105 for (i, v) in vals.iter().enumerate() {
106 if i > 0 {
107 write!(f, ", ")?;
108 }
109 write!(f, "{}", v)?;
110 }
111 write!(f, "}}")
112 }
113 Domain::Range { min, max } => {
114 let (l_bracket, r_bracket) = match (min, max) {
115 (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
116 (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
117 (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
118 (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
119 (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
120 (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
121 (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
122 (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
123 (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
124 };
125
126 let min_str = match min {
127 Bound::Unbounded => "-inf".to_string(),
128 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
129 };
130 let max_str = match max {
131 Bound::Unbounded => "+inf".to_string(),
132 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
133 };
134 write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
135 }
136 Domain::Union(parts) => {
137 for (i, p) in parts.iter().enumerate() {
138 if i > 0 {
139 write!(f, " | ")?;
140 }
141 write!(f, "{}", p)?;
142 }
143 Ok(())
144 }
145 Domain::Complement(inner) => write!(f, "not ({})", inner),
146 }
147 }
148}
149
150impl fmt::Display for Bound {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 match self {
153 Bound::Unbounded => write!(f, "inf"),
154 Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
155 Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
156 }
157 }
158}
159
160impl Serialize for Domain {
161 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
162 where
163 S: Serializer,
164 {
165 match self {
166 Domain::Empty => {
167 let mut st = serializer.serialize_struct("domain", 1)?;
168 st.serialize_field("type", "empty")?;
169 st.end()
170 }
171 Domain::Unconstrained => {
172 let mut st = serializer.serialize_struct("domain", 1)?;
173 st.serialize_field("type", "unconstrained")?;
174 st.end()
175 }
176 Domain::Enumeration(vals) => {
177 let mut st = serializer.serialize_struct("domain", 2)?;
178 st.serialize_field("type", "enumeration")?;
179 st.serialize_field("values", vals)?;
180 st.end()
181 }
182 Domain::Range { min, max } => {
183 let mut st = serializer.serialize_struct("domain", 3)?;
184 st.serialize_field("type", "range")?;
185 st.serialize_field("min", min)?;
186 st.serialize_field("max", max)?;
187 st.end()
188 }
189 Domain::Union(parts) => {
190 let mut st = serializer.serialize_struct("domain", 2)?;
191 st.serialize_field("type", "union")?;
192 st.serialize_field("parts", parts)?;
193 st.end()
194 }
195 Domain::Complement(inner) => {
196 let mut st = serializer.serialize_struct("domain", 2)?;
197 st.serialize_field("type", "complement")?;
198 st.serialize_field("inner", inner)?;
199 st.end()
200 }
201 }
202 }
203}
204
205impl Serialize for Bound {
206 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
207 where
208 S: Serializer,
209 {
210 match self {
211 Bound::Unbounded => {
212 let mut st = serializer.serialize_struct("bound", 1)?;
213 st.serialize_field("type", "unbounded")?;
214 st.end()
215 }
216 Bound::Inclusive(v) => {
217 let mut st = serializer.serialize_struct("bound", 2)?;
218 st.serialize_field("type", "inclusive")?;
219 st.serialize_field("value", v.as_ref())?;
220 st.end()
221 }
222 Bound::Exclusive(v) => {
223 let mut st = serializer.serialize_struct("bound", 2)?;
224 st.serialize_field("type", "exclusive")?;
225 st.serialize_field("value", v.as_ref())?;
226 st.end()
227 }
228 }
229 }
230}
231
232pub fn extract_domains_from_constraint(
234 constraint: &Constraint,
235) -> Result<HashMap<DataPath, Domain>, crate::Error> {
236 let all_datas = constraint.collect_data();
237 let mut domains = HashMap::new();
238
239 for data_path in all_datas {
240 let domain =
244 extract_domain_for_data(constraint, &data_path)?.unwrap_or(Domain::Unconstrained);
245 domains.insert(data_path, domain);
246 }
247
248 Ok(domains)
249}
250
251fn extract_domain_for_data(
252 constraint: &Constraint,
253 data_path: &DataPath,
254) -> Result<Option<Domain>, crate::Error> {
255 let domain = match constraint {
256 Constraint::True => return Ok(None),
257 Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
258
259 Constraint::Comparison { data, op, value } => {
260 if data == data_path {
261 Some(comparison_to_domain(op, value.as_ref())?)
262 } else {
263 None
264 }
265 }
266
267 Constraint::Data(fp) => {
268 if fp == data_path {
269 Some(Domain::Enumeration(Arc::new(vec![
270 LiteralValue::from_bool(true),
271 ])))
272 } else {
273 None
274 }
275 }
276
277 Constraint::And(left, right) => {
278 let left_domain = extract_domain_for_data(left, data_path)?;
279 let right_domain = extract_domain_for_data(right, data_path)?;
280 match (left_domain, right_domain) {
281 (None, None) => None,
282 (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
283 (Some(a), Some(b)) => match domain_intersection(a, b) {
284 Some(domain) => Some(domain),
285 None => Some(Domain::Enumeration(Arc::new(vec![]))),
286 },
287 }
288 }
289
290 Constraint::Or(left, right) => {
291 let left_domain = extract_domain_for_data(left, data_path)?;
292 let right_domain = extract_domain_for_data(right, data_path)?;
293 union_optional_domains(left_domain, right_domain)
294 }
295
296 Constraint::Not(inner) => {
297 if let Constraint::Comparison { data, op, value } = inner.as_ref() {
299 if data == data_path && op.is_equal() {
300 return Ok(Some(normalize_domain(Domain::Complement(Box::new(
301 Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
302 )))));
303 }
304 }
305
306 if let Constraint::Data(fp) = inner.as_ref() {
308 if fp == data_path {
309 return Ok(Some(Domain::Enumeration(Arc::new(vec![
310 LiteralValue::from_bool(false),
311 ]))));
312 }
313 }
314
315 extract_domain_for_data(inner, data_path)?
316 .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
317 }
318 };
319
320 Ok(domain.map(normalize_domain))
321}
322
323fn comparison_to_domain(
324 op: &ComparisonComputation,
325 value: &LiteralValue,
326) -> Result<Domain, crate::Error> {
327 if op.is_equal() {
328 return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
329 }
330 if op.is_not_equal() {
331 return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
332 vec![value.clone()],
333 )))));
334 }
335 match op {
336 ComparisonComputation::LessThan => Ok(Domain::Range {
337 min: Bound::Unbounded,
338 max: Bound::Exclusive(Arc::new(value.clone())),
339 }),
340 ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
341 min: Bound::Unbounded,
342 max: Bound::Inclusive(Arc::new(value.clone())),
343 }),
344 ComparisonComputation::GreaterThan => Ok(Domain::Range {
345 min: Bound::Exclusive(Arc::new(value.clone())),
346 max: Bound::Unbounded,
347 }),
348 ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
349 min: Bound::Inclusive(Arc::new(value.clone())),
350 max: Bound::Unbounded,
351 }),
352 _ => unreachable!(
353 "BUG: unsupported comparison operator for domain extraction: {:?}",
354 op
355 ),
356 }
357}
358
359pub(crate) fn domain_for_comparison_atom(
364 op: &ComparisonComputation,
365 value: &LiteralValue,
366) -> Result<Domain, crate::Error> {
367 comparison_to_domain(op, value)
368}
369
370impl Domain {
371 pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
378 match (self, other) {
379 (Domain::Empty, _) => true,
380 (_, Domain::Unconstrained) => true,
381 (Domain::Unconstrained, _) => false,
382
383 (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
384 (Domain::Enumeration(vals), Domain::Range { min, max }) => {
385 vals.iter().all(|v| value_within(v, min, max))
386 }
387
388 (
389 Domain::Range {
390 min: amin,
391 max: amax,
392 },
393 Domain::Range {
394 min: bmin,
395 max: bmax,
396 },
397 ) => range_within_range(amin, amax, bmin, bmax),
398
399 (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
401 Domain::Enumeration(excluded) => {
402 excluded.iter().all(|p| !value_within(p, min, max))
403 }
404 _ => false,
405 },
406
407 (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
409 Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
410 _ => false,
411 },
412
413 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
415 match (a_inner.as_ref(), b_inner.as_ref()) {
416 (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
417 excluded_b.iter().all(|v| excluded_a.contains(v))
418 }
419 _ => false,
420 }
421 }
422
423 _ => false,
424 }
425 }
426}
427
428fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
429 lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
430}
431
432fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
433 match (a, b) {
434 (_, Bound::Unbounded) => true,
435 (Bound::Unbounded, _) => false,
436 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
437 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
438 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
439 let c = lit_cmp(av.as_ref(), bv.as_ref());
440 c >= 0
441 }
442 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
443 lit_cmp(av.as_ref(), bv.as_ref()) > 0
445 }
446 }
447}
448
449fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
450 match (a, b) {
451 (Bound::Unbounded, Bound::Unbounded) => true,
452 (_, Bound::Unbounded) => true,
453 (Bound::Unbounded, _) => false,
454 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
455 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
456 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
457 lit_cmp(av.as_ref(), bv.as_ref()) <= 0
459 }
460 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
461 lit_cmp(av.as_ref(), bv.as_ref()) < 0
463 }
464 }
465}
466
467fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
468 match (a, b) {
469 (None, None) => None,
470 (Some(d), None) | (None, Some(d)) => Some(d),
471 (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
472 }
473}
474
475fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
476 use std::cmp::Ordering;
477
478 match (&a.value, &b.value) {
479 (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
480 Ordering::Less => -1,
481 Ordering::Equal => 0,
482 Ordering::Greater => 1,
483 },
484
485 (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
486 Ordering::Less => -1,
487 Ordering::Equal => 0,
488 Ordering::Greater => 1,
489 },
490
491 (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
492 Ordering::Less => -1,
493 Ordering::Equal => 0,
494 Ordering::Greater => 1,
495 },
496
497 (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
498 Ordering::Less => -1,
499 Ordering::Equal => 0,
500 Ordering::Greater => 1,
501 },
502
503 (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
504 Ordering::Less => -1,
505 Ordering::Equal => 0,
506 Ordering::Greater => 1,
507 },
508
509 (ValueKind::Calendar(la, lu), ValueKind::Calendar(lb, lu2)) => {
510 let a_months = crate::computation::units::convert_calendar_magnitude(
511 *la,
512 lu,
513 &crate::planning::semantics::SemanticCalendarUnit::Month,
514 );
515 let b_months = crate::computation::units::convert_calendar_magnitude(
516 *lb,
517 lu2,
518 &crate::planning::semantics::SemanticCalendarUnit::Month,
519 );
520 match a_months.cmp(&b_months) {
521 Ordering::Less => -1,
522 Ordering::Equal => 0,
523 Ordering::Greater => 1,
524 }
525 }
526
527 (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
528 Ordering::Less => -1,
529 Ordering::Equal => 0,
530 Ordering::Greater => 1,
531 },
532
533 (ValueKind::Quantity(la, lua, _), ValueKind::Quantity(lb, lub, _)) => {
534 if a.lemma_type != b.lemma_type {
535 unreachable!(
536 "BUG: lit_cmp compared different quantity types ({} vs {})",
537 a.lemma_type.name(),
538 b.lemma_type.name()
539 );
540 }
541
542 if lua.eq_ignore_ascii_case(lub) {
543 return match la.cmp(lb) {
544 Ordering::Less => -1,
545 Ordering::Equal => 0,
546 Ordering::Greater => 1,
547 };
548 }
549
550 let target = SemanticConversionTarget::QuantityUnit(lua.clone());
552 let converted = crate::computation::convert_unit(
553 b,
554 &target,
555 UnitResolutionContext::NamedQuantityOnly,
556 );
557 let converted_value = match converted {
558 OperationResult::Value(lit) => match lit.value {
559 ValueKind::Quantity(v, _, _) => v,
560 _ => unreachable!("BUG: quantity unit conversion returned non-quantity value"),
561 },
562 OperationResult::Veto(reason) => {
563 unreachable!(
564 "BUG: quantity unit conversion vetoed unexpectedly: {:?}",
565 reason
566 )
567 }
568 };
569
570 match la.cmp(&converted_value) {
571 Ordering::Less => -1,
572 Ordering::Equal => 0,
573 Ordering::Greater => 1,
574 }
575 }
576
577 _ => unreachable!(
578 "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
579 a.get_type(),
580 b.get_type()
581 ),
582 }
583}
584
585fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
586 let ge_min = match min {
587 Bound::Unbounded => true,
588 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
589 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
590 };
591 let le_max = match max {
592 Bound::Unbounded => true,
593 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
594 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
595 };
596 ge_min && le_max
597}
598
599fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
600 match (min, max) {
601 (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
602 (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
603 (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
604 (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
605 (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
606 }
607}
608
609fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
610 match (min1, min2) {
611 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
612 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
613 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
614 Bound::Inclusive(v1)
615 } else {
616 Bound::Inclusive(v2)
617 }
618 }
619 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
620 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
621 Bound::Inclusive(v1)
622 } else {
623 Bound::Exclusive(v2)
624 }
625 }
626 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
627 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
628 Bound::Exclusive(v1)
629 } else {
630 Bound::Inclusive(v2)
631 }
632 }
633 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
634 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
635 Bound::Exclusive(v1)
636 } else {
637 Bound::Exclusive(v2)
638 }
639 }
640 }
641}
642
643fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
644 match (max1, max2) {
645 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
646 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
647 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
648 Bound::Inclusive(v1)
649 } else {
650 Bound::Inclusive(v2)
651 }
652 }
653 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
654 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
655 Bound::Inclusive(v1)
656 } else {
657 Bound::Exclusive(v2)
658 }
659 }
660 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
661 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
662 Bound::Exclusive(v1)
663 } else {
664 Bound::Inclusive(v2)
665 }
666 }
667 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
668 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
669 Bound::Exclusive(v1)
670 } else {
671 Bound::Exclusive(v2)
672 }
673 }
674 }
675}
676
677fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
678 let a = normalize_domain(a);
679 let b = normalize_domain(b);
680
681 let result = match (a, b) {
682 (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
683 (Domain::Empty, _) | (_, Domain::Empty) => None,
684
685 (
686 Domain::Range {
687 min: min1,
688 max: max1,
689 },
690 Domain::Range {
691 min: min2,
692 max: max2,
693 },
694 ) => {
695 let min = compute_intersection_min(min1, min2);
696 let max = compute_intersection_max(max1, max2);
697
698 if bounds_contradict(&min, &max) {
699 None
700 } else {
701 Some(Domain::Range { min, max })
702 }
703 }
704 (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
705 let filtered: Vec<LiteralValue> =
706 v1.iter().filter(|x| v2.contains(x)).cloned().collect();
707 if filtered.is_empty() {
708 None
709 } else {
710 Some(Domain::Enumeration(Arc::new(filtered)))
711 }
712 }
713 (Domain::Enumeration(vs), Domain::Range { min, max })
714 | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
715 let mut kept = Vec::new();
716 for v in vs.iter() {
717 if value_within(v, &min, &max) {
718 kept.push(v.clone());
719 }
720 }
721 if kept.is_empty() {
722 None
723 } else {
724 Some(Domain::Enumeration(Arc::new(kept)))
725 }
726 }
727 (Domain::Enumeration(vs), Domain::Complement(inner))
728 | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
729 match *inner.clone() {
730 Domain::Enumeration(excluded) => {
731 let mut kept = Vec::new();
732 for v in vs.iter() {
733 if !excluded.contains(v) {
734 kept.push(v.clone());
735 }
736 }
737 if kept.is_empty() {
738 None
739 } else {
740 Some(Domain::Enumeration(Arc::new(kept)))
741 }
742 }
743 Domain::Range { min, max } => {
744 let mut kept = Vec::new();
746 for v in vs.iter() {
747 if !value_within(v, &min, &max) {
748 kept.push(v.clone());
749 }
750 }
751 if kept.is_empty() {
752 None
753 } else {
754 Some(Domain::Enumeration(Arc::new(kept)))
755 }
756 }
757 _ => {
758 let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
760 domain_intersection(Domain::Enumeration(vs.clone()), normalized)
761 }
762 }
763 }
764 (Domain::Union(v1), Domain::Union(v2)) => {
765 let mut acc: Vec<Domain> = Vec::new();
766 for a in v1.iter() {
767 for b in v2.iter() {
768 if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
769 acc.push(ix);
770 }
771 }
772 }
773 if acc.is_empty() {
774 None
775 } else {
776 Some(Domain::Union(Arc::new(acc)))
777 }
778 }
779 (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
780 let mut acc: Vec<Domain> = Vec::new();
781 for a in vs.iter() {
782 if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
783 acc.push(ix);
784 }
785 }
786 if acc.is_empty() {
787 None
788 } else if acc.len() == 1 {
789 Some(acc.remove(0))
790 } else {
791 Some(Domain::Union(Arc::new(acc)))
792 }
793 }
794 (Domain::Range { min, max }, Domain::Complement(inner))
796 | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
797 Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
798 _ => {
799 let normalized_complement = normalize_domain(Domain::Complement(inner));
802 if matches!(&normalized_complement, Domain::Complement(_)) {
803 None
804 } else {
805 domain_intersection(Domain::Range { min, max }, normalized_complement)
806 }
807 }
808 },
809 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
810 match (a_inner.as_ref(), b_inner.as_ref()) {
811 (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
812 let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
814 excluded.extend(b_ex.iter().cloned());
815 Some(normalize_domain(Domain::Complement(Box::new(
816 Domain::Enumeration(Arc::new(excluded)),
817 ))))
818 }
819 _ => None,
820 }
821 }
822 };
823 result.map(normalize_domain)
824}
825
826fn range_minus_excluded_points(
827 min: Bound,
828 max: Bound,
829 excluded: &Arc<Vec<LiteralValue>>,
830) -> Option<Domain> {
831 let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
833
834 for p in excluded.iter() {
835 let mut next: Vec<(Bound, Bound)> = Vec::new();
836
837 for (rmin, rmax) in parts {
838 if !value_within(p, &rmin, &rmax) {
839 next.push((rmin, rmax));
840 continue;
841 }
842
843 let left_max = Bound::Exclusive(Arc::new(p.clone()));
845 if !bounds_contradict(&rmin, &left_max) {
846 next.push((rmin.clone(), left_max));
847 }
848
849 let right_min = Bound::Exclusive(Arc::new(p.clone()));
851 if !bounds_contradict(&right_min, &rmax) {
852 next.push((right_min, rmax.clone()));
853 }
854 }
855
856 parts = next;
857 if parts.is_empty() {
858 return None;
859 }
860 }
861
862 if parts.is_empty() {
863 None
864 } else if parts.len() == 1 {
865 let (min, max) = parts.remove(0);
866 Some(Domain::Range { min, max })
867 } else {
868 Some(Domain::Union(Arc::new(
869 parts
870 .into_iter()
871 .map(|(min, max)| Domain::Range { min, max })
872 .collect(),
873 )))
874 }
875}
876
877fn invert_bound(bound: Bound) -> Bound {
878 match bound {
879 Bound::Unbounded => Bound::Unbounded,
880 Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
881 Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
882 }
883}
884
885fn normalize_domain(d: Domain) -> Domain {
886 match d {
887 Domain::Complement(inner) => {
888 let normalized_inner = normalize_domain(*inner);
889 match normalized_inner {
890 Domain::Complement(double_inner) => *double_inner,
891 Domain::Range { min, max } => match (&min, &max) {
892 (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
893 (Bound::Unbounded, max) => Domain::Range {
894 min: invert_bound(max.clone()),
895 max: Bound::Unbounded,
896 },
897 (min, Bound::Unbounded) => Domain::Range {
898 min: Bound::Unbounded,
899 max: invert_bound(min.clone()),
900 },
901 (min, max) => Domain::Union(Arc::new(vec![
902 Domain::Range {
903 min: Bound::Unbounded,
904 max: invert_bound(min.clone()),
905 },
906 Domain::Range {
907 min: invert_bound(max.clone()),
908 max: Bound::Unbounded,
909 },
910 ])),
911 },
912 Domain::Enumeration(vals) => {
913 if vals.len() == 1 {
914 if let Some(lit) = vals.first() {
915 if let ValueKind::Boolean(true) = &lit.value {
916 return Domain::Enumeration(Arc::new(vec![
917 LiteralValue::from_bool(false),
918 ]));
919 }
920 if let ValueKind::Boolean(false) = &lit.value {
921 return Domain::Enumeration(Arc::new(vec![
922 LiteralValue::from_bool(true),
923 ]));
924 }
925 }
926 }
927 Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
928 }
929 Domain::Unconstrained => Domain::Empty,
930 Domain::Empty => Domain::Unconstrained,
931 Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
932 }
933 }
934 Domain::Empty => Domain::Empty,
935 Domain::Union(parts) => {
936 let mut flat: Vec<Domain> = Vec::new();
937 for p in parts.iter().cloned() {
938 let normalized = normalize_domain(p);
939 match normalized {
940 Domain::Union(inner) => flat.extend(inner.iter().cloned()),
941 Domain::Unconstrained => return Domain::Unconstrained,
942 Domain::Enumeration(vals) if vals.is_empty() => {}
943 other => flat.push(other),
944 }
945 }
946
947 let mut all_enum_values: Vec<LiteralValue> = Vec::new();
948 let mut ranges: Vec<Domain> = Vec::new();
949 let mut others: Vec<Domain> = Vec::new();
950
951 for domain in flat {
952 match domain {
953 Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
954 Domain::Range { .. } => ranges.push(domain),
955 other => others.push(other),
956 }
957 }
958
959 all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
960 -1 => Ordering::Less,
961 0 => Ordering::Equal,
962 _ => Ordering::Greater,
963 });
964 all_enum_values.dedup();
965
966 all_enum_values.retain(|v| {
967 !ranges.iter().any(|r| {
968 if let Domain::Range { min, max } = r {
969 value_within(v, min, max)
970 } else {
971 false
972 }
973 })
974 });
975
976 let mut result: Vec<Domain> = Vec::new();
977 result.extend(ranges);
978 result = merge_ranges(result);
979
980 if !all_enum_values.is_empty() {
981 result.push(Domain::Enumeration(Arc::new(all_enum_values)));
982 }
983 result.extend(others);
984
985 result.sort_by(|a, b| match (a, b) {
986 (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
987 (Domain::Range { .. }, _) => Ordering::Less,
988 (_, Domain::Range { .. }) => Ordering::Greater,
989 (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
990 (Domain::Enumeration(_), _) => Ordering::Less,
991 (_, Domain::Enumeration(_)) => Ordering::Greater,
992 _ => Ordering::Equal,
993 });
994
995 if result.is_empty() {
996 Domain::Enumeration(Arc::new(vec![]))
997 } else if result.len() == 1 {
998 result.remove(0)
999 } else {
1000 Domain::Union(Arc::new(result))
1001 }
1002 }
1003 Domain::Enumeration(values) => {
1004 let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
1005 sorted.sort_by(|a, b| match lit_cmp(a, b) {
1006 -1 => Ordering::Less,
1007 0 => Ordering::Equal,
1008 _ => Ordering::Greater,
1009 });
1010 sorted.dedup();
1011 Domain::Enumeration(Arc::new(sorted))
1012 }
1013 other => other,
1014 }
1015}
1016
1017fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1018 let mut result = Vec::new();
1019 let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1020 let mut others = Vec::new();
1021
1022 for d in domains {
1023 match d {
1024 Domain::Range { min, max } => ranges.push((min, max)),
1025 other => others.push(other),
1026 }
1027 }
1028
1029 if ranges.is_empty() {
1030 return others;
1031 }
1032
1033 ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1034
1035 let mut merged: Vec<(Bound, Bound)> = Vec::new();
1036 let mut current = ranges[0].clone();
1037
1038 for next in ranges.iter().skip(1) {
1039 if ranges_adjacent_or_overlap(¤t, next) {
1040 current = (
1041 min_bound(¤t.0, &next.0),
1042 max_bound(¤t.1, &next.1),
1043 );
1044 } else {
1045 merged.push(current);
1046 current = next.clone();
1047 }
1048 }
1049 merged.push(current);
1050
1051 for (min, max) in merged {
1052 result.push(Domain::Range { min, max });
1053 }
1054 result.extend(others);
1055
1056 result
1057}
1058
1059fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1060 match (a, b) {
1061 (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1062 (Bound::Unbounded, _) => Ordering::Less,
1063 (_, Bound::Unbounded) => Ordering::Greater,
1064 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1065 | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1066 -1 => Ordering::Less,
1067 0 => Ordering::Equal,
1068 _ => Ordering::Greater,
1069 },
1070 (Bound::Inclusive(v1), Bound::Exclusive(v2))
1071 | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1072 -1 => Ordering::Less,
1073 0 => {
1074 if matches!(a, Bound::Inclusive(_)) {
1075 Ordering::Less
1076 } else {
1077 Ordering::Greater
1078 }
1079 }
1080 _ => Ordering::Greater,
1081 },
1082 }
1083}
1084
1085fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1086 match (&r1.1, &r2.0) {
1087 (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1088 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1089 | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1090 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1091 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1092 }
1093}
1094
1095fn min_bound(a: &Bound, b: &Bound) -> Bound {
1096 match (a, b) {
1097 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1098 _ => {
1099 if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1100 a.clone()
1101 } else {
1102 b.clone()
1103 }
1104 }
1105 }
1106}
1107
1108fn max_bound(a: &Bound, b: &Bound) -> Bound {
1109 match (a, b) {
1110 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1111 _ => {
1112 if matches!(compare_bounds(a, b), Ordering::Greater) {
1113 a.clone()
1114 } else {
1115 b.clone()
1116 }
1117 }
1118 }
1119}
1120
1121#[cfg(test)]
1122mod tests {
1123 use super::*;
1124 fn num(n: i64) -> LiteralValue {
1125 LiteralValue::number(crate::computation::rational::RationalInteger::new(
1126 n as i128, 1,
1127 ))
1128 }
1129
1130 fn data(name: &str) -> DataPath {
1131 DataPath::new(vec![], name.to_string())
1132 }
1133
1134 #[test]
1135 fn test_normalize_double_complement() {
1136 let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1137 let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1138 let normalized = normalize_domain(double);
1139 assert_eq!(normalized, inner);
1140 }
1141
1142 #[test]
1143 fn test_normalize_union_absorbs_unconstrained() {
1144 let union = Domain::Union(Arc::new(vec![
1145 Domain::Range {
1146 min: Bound::Inclusive(Arc::new(num(0))),
1147 max: Bound::Inclusive(Arc::new(num(10))),
1148 },
1149 Domain::Unconstrained,
1150 ]));
1151 let normalized = normalize_domain(union);
1152 assert_eq!(normalized, Domain::Unconstrained);
1153 }
1154
1155 #[test]
1156 fn test_domain_display() {
1157 let range = Domain::Range {
1158 min: Bound::Inclusive(Arc::new(num(10))),
1159 max: Bound::Exclusive(Arc::new(num(20))),
1160 };
1161 assert_eq!(format!("{}", range), "[10, 20)");
1162
1163 let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1164 assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1165 }
1166
1167 #[test]
1168 fn test_extract_domain_from_comparison() {
1169 let constraint = Constraint::Comparison {
1170 data: data("age"),
1171 op: ComparisonComputation::GreaterThan,
1172 value: Arc::new(num(18)),
1173 };
1174
1175 let domains = extract_domains_from_constraint(&constraint).unwrap();
1176 let age_domain = domains.get(&data("age")).unwrap();
1177
1178 assert_eq!(
1179 *age_domain,
1180 Domain::Range {
1181 min: Bound::Exclusive(Arc::new(num(18))),
1182 max: Bound::Unbounded,
1183 }
1184 );
1185 }
1186
1187 #[test]
1188 fn test_extract_domain_from_and() {
1189 let constraint = Constraint::And(
1190 Box::new(Constraint::Comparison {
1191 data: data("age"),
1192 op: ComparisonComputation::GreaterThan,
1193 value: Arc::new(num(18)),
1194 }),
1195 Box::new(Constraint::Comparison {
1196 data: data("age"),
1197 op: ComparisonComputation::LessThan,
1198 value: Arc::new(num(65)),
1199 }),
1200 );
1201
1202 let domains = extract_domains_from_constraint(&constraint).unwrap();
1203 let age_domain = domains.get(&data("age")).unwrap();
1204
1205 assert_eq!(
1206 *age_domain,
1207 Domain::Range {
1208 min: Bound::Exclusive(Arc::new(num(18))),
1209 max: Bound::Exclusive(Arc::new(num(65))),
1210 }
1211 );
1212 }
1213
1214 #[test]
1215 fn test_extract_domain_from_equality() {
1216 let constraint = Constraint::Comparison {
1217 data: data("status"),
1218 op: ComparisonComputation::Is,
1219 value: Arc::new(LiteralValue::text("active".to_string())),
1220 };
1221
1222 let domains = extract_domains_from_constraint(&constraint).unwrap();
1223 let status_domain = domains.get(&data("status")).unwrap();
1224
1225 assert_eq!(
1226 *status_domain,
1227 Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1228 );
1229 }
1230
1231 #[test]
1232 fn test_extract_domain_from_boolean_data() {
1233 let constraint = Constraint::Data(data("is_active"));
1234
1235 let domains = extract_domains_from_constraint(&constraint).unwrap();
1236 let is_active_domain = domains.get(&data("is_active")).unwrap();
1237
1238 assert_eq!(
1239 *is_active_domain,
1240 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1241 );
1242 }
1243
1244 #[test]
1245 fn test_extract_domain_from_not_boolean_data() {
1246 let constraint = Constraint::Not(Box::new(Constraint::Data(data("is_active"))));
1247
1248 let domains = extract_domains_from_constraint(&constraint).unwrap();
1249 let is_active_domain = domains.get(&data("is_active")).unwrap();
1250
1251 assert_eq!(
1252 *is_active_domain,
1253 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1254 );
1255 }
1256}