Skip to main content

lemma/inversion/
domain.rs

1//! Domain types and operations for inversion
2//!
3//! Provides:
4//! - `Domain` and `Bound` types for representing concrete value constraints
5//! - Domain operations: intersection, union, normalization
6//! - `extract_domains_from_constraint()`: extracts domains from constraints
7
8use crate::computation::units::UnitResolutionContext;
9use crate::planning::semantics::{
10    ComparisonComputation, DataPath, LiteralValue, SemanticConversionTarget, ValueKind,
11};
12use crate::OperationResult;
13use serde::ser::{Serialize, SerializeStruct, Serializer};
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use super::constraint::Constraint;
20
21/// Domain specification for valid values
22#[derive(Debug, Clone, PartialEq)]
23pub enum Domain {
24    /// A single continuous range
25    Range { min: Bound, max: Bound },
26
27    /// Multiple disjoint ranges
28    Union(Arc<Vec<Domain>>),
29
30    /// Specific enumerated values only
31    Enumeration(Arc<Vec<LiteralValue>>),
32
33    /// Everything except these constraints
34    Complement(Box<Domain>),
35
36    /// Any value (no constraints)
37    Unconstrained,
38
39    /// Empty domain (no valid values) - represents unsatisfiable constraints
40    Empty,
41}
42
43impl Domain {
44    /// Check if this domain is satisfiable (has at least one valid value)
45    ///
46    /// Returns false for Empty domains and empty Enumerations.
47    pub fn is_satisfiable(&self) -> bool {
48        match self {
49            Domain::Empty => false,
50            Domain::Enumeration(values) => !values.is_empty(),
51            Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
52            Domain::Range { min, max } => !bounds_contradict(min, max),
53            Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
54            Domain::Unconstrained => true,
55        }
56    }
57
58    /// Check if this domain is empty (unsatisfiable)
59    pub fn is_empty(&self) -> bool {
60        !self.is_satisfiable()
61    }
62
63    /// Intersect this domain with another, returning Empty if no overlap.
64    /// `domain_intersection` returns `None` exactly when the result is empty.
65    pub fn intersect(&self, other: &Domain) -> Domain {
66        match domain_intersection(self.clone(), other.clone()) {
67            Some(d) => d,
68            None => Domain::Empty,
69        }
70    }
71
72    /// Check if a value is contained in this domain
73    pub fn contains(&self, value: &LiteralValue) -> bool {
74        match self {
75            Domain::Empty => false,
76            Domain::Unconstrained => true,
77            Domain::Enumeration(values) => values.contains(value),
78            Domain::Range { min, max } => value_within(value, min, max),
79            Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
80            Domain::Complement(inner) => !inner.contains(value),
81        }
82    }
83}
84
85/// Bound specification for ranges
86#[derive(Debug, Clone, PartialEq)]
87pub enum Bound {
88    /// Inclusive bound [value
89    Inclusive(Arc<LiteralValue>),
90
91    /// Exclusive bound (value
92    Exclusive(Arc<LiteralValue>),
93
94    /// Unbounded (-infinity or +infinity)
95    Unbounded,
96}
97
98impl fmt::Display for Domain {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Domain::Empty => write!(f, "empty"),
102            Domain::Unconstrained => write!(f, "any"),
103            Domain::Enumeration(vals) => {
104                write!(f, "{{")?;
105                for (i, v) in vals.iter().enumerate() {
106                    if i > 0 {
107                        write!(f, ", ")?;
108                    }
109                    write!(f, "{}", v)?;
110                }
111                write!(f, "}}")
112            }
113            Domain::Range { min, max } => {
114                let (l_bracket, r_bracket) = match (min, max) {
115                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
116                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
117                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
118                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
119                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
120                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
121                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
122                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
123                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
124                };
125
126                let min_str = match min {
127                    Bound::Unbounded => "-inf".to_string(),
128                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
129                };
130                let max_str = match max {
131                    Bound::Unbounded => "+inf".to_string(),
132                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
133                };
134                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
135            }
136            Domain::Union(parts) => {
137                for (i, p) in parts.iter().enumerate() {
138                    if i > 0 {
139                        write!(f, " | ")?;
140                    }
141                    write!(f, "{}", p)?;
142                }
143                Ok(())
144            }
145            Domain::Complement(inner) => write!(f, "not ({})", inner),
146        }
147    }
148}
149
150impl fmt::Display for Bound {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        match self {
153            Bound::Unbounded => write!(f, "inf"),
154            Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
155            Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
156        }
157    }
158}
159
160impl Serialize for Domain {
161    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
162    where
163        S: Serializer,
164    {
165        match self {
166            Domain::Empty => {
167                let mut st = serializer.serialize_struct("domain", 1)?;
168                st.serialize_field("type", "empty")?;
169                st.end()
170            }
171            Domain::Unconstrained => {
172                let mut st = serializer.serialize_struct("domain", 1)?;
173                st.serialize_field("type", "unconstrained")?;
174                st.end()
175            }
176            Domain::Enumeration(vals) => {
177                let mut st = serializer.serialize_struct("domain", 2)?;
178                st.serialize_field("type", "enumeration")?;
179                st.serialize_field("values", vals)?;
180                st.end()
181            }
182            Domain::Range { min, max } => {
183                let mut st = serializer.serialize_struct("domain", 3)?;
184                st.serialize_field("type", "range")?;
185                st.serialize_field("min", min)?;
186                st.serialize_field("max", max)?;
187                st.end()
188            }
189            Domain::Union(parts) => {
190                let mut st = serializer.serialize_struct("domain", 2)?;
191                st.serialize_field("type", "union")?;
192                st.serialize_field("parts", parts)?;
193                st.end()
194            }
195            Domain::Complement(inner) => {
196                let mut st = serializer.serialize_struct("domain", 2)?;
197                st.serialize_field("type", "complement")?;
198                st.serialize_field("inner", inner)?;
199                st.end()
200            }
201        }
202    }
203}
204
205impl Serialize for Bound {
206    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
207    where
208        S: Serializer,
209    {
210        match self {
211            Bound::Unbounded => {
212                let mut st = serializer.serialize_struct("bound", 1)?;
213                st.serialize_field("type", "unbounded")?;
214                st.end()
215            }
216            Bound::Inclusive(v) => {
217                let mut st = serializer.serialize_struct("bound", 2)?;
218                st.serialize_field("type", "inclusive")?;
219                st.serialize_field("value", v.as_ref())?;
220                st.end()
221            }
222            Bound::Exclusive(v) => {
223                let mut st = serializer.serialize_struct("bound", 2)?;
224                st.serialize_field("type", "exclusive")?;
225                st.serialize_field("value", v.as_ref())?;
226                st.end()
227            }
228        }
229    }
230}
231
232/// Extract domains for all data mentioned in a constraint
233pub fn extract_domains_from_constraint(
234    constraint: &Constraint,
235) -> Result<HashMap<DataPath, Domain>, crate::Error> {
236    let all_datas = constraint.collect_data();
237    let mut domains = HashMap::new();
238
239    for data_path in all_datas {
240        // None means the data appears in the constraint but has no extractable
241        // bound (e.g. only used in equality with another data). Treating it as
242        // Unconstrained is correct: the solver will enumerate values.
243        let domain =
244            extract_domain_for_data(constraint, &data_path)?.unwrap_or(Domain::Unconstrained);
245        domains.insert(data_path, domain);
246    }
247
248    Ok(domains)
249}
250
251fn extract_domain_for_data(
252    constraint: &Constraint,
253    data_path: &DataPath,
254) -> Result<Option<Domain>, crate::Error> {
255    let domain = match constraint {
256        Constraint::True => return Ok(None),
257        Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
258
259        Constraint::Comparison { data, op, value } => {
260            if data == data_path {
261                Some(comparison_to_domain(op, value.as_ref())?)
262            } else {
263                None
264            }
265        }
266
267        Constraint::Data(fp) => {
268            if fp == data_path {
269                Some(Domain::Enumeration(Arc::new(vec![
270                    LiteralValue::from_bool(true),
271                ])))
272            } else {
273                None
274            }
275        }
276
277        Constraint::And(left, right) => {
278            let left_domain = extract_domain_for_data(left, data_path)?;
279            let right_domain = extract_domain_for_data(right, data_path)?;
280            match (left_domain, right_domain) {
281                (None, None) => None,
282                (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
283                (Some(a), Some(b)) => match domain_intersection(a, b) {
284                    Some(domain) => Some(domain),
285                    None => Some(Domain::Enumeration(Arc::new(vec![]))),
286                },
287            }
288        }
289
290        Constraint::Or(left, right) => {
291            let left_domain = extract_domain_for_data(left, data_path)?;
292            let right_domain = extract_domain_for_data(right, data_path)?;
293            union_optional_domains(left_domain, right_domain)
294        }
295
296        Constraint::Not(inner) => {
297            // Handle not (data is value)
298            if let Constraint::Comparison { data, op, value } = inner.as_ref() {
299                if data == data_path && op.is_equal() {
300                    return Ok(Some(normalize_domain(Domain::Complement(Box::new(
301                        Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
302                    )))));
303                }
304            }
305
306            // Handle not (boolean_data)
307            if let Constraint::Data(fp) = inner.as_ref() {
308                if fp == data_path {
309                    return Ok(Some(Domain::Enumeration(Arc::new(vec![
310                        LiteralValue::from_bool(false),
311                    ]))));
312                }
313            }
314
315            extract_domain_for_data(inner, data_path)?
316                .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
317        }
318    };
319
320    Ok(domain.map(normalize_domain))
321}
322
323fn comparison_to_domain(
324    op: &ComparisonComputation,
325    value: &LiteralValue,
326) -> Result<Domain, crate::Error> {
327    if op.is_equal() {
328        return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
329    }
330    if op.is_not_equal() {
331        return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
332            vec![value.clone()],
333        )))));
334    }
335    match op {
336        ComparisonComputation::LessThan => Ok(Domain::Range {
337            min: Bound::Unbounded,
338            max: Bound::Exclusive(Arc::new(value.clone())),
339        }),
340        ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
341            min: Bound::Unbounded,
342            max: Bound::Inclusive(Arc::new(value.clone())),
343        }),
344        ComparisonComputation::GreaterThan => Ok(Domain::Range {
345            min: Bound::Exclusive(Arc::new(value.clone())),
346            max: Bound::Unbounded,
347        }),
348        ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
349            min: Bound::Inclusive(Arc::new(value.clone())),
350            max: Bound::Unbounded,
351        }),
352        _ => unreachable!(
353            "BUG: unsupported comparison operator for domain extraction: {:?}",
354            op
355        ),
356    }
357}
358
359/// Compute the domain for a single comparison-atom used by inversion constraints.
360///
361/// This is used by numeric-aware constraint simplification to derive implications/exclusions
362/// between comparison atoms on the same data.
363pub(crate) fn domain_for_comparison_atom(
364    op: &ComparisonComputation,
365    value: &LiteralValue,
366) -> Result<Domain, crate::Error> {
367    comparison_to_domain(op, value)
368}
369
370impl Domain {
371    /// Proven subset check for the atom-domain forms we generate from comparisons:
372    /// - Range
373    /// - Enumeration
374    /// - Complement(Enumeration) (used for `is not`)
375    ///
376    /// Returns false when the relationship cannot be proven with these forms.
377    pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
378        match (self, other) {
379            (Domain::Empty, _) => true,
380            (_, Domain::Unconstrained) => true,
381            (Domain::Unconstrained, _) => false,
382
383            (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
384            (Domain::Enumeration(vals), Domain::Range { min, max }) => {
385                vals.iter().all(|v| value_within(v, min, max))
386            }
387
388            (
389                Domain::Range {
390                    min: amin,
391                    max: amax,
392                },
393                Domain::Range {
394                    min: bmin,
395                    max: bmax,
396                },
397            ) => range_within_range(amin, amax, bmin, bmax),
398
399            // Range ⊆ not({p}) when the range does not include p (for all excluded points)
400            (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
401                Domain::Enumeration(excluded) => {
402                    excluded.iter().all(|p| !value_within(p, min, max))
403                }
404                _ => false,
405            },
406
407            // {v} ⊆ not({p}) when v is not excluded
408            (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
409                Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
410                _ => false,
411            },
412
413            // not(A) ⊆ not(B)  iff  B ⊆ A  (for enumeration complements)
414            (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
415                match (a_inner.as_ref(), b_inner.as_ref()) {
416                    (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
417                        excluded_b.iter().all(|v| excluded_a.contains(v))
418                    }
419                    _ => false,
420                }
421            }
422
423            _ => false,
424        }
425    }
426}
427
428fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
429    lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
430}
431
432fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
433    match (a, b) {
434        (_, Bound::Unbounded) => true,
435        (Bound::Unbounded, _) => false,
436        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
437        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
438        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
439            let c = lit_cmp(av.as_ref(), bv.as_ref());
440            c >= 0
441        }
442        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
443            // a >= (b) only if a's value > b's value
444            lit_cmp(av.as_ref(), bv.as_ref()) > 0
445        }
446    }
447}
448
449fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
450    match (a, b) {
451        (Bound::Unbounded, Bound::Unbounded) => true,
452        (_, Bound::Unbounded) => true,
453        (Bound::Unbounded, _) => false,
454        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
455        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
456        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
457            // (a) <= [b] when a <= b
458            lit_cmp(av.as_ref(), bv.as_ref()) <= 0
459        }
460        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
461            // [a] <= (b) only if a < b
462            lit_cmp(av.as_ref(), bv.as_ref()) < 0
463        }
464    }
465}
466
467fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
468    match (a, b) {
469        (None, None) => None,
470        (Some(d), None) | (None, Some(d)) => Some(d),
471        (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
472    }
473}
474
475fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
476    use std::cmp::Ordering;
477
478    match (&a.value, &b.value) {
479        (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
480            Ordering::Less => -1,
481            Ordering::Equal => 0,
482            Ordering::Greater => 1,
483        },
484
485        (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
486            Ordering::Less => -1,
487            Ordering::Equal => 0,
488            Ordering::Greater => 1,
489        },
490
491        (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
492            Ordering::Less => -1,
493            Ordering::Equal => 0,
494            Ordering::Greater => 1,
495        },
496
497        (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
498            Ordering::Less => -1,
499            Ordering::Equal => 0,
500            Ordering::Greater => 1,
501        },
502
503        (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
504            Ordering::Less => -1,
505            Ordering::Equal => 0,
506            Ordering::Greater => 1,
507        },
508
509        (ValueKind::Calendar(la, lu), ValueKind::Calendar(lb, lu2)) => {
510            let a_months = crate::computation::units::convert_calendar_magnitude(
511                *la,
512                lu,
513                &crate::planning::semantics::SemanticCalendarUnit::Month,
514            );
515            let b_months = crate::computation::units::convert_calendar_magnitude(
516                *lb,
517                lu2,
518                &crate::planning::semantics::SemanticCalendarUnit::Month,
519            );
520            match a_months.cmp(&b_months) {
521                Ordering::Less => -1,
522                Ordering::Equal => 0,
523                Ordering::Greater => 1,
524            }
525        }
526
527        (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
528            Ordering::Less => -1,
529            Ordering::Equal => 0,
530            Ordering::Greater => 1,
531        },
532
533        (ValueKind::Quantity(la, lua, _), ValueKind::Quantity(lb, lub, _)) => {
534            if a.lemma_type != b.lemma_type {
535                unreachable!(
536                    "BUG: lit_cmp compared different quantity types ({} vs {})",
537                    a.lemma_type.name(),
538                    b.lemma_type.name()
539                );
540            }
541
542            if lua.eq_ignore_ascii_case(lub) {
543                return match la.cmp(lb) {
544                    Ordering::Less => -1,
545                    Ordering::Equal => 0,
546                    Ordering::Greater => 1,
547                };
548            }
549
550            // Convert b to a's unit for comparison
551            let target = SemanticConversionTarget::QuantityUnit(lua.clone());
552            let converted = crate::computation::convert_unit(
553                b,
554                &target,
555                UnitResolutionContext::NamedQuantityOnly,
556            );
557            let converted_value = match converted {
558                OperationResult::Value(lit) => match lit.value {
559                    ValueKind::Quantity(v, _, _) => v,
560                    _ => unreachable!("BUG: quantity unit conversion returned non-quantity value"),
561                },
562                OperationResult::Veto(reason) => {
563                    unreachable!(
564                        "BUG: quantity unit conversion vetoed unexpectedly: {:?}",
565                        reason
566                    )
567                }
568            };
569
570            match la.cmp(&converted_value) {
571                Ordering::Less => -1,
572                Ordering::Equal => 0,
573                Ordering::Greater => 1,
574            }
575        }
576
577        _ => unreachable!(
578            "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
579            a.get_type(),
580            b.get_type()
581        ),
582    }
583}
584
585fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
586    let ge_min = match min {
587        Bound::Unbounded => true,
588        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
589        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
590    };
591    let le_max = match max {
592        Bound::Unbounded => true,
593        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
594        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
595    };
596    ge_min && le_max
597}
598
599fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
600    match (min, max) {
601        (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
602        (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
603        (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
604        (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
605        (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
606    }
607}
608
609fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
610    match (min1, min2) {
611        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
612        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
613            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
614                Bound::Inclusive(v1)
615            } else {
616                Bound::Inclusive(v2)
617            }
618        }
619        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
620            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
621                Bound::Inclusive(v1)
622            } else {
623                Bound::Exclusive(v2)
624            }
625        }
626        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
627            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
628                Bound::Exclusive(v1)
629            } else {
630                Bound::Inclusive(v2)
631            }
632        }
633        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
634            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
635                Bound::Exclusive(v1)
636            } else {
637                Bound::Exclusive(v2)
638            }
639        }
640    }
641}
642
643fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
644    match (max1, max2) {
645        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
646        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
647            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
648                Bound::Inclusive(v1)
649            } else {
650                Bound::Inclusive(v2)
651            }
652        }
653        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
654            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
655                Bound::Inclusive(v1)
656            } else {
657                Bound::Exclusive(v2)
658            }
659        }
660        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
661            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
662                Bound::Exclusive(v1)
663            } else {
664                Bound::Inclusive(v2)
665            }
666        }
667        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
668            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
669                Bound::Exclusive(v1)
670            } else {
671                Bound::Exclusive(v2)
672            }
673        }
674    }
675}
676
677fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
678    let a = normalize_domain(a);
679    let b = normalize_domain(b);
680
681    let result = match (a, b) {
682        (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
683        (Domain::Empty, _) | (_, Domain::Empty) => None,
684
685        (
686            Domain::Range {
687                min: min1,
688                max: max1,
689            },
690            Domain::Range {
691                min: min2,
692                max: max2,
693            },
694        ) => {
695            let min = compute_intersection_min(min1, min2);
696            let max = compute_intersection_max(max1, max2);
697
698            if bounds_contradict(&min, &max) {
699                None
700            } else {
701                Some(Domain::Range { min, max })
702            }
703        }
704        (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
705            let filtered: Vec<LiteralValue> =
706                v1.iter().filter(|x| v2.contains(x)).cloned().collect();
707            if filtered.is_empty() {
708                None
709            } else {
710                Some(Domain::Enumeration(Arc::new(filtered)))
711            }
712        }
713        (Domain::Enumeration(vs), Domain::Range { min, max })
714        | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
715            let mut kept = Vec::new();
716            for v in vs.iter() {
717                if value_within(v, &min, &max) {
718                    kept.push(v.clone());
719                }
720            }
721            if kept.is_empty() {
722                None
723            } else {
724                Some(Domain::Enumeration(Arc::new(kept)))
725            }
726        }
727        (Domain::Enumeration(vs), Domain::Complement(inner))
728        | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
729            match *inner.clone() {
730                Domain::Enumeration(excluded) => {
731                    let mut kept = Vec::new();
732                    for v in vs.iter() {
733                        if !excluded.contains(v) {
734                            kept.push(v.clone());
735                        }
736                    }
737                    if kept.is_empty() {
738                        None
739                    } else {
740                        Some(Domain::Enumeration(Arc::new(kept)))
741                    }
742                }
743                Domain::Range { min, max } => {
744                    // Filter enumeration values that are NOT in the range
745                    let mut kept = Vec::new();
746                    for v in vs.iter() {
747                        if !value_within(v, &min, &max) {
748                            kept.push(v.clone());
749                        }
750                    }
751                    if kept.is_empty() {
752                        None
753                    } else {
754                        Some(Domain::Enumeration(Arc::new(kept)))
755                    }
756                }
757                _ => {
758                    // For other complement types, normalize and recurse
759                    let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
760                    domain_intersection(Domain::Enumeration(vs.clone()), normalized)
761                }
762            }
763        }
764        (Domain::Union(v1), Domain::Union(v2)) => {
765            let mut acc: Vec<Domain> = Vec::new();
766            for a in v1.iter() {
767                for b in v2.iter() {
768                    if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
769                        acc.push(ix);
770                    }
771                }
772            }
773            if acc.is_empty() {
774                None
775            } else {
776                Some(Domain::Union(Arc::new(acc)))
777            }
778        }
779        (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
780            let mut acc: Vec<Domain> = Vec::new();
781            for a in vs.iter() {
782                if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
783                    acc.push(ix);
784                }
785            }
786            if acc.is_empty() {
787                None
788            } else if acc.len() == 1 {
789                Some(acc.remove(0))
790            } else {
791                Some(Domain::Union(Arc::new(acc)))
792            }
793        }
794        // Range ∩ not({p1,p2,...})  =>  Range with excluded points removed (as union of ranges)
795        (Domain::Range { min, max }, Domain::Complement(inner))
796        | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
797            Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
798            _ => {
799                // Normalize the complement (not just the inner value) and recurse.
800                // If normalization doesn't change it, we must not recurse infinitely.
801                let normalized_complement = normalize_domain(Domain::Complement(inner));
802                if matches!(&normalized_complement, Domain::Complement(_)) {
803                    None
804                } else {
805                    domain_intersection(Domain::Range { min, max }, normalized_complement)
806                }
807            }
808        },
809        (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
810            match (a_inner.as_ref(), b_inner.as_ref()) {
811                (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
812                    // not(A) ∩ not(B) == not(A ∪ B)
813                    let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
814                    excluded.extend(b_ex.iter().cloned());
815                    Some(normalize_domain(Domain::Complement(Box::new(
816                        Domain::Enumeration(Arc::new(excluded)),
817                    ))))
818                }
819                _ => None,
820            }
821        }
822    };
823    result.map(normalize_domain)
824}
825
826fn range_minus_excluded_points(
827    min: Bound,
828    max: Bound,
829    excluded: &Arc<Vec<LiteralValue>>,
830) -> Option<Domain> {
831    // Start with a single range and iteratively split on excluded points that fall within it.
832    let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
833
834    for p in excluded.iter() {
835        let mut next: Vec<(Bound, Bound)> = Vec::new();
836
837        for (rmin, rmax) in parts {
838            if !value_within(p, &rmin, &rmax) {
839                next.push((rmin, rmax));
840                continue;
841            }
842
843            // Left part: [rmin, p) or [rmin, p] depending on rmin and exclusion
844            let left_max = Bound::Exclusive(Arc::new(p.clone()));
845            if !bounds_contradict(&rmin, &left_max) {
846                next.push((rmin.clone(), left_max));
847            }
848
849            // Right part: (p, rmax)
850            let right_min = Bound::Exclusive(Arc::new(p.clone()));
851            if !bounds_contradict(&right_min, &rmax) {
852                next.push((right_min, rmax.clone()));
853            }
854        }
855
856        parts = next;
857        if parts.is_empty() {
858            return None;
859        }
860    }
861
862    if parts.is_empty() {
863        None
864    } else if parts.len() == 1 {
865        let (min, max) = parts.remove(0);
866        Some(Domain::Range { min, max })
867    } else {
868        Some(Domain::Union(Arc::new(
869            parts
870                .into_iter()
871                .map(|(min, max)| Domain::Range { min, max })
872                .collect(),
873        )))
874    }
875}
876
877fn invert_bound(bound: Bound) -> Bound {
878    match bound {
879        Bound::Unbounded => Bound::Unbounded,
880        Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
881        Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
882    }
883}
884
885fn normalize_domain(d: Domain) -> Domain {
886    match d {
887        Domain::Complement(inner) => {
888            let normalized_inner = normalize_domain(*inner);
889            match normalized_inner {
890                Domain::Complement(double_inner) => *double_inner,
891                Domain::Range { min, max } => match (&min, &max) {
892                    (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
893                    (Bound::Unbounded, max) => Domain::Range {
894                        min: invert_bound(max.clone()),
895                        max: Bound::Unbounded,
896                    },
897                    (min, Bound::Unbounded) => Domain::Range {
898                        min: Bound::Unbounded,
899                        max: invert_bound(min.clone()),
900                    },
901                    (min, max) => Domain::Union(Arc::new(vec![
902                        Domain::Range {
903                            min: Bound::Unbounded,
904                            max: invert_bound(min.clone()),
905                        },
906                        Domain::Range {
907                            min: invert_bound(max.clone()),
908                            max: Bound::Unbounded,
909                        },
910                    ])),
911                },
912                Domain::Enumeration(vals) => {
913                    if vals.len() == 1 {
914                        if let Some(lit) = vals.first() {
915                            if let ValueKind::Boolean(true) = &lit.value {
916                                return Domain::Enumeration(Arc::new(vec![
917                                    LiteralValue::from_bool(false),
918                                ]));
919                            }
920                            if let ValueKind::Boolean(false) = &lit.value {
921                                return Domain::Enumeration(Arc::new(vec![
922                                    LiteralValue::from_bool(true),
923                                ]));
924                            }
925                        }
926                    }
927                    Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
928                }
929                Domain::Unconstrained => Domain::Empty,
930                Domain::Empty => Domain::Unconstrained,
931                Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
932            }
933        }
934        Domain::Empty => Domain::Empty,
935        Domain::Union(parts) => {
936            let mut flat: Vec<Domain> = Vec::new();
937            for p in parts.iter().cloned() {
938                let normalized = normalize_domain(p);
939                match normalized {
940                    Domain::Union(inner) => flat.extend(inner.iter().cloned()),
941                    Domain::Unconstrained => return Domain::Unconstrained,
942                    Domain::Enumeration(vals) if vals.is_empty() => {}
943                    other => flat.push(other),
944                }
945            }
946
947            let mut all_enum_values: Vec<LiteralValue> = Vec::new();
948            let mut ranges: Vec<Domain> = Vec::new();
949            let mut others: Vec<Domain> = Vec::new();
950
951            for domain in flat {
952                match domain {
953                    Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
954                    Domain::Range { .. } => ranges.push(domain),
955                    other => others.push(other),
956                }
957            }
958
959            all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
960                -1 => Ordering::Less,
961                0 => Ordering::Equal,
962                _ => Ordering::Greater,
963            });
964            all_enum_values.dedup();
965
966            all_enum_values.retain(|v| {
967                !ranges.iter().any(|r| {
968                    if let Domain::Range { min, max } = r {
969                        value_within(v, min, max)
970                    } else {
971                        false
972                    }
973                })
974            });
975
976            let mut result: Vec<Domain> = Vec::new();
977            result.extend(ranges);
978            result = merge_ranges(result);
979
980            if !all_enum_values.is_empty() {
981                result.push(Domain::Enumeration(Arc::new(all_enum_values)));
982            }
983            result.extend(others);
984
985            result.sort_by(|a, b| match (a, b) {
986                (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
987                (Domain::Range { .. }, _) => Ordering::Less,
988                (_, Domain::Range { .. }) => Ordering::Greater,
989                (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
990                (Domain::Enumeration(_), _) => Ordering::Less,
991                (_, Domain::Enumeration(_)) => Ordering::Greater,
992                _ => Ordering::Equal,
993            });
994
995            if result.is_empty() {
996                Domain::Enumeration(Arc::new(vec![]))
997            } else if result.len() == 1 {
998                result.remove(0)
999            } else {
1000                Domain::Union(Arc::new(result))
1001            }
1002        }
1003        Domain::Enumeration(values) => {
1004            let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
1005            sorted.sort_by(|a, b| match lit_cmp(a, b) {
1006                -1 => Ordering::Less,
1007                0 => Ordering::Equal,
1008                _ => Ordering::Greater,
1009            });
1010            sorted.dedup();
1011            Domain::Enumeration(Arc::new(sorted))
1012        }
1013        other => other,
1014    }
1015}
1016
1017fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1018    let mut result = Vec::new();
1019    let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1020    let mut others = Vec::new();
1021
1022    for d in domains {
1023        match d {
1024            Domain::Range { min, max } => ranges.push((min, max)),
1025            other => others.push(other),
1026        }
1027    }
1028
1029    if ranges.is_empty() {
1030        return others;
1031    }
1032
1033    ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1034
1035    let mut merged: Vec<(Bound, Bound)> = Vec::new();
1036    let mut current = ranges[0].clone();
1037
1038    for next in ranges.iter().skip(1) {
1039        if ranges_adjacent_or_overlap(&current, next) {
1040            current = (
1041                min_bound(&current.0, &next.0),
1042                max_bound(&current.1, &next.1),
1043            );
1044        } else {
1045            merged.push(current);
1046            current = next.clone();
1047        }
1048    }
1049    merged.push(current);
1050
1051    for (min, max) in merged {
1052        result.push(Domain::Range { min, max });
1053    }
1054    result.extend(others);
1055
1056    result
1057}
1058
1059fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1060    match (a, b) {
1061        (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1062        (Bound::Unbounded, _) => Ordering::Less,
1063        (_, Bound::Unbounded) => Ordering::Greater,
1064        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1065        | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1066            -1 => Ordering::Less,
1067            0 => Ordering::Equal,
1068            _ => Ordering::Greater,
1069        },
1070        (Bound::Inclusive(v1), Bound::Exclusive(v2))
1071        | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1072            -1 => Ordering::Less,
1073            0 => {
1074                if matches!(a, Bound::Inclusive(_)) {
1075                    Ordering::Less
1076                } else {
1077                    Ordering::Greater
1078                }
1079            }
1080            _ => Ordering::Greater,
1081        },
1082    }
1083}
1084
1085fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1086    match (&r1.1, &r2.0) {
1087        (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1088        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1089        | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1090        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1091        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1092    }
1093}
1094
1095fn min_bound(a: &Bound, b: &Bound) -> Bound {
1096    match (a, b) {
1097        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1098        _ => {
1099            if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1100                a.clone()
1101            } else {
1102                b.clone()
1103            }
1104        }
1105    }
1106}
1107
1108fn max_bound(a: &Bound, b: &Bound) -> Bound {
1109    match (a, b) {
1110        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1111        _ => {
1112            if matches!(compare_bounds(a, b), Ordering::Greater) {
1113                a.clone()
1114            } else {
1115                b.clone()
1116            }
1117        }
1118    }
1119}
1120
1121#[cfg(test)]
1122mod tests {
1123    use super::*;
1124    fn num(n: i64) -> LiteralValue {
1125        LiteralValue::number(crate::computation::rational::RationalInteger::new(
1126            n as i128, 1,
1127        ))
1128    }
1129
1130    fn data(name: &str) -> DataPath {
1131        DataPath::new(vec![], name.to_string())
1132    }
1133
1134    #[test]
1135    fn test_normalize_double_complement() {
1136        let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1137        let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1138        let normalized = normalize_domain(double);
1139        assert_eq!(normalized, inner);
1140    }
1141
1142    #[test]
1143    fn test_normalize_union_absorbs_unconstrained() {
1144        let union = Domain::Union(Arc::new(vec![
1145            Domain::Range {
1146                min: Bound::Inclusive(Arc::new(num(0))),
1147                max: Bound::Inclusive(Arc::new(num(10))),
1148            },
1149            Domain::Unconstrained,
1150        ]));
1151        let normalized = normalize_domain(union);
1152        assert_eq!(normalized, Domain::Unconstrained);
1153    }
1154
1155    #[test]
1156    fn test_domain_display() {
1157        let range = Domain::Range {
1158            min: Bound::Inclusive(Arc::new(num(10))),
1159            max: Bound::Exclusive(Arc::new(num(20))),
1160        };
1161        assert_eq!(format!("{}", range), "[10, 20)");
1162
1163        let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1164        assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1165    }
1166
1167    #[test]
1168    fn test_extract_domain_from_comparison() {
1169        let constraint = Constraint::Comparison {
1170            data: data("age"),
1171            op: ComparisonComputation::GreaterThan,
1172            value: Arc::new(num(18)),
1173        };
1174
1175        let domains = extract_domains_from_constraint(&constraint).unwrap();
1176        let age_domain = domains.get(&data("age")).unwrap();
1177
1178        assert_eq!(
1179            *age_domain,
1180            Domain::Range {
1181                min: Bound::Exclusive(Arc::new(num(18))),
1182                max: Bound::Unbounded,
1183            }
1184        );
1185    }
1186
1187    #[test]
1188    fn test_extract_domain_from_and() {
1189        let constraint = Constraint::And(
1190            Box::new(Constraint::Comparison {
1191                data: data("age"),
1192                op: ComparisonComputation::GreaterThan,
1193                value: Arc::new(num(18)),
1194            }),
1195            Box::new(Constraint::Comparison {
1196                data: data("age"),
1197                op: ComparisonComputation::LessThan,
1198                value: Arc::new(num(65)),
1199            }),
1200        );
1201
1202        let domains = extract_domains_from_constraint(&constraint).unwrap();
1203        let age_domain = domains.get(&data("age")).unwrap();
1204
1205        assert_eq!(
1206            *age_domain,
1207            Domain::Range {
1208                min: Bound::Exclusive(Arc::new(num(18))),
1209                max: Bound::Exclusive(Arc::new(num(65))),
1210            }
1211        );
1212    }
1213
1214    #[test]
1215    fn test_extract_domain_from_equality() {
1216        let constraint = Constraint::Comparison {
1217            data: data("status"),
1218            op: ComparisonComputation::Is,
1219            value: Arc::new(LiteralValue::text("active".to_string())),
1220        };
1221
1222        let domains = extract_domains_from_constraint(&constraint).unwrap();
1223        let status_domain = domains.get(&data("status")).unwrap();
1224
1225        assert_eq!(
1226            *status_domain,
1227            Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1228        );
1229    }
1230
1231    #[test]
1232    fn test_extract_domain_from_boolean_data() {
1233        let constraint = Constraint::Data(data("is_active"));
1234
1235        let domains = extract_domains_from_constraint(&constraint).unwrap();
1236        let is_active_domain = domains.get(&data("is_active")).unwrap();
1237
1238        assert_eq!(
1239            *is_active_domain,
1240            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1241        );
1242    }
1243
1244    #[test]
1245    fn test_extract_domain_from_not_boolean_data() {
1246        let constraint = Constraint::Not(Box::new(Constraint::Data(data("is_active"))));
1247
1248        let domains = extract_domains_from_constraint(&constraint).unwrap();
1249        let is_active_domain = domains.get(&data("is_active")).unwrap();
1250
1251        assert_eq!(
1252            *is_active_domain,
1253            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1254        );
1255    }
1256}