Skip to main content

lemma/inversion/
domain.rs

1//! Domain types and operations for inversion
2//!
3//! Provides:
4//! - `Domain` and `Bound` types for representing concrete value constraints
5//! - Domain operations: intersection, union, normalization
6//! - `extract_domains_from_constraint()`: extracts domains from constraints
7
8use crate::planning::semantics::{
9    ComparisonComputation, FactPath, LiteralValue, SemanticConversionTarget, ValueKind,
10};
11use crate::{LemmaResult, OperationResult};
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20/// Domain specification for valid values
21#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23    /// A single continuous range
24    Range { min: Bound, max: Bound },
25
26    /// Multiple disjoint ranges
27    Union(Arc<Vec<Domain>>),
28
29    /// Specific enumerated values only
30    Enumeration(Arc<Vec<LiteralValue>>),
31
32    /// Everything except these constraints
33    Complement(Box<Domain>),
34
35    /// Any value (no constraints)
36    Unconstrained,
37
38    /// Empty domain (no valid values) - represents unsatisfiable constraints
39    Empty,
40}
41
42impl Domain {
43    /// Check if this domain is satisfiable (has at least one valid value)
44    ///
45    /// Returns false for Empty domains and empty Enumerations.
46    pub fn is_satisfiable(&self) -> bool {
47        match self {
48            Domain::Empty => false,
49            Domain::Enumeration(values) => !values.is_empty(),
50            Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51            Domain::Range { min, max } => !bounds_contradict(min, max),
52            Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53            Domain::Unconstrained => true,
54        }
55    }
56
57    /// Check if this domain is empty (unsatisfiable)
58    pub fn is_empty(&self) -> bool {
59        !self.is_satisfiable()
60    }
61
62    /// Intersect this domain with another, returning Empty if no overlap
63    pub fn intersect(&self, other: &Domain) -> Domain {
64        domain_intersection(self.clone(), other.clone()).unwrap_or(Domain::Empty)
65    }
66
67    /// Check if a value is contained in this domain
68    pub fn contains(&self, value: &LiteralValue) -> bool {
69        match self {
70            Domain::Empty => false,
71            Domain::Unconstrained => true,
72            Domain::Enumeration(values) => values.contains(value),
73            Domain::Range { min, max } => value_within(value, min, max),
74            Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
75            Domain::Complement(inner) => !inner.contains(value),
76        }
77    }
78}
79
80/// Bound specification for ranges
81#[derive(Debug, Clone, PartialEq)]
82pub enum Bound {
83    /// Inclusive bound [value
84    Inclusive(Arc<LiteralValue>),
85
86    /// Exclusive bound (value
87    Exclusive(Arc<LiteralValue>),
88
89    /// Unbounded (-infinity or +infinity)
90    Unbounded,
91}
92
93impl fmt::Display for Domain {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            Domain::Empty => write!(f, "empty"),
97            Domain::Unconstrained => write!(f, "any"),
98            Domain::Enumeration(vals) => {
99                write!(f, "{{")?;
100                for (i, v) in vals.iter().enumerate() {
101                    if i > 0 {
102                        write!(f, ", ")?;
103                    }
104                    write!(f, "{}", v)?;
105                }
106                write!(f, "}}")
107            }
108            Domain::Range { min, max } => {
109                let (l_bracket, r_bracket) = match (min, max) {
110                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
111                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
112                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
113                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
114                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
115                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
116                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
117                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
118                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
119                };
120
121                let min_str = match min {
122                    Bound::Unbounded => "-inf".to_string(),
123                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
124                };
125                let max_str = match max {
126                    Bound::Unbounded => "+inf".to_string(),
127                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128                };
129                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
130            }
131            Domain::Union(parts) => {
132                for (i, p) in parts.iter().enumerate() {
133                    if i > 0 {
134                        write!(f, " | ")?;
135                    }
136                    write!(f, "{}", p)?;
137                }
138                Ok(())
139            }
140            Domain::Complement(inner) => write!(f, "not ({})", inner),
141        }
142    }
143}
144
145impl fmt::Display for Bound {
146    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147        match self {
148            Bound::Unbounded => write!(f, "inf"),
149            Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
150            Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
151        }
152    }
153}
154
155impl Serialize for Domain {
156    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
157    where
158        S: Serializer,
159    {
160        match self {
161            Domain::Empty => {
162                let mut st = serializer.serialize_struct("domain", 1)?;
163                st.serialize_field("type", "empty")?;
164                st.end()
165            }
166            Domain::Unconstrained => {
167                let mut st = serializer.serialize_struct("domain", 1)?;
168                st.serialize_field("type", "unconstrained")?;
169                st.end()
170            }
171            Domain::Enumeration(vals) => {
172                let mut st = serializer.serialize_struct("domain", 2)?;
173                st.serialize_field("type", "enumeration")?;
174                st.serialize_field("values", vals)?;
175                st.end()
176            }
177            Domain::Range { min, max } => {
178                let mut st = serializer.serialize_struct("domain", 3)?;
179                st.serialize_field("type", "range")?;
180                st.serialize_field("min", min)?;
181                st.serialize_field("max", max)?;
182                st.end()
183            }
184            Domain::Union(parts) => {
185                let mut st = serializer.serialize_struct("domain", 2)?;
186                st.serialize_field("type", "union")?;
187                st.serialize_field("parts", parts)?;
188                st.end()
189            }
190            Domain::Complement(inner) => {
191                let mut st = serializer.serialize_struct("domain", 2)?;
192                st.serialize_field("type", "complement")?;
193                st.serialize_field("inner", inner)?;
194                st.end()
195            }
196        }
197    }
198}
199
200impl Serialize for Bound {
201    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
202    where
203        S: Serializer,
204    {
205        match self {
206            Bound::Unbounded => {
207                let mut st = serializer.serialize_struct("bound", 1)?;
208                st.serialize_field("type", "unbounded")?;
209                st.end()
210            }
211            Bound::Inclusive(v) => {
212                let mut st = serializer.serialize_struct("bound", 2)?;
213                st.serialize_field("type", "inclusive")?;
214                st.serialize_field("value", v.as_ref())?;
215                st.end()
216            }
217            Bound::Exclusive(v) => {
218                let mut st = serializer.serialize_struct("bound", 2)?;
219                st.serialize_field("type", "exclusive")?;
220                st.serialize_field("value", v.as_ref())?;
221                st.end()
222            }
223        }
224    }
225}
226
227/// Extract domains for all facts mentioned in a constraint
228pub fn extract_domains_from_constraint(
229    constraint: &Constraint,
230) -> LemmaResult<HashMap<FactPath, Domain>> {
231    let all_facts = constraint.collect_facts();
232    let mut domains = HashMap::new();
233
234    for fact_path in all_facts {
235        let domain =
236            extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
237        domains.insert(fact_path, domain);
238    }
239
240    Ok(domains)
241}
242
243fn extract_domain_for_fact(
244    constraint: &Constraint,
245    fact_path: &FactPath,
246) -> LemmaResult<Option<Domain>> {
247    let domain = match constraint {
248        Constraint::True => return Ok(None),
249        Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
250
251        Constraint::Comparison { fact, op, value } => {
252            if fact == fact_path {
253                Some(comparison_to_domain(op, value.as_ref())?)
254            } else {
255                None
256            }
257        }
258
259        Constraint::Fact(fp) => {
260            if fp == fact_path {
261                Some(Domain::Enumeration(Arc::new(vec![
262                    LiteralValue::from_bool(true),
263                ])))
264            } else {
265                None
266            }
267        }
268
269        Constraint::And(left, right) => {
270            let left_domain = extract_domain_for_fact(left, fact_path)?;
271            let right_domain = extract_domain_for_fact(right, fact_path)?;
272            match (left_domain, right_domain) {
273                (None, None) => None,
274                (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
275                (Some(a), Some(b)) => match domain_intersection(a, b) {
276                    Some(domain) => Some(domain),
277                    None => Some(Domain::Enumeration(Arc::new(vec![]))),
278                },
279            }
280        }
281
282        Constraint::Or(left, right) => {
283            let left_domain = extract_domain_for_fact(left, fact_path)?;
284            let right_domain = extract_domain_for_fact(right, fact_path)?;
285            union_optional_domains(left_domain, right_domain)
286        }
287
288        Constraint::Not(inner) => {
289            // Handle not (fact == value)
290            if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
291                if fact == fact_path && op.is_equal() {
292                    return Ok(Some(normalize_domain(Domain::Complement(Box::new(
293                        Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
294                    )))));
295                }
296            }
297
298            // Handle not (boolean_fact)
299            if let Constraint::Fact(fp) = inner.as_ref() {
300                if fp == fact_path {
301                    return Ok(Some(Domain::Enumeration(Arc::new(vec![
302                        LiteralValue::from_bool(false),
303                    ]))));
304                }
305            }
306
307            extract_domain_for_fact(inner, fact_path)?
308                .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
309        }
310    };
311
312    Ok(domain.map(normalize_domain))
313}
314
315fn comparison_to_domain(op: &ComparisonComputation, value: &LiteralValue) -> LemmaResult<Domain> {
316    if op.is_equal() {
317        return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
318    }
319    if op.is_not_equal() {
320        return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
321            vec![value.clone()],
322        )))));
323    }
324    match op {
325        ComparisonComputation::LessThan => Ok(Domain::Range {
326            min: Bound::Unbounded,
327            max: Bound::Exclusive(Arc::new(value.clone())),
328        }),
329        ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
330            min: Bound::Unbounded,
331            max: Bound::Inclusive(Arc::new(value.clone())),
332        }),
333        ComparisonComputation::GreaterThan => Ok(Domain::Range {
334            min: Bound::Exclusive(Arc::new(value.clone())),
335            max: Bound::Unbounded,
336        }),
337        ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
338            min: Bound::Inclusive(Arc::new(value.clone())),
339            max: Bound::Unbounded,
340        }),
341        _ => unreachable!(
342            "BUG: unsupported comparison operator for domain extraction: {:?}",
343            op
344        ),
345    }
346}
347
348/// Compute the domain for a single comparison-atom used by inversion constraints.
349///
350/// This is used by numeric-aware constraint simplification to derive implications/exclusions
351/// between comparison atoms on the same fact.
352pub(crate) fn domain_for_comparison_atom(
353    op: &ComparisonComputation,
354    value: &LiteralValue,
355) -> LemmaResult<Domain> {
356    comparison_to_domain(op, value)
357}
358
359impl Domain {
360    /// Proven subset check for the atom-domain forms we generate from comparisons:
361    /// - Range
362    /// - Enumeration
363    /// - Complement(Enumeration) (used for != / is not)
364    ///
365    /// Returns false when the relationship cannot be proven with these forms.
366    pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
367        match (self, other) {
368            (Domain::Empty, _) => true,
369            (_, Domain::Unconstrained) => true,
370            (Domain::Unconstrained, _) => false,
371
372            (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
373            (Domain::Enumeration(vals), Domain::Range { min, max }) => {
374                vals.iter().all(|v| value_within(v, min, max))
375            }
376
377            (
378                Domain::Range {
379                    min: amin,
380                    max: amax,
381                },
382                Domain::Range {
383                    min: bmin,
384                    max: bmax,
385                },
386            ) => range_within_range(amin, amax, bmin, bmax),
387
388            // Range ⊆ not({p}) when the range does not include p (for all excluded points)
389            (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
390                Domain::Enumeration(excluded) => {
391                    excluded.iter().all(|p| !value_within(p, min, max))
392                }
393                _ => false,
394            },
395
396            // {v} ⊆ not({p}) when v is not excluded
397            (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
398                Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
399                _ => false,
400            },
401
402            // not(A) ⊆ not(B)  iff  B ⊆ A  (for enumeration complements)
403            (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
404                match (a_inner.as_ref(), b_inner.as_ref()) {
405                    (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
406                        excluded_b.iter().all(|v| excluded_a.contains(v))
407                    }
408                    _ => false,
409                }
410            }
411
412            _ => false,
413        }
414    }
415}
416
417fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
418    lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
419}
420
421fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
422    match (a, b) {
423        (_, Bound::Unbounded) => true,
424        (Bound::Unbounded, _) => false,
425        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
426        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
427        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
428            let c = lit_cmp(av.as_ref(), bv.as_ref());
429            c >= 0
430        }
431        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
432            // a >= (b) only if a's value > b's value
433            lit_cmp(av.as_ref(), bv.as_ref()) > 0
434        }
435    }
436}
437
438fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
439    match (a, b) {
440        (Bound::Unbounded, Bound::Unbounded) => true,
441        (_, Bound::Unbounded) => true,
442        (Bound::Unbounded, _) => false,
443        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
444        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
445        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
446            // (a) <= [b] when a <= b
447            lit_cmp(av.as_ref(), bv.as_ref()) <= 0
448        }
449        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
450            // [a] <= (b) only if a < b
451            lit_cmp(av.as_ref(), bv.as_ref()) < 0
452        }
453    }
454}
455
456fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
457    match (a, b) {
458        (None, None) => None,
459        (Some(d), None) | (None, Some(d)) => Some(d),
460        (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
461    }
462}
463
464fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
465    use std::cmp::Ordering;
466
467    match (&a.value, &b.value) {
468        (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
469            Ordering::Less => -1,
470            Ordering::Equal => 0,
471            Ordering::Greater => 1,
472        },
473
474        (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
475            Ordering::Less => -1,
476            Ordering::Equal => 0,
477            Ordering::Greater => 1,
478        },
479
480        (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
481            Ordering::Less => -1,
482            Ordering::Equal => 0,
483            Ordering::Greater => 1,
484        },
485
486        (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
487            Ordering::Less => -1,
488            Ordering::Equal => 0,
489            Ordering::Greater => 1,
490        },
491
492        (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
493            Ordering::Less => -1,
494            Ordering::Equal => 0,
495            Ordering::Greater => 1,
496        },
497
498        (ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub)) => {
499            let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
500            let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
501            match a_sec.cmp(&b_sec) {
502                Ordering::Less => -1,
503                Ordering::Equal => 0,
504                Ordering::Greater => 1,
505            }
506        }
507
508        (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
509            Ordering::Less => -1,
510            Ordering::Equal => 0,
511            Ordering::Greater => 1,
512        },
513
514        (ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub)) => {
515            if a.lemma_type != b.lemma_type {
516                unreachable!(
517                    "BUG: lit_cmp compared different scale types ({} vs {})",
518                    a.lemma_type.name(),
519                    b.lemma_type.name()
520                );
521            }
522
523            if lua.eq_ignore_ascii_case(lub) {
524                return match la.cmp(lb) {
525                    Ordering::Less => -1,
526                    Ordering::Equal => 0,
527                    Ordering::Greater => 1,
528                };
529            }
530
531            // Convert b to a's unit for comparison
532            let target = SemanticConversionTarget::ScaleUnit(lua.clone());
533            let converted = crate::computation::convert_unit(b, &target);
534            let converted_value = match converted {
535                OperationResult::Value(lit) => match lit.value {
536                    ValueKind::Scale(v, _) => v,
537                    _ => unreachable!("BUG: scale unit conversion returned non-scale value"),
538                },
539                OperationResult::Veto(msg) => {
540                    unreachable!("BUG: scale unit conversion vetoed unexpectedly: {:?}", msg)
541                }
542            };
543
544            match la.cmp(&converted_value) {
545                Ordering::Less => -1,
546                Ordering::Equal => 0,
547                Ordering::Greater => 1,
548            }
549        }
550
551        _ => unreachable!(
552            "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
553            a.get_type(),
554            b.get_type()
555        ),
556    }
557}
558
559fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
560    let ge_min = match min {
561        Bound::Unbounded => true,
562        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
563        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
564    };
565    let le_max = match max {
566        Bound::Unbounded => true,
567        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
568        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
569    };
570    ge_min && le_max
571}
572
573fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
574    match (min, max) {
575        (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
576        (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
577        (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
578        (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
579        (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
580    }
581}
582
583fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
584    match (min1, min2) {
585        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
586        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
587            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
588                Bound::Inclusive(v1)
589            } else {
590                Bound::Inclusive(v2)
591            }
592        }
593        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
594            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
595                Bound::Inclusive(v1)
596            } else {
597                Bound::Exclusive(v2)
598            }
599        }
600        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
601            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
602                Bound::Exclusive(v1)
603            } else {
604                Bound::Inclusive(v2)
605            }
606        }
607        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
608            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
609                Bound::Exclusive(v1)
610            } else {
611                Bound::Exclusive(v2)
612            }
613        }
614    }
615}
616
617fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
618    match (max1, max2) {
619        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
620        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
621            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
622                Bound::Inclusive(v1)
623            } else {
624                Bound::Inclusive(v2)
625            }
626        }
627        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
628            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
629                Bound::Inclusive(v1)
630            } else {
631                Bound::Exclusive(v2)
632            }
633        }
634        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
635            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
636                Bound::Exclusive(v1)
637            } else {
638                Bound::Inclusive(v2)
639            }
640        }
641        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
642            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
643                Bound::Exclusive(v1)
644            } else {
645                Bound::Exclusive(v2)
646            }
647        }
648    }
649}
650
651fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
652    let a = normalize_domain(a);
653    let b = normalize_domain(b);
654
655    let result = match (a, b) {
656        (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
657        (Domain::Empty, _) | (_, Domain::Empty) => None,
658
659        (
660            Domain::Range {
661                min: min1,
662                max: max1,
663            },
664            Domain::Range {
665                min: min2,
666                max: max2,
667            },
668        ) => {
669            let min = compute_intersection_min(min1, min2);
670            let max = compute_intersection_max(max1, max2);
671
672            if bounds_contradict(&min, &max) {
673                None
674            } else {
675                Some(Domain::Range { min, max })
676            }
677        }
678        (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
679            let filtered: Vec<LiteralValue> =
680                v1.iter().filter(|x| v2.contains(x)).cloned().collect();
681            if filtered.is_empty() {
682                None
683            } else {
684                Some(Domain::Enumeration(Arc::new(filtered)))
685            }
686        }
687        (Domain::Enumeration(vs), Domain::Range { min, max })
688        | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
689            let mut kept = Vec::new();
690            for v in vs.iter() {
691                if value_within(v, &min, &max) {
692                    kept.push(v.clone());
693                }
694            }
695            if kept.is_empty() {
696                None
697            } else {
698                Some(Domain::Enumeration(Arc::new(kept)))
699            }
700        }
701        (Domain::Enumeration(vs), Domain::Complement(inner))
702        | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
703            match *inner.clone() {
704                Domain::Enumeration(excluded) => {
705                    let mut kept = Vec::new();
706                    for v in vs.iter() {
707                        if !excluded.contains(v) {
708                            kept.push(v.clone());
709                        }
710                    }
711                    if kept.is_empty() {
712                        None
713                    } else {
714                        Some(Domain::Enumeration(Arc::new(kept)))
715                    }
716                }
717                Domain::Range { min, max } => {
718                    // Filter enumeration values that are NOT in the range
719                    let mut kept = Vec::new();
720                    for v in vs.iter() {
721                        if !value_within(v, &min, &max) {
722                            kept.push(v.clone());
723                        }
724                    }
725                    if kept.is_empty() {
726                        None
727                    } else {
728                        Some(Domain::Enumeration(Arc::new(kept)))
729                    }
730                }
731                _ => {
732                    // For other complement types, normalize and recurse
733                    let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
734                    domain_intersection(Domain::Enumeration(vs.clone()), normalized)
735                }
736            }
737        }
738        (Domain::Union(v1), Domain::Union(v2)) => {
739            let mut acc: Vec<Domain> = Vec::new();
740            for a in v1.iter() {
741                for b in v2.iter() {
742                    if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
743                        acc.push(ix);
744                    }
745                }
746            }
747            if acc.is_empty() {
748                None
749            } else {
750                Some(Domain::Union(Arc::new(acc)))
751            }
752        }
753        (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
754            let mut acc: Vec<Domain> = Vec::new();
755            for a in vs.iter() {
756                if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
757                    acc.push(ix);
758                }
759            }
760            if acc.is_empty() {
761                None
762            } else if acc.len() == 1 {
763                Some(acc.remove(0))
764            } else {
765                Some(Domain::Union(Arc::new(acc)))
766            }
767        }
768        // Range ∩ not({p1,p2,...})  =>  Range with excluded points removed (as union of ranges)
769        (Domain::Range { min, max }, Domain::Complement(inner))
770        | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
771            Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
772            _ => {
773                // Normalize the complement (not just the inner value) and recurse.
774                // If normalization doesn't change it, we must not recurse infinitely.
775                let normalized_complement = normalize_domain(Domain::Complement(inner));
776                if matches!(&normalized_complement, Domain::Complement(_)) {
777                    None
778                } else {
779                    domain_intersection(Domain::Range { min, max }, normalized_complement)
780                }
781            }
782        },
783        (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
784            match (a_inner.as_ref(), b_inner.as_ref()) {
785                (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
786                    // not(A) ∩ not(B) == not(A ∪ B)
787                    let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
788                    excluded.extend(b_ex.iter().cloned());
789                    Some(normalize_domain(Domain::Complement(Box::new(
790                        Domain::Enumeration(Arc::new(excluded)),
791                    ))))
792                }
793                _ => None,
794            }
795        }
796    };
797    result.map(normalize_domain)
798}
799
800fn range_minus_excluded_points(
801    min: Bound,
802    max: Bound,
803    excluded: &Arc<Vec<LiteralValue>>,
804) -> Option<Domain> {
805    // Start with a single range and iteratively split on excluded points that fall within it.
806    let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
807
808    for p in excluded.iter() {
809        let mut next: Vec<(Bound, Bound)> = Vec::new();
810
811        for (rmin, rmax) in parts {
812            if !value_within(p, &rmin, &rmax) {
813                next.push((rmin, rmax));
814                continue;
815            }
816
817            // Left part: [rmin, p) or [rmin, p] depending on rmin and exclusion
818            let left_max = Bound::Exclusive(Arc::new(p.clone()));
819            if !bounds_contradict(&rmin, &left_max) {
820                next.push((rmin.clone(), left_max));
821            }
822
823            // Right part: (p, rmax)
824            let right_min = Bound::Exclusive(Arc::new(p.clone()));
825            if !bounds_contradict(&right_min, &rmax) {
826                next.push((right_min, rmax.clone()));
827            }
828        }
829
830        parts = next;
831        if parts.is_empty() {
832            return None;
833        }
834    }
835
836    if parts.is_empty() {
837        None
838    } else if parts.len() == 1 {
839        let (min, max) = parts.remove(0);
840        Some(Domain::Range { min, max })
841    } else {
842        Some(Domain::Union(Arc::new(
843            parts
844                .into_iter()
845                .map(|(min, max)| Domain::Range { min, max })
846                .collect(),
847        )))
848    }
849}
850
851fn invert_bound(bound: Bound) -> Bound {
852    match bound {
853        Bound::Unbounded => Bound::Unbounded,
854        Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
855        Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
856    }
857}
858
859fn normalize_domain(d: Domain) -> Domain {
860    match d {
861        Domain::Complement(inner) => {
862            let normalized_inner = normalize_domain(*inner);
863            match normalized_inner {
864                Domain::Complement(double_inner) => *double_inner,
865                Domain::Range { min, max } => match (&min, &max) {
866                    (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
867                    (Bound::Unbounded, max) => Domain::Range {
868                        min: invert_bound(max.clone()),
869                        max: Bound::Unbounded,
870                    },
871                    (min, Bound::Unbounded) => Domain::Range {
872                        min: Bound::Unbounded,
873                        max: invert_bound(min.clone()),
874                    },
875                    (min, max) => Domain::Union(Arc::new(vec![
876                        Domain::Range {
877                            min: Bound::Unbounded,
878                            max: invert_bound(min.clone()),
879                        },
880                        Domain::Range {
881                            min: invert_bound(max.clone()),
882                            max: Bound::Unbounded,
883                        },
884                    ])),
885                },
886                Domain::Enumeration(vals) => {
887                    if vals.len() == 1 {
888                        if let Some(lit) = vals.first() {
889                            if let ValueKind::Boolean(true) = &lit.value {
890                                return Domain::Enumeration(Arc::new(vec![
891                                    LiteralValue::from_bool(false),
892                                ]));
893                            }
894                            if let ValueKind::Boolean(false) = &lit.value {
895                                return Domain::Enumeration(Arc::new(vec![
896                                    LiteralValue::from_bool(true),
897                                ]));
898                            }
899                        }
900                    }
901                    Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
902                }
903                Domain::Unconstrained => Domain::Empty,
904                Domain::Empty => Domain::Unconstrained,
905                Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
906            }
907        }
908        Domain::Empty => Domain::Empty,
909        Domain::Union(parts) => {
910            let mut flat: Vec<Domain> = Vec::new();
911            for p in parts.iter().cloned() {
912                let normalized = normalize_domain(p);
913                match normalized {
914                    Domain::Union(inner) => flat.extend(inner.iter().cloned()),
915                    Domain::Unconstrained => return Domain::Unconstrained,
916                    Domain::Enumeration(vals) if vals.is_empty() => {}
917                    other => flat.push(other),
918                }
919            }
920
921            let mut all_enum_values: Vec<LiteralValue> = Vec::new();
922            let mut ranges: Vec<Domain> = Vec::new();
923            let mut others: Vec<Domain> = Vec::new();
924
925            for domain in flat {
926                match domain {
927                    Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
928                    Domain::Range { .. } => ranges.push(domain),
929                    other => others.push(other),
930                }
931            }
932
933            all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
934                -1 => Ordering::Less,
935                0 => Ordering::Equal,
936                _ => Ordering::Greater,
937            });
938            all_enum_values.dedup();
939
940            all_enum_values.retain(|v| {
941                !ranges.iter().any(|r| {
942                    if let Domain::Range { min, max } = r {
943                        value_within(v, min, max)
944                    } else {
945                        false
946                    }
947                })
948            });
949
950            let mut result: Vec<Domain> = Vec::new();
951            result.extend(ranges);
952            result = merge_ranges(result);
953
954            if !all_enum_values.is_empty() {
955                result.push(Domain::Enumeration(Arc::new(all_enum_values)));
956            }
957            result.extend(others);
958
959            result.sort_by(|a, b| match (a, b) {
960                (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
961                (Domain::Range { .. }, _) => Ordering::Less,
962                (_, Domain::Range { .. }) => Ordering::Greater,
963                (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
964                (Domain::Enumeration(_), _) => Ordering::Less,
965                (_, Domain::Enumeration(_)) => Ordering::Greater,
966                _ => Ordering::Equal,
967            });
968
969            if result.is_empty() {
970                Domain::Enumeration(Arc::new(vec![]))
971            } else if result.len() == 1 {
972                result.remove(0)
973            } else {
974                Domain::Union(Arc::new(result))
975            }
976        }
977        Domain::Enumeration(values) => {
978            let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
979            sorted.sort_by(|a, b| match lit_cmp(a, b) {
980                -1 => Ordering::Less,
981                0 => Ordering::Equal,
982                _ => Ordering::Greater,
983            });
984            sorted.dedup();
985            Domain::Enumeration(Arc::new(sorted))
986        }
987        other => other,
988    }
989}
990
991fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
992    let mut result = Vec::new();
993    let mut ranges: Vec<(Bound, Bound)> = Vec::new();
994    let mut others = Vec::new();
995
996    for d in domains {
997        match d {
998            Domain::Range { min, max } => ranges.push((min, max)),
999            other => others.push(other),
1000        }
1001    }
1002
1003    if ranges.is_empty() {
1004        return others;
1005    }
1006
1007    ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1008
1009    let mut merged: Vec<(Bound, Bound)> = Vec::new();
1010    let mut current = ranges[0].clone();
1011
1012    for next in ranges.iter().skip(1) {
1013        if ranges_adjacent_or_overlap(&current, next) {
1014            current = (
1015                min_bound(&current.0, &next.0),
1016                max_bound(&current.1, &next.1),
1017            );
1018        } else {
1019            merged.push(current);
1020            current = next.clone();
1021        }
1022    }
1023    merged.push(current);
1024
1025    for (min, max) in merged {
1026        result.push(Domain::Range { min, max });
1027    }
1028    result.extend(others);
1029
1030    result
1031}
1032
1033fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1034    match (a, b) {
1035        (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1036        (Bound::Unbounded, _) => Ordering::Less,
1037        (_, Bound::Unbounded) => Ordering::Greater,
1038        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1039        | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1040            -1 => Ordering::Less,
1041            0 => Ordering::Equal,
1042            _ => Ordering::Greater,
1043        },
1044        (Bound::Inclusive(v1), Bound::Exclusive(v2))
1045        | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1046            -1 => Ordering::Less,
1047            0 => {
1048                if matches!(a, Bound::Inclusive(_)) {
1049                    Ordering::Less
1050                } else {
1051                    Ordering::Greater
1052                }
1053            }
1054            _ => Ordering::Greater,
1055        },
1056    }
1057}
1058
1059fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1060    match (&r1.1, &r2.0) {
1061        (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1062        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1063        | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1064        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1065        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1066    }
1067}
1068
1069fn min_bound(a: &Bound, b: &Bound) -> Bound {
1070    match (a, b) {
1071        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1072        _ => {
1073            if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1074                a.clone()
1075            } else {
1076                b.clone()
1077            }
1078        }
1079    }
1080}
1081
1082fn max_bound(a: &Bound, b: &Bound) -> Bound {
1083    match (a, b) {
1084        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1085        _ => {
1086            if matches!(compare_bounds(a, b), Ordering::Greater) {
1087                a.clone()
1088            } else {
1089                b.clone()
1090            }
1091        }
1092    }
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097    use super::*;
1098    use rust_decimal::Decimal;
1099
1100    fn num(n: i64) -> LiteralValue {
1101        LiteralValue::number(Decimal::from(n))
1102    }
1103
1104    fn fact(name: &str) -> FactPath {
1105        FactPath::new(vec![], name.to_string())
1106    }
1107
1108    #[test]
1109    fn test_normalize_double_complement() {
1110        let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1111        let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1112        let normalized = normalize_domain(double);
1113        assert_eq!(normalized, inner);
1114    }
1115
1116    #[test]
1117    fn test_normalize_union_absorbs_unconstrained() {
1118        let union = Domain::Union(Arc::new(vec![
1119            Domain::Range {
1120                min: Bound::Inclusive(Arc::new(num(0))),
1121                max: Bound::Inclusive(Arc::new(num(10))),
1122            },
1123            Domain::Unconstrained,
1124        ]));
1125        let normalized = normalize_domain(union);
1126        assert_eq!(normalized, Domain::Unconstrained);
1127    }
1128
1129    #[test]
1130    fn test_domain_display() {
1131        let range = Domain::Range {
1132            min: Bound::Inclusive(Arc::new(num(10))),
1133            max: Bound::Exclusive(Arc::new(num(20))),
1134        };
1135        assert_eq!(format!("{}", range), "[10, 20)");
1136
1137        let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1138        assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1139    }
1140
1141    #[test]
1142    fn test_extract_domain_from_comparison() {
1143        let constraint = Constraint::Comparison {
1144            fact: fact("age"),
1145            op: ComparisonComputation::GreaterThan,
1146            value: Arc::new(num(18)),
1147        };
1148
1149        let domains = extract_domains_from_constraint(&constraint).unwrap();
1150        let age_domain = domains.get(&fact("age")).unwrap();
1151
1152        assert_eq!(
1153            *age_domain,
1154            Domain::Range {
1155                min: Bound::Exclusive(Arc::new(num(18))),
1156                max: Bound::Unbounded,
1157            }
1158        );
1159    }
1160
1161    #[test]
1162    fn test_extract_domain_from_and() {
1163        let constraint = Constraint::And(
1164            Box::new(Constraint::Comparison {
1165                fact: fact("age"),
1166                op: ComparisonComputation::GreaterThan,
1167                value: Arc::new(num(18)),
1168            }),
1169            Box::new(Constraint::Comparison {
1170                fact: fact("age"),
1171                op: ComparisonComputation::LessThan,
1172                value: Arc::new(num(65)),
1173            }),
1174        );
1175
1176        let domains = extract_domains_from_constraint(&constraint).unwrap();
1177        let age_domain = domains.get(&fact("age")).unwrap();
1178
1179        assert_eq!(
1180            *age_domain,
1181            Domain::Range {
1182                min: Bound::Exclusive(Arc::new(num(18))),
1183                max: Bound::Exclusive(Arc::new(num(65))),
1184            }
1185        );
1186    }
1187
1188    #[test]
1189    fn test_extract_domain_from_equality() {
1190        let constraint = Constraint::Comparison {
1191            fact: fact("status"),
1192            op: ComparisonComputation::Equal,
1193            value: Arc::new(LiteralValue::text("active".to_string())),
1194        };
1195
1196        let domains = extract_domains_from_constraint(&constraint).unwrap();
1197        let status_domain = domains.get(&fact("status")).unwrap();
1198
1199        assert_eq!(
1200            *status_domain,
1201            Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1202        );
1203    }
1204
1205    #[test]
1206    fn test_extract_domain_from_boolean_fact() {
1207        let constraint = Constraint::Fact(fact("is_active"));
1208
1209        let domains = extract_domains_from_constraint(&constraint).unwrap();
1210        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1211
1212        assert_eq!(
1213            *is_active_domain,
1214            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1215        );
1216    }
1217
1218    #[test]
1219    fn test_extract_domain_from_not_boolean_fact() {
1220        let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1221
1222        let domains = extract_domains_from_constraint(&constraint).unwrap();
1223        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1224
1225        assert_eq!(
1226            *is_active_domain,
1227            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1228        );
1229    }
1230}