Skip to main content

lemma/inversion/
domain.rs

1//! Domain types and operations for inversion
2//!
3//! Provides:
4//! - `Domain` and `Bound` types for representing concrete value constraints
5//! - Domain operations: intersection, union, normalization
6//! - `extract_domains_from_constraint()`: extracts domains from constraints
7
8use crate::{
9    BooleanValue, ComparisonComputation, FactPath, LemmaResult, LiteralValue, OperationResult,
10    Value,
11};
12use serde::ser::{Serialize, SerializeStruct, Serializer};
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt;
16use std::sync::Arc;
17
18use super::constraint::Constraint;
19
20/// Domain specification for valid values
21#[derive(Debug, Clone, PartialEq)]
22pub enum Domain {
23    /// A single continuous range
24    Range { min: Bound, max: Bound },
25
26    /// Multiple disjoint ranges
27    Union(Arc<Vec<Domain>>),
28
29    /// Specific enumerated values only
30    Enumeration(Arc<Vec<LiteralValue>>),
31
32    /// Everything except these constraints
33    Complement(Box<Domain>),
34
35    /// Any value (no constraints)
36    Unconstrained,
37
38    /// Empty domain (no valid values) - represents unsatisfiable constraints
39    Empty,
40}
41
42impl Domain {
43    /// Check if this domain is satisfiable (has at least one valid value)
44    ///
45    /// Returns false for Empty domains and empty Enumerations.
46    pub fn is_satisfiable(&self) -> bool {
47        match self {
48            Domain::Empty => false,
49            Domain::Enumeration(values) => !values.is_empty(),
50            Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
51            Domain::Range { min, max } => !bounds_contradict(min, max),
52            Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
53            Domain::Unconstrained => true,
54        }
55    }
56
57    /// Check if this domain is empty (unsatisfiable)
58    pub fn is_empty(&self) -> bool {
59        !self.is_satisfiable()
60    }
61
62    /// Intersect this domain with another, returning Empty if no overlap
63    pub fn intersect(&self, other: &Domain) -> Domain {
64        domain_intersection(self.clone(), other.clone()).unwrap_or(Domain::Empty)
65    }
66
67    /// Check if a value is contained in this domain
68    pub fn contains(&self, value: &LiteralValue) -> bool {
69        match self {
70            Domain::Empty => false,
71            Domain::Unconstrained => true,
72            Domain::Enumeration(values) => values.contains(value),
73            Domain::Range { min, max } => value_within(value, min, max),
74            Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
75            Domain::Complement(inner) => !inner.contains(value),
76        }
77    }
78}
79
80/// Bound specification for ranges
81#[derive(Debug, Clone, PartialEq)]
82pub enum Bound {
83    /// Inclusive bound [value
84    Inclusive(Arc<LiteralValue>),
85
86    /// Exclusive bound (value
87    Exclusive(Arc<LiteralValue>),
88
89    /// Unbounded (-infinity or +infinity)
90    Unbounded,
91}
92
93impl fmt::Display for Domain {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            Domain::Empty => write!(f, "empty"),
97            Domain::Unconstrained => write!(f, "any"),
98            Domain::Enumeration(vals) => {
99                write!(f, "{{")?;
100                for (i, v) in vals.iter().enumerate() {
101                    if i > 0 {
102                        write!(f, ", ")?;
103                    }
104                    write!(f, "{}", v)?;
105                }
106                write!(f, "}}")
107            }
108            Domain::Range { min, max } => {
109                let (l_bracket, r_bracket) = match (min, max) {
110                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
111                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
112                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
113                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
114                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
115                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
116                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
117                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
118                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
119                };
120
121                let min_str = match min {
122                    Bound::Unbounded => "-inf".to_string(),
123                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
124                };
125                let max_str = match max {
126                    Bound::Unbounded => "+inf".to_string(),
127                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
128                };
129                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
130            }
131            Domain::Union(parts) => {
132                for (i, p) in parts.iter().enumerate() {
133                    if i > 0 {
134                        write!(f, " | ")?;
135                    }
136                    write!(f, "{}", p)?;
137                }
138                Ok(())
139            }
140            Domain::Complement(inner) => write!(f, "not ({})", inner),
141        }
142    }
143}
144
145impl fmt::Display for Bound {
146    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147        match self {
148            Bound::Unbounded => write!(f, "inf"),
149            Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
150            Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
151        }
152    }
153}
154
155impl Serialize for Domain {
156    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
157    where
158        S: Serializer,
159    {
160        match self {
161            Domain::Empty => {
162                let mut st = serializer.serialize_struct("domain", 1)?;
163                st.serialize_field("type", "empty")?;
164                st.end()
165            }
166            Domain::Unconstrained => {
167                let mut st = serializer.serialize_struct("domain", 1)?;
168                st.serialize_field("type", "unconstrained")?;
169                st.end()
170            }
171            Domain::Enumeration(vals) => {
172                let mut st = serializer.serialize_struct("domain", 2)?;
173                st.serialize_field("type", "enumeration")?;
174                st.serialize_field("values", vals)?;
175                st.end()
176            }
177            Domain::Range { min, max } => {
178                let mut st = serializer.serialize_struct("domain", 3)?;
179                st.serialize_field("type", "range")?;
180                st.serialize_field("min", min)?;
181                st.serialize_field("max", max)?;
182                st.end()
183            }
184            Domain::Union(parts) => {
185                let mut st = serializer.serialize_struct("domain", 2)?;
186                st.serialize_field("type", "union")?;
187                st.serialize_field("parts", parts)?;
188                st.end()
189            }
190            Domain::Complement(inner) => {
191                let mut st = serializer.serialize_struct("domain", 2)?;
192                st.serialize_field("type", "complement")?;
193                st.serialize_field("inner", inner)?;
194                st.end()
195            }
196        }
197    }
198}
199
200impl Serialize for Bound {
201    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
202    where
203        S: Serializer,
204    {
205        match self {
206            Bound::Unbounded => {
207                let mut st = serializer.serialize_struct("bound", 1)?;
208                st.serialize_field("type", "unbounded")?;
209                st.end()
210            }
211            Bound::Inclusive(v) => {
212                let mut st = serializer.serialize_struct("bound", 2)?;
213                st.serialize_field("type", "inclusive")?;
214                st.serialize_field("value", v.as_ref())?;
215                st.end()
216            }
217            Bound::Exclusive(v) => {
218                let mut st = serializer.serialize_struct("bound", 2)?;
219                st.serialize_field("type", "exclusive")?;
220                st.serialize_field("value", v.as_ref())?;
221                st.end()
222            }
223        }
224    }
225}
226
227/// Extract domains for all facts mentioned in a constraint
228pub fn extract_domains_from_constraint(
229    constraint: &Constraint,
230) -> LemmaResult<HashMap<FactPath, Domain>> {
231    let all_facts = constraint.collect_facts();
232    let mut domains = HashMap::new();
233
234    for fact_path in all_facts {
235        let domain =
236            extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
237        domains.insert(fact_path, domain);
238    }
239
240    Ok(domains)
241}
242
243fn extract_domain_for_fact(
244    constraint: &Constraint,
245    fact_path: &FactPath,
246) -> LemmaResult<Option<Domain>> {
247    let domain = match constraint {
248        Constraint::True => return Ok(None),
249        Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
250
251        Constraint::Comparison { fact, op, value } => {
252            if fact == fact_path {
253                Some(comparison_to_domain(op, value.as_ref())?)
254            } else {
255                None
256            }
257        }
258
259        Constraint::Fact(fp) => {
260            if fp == fact_path {
261                Some(Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
262                    BooleanValue::True,
263                )])))
264            } else {
265                None
266            }
267        }
268
269        Constraint::And(left, right) => {
270            let left_domain = extract_domain_for_fact(left, fact_path)?;
271            let right_domain = extract_domain_for_fact(right, fact_path)?;
272            match (left_domain, right_domain) {
273                (None, None) => None,
274                (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
275                (Some(a), Some(b)) => match domain_intersection(a, b) {
276                    Some(domain) => Some(domain),
277                    None => Some(Domain::Enumeration(Arc::new(vec![]))),
278                },
279            }
280        }
281
282        Constraint::Or(left, right) => {
283            let left_domain = extract_domain_for_fact(left, fact_path)?;
284            let right_domain = extract_domain_for_fact(right, fact_path)?;
285            union_optional_domains(left_domain, right_domain)
286        }
287
288        Constraint::Not(inner) => {
289            // Handle not (fact == value)
290            if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
291                if fact == fact_path && op.is_equal() {
292                    return Ok(Some(normalize_domain(Domain::Complement(Box::new(
293                        Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
294                    )))));
295                }
296            }
297
298            // Handle not (boolean_fact)
299            if let Constraint::Fact(fp) = inner.as_ref() {
300                if fp == fact_path {
301                    return Ok(Some(Domain::Enumeration(Arc::new(vec![
302                        LiteralValue::boolean(BooleanValue::False),
303                    ]))));
304                }
305            }
306
307            extract_domain_for_fact(inner, fact_path)?
308                .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
309        }
310    };
311
312    Ok(domain.map(normalize_domain))
313}
314
315fn comparison_to_domain(op: &ComparisonComputation, value: &LiteralValue) -> LemmaResult<Domain> {
316    if op.is_equal() {
317        return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
318    }
319    if op.is_not_equal() {
320        return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
321            vec![value.clone()],
322        )))));
323    }
324    match op {
325        ComparisonComputation::LessThan => Ok(Domain::Range {
326            min: Bound::Unbounded,
327            max: Bound::Exclusive(Arc::new(value.clone())),
328        }),
329        ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
330            min: Bound::Unbounded,
331            max: Bound::Inclusive(Arc::new(value.clone())),
332        }),
333        ComparisonComputation::GreaterThan => Ok(Domain::Range {
334            min: Bound::Exclusive(Arc::new(value.clone())),
335            max: Bound::Unbounded,
336        }),
337        ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
338            min: Bound::Inclusive(Arc::new(value.clone())),
339            max: Bound::Unbounded,
340        }),
341        _ => unreachable!(
342            "BUG: unsupported comparison operator for domain extraction: {:?}",
343            op
344        ),
345    }
346}
347
348/// Compute the domain for a single comparison-atom used by inversion constraints.
349///
350/// This is used by numeric-aware constraint simplification to derive implications/exclusions
351/// between comparison atoms on the same fact.
352pub(crate) fn domain_for_comparison_atom(
353    op: &ComparisonComputation,
354    value: &LiteralValue,
355) -> LemmaResult<Domain> {
356    comparison_to_domain(op, value)
357}
358
359impl Domain {
360    /// Proven subset check for the atom-domain forms we generate from comparisons:
361    /// - Range
362    /// - Enumeration
363    /// - Complement(Enumeration) (used for != / is not)
364    ///
365    /// Returns false when the relationship cannot be proven with these forms.
366    pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
367        match (self, other) {
368            (Domain::Empty, _) => true,
369            (_, Domain::Unconstrained) => true,
370            (Domain::Unconstrained, _) => false,
371
372            (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
373            (Domain::Enumeration(vals), Domain::Range { min, max }) => {
374                vals.iter().all(|v| value_within(v, min, max))
375            }
376
377            (
378                Domain::Range {
379                    min: amin,
380                    max: amax,
381                },
382                Domain::Range {
383                    min: bmin,
384                    max: bmax,
385                },
386            ) => range_within_range(amin, amax, bmin, bmax),
387
388            // Range ⊆ not({p}) when the range does not include p (for all excluded points)
389            (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
390                Domain::Enumeration(excluded) => {
391                    excluded.iter().all(|p| !value_within(p, min, max))
392                }
393                _ => false,
394            },
395
396            // {v} ⊆ not({p}) when v is not excluded
397            (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
398                Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
399                _ => false,
400            },
401
402            // not(A) ⊆ not(B)  iff  B ⊆ A  (for enumeration complements)
403            (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
404                match (a_inner.as_ref(), b_inner.as_ref()) {
405                    (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
406                        excluded_b.iter().all(|v| excluded_a.contains(v))
407                    }
408                    _ => false,
409                }
410            }
411
412            _ => false,
413        }
414    }
415}
416
417fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
418    lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
419}
420
421fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
422    match (a, b) {
423        (_, Bound::Unbounded) => true,
424        (Bound::Unbounded, _) => false,
425        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
426        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
427        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
428            let c = lit_cmp(av.as_ref(), bv.as_ref());
429            c >= 0
430        }
431        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
432            // a >= (b) only if a's value > b's value
433            lit_cmp(av.as_ref(), bv.as_ref()) > 0
434        }
435    }
436}
437
438fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
439    match (a, b) {
440        (Bound::Unbounded, Bound::Unbounded) => true,
441        (_, Bound::Unbounded) => true,
442        (Bound::Unbounded, _) => false,
443        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
444        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
445        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
446            // (a) <= [b] when a <= b
447            lit_cmp(av.as_ref(), bv.as_ref()) <= 0
448        }
449        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
450            // [a] <= (b) only if a < b
451            lit_cmp(av.as_ref(), bv.as_ref()) < 0
452        }
453    }
454}
455
456fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
457    match (a, b) {
458        (None, None) => None,
459        (Some(d), None) | (None, Some(d)) => Some(d),
460        (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
461    }
462}
463
464fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
465    use std::cmp::Ordering;
466
467    match (&a.value, &b.value) {
468        (Value::Number(la), Value::Number(lb)) => match la.cmp(lb) {
469            Ordering::Less => -1,
470            Ordering::Equal => 0,
471            Ordering::Greater => 1,
472        },
473
474        (Value::Boolean(la), Value::Boolean(lb)) => match bool::from(la).cmp(&bool::from(lb)) {
475            Ordering::Less => -1,
476            Ordering::Equal => 0,
477            Ordering::Greater => 1,
478        },
479
480        (Value::Text(la), Value::Text(lb)) => match la.cmp(lb) {
481            Ordering::Less => -1,
482            Ordering::Equal => 0,
483            Ordering::Greater => 1,
484        },
485
486        (Value::Date(la), Value::Date(lb)) => match la.cmp(lb) {
487            Ordering::Less => -1,
488            Ordering::Equal => 0,
489            Ordering::Greater => 1,
490        },
491
492        (Value::Time(la), Value::Time(lb)) => match la.cmp(lb) {
493            Ordering::Less => -1,
494            Ordering::Equal => 0,
495            Ordering::Greater => 1,
496        },
497
498        (Value::Duration(la, lua), Value::Duration(lb, lub)) => {
499            let a_sec = crate::computation::units::duration_to_seconds(*la, lua);
500            let b_sec = crate::computation::units::duration_to_seconds(*lb, lub);
501            match a_sec.cmp(&b_sec) {
502                Ordering::Less => -1,
503                Ordering::Equal => 0,
504                Ordering::Greater => 1,
505            }
506        }
507
508        (Value::Ratio(la, _), Value::Ratio(lb, _)) => match la.cmp(lb) {
509            Ordering::Less => -1,
510            Ordering::Equal => 0,
511            Ordering::Greater => 1,
512        },
513
514        (Value::Scale(la, lua), Value::Scale(lb, lub)) => {
515            if a.lemma_type != b.lemma_type {
516                unreachable!(
517                    "BUG: lit_cmp compared different scale types ({} vs {})",
518                    a.lemma_type.name(),
519                    b.lemma_type.name()
520                );
521            }
522
523            match (lua, lub) {
524                (None, None) => match la.cmp(lb) {
525                    Ordering::Less => -1,
526                    Ordering::Equal => 0,
527                    Ordering::Greater => 1,
528                },
529                (Some(lua), Some(lub)) => {
530                    if lua.eq_ignore_ascii_case(lub) {
531                        return match la.cmp(lb) {
532                            Ordering::Less => -1,
533                            Ordering::Equal => 0,
534                            Ordering::Greater => 1,
535                        };
536                    }
537
538                    let target = crate::semantic::ConversionTarget::ScaleUnit(lua.clone());
539                    let converted = crate::computation::convert_unit(b, &target);
540                    let converted_value = match converted {
541                        OperationResult::Value(lit) => match lit.value {
542                            Value::Scale(v, _) => v,
543                            _ => {
544                                unreachable!("BUG: scale unit conversion returned non-scale value")
545                            }
546                        },
547                        OperationResult::Veto(msg) => unreachable!(
548                            "BUG: scale unit conversion vetoed unexpectedly: {:?}",
549                            msg
550                        ),
551                    };
552
553                    match la.cmp(&converted_value) {
554                        Ordering::Less => -1,
555                        Ordering::Equal => 0,
556                        Ordering::Greater => 1,
557                    }
558                }
559                _ => unreachable!("BUG: lit_cmp saw scale value missing unit on one side"),
560            }
561        }
562
563        _ => unreachable!(
564            "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
565            a.get_type(),
566            b.get_type()
567        ),
568    }
569}
570
571fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
572    let ge_min = match min {
573        Bound::Unbounded => true,
574        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
575        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
576    };
577    let le_max = match max {
578        Bound::Unbounded => true,
579        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
580        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
581    };
582    ge_min && le_max
583}
584
585fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
586    match (min, max) {
587        (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
588        (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
589        (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
590        (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
591        (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
592    }
593}
594
595fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
596    match (min1, min2) {
597        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
598        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
599            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
600                Bound::Inclusive(v1)
601            } else {
602                Bound::Inclusive(v2)
603            }
604        }
605        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
606            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
607                Bound::Inclusive(v1)
608            } else {
609                Bound::Exclusive(v2)
610            }
611        }
612        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
613            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
614                Bound::Exclusive(v1)
615            } else {
616                Bound::Inclusive(v2)
617            }
618        }
619        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
620            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
621                Bound::Exclusive(v1)
622            } else {
623                Bound::Exclusive(v2)
624            }
625        }
626    }
627}
628
629fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
630    match (max1, max2) {
631        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
632        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
633            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
634                Bound::Inclusive(v1)
635            } else {
636                Bound::Inclusive(v2)
637            }
638        }
639        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
640            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
641                Bound::Inclusive(v1)
642            } else {
643                Bound::Exclusive(v2)
644            }
645        }
646        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
647            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
648                Bound::Exclusive(v1)
649            } else {
650                Bound::Inclusive(v2)
651            }
652        }
653        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
654            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
655                Bound::Exclusive(v1)
656            } else {
657                Bound::Exclusive(v2)
658            }
659        }
660    }
661}
662
663fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
664    let a = normalize_domain(a);
665    let b = normalize_domain(b);
666
667    let result = match (a, b) {
668        (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
669        (Domain::Empty, _) | (_, Domain::Empty) => None,
670
671        (
672            Domain::Range {
673                min: min1,
674                max: max1,
675            },
676            Domain::Range {
677                min: min2,
678                max: max2,
679            },
680        ) => {
681            let min = compute_intersection_min(min1, min2);
682            let max = compute_intersection_max(max1, max2);
683
684            if bounds_contradict(&min, &max) {
685                None
686            } else {
687                Some(Domain::Range { min, max })
688            }
689        }
690        (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
691            let filtered: Vec<LiteralValue> =
692                v1.iter().filter(|x| v2.contains(x)).cloned().collect();
693            if filtered.is_empty() {
694                None
695            } else {
696                Some(Domain::Enumeration(Arc::new(filtered)))
697            }
698        }
699        (Domain::Enumeration(vs), Domain::Range { min, max })
700        | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
701            let mut kept = Vec::new();
702            for v in vs.iter() {
703                if value_within(v, &min, &max) {
704                    kept.push(v.clone());
705                }
706            }
707            if kept.is_empty() {
708                None
709            } else {
710                Some(Domain::Enumeration(Arc::new(kept)))
711            }
712        }
713        (Domain::Enumeration(vs), Domain::Complement(inner))
714        | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
715            match *inner.clone() {
716                Domain::Enumeration(excluded) => {
717                    let mut kept = Vec::new();
718                    for v in vs.iter() {
719                        if !excluded.contains(v) {
720                            kept.push(v.clone());
721                        }
722                    }
723                    if kept.is_empty() {
724                        None
725                    } else {
726                        Some(Domain::Enumeration(Arc::new(kept)))
727                    }
728                }
729                Domain::Range { min, max } => {
730                    // Filter enumeration values that are NOT in the range
731                    let mut kept = Vec::new();
732                    for v in vs.iter() {
733                        if !value_within(v, &min, &max) {
734                            kept.push(v.clone());
735                        }
736                    }
737                    if kept.is_empty() {
738                        None
739                    } else {
740                        Some(Domain::Enumeration(Arc::new(kept)))
741                    }
742                }
743                _ => {
744                    // For other complement types, normalize and recurse
745                    let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
746                    domain_intersection(Domain::Enumeration(vs.clone()), normalized)
747                }
748            }
749        }
750        (Domain::Union(v1), Domain::Union(v2)) => {
751            let mut acc: Vec<Domain> = Vec::new();
752            for a in v1.iter() {
753                for b in v2.iter() {
754                    if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
755                        acc.push(ix);
756                    }
757                }
758            }
759            if acc.is_empty() {
760                None
761            } else {
762                Some(Domain::Union(Arc::new(acc)))
763            }
764        }
765        (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
766            let mut acc: Vec<Domain> = Vec::new();
767            for a in vs.iter() {
768                if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
769                    acc.push(ix);
770                }
771            }
772            if acc.is_empty() {
773                None
774            } else if acc.len() == 1 {
775                Some(acc.remove(0))
776            } else {
777                Some(Domain::Union(Arc::new(acc)))
778            }
779        }
780        // Range ∩ not({p1,p2,...})  =>  Range with excluded points removed (as union of ranges)
781        (Domain::Range { min, max }, Domain::Complement(inner))
782        | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
783            Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
784            _ => {
785                // Normalize the complement (not just the inner value) and recurse.
786                // If normalization doesn't change it, we must not recurse infinitely.
787                let normalized_complement = normalize_domain(Domain::Complement(inner));
788                if matches!(&normalized_complement, Domain::Complement(_)) {
789                    None
790                } else {
791                    domain_intersection(Domain::Range { min, max }, normalized_complement)
792                }
793            }
794        },
795        (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
796            match (a_inner.as_ref(), b_inner.as_ref()) {
797                (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
798                    // not(A) ∩ not(B) == not(A ∪ B)
799                    let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
800                    excluded.extend(b_ex.iter().cloned());
801                    Some(normalize_domain(Domain::Complement(Box::new(
802                        Domain::Enumeration(Arc::new(excluded)),
803                    ))))
804                }
805                _ => None,
806            }
807        }
808    };
809    result.map(normalize_domain)
810}
811
812fn range_minus_excluded_points(
813    min: Bound,
814    max: Bound,
815    excluded: &Arc<Vec<LiteralValue>>,
816) -> Option<Domain> {
817    // Start with a single range and iteratively split on excluded points that fall within it.
818    let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
819
820    for p in excluded.iter() {
821        let mut next: Vec<(Bound, Bound)> = Vec::new();
822
823        for (rmin, rmax) in parts {
824            if !value_within(p, &rmin, &rmax) {
825                next.push((rmin, rmax));
826                continue;
827            }
828
829            // Left part: [rmin, p) or [rmin, p] depending on rmin and exclusion
830            let left_max = Bound::Exclusive(Arc::new(p.clone()));
831            if !bounds_contradict(&rmin, &left_max) {
832                next.push((rmin.clone(), left_max));
833            }
834
835            // Right part: (p, rmax)
836            let right_min = Bound::Exclusive(Arc::new(p.clone()));
837            if !bounds_contradict(&right_min, &rmax) {
838                next.push((right_min, rmax.clone()));
839            }
840        }
841
842        parts = next;
843        if parts.is_empty() {
844            return None;
845        }
846    }
847
848    if parts.is_empty() {
849        None
850    } else if parts.len() == 1 {
851        let (min, max) = parts.remove(0);
852        Some(Domain::Range { min, max })
853    } else {
854        Some(Domain::Union(Arc::new(
855            parts
856                .into_iter()
857                .map(|(min, max)| Domain::Range { min, max })
858                .collect(),
859        )))
860    }
861}
862
863fn invert_bound(bound: Bound) -> Bound {
864    match bound {
865        Bound::Unbounded => Bound::Unbounded,
866        Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
867        Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
868    }
869}
870
871fn normalize_domain(d: Domain) -> Domain {
872    match d {
873        Domain::Complement(inner) => {
874            let normalized_inner = normalize_domain(*inner);
875            match normalized_inner {
876                Domain::Complement(double_inner) => *double_inner,
877                Domain::Range { min, max } => match (&min, &max) {
878                    (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
879                    (Bound::Unbounded, max) => Domain::Range {
880                        min: invert_bound(max.clone()),
881                        max: Bound::Unbounded,
882                    },
883                    (min, Bound::Unbounded) => Domain::Range {
884                        min: Bound::Unbounded,
885                        max: invert_bound(min.clone()),
886                    },
887                    (min, max) => Domain::Union(Arc::new(vec![
888                        Domain::Range {
889                            min: Bound::Unbounded,
890                            max: invert_bound(min.clone()),
891                        },
892                        Domain::Range {
893                            min: invert_bound(max.clone()),
894                            max: Bound::Unbounded,
895                        },
896                    ])),
897                },
898                Domain::Enumeration(vals) => {
899                    if vals.len() == 1 {
900                        if let Some(lit) = vals.first() {
901                            if let Value::Boolean(BooleanValue::True) = &lit.value {
902                                return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
903                                    BooleanValue::False,
904                                )]));
905                            }
906                            if let Value::Boolean(BooleanValue::False) = &lit.value {
907                                return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
908                                    BooleanValue::True,
909                                )]));
910                            }
911                        }
912                    }
913                    Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
914                }
915                Domain::Unconstrained => Domain::Empty,
916                Domain::Empty => Domain::Unconstrained,
917                Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
918            }
919        }
920        Domain::Empty => Domain::Empty,
921        Domain::Union(parts) => {
922            let mut flat: Vec<Domain> = Vec::new();
923            for p in parts.iter().cloned() {
924                let normalized = normalize_domain(p);
925                match normalized {
926                    Domain::Union(inner) => flat.extend(inner.iter().cloned()),
927                    Domain::Unconstrained => return Domain::Unconstrained,
928                    Domain::Enumeration(vals) if vals.is_empty() => {}
929                    other => flat.push(other),
930                }
931            }
932
933            let mut all_enum_values: Vec<LiteralValue> = Vec::new();
934            let mut ranges: Vec<Domain> = Vec::new();
935            let mut others: Vec<Domain> = Vec::new();
936
937            for domain in flat {
938                match domain {
939                    Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
940                    Domain::Range { .. } => ranges.push(domain),
941                    other => others.push(other),
942                }
943            }
944
945            all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
946                -1 => Ordering::Less,
947                0 => Ordering::Equal,
948                _ => Ordering::Greater,
949            });
950            all_enum_values.dedup();
951
952            all_enum_values.retain(|v| {
953                !ranges.iter().any(|r| {
954                    if let Domain::Range { min, max } = r {
955                        value_within(v, min, max)
956                    } else {
957                        false
958                    }
959                })
960            });
961
962            let mut result: Vec<Domain> = Vec::new();
963            result.extend(ranges);
964            result = merge_ranges(result);
965
966            if !all_enum_values.is_empty() {
967                result.push(Domain::Enumeration(Arc::new(all_enum_values)));
968            }
969            result.extend(others);
970
971            result.sort_by(|a, b| match (a, b) {
972                (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
973                (Domain::Range { .. }, _) => Ordering::Less,
974                (_, Domain::Range { .. }) => Ordering::Greater,
975                (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
976                (Domain::Enumeration(_), _) => Ordering::Less,
977                (_, Domain::Enumeration(_)) => Ordering::Greater,
978                _ => Ordering::Equal,
979            });
980
981            if result.is_empty() {
982                Domain::Enumeration(Arc::new(vec![]))
983            } else if result.len() == 1 {
984                result.remove(0)
985            } else {
986                Domain::Union(Arc::new(result))
987            }
988        }
989        Domain::Enumeration(values) => {
990            let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
991            sorted.sort_by(|a, b| match lit_cmp(a, b) {
992                -1 => Ordering::Less,
993                0 => Ordering::Equal,
994                _ => Ordering::Greater,
995            });
996            sorted.dedup();
997            Domain::Enumeration(Arc::new(sorted))
998        }
999        other => other,
1000    }
1001}
1002
1003fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
1004    let mut result = Vec::new();
1005    let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1006    let mut others = Vec::new();
1007
1008    for d in domains {
1009        match d {
1010            Domain::Range { min, max } => ranges.push((min, max)),
1011            other => others.push(other),
1012        }
1013    }
1014
1015    if ranges.is_empty() {
1016        return others;
1017    }
1018
1019    ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1020
1021    let mut merged: Vec<(Bound, Bound)> = Vec::new();
1022    let mut current = ranges[0].clone();
1023
1024    for next in ranges.iter().skip(1) {
1025        if ranges_adjacent_or_overlap(&current, next) {
1026            current = (
1027                min_bound(&current.0, &next.0),
1028                max_bound(&current.1, &next.1),
1029            );
1030        } else {
1031            merged.push(current);
1032            current = next.clone();
1033        }
1034    }
1035    merged.push(current);
1036
1037    for (min, max) in merged {
1038        result.push(Domain::Range { min, max });
1039    }
1040    result.extend(others);
1041
1042    result
1043}
1044
1045fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1046    match (a, b) {
1047        (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1048        (Bound::Unbounded, _) => Ordering::Less,
1049        (_, Bound::Unbounded) => Ordering::Greater,
1050        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1051        | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1052            -1 => Ordering::Less,
1053            0 => Ordering::Equal,
1054            _ => Ordering::Greater,
1055        },
1056        (Bound::Inclusive(v1), Bound::Exclusive(v2))
1057        | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1058            -1 => Ordering::Less,
1059            0 => {
1060                if matches!(a, Bound::Inclusive(_)) {
1061                    Ordering::Less
1062                } else {
1063                    Ordering::Greater
1064                }
1065            }
1066            _ => Ordering::Greater,
1067        },
1068    }
1069}
1070
1071fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1072    match (&r1.1, &r2.0) {
1073        (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1074        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1075        | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1076        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1077        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1078    }
1079}
1080
1081fn min_bound(a: &Bound, b: &Bound) -> Bound {
1082    match (a, b) {
1083        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1084        _ => {
1085            if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1086                a.clone()
1087            } else {
1088                b.clone()
1089            }
1090        }
1091    }
1092}
1093
1094fn max_bound(a: &Bound, b: &Bound) -> Bound {
1095    match (a, b) {
1096        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1097        _ => {
1098            if matches!(compare_bounds(a, b), Ordering::Greater) {
1099                a.clone()
1100            } else {
1101                b.clone()
1102            }
1103        }
1104    }
1105}
1106
1107#[cfg(test)]
1108mod tests {
1109    use super::*;
1110    use rust_decimal::Decimal;
1111
1112    fn num(n: i64) -> LiteralValue {
1113        LiteralValue::number(Decimal::from(n))
1114    }
1115
1116    fn fact(name: &str) -> FactPath {
1117        FactPath::local(name.to_string())
1118    }
1119
1120    #[test]
1121    fn test_normalize_double_complement() {
1122        let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1123        let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1124        let normalized = normalize_domain(double);
1125        assert_eq!(normalized, inner);
1126    }
1127
1128    #[test]
1129    fn test_normalize_union_absorbs_unconstrained() {
1130        let union = Domain::Union(Arc::new(vec![
1131            Domain::Range {
1132                min: Bound::Inclusive(Arc::new(num(0))),
1133                max: Bound::Inclusive(Arc::new(num(10))),
1134            },
1135            Domain::Unconstrained,
1136        ]));
1137        let normalized = normalize_domain(union);
1138        assert_eq!(normalized, Domain::Unconstrained);
1139    }
1140
1141    #[test]
1142    fn test_domain_display() {
1143        let range = Domain::Range {
1144            min: Bound::Inclusive(Arc::new(num(10))),
1145            max: Bound::Exclusive(Arc::new(num(20))),
1146        };
1147        assert_eq!(format!("{}", range), "[10, 20)");
1148
1149        let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1150        assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1151    }
1152
1153    #[test]
1154    fn test_extract_domain_from_comparison() {
1155        let constraint = Constraint::Comparison {
1156            fact: fact("age"),
1157            op: ComparisonComputation::GreaterThan,
1158            value: Arc::new(num(18)),
1159        };
1160
1161        let domains = extract_domains_from_constraint(&constraint).unwrap();
1162        let age_domain = domains.get(&fact("age")).unwrap();
1163
1164        assert_eq!(
1165            *age_domain,
1166            Domain::Range {
1167                min: Bound::Exclusive(Arc::new(num(18))),
1168                max: Bound::Unbounded,
1169            }
1170        );
1171    }
1172
1173    #[test]
1174    fn test_extract_domain_from_and() {
1175        let constraint = Constraint::And(
1176            Box::new(Constraint::Comparison {
1177                fact: fact("age"),
1178                op: ComparisonComputation::GreaterThan,
1179                value: Arc::new(num(18)),
1180            }),
1181            Box::new(Constraint::Comparison {
1182                fact: fact("age"),
1183                op: ComparisonComputation::LessThan,
1184                value: Arc::new(num(65)),
1185            }),
1186        );
1187
1188        let domains = extract_domains_from_constraint(&constraint).unwrap();
1189        let age_domain = domains.get(&fact("age")).unwrap();
1190
1191        assert_eq!(
1192            *age_domain,
1193            Domain::Range {
1194                min: Bound::Exclusive(Arc::new(num(18))),
1195                max: Bound::Exclusive(Arc::new(num(65))),
1196            }
1197        );
1198    }
1199
1200    #[test]
1201    fn test_extract_domain_from_equality() {
1202        let constraint = Constraint::Comparison {
1203            fact: fact("status"),
1204            op: ComparisonComputation::Equal,
1205            value: Arc::new(LiteralValue::text("active".to_string())),
1206        };
1207
1208        let domains = extract_domains_from_constraint(&constraint).unwrap();
1209        let status_domain = domains.get(&fact("status")).unwrap();
1210
1211        assert_eq!(
1212            *status_domain,
1213            Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1214        );
1215    }
1216
1217    #[test]
1218    fn test_extract_domain_from_boolean_fact() {
1219        let constraint = Constraint::Fact(fact("is_active"));
1220
1221        let domains = extract_domains_from_constraint(&constraint).unwrap();
1222        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1223
1224        assert_eq!(
1225            *is_active_domain,
1226            Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::True)]))
1227        );
1228    }
1229
1230    #[test]
1231    fn test_extract_domain_from_not_boolean_fact() {
1232        let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1233
1234        let domains = extract_domains_from_constraint(&constraint).unwrap();
1235        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1236
1237        assert_eq!(
1238            *is_active_domain,
1239            Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::False)]))
1240        );
1241    }
1242}