Skip to main content

lemma/inversion/
domain.rs

1//! Domain types and operations for inversion
2//!
3//! Provides:
4//! - `Domain` and `Bound` types for representing concrete value constraints
5//! - Domain operations: intersection, union, normalization
6//! - `extract_domains_from_constraint()`: extracts domains from constraints
7
8use crate::planning::semantics::{
9    ComparisonComputation, FactPath, LiteralValue, SemanticConversionTarget, ValueKind,
10};
11use crate::OperationResult;
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20/// Domain specification for valid values
21#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23    /// A single continuous range
24    Range { min: Bound, max: Bound },
25
26    /// Multiple disjoint ranges
27    Union(Arc<Vec<Domain>>),
28
29    /// Specific enumerated values only
30    Enumeration(Arc<Vec<LiteralValue>>),
31
32    /// Everything except these constraints
33    Complement(Box<Domain>),
34
35    /// Any value (no constraints)
36    Unconstrained,
37
38    /// Empty domain (no valid values) - represents unsatisfiable constraints
39    Empty,
40}
41
42impl Domain {
43    /// Check if this domain is satisfiable (has at least one valid value)
44    ///
45    /// Returns false for Empty domains and empty Enumerations.
46    pub fn is_satisfiable(&self) -> bool {
47        match self {
48            Domain::Empty => false,
49            Domain::Enumeration(values) => !values.is_empty(),
50            Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51            Domain::Range { min, max } => !bounds_contradict(min, max),
52            Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53            Domain::Unconstrained => true,
54        }
55    }
56
57    /// Check if this domain is empty (unsatisfiable)
58    pub fn is_empty(&self) -> bool {
59        !self.is_satisfiable()
60    }
61
62    /// Intersect this domain with another, returning Empty if no overlap.
63    /// `domain_intersection` returns `None` exactly when the result is empty.
64    pub fn intersect(&self, other: &Domain) -> Domain {
65        match domain_intersection(self.clone(), other.clone()) {
66            Some(d) => d,
67            None => Domain::Empty,
68        }
69    }
70
71    /// Check if a value is contained in this domain
72    pub fn contains(&self, value: &LiteralValue) -> bool {
73        match self {
74            Domain::Empty => false,
75            Domain::Unconstrained => true,
76            Domain::Enumeration(values) => values.contains(value),
77            Domain::Range { min, max } => value_within(value, min, max),
78            Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
79            Domain::Complement(inner) => !inner.contains(value),
80        }
81    }
82}
83
84/// Bound specification for ranges
85#[derive(Debug, Clone, PartialEq)]
86pub enum Bound {
87    /// Inclusive bound [value
88    Inclusive(Arc<LiteralValue>),
89
90    /// Exclusive bound (value
91    Exclusive(Arc<LiteralValue>),
92
93    /// Unbounded (-infinity or +infinity)
94    Unbounded,
95}
96
97impl fmt::Display for Domain {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        match self {
100            Domain::Empty => write!(f, "empty"),
101            Domain::Unconstrained => write!(f, "any"),
102            Domain::Enumeration(vals) => {
103                write!(f, "{{")?;
104                for (i, v) in vals.iter().enumerate() {
105                    if i > 0 {
106                        write!(f, ", ")?;
107                    }
108                    write!(f, "{}", v)?;
109                }
110                write!(f, "}}")
111            }
112            Domain::Range { min, max } => {
113                let (l_bracket, r_bracket) = match (min, max) {
114                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
115                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
116                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
117                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
118                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
119                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
120                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
121                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
122                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
123                };
124
125                let min_str = match min {
126                    Bound::Unbounded => "-inf".to_string(),
127                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128                };
129                let max_str = match max {
130                    Bound::Unbounded => "+inf".to_string(),
131                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
132                };
133                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
134            }
135            Domain::Union(parts) => {
136                for (i, p) in parts.iter().enumerate() {
137                    if i > 0 {
138                        write!(f, " | ")?;
139                    }
140                    write!(f, "{}", p)?;
141                }
142                Ok(())
143            }
144            Domain::Complement(inner) => write!(f, "not ({})", inner),
145        }
146    }
147}
148
149impl fmt::Display for Bound {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        match self {
152            Bound::Unbounded => write!(f, "inf"),
153            Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
154            Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
155        }
156    }
157}
158
159impl Serialize for Domain {
160    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161    where
162        S: Serializer,
163    {
164        match self {
165            Domain::Empty => {
166                let mut st = serializer.serialize_struct("domain", 1)?;
167                st.serialize_field("type", "empty")?;
168                st.end()
169            }
170            Domain::Unconstrained => {
171                let mut st = serializer.serialize_struct("domain", 1)?;
172                st.serialize_field("type", "unconstrained")?;
173                st.end()
174            }
175            Domain::Enumeration(vals) => {
176                let mut st = serializer.serialize_struct("domain", 2)?;
177                st.serialize_field("type", "enumeration")?;
178                st.serialize_field("values", vals)?;
179                st.end()
180            }
181            Domain::Range { min, max } => {
182                let mut st = serializer.serialize_struct("domain", 3)?;
183                st.serialize_field("type", "range")?;
184                st.serialize_field("min", min)?;
185                st.serialize_field("max", max)?;
186                st.end()
187            }
188            Domain::Union(parts) => {
189                let mut st = serializer.serialize_struct("domain", 2)?;
190                st.serialize_field("type", "union")?;
191                st.serialize_field("parts", parts)?;
192                st.end()
193            }
194            Domain::Complement(inner) => {
195                let mut st = serializer.serialize_struct("domain", 2)?;
196                st.serialize_field("type", "complement")?;
197                st.serialize_field("inner", inner)?;
198                st.end()
199            }
200        }
201    }
202}
203
204impl Serialize for Bound {
205    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206    where
207        S: Serializer,
208    {
209        match self {
210            Bound::Unbounded => {
211                let mut st = serializer.serialize_struct("bound", 1)?;
212                st.serialize_field("type", "unbounded")?;
213                st.end()
214            }
215            Bound::Inclusive(v) => {
216                let mut st = serializer.serialize_struct("bound", 2)?;
217                st.serialize_field("type", "inclusive")?;
218                st.serialize_field("value", v.as_ref())?;
219                st.end()
220            }
221            Bound::Exclusive(v) => {
222                let mut st = serializer.serialize_struct("bound", 2)?;
223                st.serialize_field("type", "exclusive")?;
224                st.serialize_field("value", v.as_ref())?;
225                st.end()
226            }
227        }
228    }
229}
230
231/// Extract domains for all facts mentioned in a constraint
232pub fn extract_domains_from_constraint(
233    constraint: &Constraint,
234) -> Result<HashMap<FactPath, Domain>, crate::Error> {
235    let all_facts = constraint.collect_facts();
236    let mut domains = HashMap::new();
237
238    for fact_path in all_facts {
239        // None means the fact appears in the constraint but has no extractable
240        // bound (e.g. only used in equality with another fact). Treating it as
241        // Unconstrained is correct: the solver will enumerate values.
242        let domain =
243            extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
244        domains.insert(fact_path, domain);
245    }
246
247    Ok(domains)
248}
249
250fn extract_domain_for_fact(
251    constraint: &Constraint,
252    fact_path: &FactPath,
253) -> Result<Option<Domain>, crate::Error> {
254    let domain = match constraint {
255        Constraint::True => return Ok(None),
256        Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
257
258        Constraint::Comparison { fact, op, value } => {
259            if fact == fact_path {
260                Some(comparison_to_domain(op, value.as_ref())?)
261            } else {
262                None
263            }
264        }
265
266        Constraint::Fact(fp) => {
267            if fp == fact_path {
268                Some(Domain::Enumeration(Arc::new(vec![
269                    LiteralValue::from_bool(true),
270                ])))
271            } else {
272                None
273            }
274        }
275
276        Constraint::And(left, right) => {
277            let left_domain = extract_domain_for_fact(left, fact_path)?;
278            let right_domain = extract_domain_for_fact(right, fact_path)?;
279            match (left_domain, right_domain) {
280                (None, None) => None,
281                (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
282                (Some(a), Some(b)) => match domain_intersection(a, b) {
283                    Some(domain) => Some(domain),
284                    None => Some(Domain::Enumeration(Arc::new(vec![]))),
285                },
286            }
287        }
288
289        Constraint::Or(left, right) => {
290            let left_domain = extract_domain_for_fact(left, fact_path)?;
291            let right_domain = extract_domain_for_fact(right, fact_path)?;
292            union_optional_domains(left_domain, right_domain)
293        }
294
295        Constraint::Not(inner) => {
296            // Handle not (fact == value)
297            if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
298                if fact == fact_path && op.is_equal() {
299                    return Ok(Some(normalize_domain(Domain::Complement(Box::new(
300                        Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
301                    )))));
302                }
303            }
304
305            // Handle not (boolean_fact)
306            if let Constraint::Fact(fp) = inner.as_ref() {
307                if fp == fact_path {
308                    return Ok(Some(Domain::Enumeration(Arc::new(vec![
309                        LiteralValue::from_bool(false),
310                    ]))));
311                }
312            }
313
314            extract_domain_for_fact(inner, fact_path)?
315                .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
316        }
317    };
318
319    Ok(domain.map(normalize_domain))
320}
321
322fn comparison_to_domain(
323    op: &ComparisonComputation,
324    value: &LiteralValue,
325) -> Result<Domain, crate::Error> {
326    if op.is_equal() {
327        return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
328    }
329    if op.is_not_equal() {
330        return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
331            vec![value.clone()],
332        )))));
333    }
334    match op {
335        ComparisonComputation::LessThan => Ok(Domain::Range {
336            min: Bound::Unbounded,
337            max: Bound::Exclusive(Arc::new(value.clone())),
338        }),
339        ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
340            min: Bound::Unbounded,
341            max: Bound::Inclusive(Arc::new(value.clone())),
342        }),
343        ComparisonComputation::GreaterThan => Ok(Domain::Range {
344            min: Bound::Exclusive(Arc::new(value.clone())),
345            max: Bound::Unbounded,
346        }),
347        ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
348            min: Bound::Inclusive(Arc::new(value.clone())),
349            max: Bound::Unbounded,
350        }),
351        _ => unreachable!(
352            "BUG: unsupported comparison operator for domain extraction: {:?}",
353            op
354        ),
355    }
356}
357
358/// Compute the domain for a single comparison-atom used by inversion constraints.
359///
360/// This is used by numeric-aware constraint simplification to derive implications/exclusions
361/// between comparison atoms on the same fact.
362pub(crate) fn domain_for_comparison_atom(
363    op: &ComparisonComputation,
364    value: &LiteralValue,
365) -> Result<Domain, crate::Error> {
366    comparison_to_domain(op, value)
367}
368
369impl Domain {
370    /// Proven subset check for the atom-domain forms we generate from comparisons:
371    /// - Range
372    /// - Enumeration
373    /// - Complement(Enumeration) (used for != / is not)
374    ///
375    /// Returns false when the relationship cannot be proven with these forms.
376    pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
377        match (self, other) {
378            (Domain::Empty, _) => true,
379            (_, Domain::Unconstrained) => true,
380            (Domain::Unconstrained, _) => false,
381
382            (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
383            (Domain::Enumeration(vals), Domain::Range { min, max }) => {
384                vals.iter().all(|v| value_within(v, min, max))
385            }
386
387            (
388                Domain::Range {
389                    min: amin,
390                    max: amax,
391                },
392                Domain::Range {
393                    min: bmin,
394                    max: bmax,
395                },
396            ) => range_within_range(amin, amax, bmin, bmax),
397
398            // Range ⊆ not({p}) when the range does not include p (for all excluded points)
399            (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
400                Domain::Enumeration(excluded) => {
401                    excluded.iter().all(|p| !value_within(p, min, max))
402                }
403                _ => false,
404            },
405
406            // {v} ⊆ not({p}) when v is not excluded
407            (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
408                Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
409                _ => false,
410            },
411
412            // not(A) ⊆ not(B)  iff  B ⊆ A  (for enumeration complements)
413            (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
414                match (a_inner.as_ref(), b_inner.as_ref()) {
415                    (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
416                        excluded_b.iter().all(|v| excluded_a.contains(v))
417                    }
418                    _ => false,
419                }
420            }
421
422            _ => false,
423        }
424    }
425}
426
427fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
428    lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
429}
430
431fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
432    match (a, b) {
433        (_, Bound::Unbounded) => true,
434        (Bound::Unbounded, _) => false,
435        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
436        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
437        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
438            let c = lit_cmp(av.as_ref(), bv.as_ref());
439            c >= 0
440        }
441        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
442            // a >= (b) only if a's value > b's value
443            lit_cmp(av.as_ref(), bv.as_ref()) > 0
444        }
445    }
446}
447
448fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
449    match (a, b) {
450        (Bound::Unbounded, Bound::Unbounded) => true,
451        (_, Bound::Unbounded) => true,
452        (Bound::Unbounded, _) => false,
453        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
454        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
455        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
456            // (a) <= [b] when a <= b
457            lit_cmp(av.as_ref(), bv.as_ref()) <= 0
458        }
459        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
460            // [a] <= (b) only if a < b
461            lit_cmp(av.as_ref(), bv.as_ref()) < 0
462        }
463    }
464}
465
466fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
467    match (a, b) {
468        (None, None) => None,
469        (Some(d), None) | (None, Some(d)) => Some(d),
470        (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
471    }
472}
473
474fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
475    use std::cmp::Ordering;
476
477    match (&a.value, &b.value) {
478        (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
479            Ordering::Less => -1,
480            Ordering::Equal => 0,
481            Ordering::Greater => 1,
482        },
483
484        (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
485            Ordering::Less => -1,
486            Ordering::Equal => 0,
487            Ordering::Greater => 1,
488        },
489
490        (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
491            Ordering::Less => -1,
492            Ordering::Equal => 0,
493            Ordering::Greater => 1,
494        },
495
496        (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
497            Ordering::Less => -1,
498            Ordering::Equal => 0,
499            Ordering::Greater => 1,
500        },
501
502        (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
503            Ordering::Less => -1,
504            Ordering::Equal => 0,
505            Ordering::Greater => 1,
506        },
507
508        (ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub)) => {
509            let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
510            let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
511            match a_sec.cmp(&b_sec) {
512                Ordering::Less => -1,
513                Ordering::Equal => 0,
514                Ordering::Greater => 1,
515            }
516        }
517
518        (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
519            Ordering::Less => -1,
520            Ordering::Equal => 0,
521            Ordering::Greater => 1,
522        },
523
524        (ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub)) => {
525            if a.lemma_type != b.lemma_type {
526                unreachable!(
527                    "BUG: lit_cmp compared different scale types ({} vs {})",
528                    a.lemma_type.name(),
529                    b.lemma_type.name()
530                );
531            }
532
533            if lua.eq_ignore_ascii_case(lub) {
534                return match la.cmp(lb) {
535                    Ordering::Less => -1,
536                    Ordering::Equal => 0,
537                    Ordering::Greater => 1,
538                };
539            }
540
541            // Convert b to a's unit for comparison
542            let target = SemanticConversionTarget::ScaleUnit(lua.clone());
543            let converted = crate::computation::convert_unit(b, &target);
544            let converted_value = match converted {
545                OperationResult::Value(lit) => match lit.value {
546                    ValueKind::Scale(v, _) => v,
547                    _ => unreachable!("BUG: scale unit conversion returned non-scale value"),
548                },
549                OperationResult::Veto(msg) => {
550                    unreachable!("BUG: scale unit conversion vetoed unexpectedly: {:?}", msg)
551                }
552            };
553
554            match la.cmp(&converted_value) {
555                Ordering::Less => -1,
556                Ordering::Equal => 0,
557                Ordering::Greater => 1,
558            }
559        }
560
561        _ => unreachable!(
562            "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
563            a.get_type(),
564            b.get_type()
565        ),
566    }
567}
568
569fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
570    let ge_min = match min {
571        Bound::Unbounded => true,
572        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
573        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
574    };
575    let le_max = match max {
576        Bound::Unbounded => true,
577        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
578        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
579    };
580    ge_min && le_max
581}
582
583fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
584    match (min, max) {
585        (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
586        (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
587        (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
588        (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
589        (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
590    }
591}
592
593fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
594    match (min1, min2) {
595        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
596        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
597            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
598                Bound::Inclusive(v1)
599            } else {
600                Bound::Inclusive(v2)
601            }
602        }
603        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
604            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
605                Bound::Inclusive(v1)
606            } else {
607                Bound::Exclusive(v2)
608            }
609        }
610        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
611            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
612                Bound::Exclusive(v1)
613            } else {
614                Bound::Inclusive(v2)
615            }
616        }
617        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
618            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
619                Bound::Exclusive(v1)
620            } else {
621                Bound::Exclusive(v2)
622            }
623        }
624    }
625}
626
627fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
628    match (max1, max2) {
629        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
630        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
631            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
632                Bound::Inclusive(v1)
633            } else {
634                Bound::Inclusive(v2)
635            }
636        }
637        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
638            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
639                Bound::Inclusive(v1)
640            } else {
641                Bound::Exclusive(v2)
642            }
643        }
644        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
645            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
646                Bound::Exclusive(v1)
647            } else {
648                Bound::Inclusive(v2)
649            }
650        }
651        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
652            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
653                Bound::Exclusive(v1)
654            } else {
655                Bound::Exclusive(v2)
656            }
657        }
658    }
659}
660
661fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
662    let a = normalize_domain(a);
663    let b = normalize_domain(b);
664
665    let result = match (a, b) {
666        (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
667        (Domain::Empty, _) | (_, Domain::Empty) => None,
668
669        (
670            Domain::Range {
671                min: min1,
672                max: max1,
673            },
674            Domain::Range {
675                min: min2,
676                max: max2,
677            },
678        ) => {
679            let min = compute_intersection_min(min1, min2);
680            let max = compute_intersection_max(max1, max2);
681
682            if bounds_contradict(&min, &max) {
683                None
684            } else {
685                Some(Domain::Range { min, max })
686            }
687        }
688        (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
689            let filtered: Vec<LiteralValue> =
690                v1.iter().filter(|x| v2.contains(x)).cloned().collect();
691            if filtered.is_empty() {
692                None
693            } else {
694                Some(Domain::Enumeration(Arc::new(filtered)))
695            }
696        }
697        (Domain::Enumeration(vs), Domain::Range { min, max })
698        | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
699            let mut kept = Vec::new();
700            for v in vs.iter() {
701                if value_within(v, &min, &max) {
702                    kept.push(v.clone());
703                }
704            }
705            if kept.is_empty() {
706                None
707            } else {
708                Some(Domain::Enumeration(Arc::new(kept)))
709            }
710        }
711        (Domain::Enumeration(vs), Domain::Complement(inner))
712        | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
713            match *inner.clone() {
714                Domain::Enumeration(excluded) => {
715                    let mut kept = Vec::new();
716                    for v in vs.iter() {
717                        if !excluded.contains(v) {
718                            kept.push(v.clone());
719                        }
720                    }
721                    if kept.is_empty() {
722                        None
723                    } else {
724                        Some(Domain::Enumeration(Arc::new(kept)))
725                    }
726                }
727                Domain::Range { min, max } => {
728                    // Filter enumeration values that are NOT in the range
729                    let mut kept = Vec::new();
730                    for v in vs.iter() {
731                        if !value_within(v, &min, &max) {
732                            kept.push(v.clone());
733                        }
734                    }
735                    if kept.is_empty() {
736                        None
737                    } else {
738                        Some(Domain::Enumeration(Arc::new(kept)))
739                    }
740                }
741                _ => {
742                    // For other complement types, normalize and recurse
743                    let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
744                    domain_intersection(Domain::Enumeration(vs.clone()), normalized)
745                }
746            }
747        }
748        (Domain::Union(v1), Domain::Union(v2)) => {
749            let mut acc: Vec<Domain> = Vec::new();
750            for a in v1.iter() {
751                for b in v2.iter() {
752                    if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
753                        acc.push(ix);
754                    }
755                }
756            }
757            if acc.is_empty() {
758                None
759            } else {
760                Some(Domain::Union(Arc::new(acc)))
761            }
762        }
763        (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
764            let mut acc: Vec<Domain> = Vec::new();
765            for a in vs.iter() {
766                if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
767                    acc.push(ix);
768                }
769            }
770            if acc.is_empty() {
771                None
772            } else if acc.len() == 1 {
773                Some(acc.remove(0))
774            } else {
775                Some(Domain::Union(Arc::new(acc)))
776            }
777        }
778        // Range ∩ not({p1,p2,...})  =>  Range with excluded points removed (as union of ranges)
779        (Domain::Range { min, max }, Domain::Complement(inner))
780        | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
781            Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
782            _ => {
783                // Normalize the complement (not just the inner value) and recurse.
784                // If normalization doesn't change it, we must not recurse infinitely.
785                let normalized_complement = normalize_domain(Domain::Complement(inner));
786                if matches!(&normalized_complement, Domain::Complement(_)) {
787                    None
788                } else {
789                    domain_intersection(Domain::Range { min, max }, normalized_complement)
790                }
791            }
792        },
793        (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
794            match (a_inner.as_ref(), b_inner.as_ref()) {
795                (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
796                    // not(A) ∩ not(B) == not(A ∪ B)
797                    let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
798                    excluded.extend(b_ex.iter().cloned());
799                    Some(normalize_domain(Domain::Complement(Box::new(
800                        Domain::Enumeration(Arc::new(excluded)),
801                    ))))
802                }
803                _ => None,
804            }
805        }
806    };
807    result.map(normalize_domain)
808}
809
810fn range_minus_excluded_points(
811    min: Bound,
812    max: Bound,
813    excluded: &Arc<Vec<LiteralValue>>,
814) -> Option<Domain> {
815    // Start with a single range and iteratively split on excluded points that fall within it.
816    let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
817
818    for p in excluded.iter() {
819        let mut next: Vec<(Bound, Bound)> = Vec::new();
820
821        for (rmin, rmax) in parts {
822            if !value_within(p, &rmin, &rmax) {
823                next.push((rmin, rmax));
824                continue;
825            }
826
827            // Left part: [rmin, p) or [rmin, p] depending on rmin and exclusion
828            let left_max = Bound::Exclusive(Arc::new(p.clone()));
829            if !bounds_contradict(&rmin, &left_max) {
830                next.push((rmin.clone(), left_max));
831            }
832
833            // Right part: (p, rmax)
834            let right_min = Bound::Exclusive(Arc::new(p.clone()));
835            if !bounds_contradict(&right_min, &rmax) {
836                next.push((right_min, rmax.clone()));
837            }
838        }
839
840        parts = next;
841        if parts.is_empty() {
842            return None;
843        }
844    }
845
846    if parts.is_empty() {
847        None
848    } else if parts.len() == 1 {
849        let (min, max) = parts.remove(0);
850        Some(Domain::Range { min, max })
851    } else {
852        Some(Domain::Union(Arc::new(
853            parts
854                .into_iter()
855                .map(|(min, max)| Domain::Range { min, max })
856                .collect(),
857        )))
858    }
859}
860
861fn invert_bound(bound: Bound) -> Bound {
862    match bound {
863        Bound::Unbounded => Bound::Unbounded,
864        Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
865        Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
866    }
867}
868
869fn normalize_domain(d: Domain) -> Domain {
870    match d {
871        Domain::Complement(inner) => {
872            let normalized_inner = normalize_domain(*inner);
873            match normalized_inner {
874                Domain::Complement(double_inner) => *double_inner,
875                Domain::Range { min, max } => match (&min, &max) {
876                    (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
877                    (Bound::Unbounded, max) => Domain::Range {
878                        min: invert_bound(max.clone()),
879                        max: Bound::Unbounded,
880                    },
881                    (min, Bound::Unbounded) => Domain::Range {
882                        min: Bound::Unbounded,
883                        max: invert_bound(min.clone()),
884                    },
885                    (min, max) => Domain::Union(Arc::new(vec![
886                        Domain::Range {
887                            min: Bound::Unbounded,
888                            max: invert_bound(min.clone()),
889                        },
890                        Domain::Range {
891                            min: invert_bound(max.clone()),
892                            max: Bound::Unbounded,
893                        },
894                    ])),
895                },
896                Domain::Enumeration(vals) => {
897                    if vals.len() == 1 {
898                        if let Some(lit) = vals.first() {
899                            if let ValueKind::Boolean(true) = &lit.value {
900                                return Domain::Enumeration(Arc::new(vec![
901                                    LiteralValue::from_bool(false),
902                                ]));
903                            }
904                            if let ValueKind::Boolean(false) = &lit.value {
905                                return Domain::Enumeration(Arc::new(vec![
906                                    LiteralValue::from_bool(true),
907                                ]));
908                            }
909                        }
910                    }
911                    Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
912                }
913                Domain::Unconstrained => Domain::Empty,
914                Domain::Empty => Domain::Unconstrained,
915                Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
916            }
917        }
918        Domain::Empty => Domain::Empty,
919        Domain::Union(parts) => {
920            let mut flat: Vec<Domain> = Vec::new();
921            for p in parts.iter().cloned() {
922                let normalized = normalize_domain(p);
923                match normalized {
924                    Domain::Union(inner) => flat.extend(inner.iter().cloned()),
925                    Domain::Unconstrained => return Domain::Unconstrained,
926                    Domain::Enumeration(vals) if vals.is_empty() => {}
927                    other => flat.push(other),
928                }
929            }
930
931            let mut all_enum_values: Vec<LiteralValue> = Vec::new();
932            let mut ranges: Vec<Domain> = Vec::new();
933            let mut others: Vec<Domain> = Vec::new();
934
935            for domain in flat {
936                match domain {
937                    Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
938                    Domain::Range { .. } => ranges.push(domain),
939                    other => others.push(other),
940                }
941            }
942
943            all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
944                -1 => Ordering::Less,
945                0 => Ordering::Equal,
946                _ => Ordering::Greater,
947            });
948            all_enum_values.dedup();
949
950            all_enum_values.retain(|v| {
951                !ranges.iter().any(|r| {
952                    if let Domain::Range { min, max } = r {
953                        value_within(v, min, max)
954                    } else {
955                        false
956                    }
957                })
958            });
959
960            let mut result: Vec<Domain> = Vec::new();
961            result.extend(ranges);
962            result = merge_ranges(result);
963
964            if !all_enum_values.is_empty() {
965                result.push(Domain::Enumeration(Arc::new(all_enum_values)));
966            }
967            result.extend(others);
968
969            result.sort_by(|a, b| match (a, b) {
970                (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
971                (Domain::Range { .. }, _) => Ordering::Less,
972                (_, Domain::Range { .. }) => Ordering::Greater,
973                (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
974                (Domain::Enumeration(_), _) => Ordering::Less,
975                (_, Domain::Enumeration(_)) => Ordering::Greater,
976                _ => Ordering::Equal,
977            });
978
979            if result.is_empty() {
980                Domain::Enumeration(Arc::new(vec![]))
981            } else if result.len() == 1 {
982                result.remove(0)
983            } else {
984                Domain::Union(Arc::new(result))
985            }
986        }
987        Domain::Enumeration(values) => {
988            let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
989            sorted.sort_by(|a, b| match lit_cmp(a, b) {
990                -1 => Ordering::Less,
991                0 => Ordering::Equal,
992                _ => Ordering::Greater,
993            });
994            sorted.dedup();
995            Domain::Enumeration(Arc::new(sorted))
996        }
997        other => other,
998    }
999}
1000
1001fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1002    let mut result = Vec::new();
1003    let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1004    let mut others = Vec::new();
1005
1006    for d in domains {
1007        match d {
1008            Domain::Range { min, max } => ranges.push((min, max)),
1009            other => others.push(other),
1010        }
1011    }
1012
1013    if ranges.is_empty() {
1014        return others;
1015    }
1016
1017    ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1018
1019    let mut merged: Vec<(Bound, Bound)> = Vec::new();
1020    let mut current = ranges[0].clone();
1021
1022    for next in ranges.iter().skip(1) {
1023        if ranges_adjacent_or_overlap(&current, next) {
1024            current = (
1025                min_bound(&current.0, &next.0),
1026                max_bound(&current.1, &next.1),
1027            );
1028        } else {
1029            merged.push(current);
1030            current = next.clone();
1031        }
1032    }
1033    merged.push(current);
1034
1035    for (min, max) in merged {
1036        result.push(Domain::Range { min, max });
1037    }
1038    result.extend(others);
1039
1040    result
1041}
1042
1043fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1044    match (a, b) {
1045        (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1046        (Bound::Unbounded, _) => Ordering::Less,
1047        (_, Bound::Unbounded) => Ordering::Greater,
1048        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1049        | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1050            -1 => Ordering::Less,
1051            0 => Ordering::Equal,
1052            _ => Ordering::Greater,
1053        },
1054        (Bound::Inclusive(v1), Bound::Exclusive(v2))
1055        | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1056            -1 => Ordering::Less,
1057            0 => {
1058                if matches!(a, Bound::Inclusive(_)) {
1059                    Ordering::Less
1060                } else {
1061                    Ordering::Greater
1062                }
1063            }
1064            _ => Ordering::Greater,
1065        },
1066    }
1067}
1068
1069fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1070    match (&r1.1, &r2.0) {
1071        (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1072        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1073        | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1074        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1075        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1076    }
1077}
1078
1079fn min_bound(a: &Bound, b: &Bound) -> Bound {
1080    match (a, b) {
1081        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1082        _ => {
1083            if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1084                a.clone()
1085            } else {
1086                b.clone()
1087            }
1088        }
1089    }
1090}
1091
1092fn max_bound(a: &Bound, b: &Bound) -> Bound {
1093    match (a, b) {
1094        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1095        _ => {
1096            if matches!(compare_bounds(a, b), Ordering::Greater) {
1097                a.clone()
1098            } else {
1099                b.clone()
1100            }
1101        }
1102    }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108    use rust_decimal::Decimal;
1109
1110    fn num(n: i64) -> LiteralValue {
1111        LiteralValue::number(Decimal::from(n))
1112    }
1113
1114    fn fact(name: &str) -> FactPath {
1115        FactPath::new(vec![], name.to_string())
1116    }
1117
1118    #[test]
1119    fn test_normalize_double_complement() {
1120        let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1121        let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1122        let normalized = normalize_domain(double);
1123        assert_eq!(normalized, inner);
1124    }
1125
1126    #[test]
1127    fn test_normalize_union_absorbs_unconstrained() {
1128        let union = Domain::Union(Arc::new(vec![
1129            Domain::Range {
1130                min: Bound::Inclusive(Arc::new(num(0))),
1131                max: Bound::Inclusive(Arc::new(num(10))),
1132            },
1133            Domain::Unconstrained,
1134        ]));
1135        let normalized = normalize_domain(union);
1136        assert_eq!(normalized, Domain::Unconstrained);
1137    }
1138
1139    #[test]
1140    fn test_domain_display() {
1141        let range = Domain::Range {
1142            min: Bound::Inclusive(Arc::new(num(10))),
1143            max: Bound::Exclusive(Arc::new(num(20))),
1144        };
1145        assert_eq!(format!("{}", range), "[10, 20)");
1146
1147        let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1148        assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1149    }
1150
1151    #[test]
1152    fn test_extract_domain_from_comparison() {
1153        let constraint = Constraint::Comparison {
1154            fact: fact("age"),
1155            op: ComparisonComputation::GreaterThan,
1156            value: Arc::new(num(18)),
1157        };
1158
1159        let domains = extract_domains_from_constraint(&constraint).unwrap();
1160        let age_domain = domains.get(&fact("age")).unwrap();
1161
1162        assert_eq!(
1163            *age_domain,
1164            Domain::Range {
1165                min: Bound::Exclusive(Arc::new(num(18))),
1166                max: Bound::Unbounded,
1167            }
1168        );
1169    }
1170
1171    #[test]
1172    fn test_extract_domain_from_and() {
1173        let constraint = Constraint::And(
1174            Box::new(Constraint::Comparison {
1175                fact: fact("age"),
1176                op: ComparisonComputation::GreaterThan,
1177                value: Arc::new(num(18)),
1178            }),
1179            Box::new(Constraint::Comparison {
1180                fact: fact("age"),
1181                op: ComparisonComputation::LessThan,
1182                value: Arc::new(num(65)),
1183            }),
1184        );
1185
1186        let domains = extract_domains_from_constraint(&constraint).unwrap();
1187        let age_domain = domains.get(&fact("age")).unwrap();
1188
1189        assert_eq!(
1190            *age_domain,
1191            Domain::Range {
1192                min: Bound::Exclusive(Arc::new(num(18))),
1193                max: Bound::Exclusive(Arc::new(num(65))),
1194            }
1195        );
1196    }
1197
1198    #[test]
1199    fn test_extract_domain_from_equality() {
1200        let constraint = Constraint::Comparison {
1201            fact: fact("status"),
1202            op: ComparisonComputation::Equal,
1203            value: Arc::new(LiteralValue::text("active".to_string())),
1204        };
1205
1206        let domains = extract_domains_from_constraint(&constraint).unwrap();
1207        let status_domain = domains.get(&fact("status")).unwrap();
1208
1209        assert_eq!(
1210            *status_domain,
1211            Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1212        );
1213    }
1214
1215    #[test]
1216    fn test_extract_domain_from_boolean_fact() {
1217        let constraint = Constraint::Fact(fact("is_active"));
1218
1219        let domains = extract_domains_from_constraint(&constraint).unwrap();
1220        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1221
1222        assert_eq!(
1223            *is_active_domain,
1224            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1225        );
1226    }
1227
1228    #[test]
1229    fn test_extract_domain_from_not_boolean_fact() {
1230        let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1231
1232        let domains = extract_domains_from_constraint(&constraint).unwrap();
1233        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1234
1235        assert_eq!(
1236            *is_active_domain,
1237            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1238        );
1239    }
1240}