lemma/inversion/
domain.rs

1//! Domain types and operations for inversion
2//!
3//! Provides:
4//! - `Domain` and `Bound` types for representing concrete value constraints
5//! - Domain operations: intersection, union, normalization
6//! - `extract_domains_from_constraint()`: extracts domains from constraints
7
8use crate::computation::comparison_operation;
9use crate::parsing::ast::Span;
10use crate::{
11    BooleanValue, ComparisonComputation, FactPath, LemmaError, LemmaResult, LiteralValue,
12    OperationResult, Value,
13};
14use serde::ser::{Serialize, SerializeStruct, Serializer};
15use std::cmp::Ordering;
16use std::collections::HashMap;
17use std::fmt;
18use std::sync::Arc;
19
20use super::constraint::Constraint;
21
22/// Domain specification for valid values
23#[derive(Debug, Clone, PartialEq)]
24pub enum Domain {
25    /// A single continuous range
26    Range { min: Bound, max: Bound },
27
28    /// Multiple disjoint ranges
29    Union(Arc<Vec<Domain>>),
30
31    /// Specific enumerated values only
32    Enumeration(Arc<Vec<LiteralValue>>),
33
34    /// Everything except these constraints
35    Complement(Box<Domain>),
36
37    /// Any value (no constraints)
38    Unconstrained,
39
40    /// Empty domain (no valid values) - represents unsatisfiable constraints
41    Empty,
42}
43
44impl Domain {
45    /// Check if this domain is satisfiable (has at least one valid value)
46    ///
47    /// Returns false for Empty domains and empty Enumerations.
48    pub fn is_satisfiable(&self) -> bool {
49        match self {
50            Domain::Empty => false,
51            Domain::Enumeration(values) => !values.is_empty(),
52            Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
53            Domain::Range { min, max } => !bounds_contradict(min, max),
54            Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
55            Domain::Unconstrained => true,
56        }
57    }
58
59    /// Check if this domain is empty (unsatisfiable)
60    pub fn is_empty(&self) -> bool {
61        !self.is_satisfiable()
62    }
63
64    /// Intersect this domain with another, returning Empty if no overlap
65    pub fn intersect(&self, other: &Domain) -> Domain {
66        domain_intersection(self.clone(), other.clone()).unwrap_or(Domain::Empty)
67    }
68
69    /// Check if a value is contained in this domain
70    pub fn contains(&self, value: &LiteralValue) -> bool {
71        match self {
72            Domain::Empty => false,
73            Domain::Unconstrained => true,
74            Domain::Enumeration(values) => values.contains(value),
75            Domain::Range { min, max } => value_within(value, min, max),
76            Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
77            Domain::Complement(inner) => !inner.contains(value),
78        }
79    }
80}
81
82/// Bound specification for ranges
83#[derive(Debug, Clone, PartialEq)]
84pub enum Bound {
85    /// Inclusive bound [value
86    Inclusive(Arc<LiteralValue>),
87
88    /// Exclusive bound (value
89    Exclusive(Arc<LiteralValue>),
90
91    /// Unbounded (-infinity or +infinity)
92    Unbounded,
93}
94
95impl fmt::Display for Domain {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        match self {
98            Domain::Empty => write!(f, "empty"),
99            Domain::Unconstrained => write!(f, "any"),
100            Domain::Enumeration(vals) => {
101                write!(f, "{{")?;
102                for (i, v) in vals.iter().enumerate() {
103                    if i > 0 {
104                        write!(f, ", ")?;
105                    }
106                    write!(f, "{}", v)?;
107                }
108                write!(f, "}}")
109            }
110            Domain::Range { min, max } => {
111                let (l_bracket, r_bracket) = match (min, max) {
112                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
113                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
114                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
115                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
116                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
117                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
118                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
119                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
120                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
121                };
122
123                let min_str = match min {
124                    Bound::Unbounded => "-inf".to_string(),
125                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
126                };
127                let max_str = match max {
128                    Bound::Unbounded => "+inf".to_string(),
129                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
130                };
131                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
132            }
133            Domain::Union(parts) => {
134                for (i, p) in parts.iter().enumerate() {
135                    if i > 0 {
136                        write!(f, " | ")?;
137                    }
138                    write!(f, "{}", p)?;
139                }
140                Ok(())
141            }
142            Domain::Complement(inner) => write!(f, "not ({})", inner),
143        }
144    }
145}
146
147impl fmt::Display for Bound {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        match self {
150            Bound::Unbounded => write!(f, "inf"),
151            Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
152            Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
153        }
154    }
155}
156
157impl Serialize for Domain {
158    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
159    where
160        S: Serializer,
161    {
162        match self {
163            Domain::Empty => {
164                let mut st = serializer.serialize_struct("domain", 1)?;
165                st.serialize_field("type", "empty")?;
166                st.end()
167            }
168            Domain::Unconstrained => {
169                let mut st = serializer.serialize_struct("domain", 1)?;
170                st.serialize_field("type", "unconstrained")?;
171                st.end()
172            }
173            Domain::Enumeration(vals) => {
174                let mut st = serializer.serialize_struct("domain", 2)?;
175                st.serialize_field("type", "enumeration")?;
176                st.serialize_field("values", vals)?;
177                st.end()
178            }
179            Domain::Range { min, max } => {
180                let mut st = serializer.serialize_struct("domain", 3)?;
181                st.serialize_field("type", "range")?;
182                st.serialize_field("min", min)?;
183                st.serialize_field("max", max)?;
184                st.end()
185            }
186            Domain::Union(parts) => {
187                let mut st = serializer.serialize_struct("domain", 2)?;
188                st.serialize_field("type", "union")?;
189                st.serialize_field("parts", parts)?;
190                st.end()
191            }
192            Domain::Complement(inner) => {
193                let mut st = serializer.serialize_struct("domain", 2)?;
194                st.serialize_field("type", "complement")?;
195                st.serialize_field("inner", inner)?;
196                st.end()
197            }
198        }
199    }
200}
201
202impl Serialize for Bound {
203    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
204    where
205        S: Serializer,
206    {
207        match self {
208            Bound::Unbounded => {
209                let mut st = serializer.serialize_struct("bound", 1)?;
210                st.serialize_field("type", "unbounded")?;
211                st.end()
212            }
213            Bound::Inclusive(v) => {
214                let mut st = serializer.serialize_struct("bound", 2)?;
215                st.serialize_field("type", "inclusive")?;
216                st.serialize_field("value", v.as_ref())?;
217                st.end()
218            }
219            Bound::Exclusive(v) => {
220                let mut st = serializer.serialize_struct("bound", 2)?;
221                st.serialize_field("type", "exclusive")?;
222                st.serialize_field("value", v.as_ref())?;
223                st.end()
224            }
225        }
226    }
227}
228
229/// Extract domains for all facts mentioned in a constraint
230pub fn extract_domains_from_constraint(
231    constraint: &Constraint,
232) -> LemmaResult<HashMap<FactPath, Domain>> {
233    let all_facts = constraint.collect_facts();
234    let mut domains = HashMap::new();
235
236    for fact_path in all_facts {
237        let domain =
238            extract_domain_for_fact(constraint, &fact_path)?.unwrap_or(Domain::Unconstrained);
239        domains.insert(fact_path, domain);
240    }
241
242    Ok(domains)
243}
244
245fn extract_domain_for_fact(
246    constraint: &Constraint,
247    fact_path: &FactPath,
248) -> LemmaResult<Option<Domain>> {
249    let domain = match constraint {
250        Constraint::True => return Ok(None),
251        Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
252
253        Constraint::Comparison { fact, op, value } => {
254            if fact == fact_path {
255                Some(comparison_to_domain(op, value.as_ref())?)
256            } else {
257                None
258            }
259        }
260
261        Constraint::Fact(fp) => {
262            if fp == fact_path {
263                Some(Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
264                    BooleanValue::True,
265                )])))
266            } else {
267                None
268            }
269        }
270
271        Constraint::And(left, right) => {
272            let left_domain = extract_domain_for_fact(left, fact_path)?;
273            let right_domain = extract_domain_for_fact(right, fact_path)?;
274            match (left_domain, right_domain) {
275                (None, None) => None,
276                (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
277                (Some(a), Some(b)) => match domain_intersection(a, b) {
278                    Some(domain) => Some(domain),
279                    None => Some(Domain::Enumeration(Arc::new(vec![]))),
280                },
281            }
282        }
283
284        Constraint::Or(left, right) => {
285            let left_domain = extract_domain_for_fact(left, fact_path)?;
286            let right_domain = extract_domain_for_fact(right, fact_path)?;
287            union_optional_domains(left_domain, right_domain)
288        }
289
290        Constraint::Not(inner) => {
291            // Handle not (fact == value)
292            if let Constraint::Comparison { fact, op, value } = inner.as_ref() {
293                if fact == fact_path && op.is_equal() {
294                    return Ok(Some(normalize_domain(Domain::Complement(Box::new(
295                        Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
296                    )))));
297                }
298            }
299
300            // Handle not (boolean_fact)
301            if let Constraint::Fact(fp) = inner.as_ref() {
302                if fp == fact_path {
303                    return Ok(Some(Domain::Enumeration(Arc::new(vec![
304                        LiteralValue::boolean(BooleanValue::False),
305                    ]))));
306                }
307            }
308
309            extract_domain_for_fact(inner, fact_path)?
310                .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
311        }
312    };
313
314    Ok(domain.map(normalize_domain))
315}
316
317fn comparison_to_domain(op: &ComparisonComputation, value: &LiteralValue) -> LemmaResult<Domain> {
318    if op.is_equal() {
319        return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
320    }
321    if op.is_not_equal() {
322        return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
323            vec![value.clone()],
324        )))));
325    }
326    match op {
327        ComparisonComputation::LessThan => Ok(Domain::Range {
328            min: Bound::Unbounded,
329            max: Bound::Exclusive(Arc::new(value.clone())),
330        }),
331        ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
332            min: Bound::Unbounded,
333            max: Bound::Inclusive(Arc::new(value.clone())),
334        }),
335        ComparisonComputation::GreaterThan => Ok(Domain::Range {
336            min: Bound::Exclusive(Arc::new(value.clone())),
337            max: Bound::Unbounded,
338        }),
339        ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
340            min: Bound::Inclusive(Arc::new(value.clone())),
341            max: Bound::Unbounded,
342        }),
343        _ => Err(LemmaError::engine(
344            format!(
345                "Unsupported comparison operator for domain extraction: {:?}",
346                op
347            ),
348            Span {
349                start: 0,
350                end: 0,
351                line: 1,
352                col: 0,
353            },
354            "<unknown>",
355            Arc::from(""),
356            "<unknown>",
357            1,
358            None::<String>,
359        )),
360    }
361}
362
363/// Compute the domain for a single comparison-atom used by inversion constraints.
364///
365/// This is used by numeric-aware constraint simplification to derive implications/exclusions
366/// between comparison atoms on the same fact.
367pub(crate) fn domain_for_comparison_atom(
368    op: &ComparisonComputation,
369    value: &LiteralValue,
370) -> LemmaResult<Domain> {
371    comparison_to_domain(op, value)
372}
373
374impl Domain {
375    /// Proven subset check for the atom-domain forms we generate from comparisons:
376    /// - Range
377    /// - Enumeration
378    /// - Complement(Enumeration) (used for != / is not)
379    ///
380    /// Returns false when the relationship cannot be proven with these forms.
381    pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
382        match (self, other) {
383            (Domain::Empty, _) => true,
384            (_, Domain::Unconstrained) => true,
385            (Domain::Unconstrained, _) => false,
386
387            (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
388            (Domain::Enumeration(vals), Domain::Range { min, max }) => {
389                vals.iter().all(|v| value_within(v, min, max))
390            }
391
392            (
393                Domain::Range {
394                    min: amin,
395                    max: amax,
396                },
397                Domain::Range {
398                    min: bmin,
399                    max: bmax,
400                },
401            ) => range_within_range(amin, amax, bmin, bmax),
402
403            // Range ⊆ not({p}) when the range does not include p (for all excluded points)
404            (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
405                Domain::Enumeration(excluded) => {
406                    excluded.iter().all(|p| !value_within(p, min, max))
407                }
408                _ => false,
409            },
410
411            // {v} ⊆ not({p}) when v is not excluded
412            (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
413                Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
414                _ => false,
415            },
416
417            // not(A) ⊆ not(B)  iff  B ⊆ A  (for enumeration complements)
418            (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
419                match (a_inner.as_ref(), b_inner.as_ref()) {
420                    (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
421                        excluded_b.iter().all(|v| excluded_a.contains(v))
422                    }
423                    _ => false,
424                }
425            }
426
427            _ => false,
428        }
429    }
430}
431
432fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
433    lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
434}
435
436fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
437    match (a, b) {
438        (_, Bound::Unbounded) => true,
439        (Bound::Unbounded, _) => false,
440        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
441        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
442        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
443            let c = lit_cmp(av.as_ref(), bv.as_ref());
444            c >= 0
445        }
446        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
447            // a >= (b) only if a's value > b's value
448            lit_cmp(av.as_ref(), bv.as_ref()) > 0
449        }
450    }
451}
452
453fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
454    match (a, b) {
455        (Bound::Unbounded, Bound::Unbounded) => true,
456        (_, Bound::Unbounded) => true,
457        (Bound::Unbounded, _) => false,
458        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
459        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
460        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
461            // (a) <= [b] when a <= b
462            lit_cmp(av.as_ref(), bv.as_ref()) <= 0
463        }
464        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
465            // [a] <= (b) only if a < b
466            lit_cmp(av.as_ref(), bv.as_ref()) < 0
467        }
468    }
469}
470
471fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
472    match (a, b) {
473        (None, None) => None,
474        (Some(d), None) | (None, Some(d)) => Some(d),
475        (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
476    }
477}
478
479fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
480    if let OperationResult::Value(lit) =
481        comparison_operation(a, &ComparisonComputation::LessThan, b)
482    {
483        if let Value::Boolean(BooleanValue::True) = &lit.value {
484            return -1;
485        }
486    }
487    if let OperationResult::Value(lit) = comparison_operation(a, &ComparisonComputation::Equal, b) {
488        if let Value::Boolean(BooleanValue::True) = &lit.value {
489            return 0;
490        }
491    }
492    1
493}
494
495fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
496    let ge_min = match min {
497        Bound::Unbounded => true,
498        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
499        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
500    };
501    let le_max = match max {
502        Bound::Unbounded => true,
503        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
504        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
505    };
506    ge_min && le_max
507}
508
509fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
510    match (min, max) {
511        (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
512        (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
513        (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
514        (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
515        (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
516    }
517}
518
519fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
520    match (min1, min2) {
521        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
522        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
523            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
524                Bound::Inclusive(v1)
525            } else {
526                Bound::Inclusive(v2)
527            }
528        }
529        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
530            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
531                Bound::Inclusive(v1)
532            } else {
533                Bound::Exclusive(v2)
534            }
535        }
536        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
537            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
538                Bound::Exclusive(v1)
539            } else {
540                Bound::Inclusive(v2)
541            }
542        }
543        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
544            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
545                Bound::Exclusive(v1)
546            } else {
547                Bound::Exclusive(v2)
548            }
549        }
550    }
551}
552
553fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
554    match (max1, max2) {
555        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
556        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
557            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
558                Bound::Inclusive(v1)
559            } else {
560                Bound::Inclusive(v2)
561            }
562        }
563        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
564            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
565                Bound::Inclusive(v1)
566            } else {
567                Bound::Exclusive(v2)
568            }
569        }
570        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
571            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
572                Bound::Exclusive(v1)
573            } else {
574                Bound::Inclusive(v2)
575            }
576        }
577        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
578            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
579                Bound::Exclusive(v1)
580            } else {
581                Bound::Exclusive(v2)
582            }
583        }
584    }
585}
586
587fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
588    let a = normalize_domain(a);
589    let b = normalize_domain(b);
590
591    let result = match (a, b) {
592        (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
593        (Domain::Empty, _) | (_, Domain::Empty) => None,
594
595        (
596            Domain::Range {
597                min: min1,
598                max: max1,
599            },
600            Domain::Range {
601                min: min2,
602                max: max2,
603            },
604        ) => {
605            let min = compute_intersection_min(min1, min2);
606            let max = compute_intersection_max(max1, max2);
607
608            if bounds_contradict(&min, &max) {
609                None
610            } else {
611                Some(Domain::Range { min, max })
612            }
613        }
614        (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
615            let filtered: Vec<LiteralValue> =
616                v1.iter().filter(|x| v2.contains(x)).cloned().collect();
617            if filtered.is_empty() {
618                None
619            } else {
620                Some(Domain::Enumeration(Arc::new(filtered)))
621            }
622        }
623        (Domain::Enumeration(vs), Domain::Range { min, max })
624        | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
625            let mut kept = Vec::new();
626            for v in vs.iter() {
627                if value_within(v, &min, &max) {
628                    kept.push(v.clone());
629                }
630            }
631            if kept.is_empty() {
632                None
633            } else {
634                Some(Domain::Enumeration(Arc::new(kept)))
635            }
636        }
637        (Domain::Enumeration(vs), Domain::Complement(inner))
638        | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
639            match *inner.clone() {
640                Domain::Enumeration(excluded) => {
641                    let mut kept = Vec::new();
642                    for v in vs.iter() {
643                        if !excluded.contains(v) {
644                            kept.push(v.clone());
645                        }
646                    }
647                    if kept.is_empty() {
648                        None
649                    } else {
650                        Some(Domain::Enumeration(Arc::new(kept)))
651                    }
652                }
653                Domain::Range { min, max } => {
654                    // Filter enumeration values that are NOT in the range
655                    let mut kept = Vec::new();
656                    for v in vs.iter() {
657                        if !value_within(v, &min, &max) {
658                            kept.push(v.clone());
659                        }
660                    }
661                    if kept.is_empty() {
662                        None
663                    } else {
664                        Some(Domain::Enumeration(Arc::new(kept)))
665                    }
666                }
667                _ => {
668                    // For other complement types, normalize and recurse
669                    let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
670                    domain_intersection(Domain::Enumeration(vs.clone()), normalized)
671                }
672            }
673        }
674        (Domain::Union(v1), Domain::Union(v2)) => {
675            let mut acc: Vec<Domain> = Vec::new();
676            for a in v1.iter() {
677                for b in v2.iter() {
678                    if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
679                        acc.push(ix);
680                    }
681                }
682            }
683            if acc.is_empty() {
684                None
685            } else {
686                Some(Domain::Union(Arc::new(acc)))
687            }
688        }
689        (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
690            let mut acc: Vec<Domain> = Vec::new();
691            for a in vs.iter() {
692                if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
693                    acc.push(ix);
694                }
695            }
696            if acc.is_empty() {
697                None
698            } else if acc.len() == 1 {
699                Some(acc.remove(0))
700            } else {
701                Some(Domain::Union(Arc::new(acc)))
702            }
703        }
704        // Range ∩ not({p1,p2,...})  =>  Range with excluded points removed (as union of ranges)
705        (Domain::Range { min, max }, Domain::Complement(inner))
706        | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
707            Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
708            _ => {
709                // Normalize the complement (not just the inner value) and recurse.
710                // If normalization doesn't change it, we must not recurse infinitely.
711                let normalized_complement = normalize_domain(Domain::Complement(inner));
712                if matches!(&normalized_complement, Domain::Complement(_)) {
713                    None
714                } else {
715                    domain_intersection(Domain::Range { min, max }, normalized_complement)
716                }
717            }
718        },
719        (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
720            match (a_inner.as_ref(), b_inner.as_ref()) {
721                (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
722                    // not(A) ∩ not(B) == not(A ∪ B)
723                    let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
724                    excluded.extend(b_ex.iter().cloned());
725                    Some(normalize_domain(Domain::Complement(Box::new(
726                        Domain::Enumeration(Arc::new(excluded)),
727                    ))))
728                }
729                _ => None,
730            }
731        }
732    };
733    result.map(normalize_domain)
734}
735
736fn range_minus_excluded_points(
737    min: Bound,
738    max: Bound,
739    excluded: &Arc<Vec<LiteralValue>>,
740) -> Option<Domain> {
741    // Start with a single range and iteratively split on excluded points that fall within it.
742    let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
743
744    for p in excluded.iter() {
745        let mut next: Vec<(Bound, Bound)> = Vec::new();
746
747        for (rmin, rmax) in parts {
748            if !value_within(p, &rmin, &rmax) {
749                next.push((rmin, rmax));
750                continue;
751            }
752
753            // Left part: [rmin, p) or [rmin, p] depending on rmin and exclusion
754            let left_max = Bound::Exclusive(Arc::new(p.clone()));
755            if !bounds_contradict(&rmin, &left_max) {
756                next.push((rmin.clone(), left_max));
757            }
758
759            // Right part: (p, rmax)
760            let right_min = Bound::Exclusive(Arc::new(p.clone()));
761            if !bounds_contradict(&right_min, &rmax) {
762                next.push((right_min, rmax.clone()));
763            }
764        }
765
766        parts = next;
767        if parts.is_empty() {
768            return None;
769        }
770    }
771
772    if parts.is_empty() {
773        None
774    } else if parts.len() == 1 {
775        let (min, max) = parts.remove(0);
776        Some(Domain::Range { min, max })
777    } else {
778        Some(Domain::Union(Arc::new(
779            parts
780                .into_iter()
781                .map(|(min, max)| Domain::Range { min, max })
782                .collect(),
783        )))
784    }
785}
786
787fn invert_bound(bound: Bound) -> Bound {
788    match bound {
789        Bound::Unbounded => Bound::Unbounded,
790        Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
791        Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
792    }
793}
794
795fn normalize_domain(d: Domain) -> Domain {
796    match d {
797        Domain::Complement(inner) => {
798            let normalized_inner = normalize_domain(*inner);
799            match normalized_inner {
800                Domain::Complement(double_inner) => *double_inner,
801                Domain::Range { min, max } => match (&min, &max) {
802                    (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
803                    (Bound::Unbounded, max) => Domain::Range {
804                        min: invert_bound(max.clone()),
805                        max: Bound::Unbounded,
806                    },
807                    (min, Bound::Unbounded) => Domain::Range {
808                        min: Bound::Unbounded,
809                        max: invert_bound(min.clone()),
810                    },
811                    (min, max) => Domain::Union(Arc::new(vec![
812                        Domain::Range {
813                            min: Bound::Unbounded,
814                            max: invert_bound(min.clone()),
815                        },
816                        Domain::Range {
817                            min: invert_bound(max.clone()),
818                            max: Bound::Unbounded,
819                        },
820                    ])),
821                },
822                Domain::Enumeration(vals) => {
823                    if vals.len() == 1 {
824                        if let Some(lit) = vals.first() {
825                            if let Value::Boolean(BooleanValue::True) = &lit.value {
826                                return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
827                                    BooleanValue::False,
828                                )]));
829                            }
830                            if let Value::Boolean(BooleanValue::False) = &lit.value {
831                                return Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(
832                                    BooleanValue::True,
833                                )]));
834                            }
835                        }
836                    }
837                    Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
838                }
839                Domain::Unconstrained => Domain::Empty,
840                Domain::Empty => Domain::Unconstrained,
841                Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
842            }
843        }
844        Domain::Empty => Domain::Empty,
845        Domain::Union(parts) => {
846            let mut flat: Vec<Domain> = Vec::new();
847            for p in parts.iter().cloned() {
848                let normalized = normalize_domain(p);
849                match normalized {
850                    Domain::Union(inner) => flat.extend(inner.iter().cloned()),
851                    Domain::Unconstrained => return Domain::Unconstrained,
852                    Domain::Enumeration(vals) if vals.is_empty() => {}
853                    other => flat.push(other),
854                }
855            }
856
857            let mut all_enum_values: Vec<LiteralValue> = Vec::new();
858            let mut ranges: Vec<Domain> = Vec::new();
859            let mut others: Vec<Domain> = Vec::new();
860
861            for domain in flat {
862                match domain {
863                    Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
864                    Domain::Range { .. } => ranges.push(domain),
865                    other => others.push(other),
866                }
867            }
868
869            all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
870                -1 => Ordering::Less,
871                0 => Ordering::Equal,
872                _ => Ordering::Greater,
873            });
874            all_enum_values.dedup();
875
876            all_enum_values.retain(|v| {
877                !ranges.iter().any(|r| {
878                    if let Domain::Range { min, max } = r {
879                        value_within(v, min, max)
880                    } else {
881                        false
882                    }
883                })
884            });
885
886            let mut result: Vec<Domain> = Vec::new();
887            result.extend(ranges);
888            result = merge_ranges(result);
889
890            if !all_enum_values.is_empty() {
891                result.push(Domain::Enumeration(Arc::new(all_enum_values)));
892            }
893            result.extend(others);
894
895            result.sort_by(|a, b| match (a, b) {
896                (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
897                (Domain::Range { .. }, _) => Ordering::Less,
898                (_, Domain::Range { .. }) => Ordering::Greater,
899                (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
900                (Domain::Enumeration(_), _) => Ordering::Less,
901                (_, Domain::Enumeration(_)) => Ordering::Greater,
902                _ => Ordering::Equal,
903            });
904
905            if result.is_empty() {
906                Domain::Enumeration(Arc::new(vec![]))
907            } else if result.len() == 1 {
908                result.remove(0)
909            } else {
910                Domain::Union(Arc::new(result))
911            }
912        }
913        Domain::Enumeration(values) => {
914            let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
915            sorted.sort_by(|a, b| match lit_cmp(a, b) {
916                -1 => Ordering::Less,
917                0 => Ordering::Equal,
918                _ => Ordering::Greater,
919            });
920            sorted.dedup();
921            Domain::Enumeration(Arc::new(sorted))
922        }
923        other => other,
924    }
925}
926
927fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
928    let mut result = Vec::new();
929    let mut ranges: Vec<(Bound, Bound)> = Vec::new();
930    let mut others = Vec::new();
931
932    for d in domains {
933        match d {
934            Domain::Range { min, max } => ranges.push((min, max)),
935            other => others.push(other),
936        }
937    }
938
939    if ranges.is_empty() {
940        return others;
941    }
942
943    ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
944
945    let mut merged: Vec<(Bound, Bound)> = Vec::new();
946    let mut current = ranges[0].clone();
947
948    for next in ranges.iter().skip(1) {
949        if ranges_adjacent_or_overlap(&current, next) {
950            current = (
951                min_bound(&current.0, &next.0),
952                max_bound(&current.1, &next.1),
953            );
954        } else {
955            merged.push(current);
956            current = next.clone();
957        }
958    }
959    merged.push(current);
960
961    for (min, max) in merged {
962        result.push(Domain::Range { min, max });
963    }
964    result.extend(others);
965
966    result
967}
968
969fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
970    match (a, b) {
971        (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
972        (Bound::Unbounded, _) => Ordering::Less,
973        (_, Bound::Unbounded) => Ordering::Greater,
974        (Bound::Inclusive(v1), Bound::Inclusive(v2))
975        | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
976            -1 => Ordering::Less,
977            0 => Ordering::Equal,
978            _ => Ordering::Greater,
979        },
980        (Bound::Inclusive(v1), Bound::Exclusive(v2))
981        | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
982            -1 => Ordering::Less,
983            0 => {
984                if matches!(a, Bound::Inclusive(_)) {
985                    Ordering::Less
986                } else {
987                    Ordering::Greater
988                }
989            }
990            _ => Ordering::Greater,
991        },
992    }
993}
994
995fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
996    match (&r1.1, &r2.0) {
997        (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
998        (Bound::Inclusive(v1), Bound::Inclusive(v2))
999        | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1000        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1001        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1002    }
1003}
1004
1005fn min_bound(a: &Bound, b: &Bound) -> Bound {
1006    match (a, b) {
1007        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1008        _ => {
1009            if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1010                a.clone()
1011            } else {
1012                b.clone()
1013            }
1014        }
1015    }
1016}
1017
1018fn max_bound(a: &Bound, b: &Bound) -> Bound {
1019    match (a, b) {
1020        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1021        _ => {
1022            if matches!(compare_bounds(a, b), Ordering::Greater) {
1023                a.clone()
1024            } else {
1025                b.clone()
1026            }
1027        }
1028    }
1029}
1030
1031#[cfg(test)]
1032mod tests {
1033    use super::*;
1034    use rust_decimal::Decimal;
1035
1036    fn num(n: i64) -> LiteralValue {
1037        LiteralValue::number(Decimal::from(n))
1038    }
1039
1040    fn fact(name: &str) -> FactPath {
1041        FactPath::local(name.to_string())
1042    }
1043
1044    #[test]
1045    fn test_normalize_double_complement() {
1046        let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1047        let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1048        let normalized = normalize_domain(double);
1049        assert_eq!(normalized, inner);
1050    }
1051
1052    #[test]
1053    fn test_normalize_union_absorbs_unconstrained() {
1054        let union = Domain::Union(Arc::new(vec![
1055            Domain::Range {
1056                min: Bound::Inclusive(Arc::new(num(0))),
1057                max: Bound::Inclusive(Arc::new(num(10))),
1058            },
1059            Domain::Unconstrained,
1060        ]));
1061        let normalized = normalize_domain(union);
1062        assert_eq!(normalized, Domain::Unconstrained);
1063    }
1064
1065    #[test]
1066    fn test_domain_display() {
1067        let range = Domain::Range {
1068            min: Bound::Inclusive(Arc::new(num(10))),
1069            max: Bound::Exclusive(Arc::new(num(20))),
1070        };
1071        assert_eq!(format!("{}", range), "[10, 20)");
1072
1073        let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1074        assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1075    }
1076
1077    #[test]
1078    fn test_extract_domain_from_comparison() {
1079        let constraint = Constraint::Comparison {
1080            fact: fact("age"),
1081            op: ComparisonComputation::GreaterThan,
1082            value: Arc::new(num(18)),
1083        };
1084
1085        let domains = extract_domains_from_constraint(&constraint).unwrap();
1086        let age_domain = domains.get(&fact("age")).unwrap();
1087
1088        assert_eq!(
1089            *age_domain,
1090            Domain::Range {
1091                min: Bound::Exclusive(Arc::new(num(18))),
1092                max: Bound::Unbounded,
1093            }
1094        );
1095    }
1096
1097    #[test]
1098    fn test_extract_domain_from_and() {
1099        let constraint = Constraint::And(
1100            Box::new(Constraint::Comparison {
1101                fact: fact("age"),
1102                op: ComparisonComputation::GreaterThan,
1103                value: Arc::new(num(18)),
1104            }),
1105            Box::new(Constraint::Comparison {
1106                fact: fact("age"),
1107                op: ComparisonComputation::LessThan,
1108                value: Arc::new(num(65)),
1109            }),
1110        );
1111
1112        let domains = extract_domains_from_constraint(&constraint).unwrap();
1113        let age_domain = domains.get(&fact("age")).unwrap();
1114
1115        assert_eq!(
1116            *age_domain,
1117            Domain::Range {
1118                min: Bound::Exclusive(Arc::new(num(18))),
1119                max: Bound::Exclusive(Arc::new(num(65))),
1120            }
1121        );
1122    }
1123
1124    #[test]
1125    fn test_extract_domain_from_equality() {
1126        let constraint = Constraint::Comparison {
1127            fact: fact("status"),
1128            op: ComparisonComputation::Equal,
1129            value: Arc::new(LiteralValue::text("active".to_string())),
1130        };
1131
1132        let domains = extract_domains_from_constraint(&constraint).unwrap();
1133        let status_domain = domains.get(&fact("status")).unwrap();
1134
1135        assert_eq!(
1136            *status_domain,
1137            Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1138        );
1139    }
1140
1141    #[test]
1142    fn test_extract_domain_from_boolean_fact() {
1143        let constraint = Constraint::Fact(fact("is_active"));
1144
1145        let domains = extract_domains_from_constraint(&constraint).unwrap();
1146        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1147
1148        assert_eq!(
1149            *is_active_domain,
1150            Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::True)]))
1151        );
1152    }
1153
1154    #[test]
1155    fn test_extract_domain_from_not_boolean_fact() {
1156        let constraint = Constraint::Not(Box::new(Constraint::Fact(fact("is_active"))));
1157
1158        let domains = extract_domains_from_constraint(&constraint).unwrap();
1159        let is_active_domain = domains.get(&fact("is_active")).unwrap();
1160
1161        assert_eq!(
1162            *is_active_domain,
1163            Domain::Enumeration(Arc::new(vec![LiteralValue::boolean(BooleanValue::False)]))
1164        );
1165    }
1166}