1use crate::planning::semantics::{
9 ComparisonComputation, DataPath, LiteralValue, SemanticConversionTarget, ValueKind,
10};
11use crate::OperationResult;
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23 Range { min: Bound, max: Bound },
25
26 Union(Arc<Vec<Domain>>),
28
29 Enumeration(Arc<Vec<LiteralValue>>),
31
32 Complement(Box<Domain>),
34
35 Unconstrained,
37
38 Empty,
40}
41
42impl Domain {
43 pub fn is_satisfiable(&self) -> bool {
47 match self {
48 Domain::Empty => false,
49 Domain::Enumeration(values) => !values.is_empty(),
50 Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51 Domain::Range { min, max } => !bounds_contradict(min, max),
52 Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53 Domain::Unconstrained => true,
54 }
55 }
56
57 pub fn is_empty(&self) -> bool {
59 !self.is_satisfiable()
60 }
61
62 pub fn intersect(&self, other: &Domain) -> Domain {
65 match domain_intersection(self.clone(), other.clone()) {
66 Some(d) => d,
67 None => Domain::Empty,
68 }
69 }
70
71 pub fn contains(&self, value: &LiteralValue) -> bool {
73 match self {
74 Domain::Empty => false,
75 Domain::Unconstrained => true,
76 Domain::Enumeration(values) => values.contains(value),
77 Domain::Range { min, max } => value_within(value, min, max),
78 Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
79 Domain::Complement(inner) => !inner.contains(value),
80 }
81 }
82}
83
84#[derive(Debug, Clone, PartialEq)]
86pub enum Bound {
87 Inclusive(Arc<LiteralValue>),
89
90 Exclusive(Arc<LiteralValue>),
92
93 Unbounded,
95}
96
97impl fmt::Display for Domain {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 match self {
100 Domain::Empty => write!(f, "empty"),
101 Domain::Unconstrained => write!(f, "any"),
102 Domain::Enumeration(vals) => {
103 write!(f, "{{")?;
104 for (i, v) in vals.iter().enumerate() {
105 if i > 0 {
106 write!(f, ", ")?;
107 }
108 write!(f, "{}", v)?;
109 }
110 write!(f, "}}")
111 }
112 Domain::Range { min, max } => {
113 let (l_bracket, r_bracket) = match (min, max) {
114 (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
115 (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
116 (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
117 (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
118 (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
119 (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
120 (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
121 (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
122 (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
123 };
124
125 let min_str = match min {
126 Bound::Unbounded => "-inf".to_string(),
127 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128 };
129 let max_str = match max {
130 Bound::Unbounded => "+inf".to_string(),
131 Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
132 };
133 write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
134 }
135 Domain::Union(parts) => {
136 for (i, p) in parts.iter().enumerate() {
137 if i > 0 {
138 write!(f, " | ")?;
139 }
140 write!(f, "{}", p)?;
141 }
142 Ok(())
143 }
144 Domain::Complement(inner) => write!(f, "not ({})", inner),
145 }
146 }
147}
148
149impl fmt::Display for Bound {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 match self {
152 Bound::Unbounded => write!(f, "inf"),
153 Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
154 Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
155 }
156 }
157}
158
159impl Serialize for Domain {
160 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161 where
162 S: Serializer,
163 {
164 match self {
165 Domain::Empty => {
166 let mut st = serializer.serialize_struct("domain", 1)?;
167 st.serialize_field("type", "empty")?;
168 st.end()
169 }
170 Domain::Unconstrained => {
171 let mut st = serializer.serialize_struct("domain", 1)?;
172 st.serialize_field("type", "unconstrained")?;
173 st.end()
174 }
175 Domain::Enumeration(vals) => {
176 let mut st = serializer.serialize_struct("domain", 2)?;
177 st.serialize_field("type", "enumeration")?;
178 st.serialize_field("values", vals)?;
179 st.end()
180 }
181 Domain::Range { min, max } => {
182 let mut st = serializer.serialize_struct("domain", 3)?;
183 st.serialize_field("type", "range")?;
184 st.serialize_field("min", min)?;
185 st.serialize_field("max", max)?;
186 st.end()
187 }
188 Domain::Union(parts) => {
189 let mut st = serializer.serialize_struct("domain", 2)?;
190 st.serialize_field("type", "union")?;
191 st.serialize_field("parts", parts)?;
192 st.end()
193 }
194 Domain::Complement(inner) => {
195 let mut st = serializer.serialize_struct("domain", 2)?;
196 st.serialize_field("type", "complement")?;
197 st.serialize_field("inner", inner)?;
198 st.end()
199 }
200 }
201 }
202}
203
204impl Serialize for Bound {
205 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206 where
207 S: Serializer,
208 {
209 match self {
210 Bound::Unbounded => {
211 let mut st = serializer.serialize_struct("bound", 1)?;
212 st.serialize_field("type", "unbounded")?;
213 st.end()
214 }
215 Bound::Inclusive(v) => {
216 let mut st = serializer.serialize_struct("bound", 2)?;
217 st.serialize_field("type", "inclusive")?;
218 st.serialize_field("value", v.as_ref())?;
219 st.end()
220 }
221 Bound::Exclusive(v) => {
222 let mut st = serializer.serialize_struct("bound", 2)?;
223 st.serialize_field("type", "exclusive")?;
224 st.serialize_field("value", v.as_ref())?;
225 st.end()
226 }
227 }
228 }
229}
230
231pub fn extract_domains_from_constraint(
233 constraint: &Constraint,
234) -> Result<HashMap<DataPath, Domain>, crate::Error> {
235 let all_datas = constraint.collect_data();
236 let mut domains = HashMap::new();
237
238 for data_path in all_datas {
239 let domain =
243 extract_domain_for_data(constraint, &data_path)?.unwrap_or(Domain::Unconstrained);
244 domains.insert(data_path, domain);
245 }
246
247 Ok(domains)
248}
249
250fn extract_domain_for_data(
251 constraint: &Constraint,
252 data_path: &DataPath,
253) -> Result<Option<Domain>, crate::Error> {
254 let domain = match constraint {
255 Constraint::True => return Ok(None),
256 Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
257
258 Constraint::Comparison { data, op, value } => {
259 if data == data_path {
260 Some(comparison_to_domain(op, value.as_ref())?)
261 } else {
262 None
263 }
264 }
265
266 Constraint::Data(fp) => {
267 if fp == data_path {
268 Some(Domain::Enumeration(Arc::new(vec![
269 LiteralValue::from_bool(true),
270 ])))
271 } else {
272 None
273 }
274 }
275
276 Constraint::And(left, right) => {
277 let left_domain = extract_domain_for_data(left, data_path)?;
278 let right_domain = extract_domain_for_data(right, data_path)?;
279 match (left_domain, right_domain) {
280 (None, None) => None,
281 (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
282 (Some(a), Some(b)) => match domain_intersection(a, b) {
283 Some(domain) => Some(domain),
284 None => Some(Domain::Enumeration(Arc::new(vec![]))),
285 },
286 }
287 }
288
289 Constraint::Or(left, right) => {
290 let left_domain = extract_domain_for_data(left, data_path)?;
291 let right_domain = extract_domain_for_data(right, data_path)?;
292 union_optional_domains(left_domain, right_domain)
293 }
294
295 Constraint::Not(inner) => {
296 if let Constraint::Comparison { data, op, value } = inner.as_ref() {
298 if data == data_path && op.is_equal() {
299 return Ok(Some(normalize_domain(Domain::Complement(Box::new(
300 Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
301 )))));
302 }
303 }
304
305 if let Constraint::Data(fp) = inner.as_ref() {
307 if fp == data_path {
308 return Ok(Some(Domain::Enumeration(Arc::new(vec![
309 LiteralValue::from_bool(false),
310 ]))));
311 }
312 }
313
314 extract_domain_for_data(inner, data_path)?
315 .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
316 }
317 };
318
319 Ok(domain.map(normalize_domain))
320}
321
322fn comparison_to_domain(
323 op: &ComparisonComputation,
324 value: &LiteralValue,
325) -> Result<Domain, crate::Error> {
326 if op.is_equal() {
327 return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
328 }
329 if op.is_not_equal() {
330 return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
331 vec![value.clone()],
332 )))));
333 }
334 match op {
335 ComparisonComputation::LessThan => Ok(Domain::Range {
336 min: Bound::Unbounded,
337 max: Bound::Exclusive(Arc::new(value.clone())),
338 }),
339 ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
340 min: Bound::Unbounded,
341 max: Bound::Inclusive(Arc::new(value.clone())),
342 }),
343 ComparisonComputation::GreaterThan => Ok(Domain::Range {
344 min: Bound::Exclusive(Arc::new(value.clone())),
345 max: Bound::Unbounded,
346 }),
347 ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
348 min: Bound::Inclusive(Arc::new(value.clone())),
349 max: Bound::Unbounded,
350 }),
351 _ => unreachable!(
352 "BUG: unsupported comparison operator for domain extraction: {:?}",
353 op
354 ),
355 }
356}
357
358pub(crate) fn domain_for_comparison_atom(
363 op: &ComparisonComputation,
364 value: &LiteralValue,
365) -> Result<Domain, crate::Error> {
366 comparison_to_domain(op, value)
367}
368
369impl Domain {
370 pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
377 match (self, other) {
378 (Domain::Empty, _) => true,
379 (_, Domain::Unconstrained) => true,
380 (Domain::Unconstrained, _) => false,
381
382 (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
383 (Domain::Enumeration(vals), Domain::Range { min, max }) => {
384 vals.iter().all(|v| value_within(v, min, max))
385 }
386
387 (
388 Domain::Range {
389 min: amin,
390 max: amax,
391 },
392 Domain::Range {
393 min: bmin,
394 max: bmax,
395 },
396 ) => range_within_range(amin, amax, bmin, bmax),
397
398 (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
400 Domain::Enumeration(excluded) => {
401 excluded.iter().all(|p| !value_within(p, min, max))
402 }
403 _ => false,
404 },
405
406 (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
408 Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
409 _ => false,
410 },
411
412 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
414 match (a_inner.as_ref(), b_inner.as_ref()) {
415 (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
416 excluded_b.iter().all(|v| excluded_a.contains(v))
417 }
418 _ => false,
419 }
420 }
421
422 _ => false,
423 }
424 }
425}
426
427fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
428 lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
429}
430
431fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
432 match (a, b) {
433 (_, Bound::Unbounded) => true,
434 (Bound::Unbounded, _) => false,
435 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
436 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
437 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
438 let c = lit_cmp(av.as_ref(), bv.as_ref());
439 c >= 0
440 }
441 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
442 lit_cmp(av.as_ref(), bv.as_ref()) > 0
444 }
445 }
446}
447
448fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
449 match (a, b) {
450 (Bound::Unbounded, Bound::Unbounded) => true,
451 (_, Bound::Unbounded) => true,
452 (Bound::Unbounded, _) => false,
453 (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
454 (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
455 (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
456 lit_cmp(av.as_ref(), bv.as_ref()) <= 0
458 }
459 (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
460 lit_cmp(av.as_ref(), bv.as_ref()) < 0
462 }
463 }
464}
465
466fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
467 match (a, b) {
468 (None, None) => None,
469 (Some(d), None) | (None, Some(d)) => Some(d),
470 (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
471 }
472}
473
474fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
475 use std::cmp::Ordering;
476
477 match (&a.value, &b.value) {
478 (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
479 Ordering::Less => -1,
480 Ordering::Equal => 0,
481 Ordering::Greater => 1,
482 },
483
484 (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
485 Ordering::Less => -1,
486 Ordering::Equal => 0,
487 Ordering::Greater => 1,
488 },
489
490 (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
491 Ordering::Less => -1,
492 Ordering::Equal => 0,
493 Ordering::Greater => 1,
494 },
495
496 (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
497 Ordering::Less => -1,
498 Ordering::Equal => 0,
499 Ordering::Greater => 1,
500 },
501
502 (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
503 Ordering::Less => -1,
504 Ordering::Equal => 0,
505 Ordering::Greater => 1,
506 },
507
508 (ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub)) => {
509 let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
510 let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
511 match a_sec.cmp(&b_sec) {
512 Ordering::Less => -1,
513 Ordering::Equal => 0,
514 Ordering::Greater => 1,
515 }
516 }
517
518 (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
519 Ordering::Less => -1,
520 Ordering::Equal => 0,
521 Ordering::Greater => 1,
522 },
523
524 (ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub)) => {
525 if a.lemma_type != b.lemma_type {
526 unreachable!(
527 "BUG: lit_cmp compared different scale types ({} vs {})",
528 a.lemma_type.name(),
529 b.lemma_type.name()
530 );
531 }
532
533 if lua.eq_ignore_ascii_case(lub) {
534 return match la.cmp(lb) {
535 Ordering::Less => -1,
536 Ordering::Equal => 0,
537 Ordering::Greater => 1,
538 };
539 }
540
541 let target = SemanticConversionTarget::ScaleUnit(lua.clone());
543 let converted = crate::computation::convert_unit(b, &target);
544 let converted_value = match converted {
545 OperationResult::Value(lit) => match lit.value {
546 ValueKind::Scale(v, _) => v,
547 _ => unreachable!("BUG: scale unit conversion returned non-scale value"),
548 },
549 OperationResult::Veto(reason) => {
550 unreachable!(
551 "BUG: scale unit conversion vetoed unexpectedly: {:?}",
552 reason
553 )
554 }
555 };
556
557 match la.cmp(&converted_value) {
558 Ordering::Less => -1,
559 Ordering::Equal => 0,
560 Ordering::Greater => 1,
561 }
562 }
563
564 _ => unreachable!(
565 "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
566 a.get_type(),
567 b.get_type()
568 ),
569 }
570}
571
572fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
573 let ge_min = match min {
574 Bound::Unbounded => true,
575 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
576 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
577 };
578 let le_max = match max {
579 Bound::Unbounded => true,
580 Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
581 Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
582 };
583 ge_min && le_max
584}
585
586fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
587 match (min, max) {
588 (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
589 (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
590 (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
591 (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
592 (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
593 }
594}
595
596fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
597 match (min1, min2) {
598 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
599 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
600 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
601 Bound::Inclusive(v1)
602 } else {
603 Bound::Inclusive(v2)
604 }
605 }
606 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
607 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
608 Bound::Inclusive(v1)
609 } else {
610 Bound::Exclusive(v2)
611 }
612 }
613 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
614 if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
615 Bound::Exclusive(v1)
616 } else {
617 Bound::Inclusive(v2)
618 }
619 }
620 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
621 if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
622 Bound::Exclusive(v1)
623 } else {
624 Bound::Exclusive(v2)
625 }
626 }
627 }
628}
629
630fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
631 match (max1, max2) {
632 (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
633 (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
634 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
635 Bound::Inclusive(v1)
636 } else {
637 Bound::Inclusive(v2)
638 }
639 }
640 (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
641 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
642 Bound::Inclusive(v1)
643 } else {
644 Bound::Exclusive(v2)
645 }
646 }
647 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
648 if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
649 Bound::Exclusive(v1)
650 } else {
651 Bound::Inclusive(v2)
652 }
653 }
654 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
655 if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
656 Bound::Exclusive(v1)
657 } else {
658 Bound::Exclusive(v2)
659 }
660 }
661 }
662}
663
664fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
665 let a = normalize_domain(a);
666 let b = normalize_domain(b);
667
668 let result = match (a, b) {
669 (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
670 (Domain::Empty, _) | (_, Domain::Empty) => None,
671
672 (
673 Domain::Range {
674 min: min1,
675 max: max1,
676 },
677 Domain::Range {
678 min: min2,
679 max: max2,
680 },
681 ) => {
682 let min = compute_intersection_min(min1, min2);
683 let max = compute_intersection_max(max1, max2);
684
685 if bounds_contradict(&min, &max) {
686 None
687 } else {
688 Some(Domain::Range { min, max })
689 }
690 }
691 (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
692 let filtered: Vec<LiteralValue> =
693 v1.iter().filter(|x| v2.contains(x)).cloned().collect();
694 if filtered.is_empty() {
695 None
696 } else {
697 Some(Domain::Enumeration(Arc::new(filtered)))
698 }
699 }
700 (Domain::Enumeration(vs), Domain::Range { min, max })
701 | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
702 let mut kept = Vec::new();
703 for v in vs.iter() {
704 if value_within(v, &min, &max) {
705 kept.push(v.clone());
706 }
707 }
708 if kept.is_empty() {
709 None
710 } else {
711 Some(Domain::Enumeration(Arc::new(kept)))
712 }
713 }
714 (Domain::Enumeration(vs), Domain::Complement(inner))
715 | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
716 match *inner.clone() {
717 Domain::Enumeration(excluded) => {
718 let mut kept = Vec::new();
719 for v in vs.iter() {
720 if !excluded.contains(v) {
721 kept.push(v.clone());
722 }
723 }
724 if kept.is_empty() {
725 None
726 } else {
727 Some(Domain::Enumeration(Arc::new(kept)))
728 }
729 }
730 Domain::Range { min, max } => {
731 let mut kept = Vec::new();
733 for v in vs.iter() {
734 if !value_within(v, &min, &max) {
735 kept.push(v.clone());
736 }
737 }
738 if kept.is_empty() {
739 None
740 } else {
741 Some(Domain::Enumeration(Arc::new(kept)))
742 }
743 }
744 _ => {
745 let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
747 domain_intersection(Domain::Enumeration(vs.clone()), normalized)
748 }
749 }
750 }
751 (Domain::Union(v1), Domain::Union(v2)) => {
752 let mut acc: Vec<Domain> = Vec::new();
753 for a in v1.iter() {
754 for b in v2.iter() {
755 if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
756 acc.push(ix);
757 }
758 }
759 }
760 if acc.is_empty() {
761 None
762 } else {
763 Some(Domain::Union(Arc::new(acc)))
764 }
765 }
766 (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
767 let mut acc: Vec<Domain> = Vec::new();
768 for a in vs.iter() {
769 if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
770 acc.push(ix);
771 }
772 }
773 if acc.is_empty() {
774 None
775 } else if acc.len() == 1 {
776 Some(acc.remove(0))
777 } else {
778 Some(Domain::Union(Arc::new(acc)))
779 }
780 }
781 (Domain::Range { min, max }, Domain::Complement(inner))
783 | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
784 Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
785 _ => {
786 let normalized_complement = normalize_domain(Domain::Complement(inner));
789 if matches!(&normalized_complement, Domain::Complement(_)) {
790 None
791 } else {
792 domain_intersection(Domain::Range { min, max }, normalized_complement)
793 }
794 }
795 },
796 (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
797 match (a_inner.as_ref(), b_inner.as_ref()) {
798 (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
799 let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
801 excluded.extend(b_ex.iter().cloned());
802 Some(normalize_domain(Domain::Complement(Box::new(
803 Domain::Enumeration(Arc::new(excluded)),
804 ))))
805 }
806 _ => None,
807 }
808 }
809 };
810 result.map(normalize_domain)
811}
812
813fn range_minus_excluded_points(
814 min: Bound,
815 max: Bound,
816 excluded: &Arc<Vec<LiteralValue>>,
817) -> Option<Domain> {
818 let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
820
821 for p in excluded.iter() {
822 let mut next: Vec<(Bound, Bound)> = Vec::new();
823
824 for (rmin, rmax) in parts {
825 if !value_within(p, &rmin, &rmax) {
826 next.push((rmin, rmax));
827 continue;
828 }
829
830 let left_max = Bound::Exclusive(Arc::new(p.clone()));
832 if !bounds_contradict(&rmin, &left_max) {
833 next.push((rmin.clone(), left_max));
834 }
835
836 let right_min = Bound::Exclusive(Arc::new(p.clone()));
838 if !bounds_contradict(&right_min, &rmax) {
839 next.push((right_min, rmax.clone()));
840 }
841 }
842
843 parts = next;
844 if parts.is_empty() {
845 return None;
846 }
847 }
848
849 if parts.is_empty() {
850 None
851 } else if parts.len() == 1 {
852 let (min, max) = parts.remove(0);
853 Some(Domain::Range { min, max })
854 } else {
855 Some(Domain::Union(Arc::new(
856 parts
857 .into_iter()
858 .map(|(min, max)| Domain::Range { min, max })
859 .collect(),
860 )))
861 }
862}
863
864fn invert_bound(bound: Bound) -> Bound {
865 match bound {
866 Bound::Unbounded => Bound::Unbounded,
867 Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
868 Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
869 }
870}
871
872fn normalize_domain(d: Domain) -> Domain {
873 match d {
874 Domain::Complement(inner) => {
875 let normalized_inner = normalize_domain(*inner);
876 match normalized_inner {
877 Domain::Complement(double_inner) => *double_inner,
878 Domain::Range { min, max } => match (&min, &max) {
879 (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
880 (Bound::Unbounded, max) => Domain::Range {
881 min: invert_bound(max.clone()),
882 max: Bound::Unbounded,
883 },
884 (min, Bound::Unbounded) => Domain::Range {
885 min: Bound::Unbounded,
886 max: invert_bound(min.clone()),
887 },
888 (min, max) => Domain::Union(Arc::new(vec![
889 Domain::Range {
890 min: Bound::Unbounded,
891 max: invert_bound(min.clone()),
892 },
893 Domain::Range {
894 min: invert_bound(max.clone()),
895 max: Bound::Unbounded,
896 },
897 ])),
898 },
899 Domain::Enumeration(vals) => {
900 if vals.len() == 1 {
901 if let Some(lit) = vals.first() {
902 if let ValueKind::Boolean(true) = &lit.value {
903 return Domain::Enumeration(Arc::new(vec![
904 LiteralValue::from_bool(false),
905 ]));
906 }
907 if let ValueKind::Boolean(false) = &lit.value {
908 return Domain::Enumeration(Arc::new(vec![
909 LiteralValue::from_bool(true),
910 ]));
911 }
912 }
913 }
914 Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
915 }
916 Domain::Unconstrained => Domain::Empty,
917 Domain::Empty => Domain::Unconstrained,
918 Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
919 }
920 }
921 Domain::Empty => Domain::Empty,
922 Domain::Union(parts) => {
923 let mut flat: Vec<Domain> = Vec::new();
924 for p in parts.iter().cloned() {
925 let normalized = normalize_domain(p);
926 match normalized {
927 Domain::Union(inner) => flat.extend(inner.iter().cloned()),
928 Domain::Unconstrained => return Domain::Unconstrained,
929 Domain::Enumeration(vals) if vals.is_empty() => {}
930 other => flat.push(other),
931 }
932 }
933
934 let mut all_enum_values: Vec<LiteralValue> = Vec::new();
935 let mut ranges: Vec<Domain> = Vec::new();
936 let mut others: Vec<Domain> = Vec::new();
937
938 for domain in flat {
939 match domain {
940 Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
941 Domain::Range { .. } => ranges.push(domain),
942 other => others.push(other),
943 }
944 }
945
946 all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
947 -1 => Ordering::Less,
948 0 => Ordering::Equal,
949 _ => Ordering::Greater,
950 });
951 all_enum_values.dedup();
952
953 all_enum_values.retain(|v| {
954 !ranges.iter().any(|r| {
955 if let Domain::Range { min, max } = r {
956 value_within(v, min, max)
957 } else {
958 false
959 }
960 })
961 });
962
963 let mut result: Vec<Domain> = Vec::new();
964 result.extend(ranges);
965 result = merge_ranges(result);
966
967 if !all_enum_values.is_empty() {
968 result.push(Domain::Enumeration(Arc::new(all_enum_values)));
969 }
970 result.extend(others);
971
972 result.sort_by(|a, b| match (a, b) {
973 (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
974 (Domain::Range { .. }, _) => Ordering::Less,
975 (_, Domain::Range { .. }) => Ordering::Greater,
976 (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
977 (Domain::Enumeration(_), _) => Ordering::Less,
978 (_, Domain::Enumeration(_)) => Ordering::Greater,
979 _ => Ordering::Equal,
980 });
981
982 if result.is_empty() {
983 Domain::Enumeration(Arc::new(vec![]))
984 } else if result.len() == 1 {
985 result.remove(0)
986 } else {
987 Domain::Union(Arc::new(result))
988 }
989 }
990 Domain::Enumeration(values) => {
991 let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
992 sorted.sort_by(|a, b| match lit_cmp(a, b) {
993 -1 => Ordering::Less,
994 0 => Ordering::Equal,
995 _ => Ordering::Greater,
996 });
997 sorted.dedup();
998 Domain::Enumeration(Arc::new(sorted))
999 }
1000 other => other,
1001 }
1002}
1003
1004fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1005 let mut result = Vec::new();
1006 let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1007 let mut others = Vec::new();
1008
1009 for d in domains {
1010 match d {
1011 Domain::Range { min, max } => ranges.push((min, max)),
1012 other => others.push(other),
1013 }
1014 }
1015
1016 if ranges.is_empty() {
1017 return others;
1018 }
1019
1020 ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1021
1022 let mut merged: Vec<(Bound, Bound)> = Vec::new();
1023 let mut current = ranges[0].clone();
1024
1025 for next in ranges.iter().skip(1) {
1026 if ranges_adjacent_or_overlap(¤t, next) {
1027 current = (
1028 min_bound(¤t.0, &next.0),
1029 max_bound(¤t.1, &next.1),
1030 );
1031 } else {
1032 merged.push(current);
1033 current = next.clone();
1034 }
1035 }
1036 merged.push(current);
1037
1038 for (min, max) in merged {
1039 result.push(Domain::Range { min, max });
1040 }
1041 result.extend(others);
1042
1043 result
1044}
1045
1046fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1047 match (a, b) {
1048 (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1049 (Bound::Unbounded, _) => Ordering::Less,
1050 (_, Bound::Unbounded) => Ordering::Greater,
1051 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1052 | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1053 -1 => Ordering::Less,
1054 0 => Ordering::Equal,
1055 _ => Ordering::Greater,
1056 },
1057 (Bound::Inclusive(v1), Bound::Exclusive(v2))
1058 | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1059 -1 => Ordering::Less,
1060 0 => {
1061 if matches!(a, Bound::Inclusive(_)) {
1062 Ordering::Less
1063 } else {
1064 Ordering::Greater
1065 }
1066 }
1067 _ => Ordering::Greater,
1068 },
1069 }
1070}
1071
1072fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1073 match (&r1.1, &r2.0) {
1074 (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1075 (Bound::Inclusive(v1), Bound::Inclusive(v2))
1076 | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1077 (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1078 (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1079 }
1080}
1081
1082fn min_bound(a: &Bound, b: &Bound) -> Bound {
1083 match (a, b) {
1084 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1085 _ => {
1086 if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1087 a.clone()
1088 } else {
1089 b.clone()
1090 }
1091 }
1092 }
1093}
1094
1095fn max_bound(a: &Bound, b: &Bound) -> Bound {
1096 match (a, b) {
1097 (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1098 _ => {
1099 if matches!(compare_bounds(a, b), Ordering::Greater) {
1100 a.clone()
1101 } else {
1102 b.clone()
1103 }
1104 }
1105 }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110 use super::*;
1111 use rust_decimal::Decimal;
1112
1113 fn num(n: i64) -> LiteralValue {
1114 LiteralValue::number(Decimal::from(n))
1115 }
1116
1117 fn data(name: &str) -> DataPath {
1118 DataPath::new(vec![], name.to_string())
1119 }
1120
1121 #[test]
1122 fn test_normalize_double_complement() {
1123 let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1124 let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1125 let normalized = normalize_domain(double);
1126 assert_eq!(normalized, inner);
1127 }
1128
1129 #[test]
1130 fn test_normalize_union_absorbs_unconstrained() {
1131 let union = Domain::Union(Arc::new(vec![
1132 Domain::Range {
1133 min: Bound::Inclusive(Arc::new(num(0))),
1134 max: Bound::Inclusive(Arc::new(num(10))),
1135 },
1136 Domain::Unconstrained,
1137 ]));
1138 let normalized = normalize_domain(union);
1139 assert_eq!(normalized, Domain::Unconstrained);
1140 }
1141
1142 #[test]
1143 fn test_domain_display() {
1144 let range = Domain::Range {
1145 min: Bound::Inclusive(Arc::new(num(10))),
1146 max: Bound::Exclusive(Arc::new(num(20))),
1147 };
1148 assert_eq!(format!("{}", range), "[10, 20)");
1149
1150 let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1151 assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1152 }
1153
1154 #[test]
1155 fn test_extract_domain_from_comparison() {
1156 let constraint = Constraint::Comparison {
1157 data: data("age"),
1158 op: ComparisonComputation::GreaterThan,
1159 value: Arc::new(num(18)),
1160 };
1161
1162 let domains = extract_domains_from_constraint(&constraint).unwrap();
1163 let age_domain = domains.get(&data("age")).unwrap();
1164
1165 assert_eq!(
1166 *age_domain,
1167 Domain::Range {
1168 min: Bound::Exclusive(Arc::new(num(18))),
1169 max: Bound::Unbounded,
1170 }
1171 );
1172 }
1173
1174 #[test]
1175 fn test_extract_domain_from_and() {
1176 let constraint = Constraint::And(
1177 Box::new(Constraint::Comparison {
1178 data: data("age"),
1179 op: ComparisonComputation::GreaterThan,
1180 value: Arc::new(num(18)),
1181 }),
1182 Box::new(Constraint::Comparison {
1183 data: data("age"),
1184 op: ComparisonComputation::LessThan,
1185 value: Arc::new(num(65)),
1186 }),
1187 );
1188
1189 let domains = extract_domains_from_constraint(&constraint).unwrap();
1190 let age_domain = domains.get(&data("age")).unwrap();
1191
1192 assert_eq!(
1193 *age_domain,
1194 Domain::Range {
1195 min: Bound::Exclusive(Arc::new(num(18))),
1196 max: Bound::Exclusive(Arc::new(num(65))),
1197 }
1198 );
1199 }
1200
1201 #[test]
1202 fn test_extract_domain_from_equality() {
1203 let constraint = Constraint::Comparison {
1204 data: data("status"),
1205 op: ComparisonComputation::Is,
1206 value: Arc::new(LiteralValue::text("active".to_string())),
1207 };
1208
1209 let domains = extract_domains_from_constraint(&constraint).unwrap();
1210 let status_domain = domains.get(&data("status")).unwrap();
1211
1212 assert_eq!(
1213 *status_domain,
1214 Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1215 );
1216 }
1217
1218 #[test]
1219 fn test_extract_domain_from_boolean_data() {
1220 let constraint = Constraint::Data(data("is_active"));
1221
1222 let domains = extract_domains_from_constraint(&constraint).unwrap();
1223 let is_active_domain = domains.get(&data("is_active")).unwrap();
1224
1225 assert_eq!(
1226 *is_active_domain,
1227 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1228 );
1229 }
1230
1231 #[test]
1232 fn test_extract_domain_from_not_boolean_data() {
1233 let constraint = Constraint::Not(Box::new(Constraint::Data(data("is_active"))));
1234
1235 let domains = extract_domains_from_constraint(&constraint).unwrap();
1236 let is_active_domain = domains.get(&data("is_active")).unwrap();
1237
1238 assert_eq!(
1239 *is_active_domain,
1240 Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1241 );
1242 }
1243}