Skip to main content

lemma/inversion/
domain.rs

1//! Domain types and operations for inversion
2//!
3//! Provides:
4//! - `Domain` and `Bound` types for representing concrete value constraints
5//! - Domain operations: intersection, union, normalization
6//! - `extract_domains_from_constraint()`: extracts domains from constraints
7
8use crate::planning::semantics::{
9    ComparisonComputation, DataPath, LiteralValue, SemanticConversionTarget, ValueKind,
10};
11use crate::OperationResult;
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20/// Domain specification for valid values
21#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23    /// A single continuous range
24    Range { min: Bound, max: Bound },
25
26    /// Multiple disjoint ranges
27    Union(Arc<Vec<Domain>>),
28
29    /// Specific enumerated values only
30    Enumeration(Arc<Vec<LiteralValue>>),
31
32    /// Everything except these constraints
33    Complement(Box<Domain>),
34
35    /// Any value (no constraints)
36    Unconstrained,
37
38    /// Empty domain (no valid values) - represents unsatisfiable constraints
39    Empty,
40}
41
42impl Domain {
43    /// Check if this domain is satisfiable (has at least one valid value)
44    ///
45    /// Returns false for Empty domains and empty Enumerations.
46    pub fn is_satisfiable(&self) -> bool {
47        match self {
48            Domain::Empty => false,
49            Domain::Enumeration(values) => !values.is_empty(),
50            Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51            Domain::Range { min, max } => !bounds_contradict(min, max),
52            Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53            Domain::Unconstrained => true,
54        }
55    }
56
57    /// Check if this domain is empty (unsatisfiable)
58    pub fn is_empty(&self) -> bool {
59        !self.is_satisfiable()
60    }
61
62    /// Intersect this domain with another, returning Empty if no overlap.
63    /// `domain_intersection` returns `None` exactly when the result is empty.
64    pub fn intersect(&self, other: &Domain) -> Domain {
65        match domain_intersection(self.clone(), other.clone()) {
66            Some(d) => d,
67            None => Domain::Empty,
68        }
69    }
70
71    /// Check if a value is contained in this domain
72    pub fn contains(&self, value: &LiteralValue) -> bool {
73        match self {
74            Domain::Empty => false,
75            Domain::Unconstrained => true,
76            Domain::Enumeration(values) => values.contains(value),
77            Domain::Range { min, max } => value_within(value, min, max),
78            Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
79            Domain::Complement(inner) => !inner.contains(value),
80        }
81    }
82}
83
84/// Bound specification for ranges
85#[derive(Debug, Clone, PartialEq)]
86pub enum Bound {
87    /// Inclusive bound [value
88    Inclusive(Arc<LiteralValue>),
89
90    /// Exclusive bound (value
91    Exclusive(Arc<LiteralValue>),
92
93    /// Unbounded (-infinity or +infinity)
94    Unbounded,
95}
96
97impl fmt::Display for Domain {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        match self {
100            Domain::Empty => write!(f, "empty"),
101            Domain::Unconstrained => write!(f, "any"),
102            Domain::Enumeration(vals) => {
103                write!(f, "{{")?;
104                for (i, v) in vals.iter().enumerate() {
105                    if i > 0 {
106                        write!(f, ", ")?;
107                    }
108                    write!(f, "{}", v)?;
109                }
110                write!(f, "}}")
111            }
112            Domain::Range { min, max } => {
113                let (l_bracket, r_bracket) = match (min, max) {
114                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
115                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
116                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
117                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
118                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
119                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
120                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
121                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
122                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
123                };
124
125                let min_str = match min {
126                    Bound::Unbounded => "-inf".to_string(),
127                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128                };
129                let max_str = match max {
130                    Bound::Unbounded => "+inf".to_string(),
131                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
132                };
133                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
134            }
135            Domain::Union(parts) => {
136                for (i, p) in parts.iter().enumerate() {
137                    if i > 0 {
138                        write!(f, " | ")?;
139                    }
140                    write!(f, "{}", p)?;
141                }
142                Ok(())
143            }
144            Domain::Complement(inner) => write!(f, "not ({})", inner),
145        }
146    }
147}
148
149impl fmt::Display for Bound {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        match self {
152            Bound::Unbounded => write!(f, "inf"),
153            Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
154            Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
155        }
156    }
157}
158
159impl Serialize for Domain {
160    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161    where
162        S: Serializer,
163    {
164        match self {
165            Domain::Empty => {
166                let mut st = serializer.serialize_struct("domain", 1)?;
167                st.serialize_field("type", "empty")?;
168                st.end()
169            }
170            Domain::Unconstrained => {
171                let mut st = serializer.serialize_struct("domain", 1)?;
172                st.serialize_field("type", "unconstrained")?;
173                st.end()
174            }
175            Domain::Enumeration(vals) => {
176                let mut st = serializer.serialize_struct("domain", 2)?;
177                st.serialize_field("type", "enumeration")?;
178                st.serialize_field("values", vals)?;
179                st.end()
180            }
181            Domain::Range { min, max } => {
182                let mut st = serializer.serialize_struct("domain", 3)?;
183                st.serialize_field("type", "range")?;
184                st.serialize_field("min", min)?;
185                st.serialize_field("max", max)?;
186                st.end()
187            }
188            Domain::Union(parts) => {
189                let mut st = serializer.serialize_struct("domain", 2)?;
190                st.serialize_field("type", "union")?;
191                st.serialize_field("parts", parts)?;
192                st.end()
193            }
194            Domain::Complement(inner) => {
195                let mut st = serializer.serialize_struct("domain", 2)?;
196                st.serialize_field("type", "complement")?;
197                st.serialize_field("inner", inner)?;
198                st.end()
199            }
200        }
201    }
202}
203
204impl Serialize for Bound {
205    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206    where
207        S: Serializer,
208    {
209        match self {
210            Bound::Unbounded => {
211                let mut st = serializer.serialize_struct("bound", 1)?;
212                st.serialize_field("type", "unbounded")?;
213                st.end()
214            }
215            Bound::Inclusive(v) => {
216                let mut st = serializer.serialize_struct("bound", 2)?;
217                st.serialize_field("type", "inclusive")?;
218                st.serialize_field("value", v.as_ref())?;
219                st.end()
220            }
221            Bound::Exclusive(v) => {
222                let mut st = serializer.serialize_struct("bound", 2)?;
223                st.serialize_field("type", "exclusive")?;
224                st.serialize_field("value", v.as_ref())?;
225                st.end()
226            }
227        }
228    }
229}
230
231/// Extract domains for all data mentioned in a constraint
232pub fn extract_domains_from_constraint(
233    constraint: &Constraint,
234) -> Result<HashMap<DataPath, Domain>, crate::Error> {
235    let all_datas = constraint.collect_data();
236    let mut domains = HashMap::new();
237
238    for data_path in all_datas {
239        // None means the data appears in the constraint but has no extractable
240        // bound (e.g. only used in equality with another data). Treating it as
241        // Unconstrained is correct: the solver will enumerate values.
242        let domain =
243            extract_domain_for_data(constraint, &data_path)?.unwrap_or(Domain::Unconstrained);
244        domains.insert(data_path, domain);
245    }
246
247    Ok(domains)
248}
249
250fn extract_domain_for_data(
251    constraint: &Constraint,
252    data_path: &DataPath,
253) -> Result<Option<Domain>, crate::Error> {
254    let domain = match constraint {
255        Constraint::True => return Ok(None),
256        Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
257
258        Constraint::Comparison { data, op, value } => {
259            if data == data_path {
260                Some(comparison_to_domain(op, value.as_ref())?)
261            } else {
262                None
263            }
264        }
265
266        Constraint::Data(fp) => {
267            if fp == data_path {
268                Some(Domain::Enumeration(Arc::new(vec![
269                    LiteralValue::from_bool(true),
270                ])))
271            } else {
272                None
273            }
274        }
275
276        Constraint::And(left, right) => {
277            let left_domain = extract_domain_for_data(left, data_path)?;
278            let right_domain = extract_domain_for_data(right, data_path)?;
279            match (left_domain, right_domain) {
280                (None, None) => None,
281                (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
282                (Some(a), Some(b)) => match domain_intersection(a, b) {
283                    Some(domain) => Some(domain),
284                    None => Some(Domain::Enumeration(Arc::new(vec![]))),
285                },
286            }
287        }
288
289        Constraint::Or(left, right) => {
290            let left_domain = extract_domain_for_data(left, data_path)?;
291            let right_domain = extract_domain_for_data(right, data_path)?;
292            union_optional_domains(left_domain, right_domain)
293        }
294
295        Constraint::Not(inner) => {
296            // Handle not (data is value)
297            if let Constraint::Comparison { data, op, value } = inner.as_ref() {
298                if data == data_path && op.is_equal() {
299                    return Ok(Some(normalize_domain(Domain::Complement(Box::new(
300                        Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
301                    )))));
302                }
303            }
304
305            // Handle not (boolean_data)
306            if let Constraint::Data(fp) = inner.as_ref() {
307                if fp == data_path {
308                    return Ok(Some(Domain::Enumeration(Arc::new(vec![
309                        LiteralValue::from_bool(false),
310                    ]))));
311                }
312            }
313
314            extract_domain_for_data(inner, data_path)?
315                .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
316        }
317    };
318
319    Ok(domain.map(normalize_domain))
320}
321
322fn comparison_to_domain(
323    op: &ComparisonComputation,
324    value: &LiteralValue,
325) -> Result<Domain, crate::Error> {
326    if op.is_equal() {
327        return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
328    }
329    if op.is_not_equal() {
330        return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
331            vec![value.clone()],
332        )))));
333    }
334    match op {
335        ComparisonComputation::LessThan => Ok(Domain::Range {
336            min: Bound::Unbounded,
337            max: Bound::Exclusive(Arc::new(value.clone())),
338        }),
339        ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
340            min: Bound::Unbounded,
341            max: Bound::Inclusive(Arc::new(value.clone())),
342        }),
343        ComparisonComputation::GreaterThan => Ok(Domain::Range {
344            min: Bound::Exclusive(Arc::new(value.clone())),
345            max: Bound::Unbounded,
346        }),
347        ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
348            min: Bound::Inclusive(Arc::new(value.clone())),
349            max: Bound::Unbounded,
350        }),
351        _ => unreachable!(
352            "BUG: unsupported comparison operator for domain extraction: {:?}",
353            op
354        ),
355    }
356}
357
358/// Compute the domain for a single comparison-atom used by inversion constraints.
359///
360/// This is used by numeric-aware constraint simplification to derive implications/exclusions
361/// between comparison atoms on the same data.
362pub(crate) fn domain_for_comparison_atom(
363    op: &ComparisonComputation,
364    value: &LiteralValue,
365) -> Result<Domain, crate::Error> {
366    comparison_to_domain(op, value)
367}
368
369impl Domain {
370    /// Proven subset check for the atom-domain forms we generate from comparisons:
371    /// - Range
372    /// - Enumeration
373    /// - Complement(Enumeration) (used for `is not`)
374    ///
375    /// Returns false when the relationship cannot be proven with these forms.
376    pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
377        match (self, other) {
378            (Domain::Empty, _) => true,
379            (_, Domain::Unconstrained) => true,
380            (Domain::Unconstrained, _) => false,
381
382            (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
383            (Domain::Enumeration(vals), Domain::Range { min, max }) => {
384                vals.iter().all(|v| value_within(v, min, max))
385            }
386
387            (
388                Domain::Range {
389                    min: amin,
390                    max: amax,
391                },
392                Domain::Range {
393                    min: bmin,
394                    max: bmax,
395                },
396            ) => range_within_range(amin, amax, bmin, bmax),
397
398            // Range ⊆ not({p}) when the range does not include p (for all excluded points)
399            (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
400                Domain::Enumeration(excluded) => {
401                    excluded.iter().all(|p| !value_within(p, min, max))
402                }
403                _ => false,
404            },
405
406            // {v} ⊆ not({p}) when v is not excluded
407            (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
408                Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
409                _ => false,
410            },
411
412            // not(A) ⊆ not(B)  iff  B ⊆ A  (for enumeration complements)
413            (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
414                match (a_inner.as_ref(), b_inner.as_ref()) {
415                    (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
416                        excluded_b.iter().all(|v| excluded_a.contains(v))
417                    }
418                    _ => false,
419                }
420            }
421
422            _ => false,
423        }
424    }
425}
426
427fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
428    lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
429}
430
431fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
432    match (a, b) {
433        (_, Bound::Unbounded) => true,
434        (Bound::Unbounded, _) => false,
435        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
436        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
437        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
438            let c = lit_cmp(av.as_ref(), bv.as_ref());
439            c >= 0
440        }
441        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
442            // a >= (b) only if a's value > b's value
443            lit_cmp(av.as_ref(), bv.as_ref()) > 0
444        }
445    }
446}
447
448fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
449    match (a, b) {
450        (Bound::Unbounded, Bound::Unbounded) => true,
451        (_, Bound::Unbounded) => true,
452        (Bound::Unbounded, _) => false,
453        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
454        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
455        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
456            // (a) <= [b] when a <= b
457            lit_cmp(av.as_ref(), bv.as_ref()) <= 0
458        }
459        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
460            // [a] <= (b) only if a < b
461            lit_cmp(av.as_ref(), bv.as_ref()) < 0
462        }
463    }
464}
465
466fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
467    match (a, b) {
468        (None, None) => None,
469        (Some(d), None) | (None, Some(d)) => Some(d),
470        (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
471    }
472}
473
474fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
475    use std::cmp::Ordering;
476
477    match (&a.value, &b.value) {
478        (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
479            Ordering::Less => -1,
480            Ordering::Equal => 0,
481            Ordering::Greater => 1,
482        },
483
484        (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
485            Ordering::Less => -1,
486            Ordering::Equal => 0,
487            Ordering::Greater => 1,
488        },
489
490        (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
491            Ordering::Less => -1,
492            Ordering::Equal => 0,
493            Ordering::Greater => 1,
494        },
495
496        (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
497            Ordering::Less => -1,
498            Ordering::Equal => 0,
499            Ordering::Greater => 1,
500        },
501
502        (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
503            Ordering::Less => -1,
504            Ordering::Equal => 0,
505            Ordering::Greater => 1,
506        },
507
508        (ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub)) => {
509            let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
510            let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
511            match a_sec.cmp(&b_sec) {
512                Ordering::Less => -1,
513                Ordering::Equal => 0,
514                Ordering::Greater => 1,
515            }
516        }
517
518        (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
519            Ordering::Less => -1,
520            Ordering::Equal => 0,
521            Ordering::Greater => 1,
522        },
523
524        (ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub)) => {
525            if a.lemma_type != b.lemma_type {
526                unreachable!(
527                    "BUG: lit_cmp compared different scale types ({} vs {})",
528                    a.lemma_type.name(),
529                    b.lemma_type.name()
530                );
531            }
532
533            if lua.eq_ignore_ascii_case(lub) {
534                return match la.cmp(lb) {
535                    Ordering::Less => -1,
536                    Ordering::Equal => 0,
537                    Ordering::Greater => 1,
538                };
539            }
540
541            // Convert b to a's unit for comparison
542            let target = SemanticConversionTarget::ScaleUnit(lua.clone());
543            let converted = crate::computation::convert_unit(b, &target);
544            let converted_value = match converted {
545                OperationResult::Value(lit) => match lit.value {
546                    ValueKind::Scale(v, _) => v,
547                    _ => unreachable!("BUG: scale unit conversion returned non-scale value"),
548                },
549                OperationResult::Veto(reason) => {
550                    unreachable!(
551                        "BUG: scale unit conversion vetoed unexpectedly: {:?}",
552                        reason
553                    )
554                }
555            };
556
557            match la.cmp(&converted_value) {
558                Ordering::Less => -1,
559                Ordering::Equal => 0,
560                Ordering::Greater => 1,
561            }
562        }
563
564        _ => unreachable!(
565            "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
566            a.get_type(),
567            b.get_type()
568        ),
569    }
570}
571
572fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
573    let ge_min = match min {
574        Bound::Unbounded => true,
575        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
576        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
577    };
578    let le_max = match max {
579        Bound::Unbounded => true,
580        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
581        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
582    };
583    ge_min && le_max
584}
585
586fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
587    match (min, max) {
588        (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
589        (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
590        (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
591        (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
592        (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
593    }
594}
595
596fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
597    match (min1, min2) {
598        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
599        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
600            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
601                Bound::Inclusive(v1)
602            } else {
603                Bound::Inclusive(v2)
604            }
605        }
606        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
607            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
608                Bound::Inclusive(v1)
609            } else {
610                Bound::Exclusive(v2)
611            }
612        }
613        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
614            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
615                Bound::Exclusive(v1)
616            } else {
617                Bound::Inclusive(v2)
618            }
619        }
620        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
621            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
622                Bound::Exclusive(v1)
623            } else {
624                Bound::Exclusive(v2)
625            }
626        }
627    }
628}
629
630fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
631    match (max1, max2) {
632        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
633        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
634            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
635                Bound::Inclusive(v1)
636            } else {
637                Bound::Inclusive(v2)
638            }
639        }
640        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
641            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
642                Bound::Inclusive(v1)
643            } else {
644                Bound::Exclusive(v2)
645            }
646        }
647        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
648            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
649                Bound::Exclusive(v1)
650            } else {
651                Bound::Inclusive(v2)
652            }
653        }
654        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
655            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
656                Bound::Exclusive(v1)
657            } else {
658                Bound::Exclusive(v2)
659            }
660        }
661    }
662}
663
664fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
665    let a = normalize_domain(a);
666    let b = normalize_domain(b);
667
668    let result = match (a, b) {
669        (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
670        (Domain::Empty, _) | (_, Domain::Empty) => None,
671
672        (
673            Domain::Range {
674                min: min1,
675                max: max1,
676            },
677            Domain::Range {
678                min: min2,
679                max: max2,
680            },
681        ) => {
682            let min = compute_intersection_min(min1, min2);
683            let max = compute_intersection_max(max1, max2);
684
685            if bounds_contradict(&min, &max) {
686                None
687            } else {
688                Some(Domain::Range { min, max })
689            }
690        }
691        (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
692            let filtered: Vec<LiteralValue> =
693                v1.iter().filter(|x| v2.contains(x)).cloned().collect();
694            if filtered.is_empty() {
695                None
696            } else {
697                Some(Domain::Enumeration(Arc::new(filtered)))
698            }
699        }
700        (Domain::Enumeration(vs), Domain::Range { min, max })
701        | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
702            let mut kept = Vec::new();
703            for v in vs.iter() {
704                if value_within(v, &min, &max) {
705                    kept.push(v.clone());
706                }
707            }
708            if kept.is_empty() {
709                None
710            } else {
711                Some(Domain::Enumeration(Arc::new(kept)))
712            }
713        }
714        (Domain::Enumeration(vs), Domain::Complement(inner))
715        | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
716            match *inner.clone() {
717                Domain::Enumeration(excluded) => {
718                    let mut kept = Vec::new();
719                    for v in vs.iter() {
720                        if !excluded.contains(v) {
721                            kept.push(v.clone());
722                        }
723                    }
724                    if kept.is_empty() {
725                        None
726                    } else {
727                        Some(Domain::Enumeration(Arc::new(kept)))
728                    }
729                }
730                Domain::Range { min, max } => {
731                    // Filter enumeration values that are NOT in the range
732                    let mut kept = Vec::new();
733                    for v in vs.iter() {
734                        if !value_within(v, &min, &max) {
735                            kept.push(v.clone());
736                        }
737                    }
738                    if kept.is_empty() {
739                        None
740                    } else {
741                        Some(Domain::Enumeration(Arc::new(kept)))
742                    }
743                }
744                _ => {
745                    // For other complement types, normalize and recurse
746                    let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
747                    domain_intersection(Domain::Enumeration(vs.clone()), normalized)
748                }
749            }
750        }
751        (Domain::Union(v1), Domain::Union(v2)) => {
752            let mut acc: Vec<Domain> = Vec::new();
753            for a in v1.iter() {
754                for b in v2.iter() {
755                    if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
756                        acc.push(ix);
757                    }
758                }
759            }
760            if acc.is_empty() {
761                None
762            } else {
763                Some(Domain::Union(Arc::new(acc)))
764            }
765        }
766        (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
767            let mut acc: Vec<Domain> = Vec::new();
768            for a in vs.iter() {
769                if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
770                    acc.push(ix);
771                }
772            }
773            if acc.is_empty() {
774                None
775            } else if acc.len() == 1 {
776                Some(acc.remove(0))
777            } else {
778                Some(Domain::Union(Arc::new(acc)))
779            }
780        }
781        // Range ∩ not({p1,p2,...})  =>  Range with excluded points removed (as union of ranges)
782        (Domain::Range { min, max }, Domain::Complement(inner))
783        | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
784            Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
785            _ => {
786                // Normalize the complement (not just the inner value) and recurse.
787                // If normalization doesn't change it, we must not recurse infinitely.
788                let normalized_complement = normalize_domain(Domain::Complement(inner));
789                if matches!(&normalized_complement, Domain::Complement(_)) {
790                    None
791                } else {
792                    domain_intersection(Domain::Range { min, max }, normalized_complement)
793                }
794            }
795        },
796        (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
797            match (a_inner.as_ref(), b_inner.as_ref()) {
798                (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
799                    // not(A) ∩ not(B) == not(A ∪ B)
800                    let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
801                    excluded.extend(b_ex.iter().cloned());
802                    Some(normalize_domain(Domain::Complement(Box::new(
803                        Domain::Enumeration(Arc::new(excluded)),
804                    ))))
805                }
806                _ => None,
807            }
808        }
809    };
810    result.map(normalize_domain)
811}
812
813fn range_minus_excluded_points(
814    min: Bound,
815    max: Bound,
816    excluded: &Arc<Vec<LiteralValue>>,
817) -> Option<Domain> {
818    // Start with a single range and iteratively split on excluded points that fall within it.
819    let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
820
821    for p in excluded.iter() {
822        let mut next: Vec<(Bound, Bound)> = Vec::new();
823
824        for (rmin, rmax) in parts {
825            if !value_within(p, &rmin, &rmax) {
826                next.push((rmin, rmax));
827                continue;
828            }
829
830            // Left part: [rmin, p) or [rmin, p] depending on rmin and exclusion
831            let left_max = Bound::Exclusive(Arc::new(p.clone()));
832            if !bounds_contradict(&rmin, &left_max) {
833                next.push((rmin.clone(), left_max));
834            }
835
836            // Right part: (p, rmax)
837            let right_min = Bound::Exclusive(Arc::new(p.clone()));
838            if !bounds_contradict(&right_min, &rmax) {
839                next.push((right_min, rmax.clone()));
840            }
841        }
842
843        parts = next;
844        if parts.is_empty() {
845            return None;
846        }
847    }
848
849    if parts.is_empty() {
850        None
851    } else if parts.len() == 1 {
852        let (min, max) = parts.remove(0);
853        Some(Domain::Range { min, max })
854    } else {
855        Some(Domain::Union(Arc::new(
856            parts
857                .into_iter()
858                .map(|(min, max)| Domain::Range { min, max })
859                .collect(),
860        )))
861    }
862}
863
864fn invert_bound(bound: Bound) -> Bound {
865    match bound {
866        Bound::Unbounded => Bound::Unbounded,
867        Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
868        Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
869    }
870}
871
872fn normalize_domain(d: Domain) -> Domain {
873    match d {
874        Domain::Complement(inner) => {
875            let normalized_inner = normalize_domain(*inner);
876            match normalized_inner {
877                Domain::Complement(double_inner) => *double_inner,
878                Domain::Range { min, max } => match (&min, &max) {
879                    (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
880                    (Bound::Unbounded, max) => Domain::Range {
881                        min: invert_bound(max.clone()),
882                        max: Bound::Unbounded,
883                    },
884                    (min, Bound::Unbounded) => Domain::Range {
885                        min: Bound::Unbounded,
886                        max: invert_bound(min.clone()),
887                    },
888                    (min, max) => Domain::Union(Arc::new(vec![
889                        Domain::Range {
890                            min: Bound::Unbounded,
891                            max: invert_bound(min.clone()),
892                        },
893                        Domain::Range {
894                            min: invert_bound(max.clone()),
895                            max: Bound::Unbounded,
896                        },
897                    ])),
898                },
899                Domain::Enumeration(vals) => {
900                    if vals.len() == 1 {
901                        if let Some(lit) = vals.first() {
902                            if let ValueKind::Boolean(true) = &lit.value {
903                                return Domain::Enumeration(Arc::new(vec![
904                                    LiteralValue::from_bool(false),
905                                ]));
906                            }
907                            if let ValueKind::Boolean(false) = &lit.value {
908                                return Domain::Enumeration(Arc::new(vec![
909                                    LiteralValue::from_bool(true),
910                                ]));
911                            }
912                        }
913                    }
914                    Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
915                }
916                Domain::Unconstrained => Domain::Empty,
917                Domain::Empty => Domain::Unconstrained,
918                Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
919            }
920        }
921        Domain::Empty => Domain::Empty,
922        Domain::Union(parts) => {
923            let mut flat: Vec<Domain> = Vec::new();
924            for p in parts.iter().cloned() {
925                let normalized = normalize_domain(p);
926                match normalized {
927                    Domain::Union(inner) => flat.extend(inner.iter().cloned()),
928                    Domain::Unconstrained => return Domain::Unconstrained,
929                    Domain::Enumeration(vals) if vals.is_empty() => {}
930                    other => flat.push(other),
931                }
932            }
933
934            let mut all_enum_values: Vec<LiteralValue> = Vec::new();
935            let mut ranges: Vec<Domain> = Vec::new();
936            let mut others: Vec<Domain> = Vec::new();
937
938            for domain in flat {
939                match domain {
940                    Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
941                    Domain::Range { .. } => ranges.push(domain),
942                    other => others.push(other),
943                }
944            }
945
946            all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
947                -1 => Ordering::Less,
948                0 => Ordering::Equal,
949                _ => Ordering::Greater,
950            });
951            all_enum_values.dedup();
952
953            all_enum_values.retain(|v| {
954                !ranges.iter().any(|r| {
955                    if let Domain::Range { min, max } = r {
956                        value_within(v, min, max)
957                    } else {
958                        false
959                    }
960                })
961            });
962
963            let mut result: Vec<Domain> = Vec::new();
964            result.extend(ranges);
965            result = merge_ranges(result);
966
967            if !all_enum_values.is_empty() {
968                result.push(Domain::Enumeration(Arc::new(all_enum_values)));
969            }
970            result.extend(others);
971
972            result.sort_by(|a, b| match (a, b) {
973                (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
974                (Domain::Range { .. }, _) => Ordering::Less,
975                (_, Domain::Range { .. }) => Ordering::Greater,
976                (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
977                (Domain::Enumeration(_), _) => Ordering::Less,
978                (_, Domain::Enumeration(_)) => Ordering::Greater,
979                _ => Ordering::Equal,
980            });
981
982            if result.is_empty() {
983                Domain::Enumeration(Arc::new(vec![]))
984            } else if result.len() == 1 {
985                result.remove(0)
986            } else {
987                Domain::Union(Arc::new(result))
988            }
989        }
990        Domain::Enumeration(values) => {
991            let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
992            sorted.sort_by(|a, b| match lit_cmp(a, b) {
993                -1 => Ordering::Less,
994                0 => Ordering::Equal,
995                _ => Ordering::Greater,
996            });
997            sorted.dedup();
998            Domain::Enumeration(Arc::new(sorted))
999        }
1000        other => other,
1001    }
1002}
1003
1004fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1005    let mut result = Vec::new();
1006    let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1007    let mut others = Vec::new();
1008
1009    for d in domains {
1010        match d {
1011            Domain::Range { min, max } => ranges.push((min, max)),
1012            other => others.push(other),
1013        }
1014    }
1015
1016    if ranges.is_empty() {
1017        return others;
1018    }
1019
1020    ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1021
1022    let mut merged: Vec<(Bound, Bound)> = Vec::new();
1023    let mut current = ranges[0].clone();
1024
1025    for next in ranges.iter().skip(1) {
1026        if ranges_adjacent_or_overlap(&current, next) {
1027            current = (
1028                min_bound(&current.0, &next.0),
1029                max_bound(&current.1, &next.1),
1030            );
1031        } else {
1032            merged.push(current);
1033            current = next.clone();
1034        }
1035    }
1036    merged.push(current);
1037
1038    for (min, max) in merged {
1039        result.push(Domain::Range { min, max });
1040    }
1041    result.extend(others);
1042
1043    result
1044}
1045
1046fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1047    match (a, b) {
1048        (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1049        (Bound::Unbounded, _) => Ordering::Less,
1050        (_, Bound::Unbounded) => Ordering::Greater,
1051        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1052        | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1053            -1 => Ordering::Less,
1054            0 => Ordering::Equal,
1055            _ => Ordering::Greater,
1056        },
1057        (Bound::Inclusive(v1), Bound::Exclusive(v2))
1058        | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1059            -1 => Ordering::Less,
1060            0 => {
1061                if matches!(a, Bound::Inclusive(_)) {
1062                    Ordering::Less
1063                } else {
1064                    Ordering::Greater
1065                }
1066            }
1067            _ => Ordering::Greater,
1068        },
1069    }
1070}
1071
1072fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1073    match (&r1.1, &r2.0) {
1074        (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1075        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1076        | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1077        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1078        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1079    }
1080}
1081
1082fn min_bound(a: &Bound, b: &Bound) -> Bound {
1083    match (a, b) {
1084        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1085        _ => {
1086            if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1087                a.clone()
1088            } else {
1089                b.clone()
1090            }
1091        }
1092    }
1093}
1094
1095fn max_bound(a: &Bound, b: &Bound) -> Bound {
1096    match (a, b) {
1097        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1098        _ => {
1099            if matches!(compare_bounds(a, b), Ordering::Greater) {
1100                a.clone()
1101            } else {
1102                b.clone()
1103            }
1104        }
1105    }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110    use super::*;
1111    use rust_decimal::Decimal;
1112
1113    fn num(n: i64) -> LiteralValue {
1114        LiteralValue::number(Decimal::from(n))
1115    }
1116
1117    fn data(name: &str) -> DataPath {
1118        DataPath::new(vec![], name.to_string())
1119    }
1120
1121    #[test]
1122    fn test_normalize_double_complement() {
1123        let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1124        let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1125        let normalized = normalize_domain(double);
1126        assert_eq!(normalized, inner);
1127    }
1128
1129    #[test]
1130    fn test_normalize_union_absorbs_unconstrained() {
1131        let union = Domain::Union(Arc::new(vec![
1132            Domain::Range {
1133                min: Bound::Inclusive(Arc::new(num(0))),
1134                max: Bound::Inclusive(Arc::new(num(10))),
1135            },
1136            Domain::Unconstrained,
1137        ]));
1138        let normalized = normalize_domain(union);
1139        assert_eq!(normalized, Domain::Unconstrained);
1140    }
1141
1142    #[test]
1143    fn test_domain_display() {
1144        let range = Domain::Range {
1145            min: Bound::Inclusive(Arc::new(num(10))),
1146            max: Bound::Exclusive(Arc::new(num(20))),
1147        };
1148        assert_eq!(format!("{}", range), "[10, 20)");
1149
1150        let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1151        assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1152    }
1153
1154    #[test]
1155    fn test_extract_domain_from_comparison() {
1156        let constraint = Constraint::Comparison {
1157            data: data("age"),
1158            op: ComparisonComputation::GreaterThan,
1159            value: Arc::new(num(18)),
1160        };
1161
1162        let domains = extract_domains_from_constraint(&constraint).unwrap();
1163        let age_domain = domains.get(&data("age")).unwrap();
1164
1165        assert_eq!(
1166            *age_domain,
1167            Domain::Range {
1168                min: Bound::Exclusive(Arc::new(num(18))),
1169                max: Bound::Unbounded,
1170            }
1171        );
1172    }
1173
1174    #[test]
1175    fn test_extract_domain_from_and() {
1176        let constraint = Constraint::And(
1177            Box::new(Constraint::Comparison {
1178                data: data("age"),
1179                op: ComparisonComputation::GreaterThan,
1180                value: Arc::new(num(18)),
1181            }),
1182            Box::new(Constraint::Comparison {
1183                data: data("age"),
1184                op: ComparisonComputation::LessThan,
1185                value: Arc::new(num(65)),
1186            }),
1187        );
1188
1189        let domains = extract_domains_from_constraint(&constraint).unwrap();
1190        let age_domain = domains.get(&data("age")).unwrap();
1191
1192        assert_eq!(
1193            *age_domain,
1194            Domain::Range {
1195                min: Bound::Exclusive(Arc::new(num(18))),
1196                max: Bound::Exclusive(Arc::new(num(65))),
1197            }
1198        );
1199    }
1200
1201    #[test]
1202    fn test_extract_domain_from_equality() {
1203        let constraint = Constraint::Comparison {
1204            data: data("status"),
1205            op: ComparisonComputation::Is,
1206            value: Arc::new(LiteralValue::text("active".to_string())),
1207        };
1208
1209        let domains = extract_domains_from_constraint(&constraint).unwrap();
1210        let status_domain = domains.get(&data("status")).unwrap();
1211
1212        assert_eq!(
1213            *status_domain,
1214            Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1215        );
1216    }
1217
1218    #[test]
1219    fn test_extract_domain_from_boolean_data() {
1220        let constraint = Constraint::Data(data("is_active"));
1221
1222        let domains = extract_domains_from_constraint(&constraint).unwrap();
1223        let is_active_domain = domains.get(&data("is_active")).unwrap();
1224
1225        assert_eq!(
1226            *is_active_domain,
1227            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1228        );
1229    }
1230
1231    #[test]
1232    fn test_extract_domain_from_not_boolean_data() {
1233        let constraint = Constraint::Not(Box::new(Constraint::Data(data("is_active"))));
1234
1235        let domains = extract_domains_from_constraint(&constraint).unwrap();
1236        let is_active_domain = domains.get(&data("is_active")).unwrap();
1237
1238        assert_eq!(
1239            *is_active_domain,
1240            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1241        );
1242    }
1243}