Skip to main content

lemma/inversion/
domain.rs

1//! Domain types and operations for inversion
2//!
3//! Provides:
4//! - `Domain` and `Bound` types for representing concrete value constraints
5//! - Domain operations: intersection, union, normalization
6//! - `extract_domains_from_constraint()`: extracts domains from constraints
7
8use crate::planning::semantics::{ComparisonComputation, DataPath, LiteralValue, ValueKind};
9use serde::ser::{Serialize, SerializeStruct, Serializer};
10use std::cmp::Ordering;
11use std::collections::HashMap;
12use std::fmt;
13use std::sync::Arc;
14
15use super::constraint::Constraint;
16
17/// Domain specification for valid values
18#[derive(Debug, Clone, PartialEq)]
19pub enum Domain {
20    /// A single continuous range
21    Range { min: Bound, max: Bound },
22
23    /// Multiple disjoint ranges
24    Union(Arc<Vec<Domain>>),
25
26    /// Specific enumerated values only
27    Enumeration(Arc<Vec<LiteralValue>>),
28
29    /// Everything except these constraints
30    Complement(Box<Domain>),
31
32    /// Any value (no constraints)
33    Unconstrained,
34
35    /// Empty domain (no valid values) - represents unsatisfiable constraints
36    Empty,
37}
38
39impl Domain {
40    /// Check if this domain is satisfiable (has at least one valid value)
41    ///
42    /// Returns false for Empty domains and empty Enumerations.
43    pub fn is_satisfiable(&self) -> bool {
44        match self {
45            Domain::Empty => false,
46            Domain::Enumeration(values) => !values.is_empty(),
47            Domain::Union(parts) => parts.iter().any(|p| p.is_satisfiable()),
48            Domain::Range { min, max } => !bounds_contradict(min, max),
49            Domain::Complement(inner) => !matches!(inner.as_ref(), Domain::Unconstrained),
50            Domain::Unconstrained => true,
51        }
52    }
53
54    /// Check if this domain is empty (unsatisfiable)
55    pub fn is_empty(&self) -> bool {
56        !self.is_satisfiable()
57    }
58
59    /// Intersect this domain with another, returning Empty if no overlap.
60    /// `domain_intersection` returns `None` exactly when the result is empty.
61    pub fn intersect(&self, other: &Domain) -> Domain {
62        match domain_intersection(self.clone(), other.clone()) {
63            Some(d) => d,
64            None => Domain::Empty,
65        }
66    }
67
68    /// Check if a value is contained in this domain
69    pub fn contains(&self, value: &LiteralValue) -> bool {
70        match self {
71            Domain::Empty => false,
72            Domain::Unconstrained => true,
73            Domain::Enumeration(values) => values.contains(value),
74            Domain::Range { min, max } => value_within(value, min, max),
75            Domain::Union(parts) => parts.iter().any(|p| p.contains(value)),
76            Domain::Complement(inner) => !inner.contains(value),
77        }
78    }
79}
80
81/// Bound specification for ranges
82#[derive(Debug, Clone, PartialEq)]
83pub enum Bound {
84    /// Inclusive bound [value
85    Inclusive(Arc<LiteralValue>),
86
87    /// Exclusive bound (value
88    Exclusive(Arc<LiteralValue>),
89
90    /// Unbounded (-infinity or +infinity)
91    Unbounded,
92}
93
94impl fmt::Display for Domain {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        match self {
97            Domain::Empty => write!(f, "empty"),
98            Domain::Unconstrained => write!(f, "any"),
99            Domain::Enumeration(vals) => {
100                write!(f, "{{")?;
101                for (i, v) in vals.iter().enumerate() {
102                    if i > 0 {
103                        write!(f, ", ")?;
104                    }
105                    write!(f, "{}", v)?;
106                }
107                write!(f, "}}")
108            }
109            Domain::Range { min, max } => {
110                let (l_bracket, r_bracket) = match (min, max) {
111                    (Bound::Inclusive(_), Bound::Inclusive(_)) => ('[', ']'),
112                    (Bound::Inclusive(_), Bound::Exclusive(_)) => ('[', ')'),
113                    (Bound::Exclusive(_), Bound::Inclusive(_)) => ('(', ']'),
114                    (Bound::Exclusive(_), Bound::Exclusive(_)) => ('(', ')'),
115                    (Bound::Unbounded, Bound::Inclusive(_)) => ('(', ']'),
116                    (Bound::Unbounded, Bound::Exclusive(_)) => ('(', ')'),
117                    (Bound::Inclusive(_), Bound::Unbounded) => ('[', ')'),
118                    (Bound::Exclusive(_), Bound::Unbounded) => ('(', ')'),
119                    (Bound::Unbounded, Bound::Unbounded) => ('(', ')'),
120                };
121
122                let min_str = match min {
123                    Bound::Unbounded => "-inf".to_string(),
124                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
125                };
126                let max_str = match max {
127                    Bound::Unbounded => "+inf".to_string(),
128                    Bound::Inclusive(v) | Bound::Exclusive(v) => v.as_ref().to_string(),
129                };
130                write!(f, "{}{}, {}{}", l_bracket, min_str, max_str, r_bracket)
131            }
132            Domain::Union(parts) => {
133                for (i, p) in parts.iter().enumerate() {
134                    if i > 0 {
135                        write!(f, " | ")?;
136                    }
137                    write!(f, "{}", p)?;
138                }
139                Ok(())
140            }
141            Domain::Complement(inner) => write!(f, "not ({})", inner),
142        }
143    }
144}
145
146impl fmt::Display for Bound {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        match self {
149            Bound::Unbounded => write!(f, "inf"),
150            Bound::Inclusive(v) => write!(f, "[{}", v.as_ref()),
151            Bound::Exclusive(v) => write!(f, "({}", v.as_ref()),
152        }
153    }
154}
155
156impl Serialize for Domain {
157    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
158    where
159        S: Serializer,
160    {
161        match self {
162            Domain::Empty => {
163                let mut st = serializer.serialize_struct("domain", 1)?;
164                st.serialize_field("type", "empty")?;
165                st.end()
166            }
167            Domain::Unconstrained => {
168                let mut st = serializer.serialize_struct("domain", 1)?;
169                st.serialize_field("type", "unconstrained")?;
170                st.end()
171            }
172            Domain::Enumeration(vals) => {
173                let mut st = serializer.serialize_struct("domain", 2)?;
174                st.serialize_field("type", "enumeration")?;
175                st.serialize_field("values", vals)?;
176                st.end()
177            }
178            Domain::Range { min, max } => {
179                let mut st = serializer.serialize_struct("domain", 3)?;
180                st.serialize_field("type", "range")?;
181                st.serialize_field("min", min)?;
182                st.serialize_field("max", max)?;
183                st.end()
184            }
185            Domain::Union(parts) => {
186                let mut st = serializer.serialize_struct("domain", 2)?;
187                st.serialize_field("type", "union")?;
188                st.serialize_field("parts", parts)?;
189                st.end()
190            }
191            Domain::Complement(inner) => {
192                let mut st = serializer.serialize_struct("domain", 2)?;
193                st.serialize_field("type", "complement")?;
194                st.serialize_field("inner", inner)?;
195                st.end()
196            }
197        }
198    }
199}
200
201impl Serialize for Bound {
202    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
203    where
204        S: Serializer,
205    {
206        match self {
207            Bound::Unbounded => {
208                let mut st = serializer.serialize_struct("bound", 1)?;
209                st.serialize_field("type", "unbounded")?;
210                st.end()
211            }
212            Bound::Inclusive(v) => {
213                let mut st = serializer.serialize_struct("bound", 2)?;
214                st.serialize_field("type", "inclusive")?;
215                st.serialize_field("value", v.as_ref())?;
216                st.end()
217            }
218            Bound::Exclusive(v) => {
219                let mut st = serializer.serialize_struct("bound", 2)?;
220                st.serialize_field("type", "exclusive")?;
221                st.serialize_field("value", v.as_ref())?;
222                st.end()
223            }
224        }
225    }
226}
227
228/// Extract domains for all data mentioned in a constraint
229pub fn extract_domains_from_constraint(
230    constraint: &Constraint,
231) -> Result<HashMap<DataPath, Domain>, crate::Error> {
232    let all_datas = constraint.collect_data();
233    let mut domains = HashMap::new();
234
235    for data_path in all_datas {
236        // None means the data appears in the constraint but has no extractable
237        // bound (e.g. only used in equality with another data). Treating it as
238        // Unconstrained is correct: the solver will enumerate values.
239        let domain =
240            extract_domain_for_data(constraint, &data_path)?.unwrap_or(Domain::Unconstrained);
241        domains.insert(data_path, domain);
242    }
243
244    Ok(domains)
245}
246
247fn extract_domain_for_data(
248    constraint: &Constraint,
249    data_path: &DataPath,
250) -> Result<Option<Domain>, crate::Error> {
251    let domain = match constraint {
252        Constraint::True => return Ok(None),
253        Constraint::False => Some(Domain::Enumeration(Arc::new(vec![]))),
254
255        Constraint::Comparison { data, op, value } => {
256            if data == data_path {
257                Some(comparison_to_domain(op, value.as_ref())?)
258            } else {
259                None
260            }
261        }
262
263        Constraint::Data(fp) => {
264            if fp == data_path {
265                Some(Domain::Enumeration(Arc::new(vec![
266                    LiteralValue::from_bool(true),
267                ])))
268            } else {
269                None
270            }
271        }
272
273        Constraint::And(left, right) => {
274            let left_domain = extract_domain_for_data(left, data_path)?;
275            let right_domain = extract_domain_for_data(right, data_path)?;
276            match (left_domain, right_domain) {
277                (None, None) => None,
278                (Some(d), None) | (None, Some(d)) => Some(normalize_domain(d)),
279                (Some(a), Some(b)) => match domain_intersection(a, b) {
280                    Some(domain) => Some(domain),
281                    None => Some(Domain::Enumeration(Arc::new(vec![]))),
282                },
283            }
284        }
285
286        Constraint::Or(left, right) => {
287            let left_domain = extract_domain_for_data(left, data_path)?;
288            let right_domain = extract_domain_for_data(right, data_path)?;
289            union_optional_domains(left_domain, right_domain)
290        }
291
292        Constraint::Not(inner) => {
293            // Handle not (data is value)
294            if let Constraint::Comparison { data, op, value } = inner.as_ref() {
295                if data == data_path && op.is_equal() {
296                    return Ok(Some(normalize_domain(Domain::Complement(Box::new(
297                        Domain::Enumeration(Arc::new(vec![value.as_ref().clone()])),
298                    )))));
299                }
300            }
301
302            // Handle not (boolean_data)
303            if let Constraint::Data(fp) = inner.as_ref() {
304                if fp == data_path {
305                    return Ok(Some(Domain::Enumeration(Arc::new(vec![
306                        LiteralValue::from_bool(false),
307                    ]))));
308                }
309            }
310
311            extract_domain_for_data(inner, data_path)?
312                .map(|domain| normalize_domain(Domain::Complement(Box::new(domain))))
313        }
314    };
315
316    Ok(domain.map(normalize_domain))
317}
318
319fn comparison_to_domain(
320    op: &ComparisonComputation,
321    value: &LiteralValue,
322) -> Result<Domain, crate::Error> {
323    if op.is_equal() {
324        return Ok(Domain::Enumeration(Arc::new(vec![value.clone()])));
325    }
326    if op.is_not_equal() {
327        return Ok(Domain::Complement(Box::new(Domain::Enumeration(Arc::new(
328            vec![value.clone()],
329        )))));
330    }
331    match op {
332        ComparisonComputation::LessThan => Ok(Domain::Range {
333            min: Bound::Unbounded,
334            max: Bound::Exclusive(Arc::new(value.clone())),
335        }),
336        ComparisonComputation::LessThanOrEqual => Ok(Domain::Range {
337            min: Bound::Unbounded,
338            max: Bound::Inclusive(Arc::new(value.clone())),
339        }),
340        ComparisonComputation::GreaterThan => Ok(Domain::Range {
341            min: Bound::Exclusive(Arc::new(value.clone())),
342            max: Bound::Unbounded,
343        }),
344        ComparisonComputation::GreaterThanOrEqual => Ok(Domain::Range {
345            min: Bound::Inclusive(Arc::new(value.clone())),
346            max: Bound::Unbounded,
347        }),
348        _ => unreachable!(
349            "BUG: unsupported comparison operator for domain extraction: {:?}",
350            op
351        ),
352    }
353}
354
355/// Compute the domain for a single comparison-atom used by inversion constraints.
356///
357/// This is used by numeric-aware constraint simplification to derive implications/exclusions
358/// between comparison atoms on the same data.
359pub(crate) fn domain_for_comparison_atom(
360    op: &ComparisonComputation,
361    value: &LiteralValue,
362) -> Result<Domain, crate::Error> {
363    comparison_to_domain(op, value)
364}
365
366impl Domain {
367    /// Proven subset check for the atom-domain forms we generate from comparisons:
368    /// - Range
369    /// - Enumeration
370    /// - Complement(Enumeration) (used for `is not`)
371    ///
372    /// Returns false when the relationship cannot be proven with these forms.
373    pub(crate) fn is_subset_of(&self, other: &Domain) -> bool {
374        match (self, other) {
375            (Domain::Empty, _) => true,
376            (_, Domain::Unconstrained) => true,
377            (Domain::Unconstrained, _) => false,
378
379            (Domain::Enumeration(a), Domain::Enumeration(b)) => a.iter().all(|v| b.contains(v)),
380            (Domain::Enumeration(vals), Domain::Range { min, max }) => {
381                vals.iter().all(|v| value_within(v, min, max))
382            }
383
384            (
385                Domain::Range {
386                    min: amin,
387                    max: amax,
388                },
389                Domain::Range {
390                    min: bmin,
391                    max: bmax,
392                },
393            ) => range_within_range(amin, amax, bmin, bmax),
394
395            // Range ⊆ not({p}) when the range does not include p (for all excluded points)
396            (Domain::Range { min, max }, Domain::Complement(inner)) => match inner.as_ref() {
397                Domain::Enumeration(excluded) => {
398                    excluded.iter().all(|p| !value_within(p, min, max))
399                }
400                _ => false,
401            },
402
403            // {v} ⊆ not({p}) when v is not excluded
404            (Domain::Enumeration(vals), Domain::Complement(inner)) => match inner.as_ref() {
405                Domain::Enumeration(excluded) => vals.iter().all(|v| !excluded.contains(v)),
406                _ => false,
407            },
408
409            // not(A) ⊆ not(B)  iff  B ⊆ A  (for enumeration complements)
410            (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
411                match (a_inner.as_ref(), b_inner.as_ref()) {
412                    (Domain::Enumeration(excluded_a), Domain::Enumeration(excluded_b)) => {
413                        excluded_b.iter().all(|v| excluded_a.contains(v))
414                    }
415                    _ => false,
416                }
417            }
418
419            _ => false,
420        }
421    }
422}
423
424fn range_within_range(amin: &Bound, amax: &Bound, bmin: &Bound, bmax: &Bound) -> bool {
425    lower_bound_geq(amin, bmin) && upper_bound_leq(amax, bmax)
426}
427
428fn lower_bound_geq(a: &Bound, b: &Bound) -> bool {
429    match (a, b) {
430        (_, Bound::Unbounded) => true,
431        (Bound::Unbounded, _) => false,
432        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
433        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) >= 0,
434        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
435            let c = lit_cmp(av.as_ref(), bv.as_ref());
436            c >= 0
437        }
438        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
439            // a >= (b) only if a's value > b's value
440            lit_cmp(av.as_ref(), bv.as_ref()) > 0
441        }
442    }
443}
444
445fn upper_bound_leq(a: &Bound, b: &Bound) -> bool {
446    match (a, b) {
447        (Bound::Unbounded, Bound::Unbounded) => true,
448        (_, Bound::Unbounded) => true,
449        (Bound::Unbounded, _) => false,
450        (Bound::Inclusive(av), Bound::Inclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
451        (Bound::Exclusive(av), Bound::Exclusive(bv)) => lit_cmp(av.as_ref(), bv.as_ref()) <= 0,
452        (Bound::Exclusive(av), Bound::Inclusive(bv)) => {
453            // (a) <= [b] when a <= b
454            lit_cmp(av.as_ref(), bv.as_ref()) <= 0
455        }
456        (Bound::Inclusive(av), Bound::Exclusive(bv)) => {
457            // [a] <= (b) only if a < b
458            lit_cmp(av.as_ref(), bv.as_ref()) < 0
459        }
460    }
461}
462
463fn union_optional_domains(a: Option<Domain>, b: Option<Domain>) -> Option<Domain> {
464    match (a, b) {
465        (None, None) => None,
466        (Some(d), None) | (None, Some(d)) => Some(d),
467        (Some(a), Some(b)) => Some(normalize_domain(Domain::Union(Arc::new(vec![a, b])))),
468    }
469}
470
471fn lit_cmp(a: &LiteralValue, b: &LiteralValue) -> i8 {
472    use std::cmp::Ordering;
473
474    match (&a.value, &b.value) {
475        (ValueKind::Number(la), ValueKind::Number(lb)) => match la.cmp(lb) {
476            Ordering::Less => -1,
477            Ordering::Equal => 0,
478            Ordering::Greater => 1,
479        },
480
481        (ValueKind::Boolean(la), ValueKind::Boolean(lb)) => match la.cmp(lb) {
482            Ordering::Less => -1,
483            Ordering::Equal => 0,
484            Ordering::Greater => 1,
485        },
486
487        (ValueKind::Text(la), ValueKind::Text(lb)) => match la.cmp(lb) {
488            Ordering::Less => -1,
489            Ordering::Equal => 0,
490            Ordering::Greater => 1,
491        },
492
493        (ValueKind::Date(la), ValueKind::Date(lb)) => match la.cmp(lb) {
494            Ordering::Less => -1,
495            Ordering::Equal => 0,
496            Ordering::Greater => 1,
497        },
498
499        (ValueKind::Time(la), ValueKind::Time(lb)) => match la.cmp(lb) {
500            Ordering::Less => -1,
501            Ordering::Equal => 0,
502            Ordering::Greater => 1,
503        },
504
505        (ValueKind::Quantity(la, sig_a), ValueKind::Quantity(lb, sig_b))
506            if a.lemma_type.is_calendar_like() && b.lemma_type.is_calendar_like() =>
507        {
508            let lu =
509                crate::planning::semantics::semantic_calendar_unit_from_quantity_signature(sig_a);
510            let lu2 =
511                crate::planning::semantics::semantic_calendar_unit_from_quantity_signature(sig_b);
512            let a_months = crate::computation::units::convert_calendar_magnitude(
513                *la,
514                &lu,
515                &crate::planning::semantics::SemanticCalendarUnit::Month,
516            );
517            let b_months = crate::computation::units::convert_calendar_magnitude(
518                *lb,
519                &lu2,
520                &crate::planning::semantics::SemanticCalendarUnit::Month,
521            );
522            match a_months.cmp(&b_months) {
523                Ordering::Less => -1,
524                Ordering::Equal => 0,
525                Ordering::Greater => 1,
526            }
527        }
528
529        (ValueKind::Ratio(la, _), ValueKind::Ratio(lb, _)) => match la.cmp(lb) {
530            Ordering::Less => -1,
531            Ordering::Equal => 0,
532            Ordering::Greater => 1,
533        },
534
535        (ValueKind::Quantity(la, _), ValueKind::Quantity(lb, _)) => {
536            let a_decomp = a
537                .lemma_type
538                .quantity_type_decomposition()
539                .expect("BUG: decomposition must be resolved after planning");
540            let b_decomp = b
541                .lemma_type
542                .quantity_type_decomposition()
543                .expect("BUG: decomposition must be resolved after planning");
544            if a_decomp != b_decomp {
545                unreachable!(
546                    "BUG: lit_cmp compared quantities with different decompositions ({:?} vs {:?})",
547                    a_decomp, b_decomp
548                );
549            }
550            match la.cmp(lb) {
551                Ordering::Less => -1,
552                Ordering::Equal => 0,
553                Ordering::Greater => 1,
554            }
555        }
556
557        _ => unreachable!(
558            "BUG: lit_cmp cannot compare different literal kinds ({:?} vs {:?})",
559            a.get_type(),
560            b.get_type()
561        ),
562    }
563}
564
565fn value_within(v: &LiteralValue, min: &Bound, max: &Bound) -> bool {
566    let ge_min = match min {
567        Bound::Unbounded => true,
568        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) >= 0,
569        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) > 0,
570    };
571    let le_max = match max {
572        Bound::Unbounded => true,
573        Bound::Inclusive(m) => lit_cmp(v, m.as_ref()) <= 0,
574        Bound::Exclusive(m) => lit_cmp(v, m.as_ref()) < 0,
575    };
576    ge_min && le_max
577}
578
579fn bounds_contradict(min: &Bound, max: &Bound) -> bool {
580    match (min, max) {
581        (Bound::Unbounded, _) | (_, Bound::Unbounded) => false,
582        (Bound::Inclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) > 0,
583        (Bound::Inclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
584        (Bound::Exclusive(a), Bound::Inclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
585        (Bound::Exclusive(a), Bound::Exclusive(b)) => lit_cmp(a.as_ref(), b.as_ref()) >= 0,
586    }
587}
588
589fn compute_intersection_min(min1: Bound, min2: Bound) -> Bound {
590    match (min1, min2) {
591        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
592        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
593            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
594                Bound::Inclusive(v1)
595            } else {
596                Bound::Inclusive(v2)
597            }
598        }
599        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
600            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
601                Bound::Inclusive(v1)
602            } else {
603                Bound::Exclusive(v2)
604            }
605        }
606        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
607            if lit_cmp(v1.as_ref(), v2.as_ref()) > 0 {
608                Bound::Exclusive(v1)
609            } else {
610                Bound::Inclusive(v2)
611            }
612        }
613        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
614            if lit_cmp(v1.as_ref(), v2.as_ref()) >= 0 {
615                Bound::Exclusive(v1)
616            } else {
617                Bound::Exclusive(v2)
618            }
619        }
620    }
621}
622
623fn compute_intersection_max(max1: Bound, max2: Bound) -> Bound {
624    match (max1, max2) {
625        (Bound::Unbounded, x) | (x, Bound::Unbounded) => x,
626        (Bound::Inclusive(v1), Bound::Inclusive(v2)) => {
627            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
628                Bound::Inclusive(v1)
629            } else {
630                Bound::Inclusive(v2)
631            }
632        }
633        (Bound::Inclusive(v1), Bound::Exclusive(v2)) => {
634            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
635                Bound::Inclusive(v1)
636            } else {
637                Bound::Exclusive(v2)
638            }
639        }
640        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => {
641            if lit_cmp(v1.as_ref(), v2.as_ref()) < 0 {
642                Bound::Exclusive(v1)
643            } else {
644                Bound::Inclusive(v2)
645            }
646        }
647        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => {
648            if lit_cmp(v1.as_ref(), v2.as_ref()) <= 0 {
649                Bound::Exclusive(v1)
650            } else {
651                Bound::Exclusive(v2)
652            }
653        }
654    }
655}
656
657fn domain_intersection(a: Domain, b: Domain) -> Option<Domain> {
658    let a = normalize_domain(a);
659    let b = normalize_domain(b);
660
661    let result = match (a, b) {
662        (Domain::Unconstrained, d) | (d, Domain::Unconstrained) => Some(d),
663        (Domain::Empty, _) | (_, Domain::Empty) => None,
664
665        (
666            Domain::Range {
667                min: min1,
668                max: max1,
669            },
670            Domain::Range {
671                min: min2,
672                max: max2,
673            },
674        ) => {
675            let min = compute_intersection_min(min1, min2);
676            let max = compute_intersection_max(max1, max2);
677
678            if bounds_contradict(&min, &max) {
679                None
680            } else {
681                Some(Domain::Range { min, max })
682            }
683        }
684        (Domain::Enumeration(v1), Domain::Enumeration(v2)) => {
685            let filtered: Vec<LiteralValue> =
686                v1.iter().filter(|x| v2.contains(x)).cloned().collect();
687            if filtered.is_empty() {
688                None
689            } else {
690                Some(Domain::Enumeration(Arc::new(filtered)))
691            }
692        }
693        (Domain::Enumeration(vs), Domain::Range { min, max })
694        | (Domain::Range { min, max }, Domain::Enumeration(vs)) => {
695            let mut kept = Vec::new();
696            for v in vs.iter() {
697                if value_within(v, &min, &max) {
698                    kept.push(v.clone());
699                }
700            }
701            if kept.is_empty() {
702                None
703            } else {
704                Some(Domain::Enumeration(Arc::new(kept)))
705            }
706        }
707        (Domain::Enumeration(vs), Domain::Complement(inner))
708        | (Domain::Complement(inner), Domain::Enumeration(vs)) => {
709            match *inner.clone() {
710                Domain::Enumeration(excluded) => {
711                    let mut kept = Vec::new();
712                    for v in vs.iter() {
713                        if !excluded.contains(v) {
714                            kept.push(v.clone());
715                        }
716                    }
717                    if kept.is_empty() {
718                        None
719                    } else {
720                        Some(Domain::Enumeration(Arc::new(kept)))
721                    }
722                }
723                Domain::Range { min, max } => {
724                    // Filter enumeration values that are NOT in the range
725                    let mut kept = Vec::new();
726                    for v in vs.iter() {
727                        if !value_within(v, &min, &max) {
728                            kept.push(v.clone());
729                        }
730                    }
731                    if kept.is_empty() {
732                        None
733                    } else {
734                        Some(Domain::Enumeration(Arc::new(kept)))
735                    }
736                }
737                _ => {
738                    // For other complement types, normalize and recurse
739                    let normalized = normalize_domain(Domain::Complement(Box::new(*inner)));
740                    domain_intersection(Domain::Enumeration(vs.clone()), normalized)
741                }
742            }
743        }
744        (Domain::Union(v1), Domain::Union(v2)) => {
745            let mut acc: Vec<Domain> = Vec::new();
746            for a in v1.iter() {
747                for b in v2.iter() {
748                    if let Some(ix) = domain_intersection(a.clone(), b.clone()) {
749                        acc.push(ix);
750                    }
751                }
752            }
753            if acc.is_empty() {
754                None
755            } else {
756                Some(Domain::Union(Arc::new(acc)))
757            }
758        }
759        (Domain::Union(vs), d) | (d, Domain::Union(vs)) => {
760            let mut acc: Vec<Domain> = Vec::new();
761            for a in vs.iter() {
762                if let Some(ix) = domain_intersection(a.clone(), d.clone()) {
763                    acc.push(ix);
764                }
765            }
766            if acc.is_empty() {
767                None
768            } else if acc.len() == 1 {
769                Some(acc.remove(0))
770            } else {
771                Some(Domain::Union(Arc::new(acc)))
772            }
773        }
774        // Range ∩ not({p1,p2,...})  =>  Range with excluded points removed (as union of ranges)
775        (Domain::Range { min, max }, Domain::Complement(inner))
776        | (Domain::Complement(inner), Domain::Range { min, max }) => match inner.as_ref() {
777            Domain::Enumeration(excluded) => range_minus_excluded_points(min, max, excluded),
778            _ => {
779                // Normalize the complement (not just the inner value) and recurse.
780                // If normalization doesn't change it, we must not recurse infinitely.
781                let normalized_complement = normalize_domain(Domain::Complement(inner));
782                if matches!(&normalized_complement, Domain::Complement(_)) {
783                    None
784                } else {
785                    domain_intersection(Domain::Range { min, max }, normalized_complement)
786                }
787            }
788        },
789        (Domain::Complement(a_inner), Domain::Complement(b_inner)) => {
790            match (a_inner.as_ref(), b_inner.as_ref()) {
791                (Domain::Enumeration(a_ex), Domain::Enumeration(b_ex)) => {
792                    // not(A) ∩ not(B) == not(A ∪ B)
793                    let mut excluded: Vec<LiteralValue> = a_ex.iter().cloned().collect();
794                    excluded.extend(b_ex.iter().cloned());
795                    Some(normalize_domain(Domain::Complement(Box::new(
796                        Domain::Enumeration(Arc::new(excluded)),
797                    ))))
798                }
799                _ => None,
800            }
801        }
802    };
803    result.map(normalize_domain)
804}
805
806fn range_minus_excluded_points(
807    min: Bound,
808    max: Bound,
809    excluded: &Arc<Vec<LiteralValue>>,
810) -> Option<Domain> {
811    // Start with a single range and iteratively split on excluded points that fall within it.
812    let mut parts: Vec<(Bound, Bound)> = vec![(min, max)];
813
814    for p in excluded.iter() {
815        let mut next: Vec<(Bound, Bound)> = Vec::new();
816
817        for (rmin, rmax) in parts {
818            if !value_within(p, &rmin, &rmax) {
819                next.push((rmin, rmax));
820                continue;
821            }
822
823            // Left part: [rmin, p) or [rmin, p] depending on rmin and exclusion
824            let left_max = Bound::Exclusive(Arc::new(p.clone()));
825            if !bounds_contradict(&rmin, &left_max) {
826                next.push((rmin.clone(), left_max));
827            }
828
829            // Right part: (p, rmax)
830            let right_min = Bound::Exclusive(Arc::new(p.clone()));
831            if !bounds_contradict(&right_min, &rmax) {
832                next.push((right_min, rmax.clone()));
833            }
834        }
835
836        parts = next;
837        if parts.is_empty() {
838            return None;
839        }
840    }
841
842    if parts.is_empty() {
843        None
844    } else if parts.len() == 1 {
845        let (min, max) = parts.remove(0);
846        Some(Domain::Range { min, max })
847    } else {
848        Some(Domain::Union(Arc::new(
849            parts
850                .into_iter()
851                .map(|(min, max)| Domain::Range { min, max })
852                .collect(),
853        )))
854    }
855}
856
857fn invert_bound(bound: Bound) -> Bound {
858    match bound {
859        Bound::Unbounded => Bound::Unbounded,
860        Bound::Inclusive(v) => Bound::Exclusive(v.clone()),
861        Bound::Exclusive(v) => Bound::Inclusive(v.clone()),
862    }
863}
864
865fn normalize_domain(d: Domain) -> Domain {
866    match d {
867        Domain::Complement(inner) => {
868            let normalized_inner = normalize_domain(*inner);
869            match normalized_inner {
870                Domain::Complement(double_inner) => *double_inner,
871                Domain::Range { min, max } => match (&min, &max) {
872                    (Bound::Unbounded, Bound::Unbounded) => Domain::Enumeration(Arc::new(vec![])),
873                    (Bound::Unbounded, max) => Domain::Range {
874                        min: invert_bound(max.clone()),
875                        max: Bound::Unbounded,
876                    },
877                    (min, Bound::Unbounded) => Domain::Range {
878                        min: Bound::Unbounded,
879                        max: invert_bound(min.clone()),
880                    },
881                    (min, max) => Domain::Union(Arc::new(vec![
882                        Domain::Range {
883                            min: Bound::Unbounded,
884                            max: invert_bound(min.clone()),
885                        },
886                        Domain::Range {
887                            min: invert_bound(max.clone()),
888                            max: Bound::Unbounded,
889                        },
890                    ])),
891                },
892                Domain::Enumeration(vals) => {
893                    if vals.len() == 1 {
894                        if let Some(lit) = vals.first() {
895                            if let ValueKind::Boolean(true) = &lit.value {
896                                return Domain::Enumeration(Arc::new(vec![
897                                    LiteralValue::from_bool(false),
898                                ]));
899                            }
900                            if let ValueKind::Boolean(false) = &lit.value {
901                                return Domain::Enumeration(Arc::new(vec![
902                                    LiteralValue::from_bool(true),
903                                ]));
904                            }
905                        }
906                    }
907                    Domain::Complement(Box::new(Domain::Enumeration(vals.clone())))
908                }
909                Domain::Unconstrained => Domain::Empty,
910                Domain::Empty => Domain::Unconstrained,
911                Domain::Union(parts) => Domain::Complement(Box::new(Domain::Union(parts.clone()))),
912            }
913        }
914        Domain::Empty => Domain::Empty,
915        Domain::Union(parts) => {
916            let mut flat: Vec<Domain> = Vec::new();
917            for p in parts.iter().cloned() {
918                let normalized = normalize_domain(p);
919                match normalized {
920                    Domain::Union(inner) => flat.extend(inner.iter().cloned()),
921                    Domain::Unconstrained => return Domain::Unconstrained,
922                    Domain::Enumeration(vals) if vals.is_empty() => {}
923                    other => flat.push(other),
924                }
925            }
926
927            let mut all_enum_values: Vec<LiteralValue> = Vec::new();
928            let mut ranges: Vec<Domain> = Vec::new();
929            let mut others: Vec<Domain> = Vec::new();
930
931            for domain in flat {
932                match domain {
933                    Domain::Enumeration(vals) => all_enum_values.extend(vals.iter().cloned()),
934                    Domain::Range { .. } => ranges.push(domain),
935                    other => others.push(other),
936                }
937            }
938
939            all_enum_values.sort_by(|a, b| match lit_cmp(a, b) {
940                -1 => Ordering::Less,
941                0 => Ordering::Equal,
942                _ => Ordering::Greater,
943            });
944            all_enum_values.dedup();
945
946            all_enum_values.retain(|v| {
947                !ranges.iter().any(|r| {
948                    if let Domain::Range { min, max } = r {
949                        value_within(v, min, max)
950                    } else {
951                        false
952                    }
953                })
954            });
955
956            let mut result: Vec<Domain> = Vec::new();
957            result.extend(ranges);
958            result = merge_ranges(result);
959
960            if !all_enum_values.is_empty() {
961                result.push(Domain::Enumeration(Arc::new(all_enum_values)));
962            }
963            result.extend(others);
964
965            result.sort_by(|a, b| match (a, b) {
966                (Domain::Range { .. }, Domain::Range { .. }) => Ordering::Equal,
967                (Domain::Range { .. }, _) => Ordering::Less,
968                (_, Domain::Range { .. }) => Ordering::Greater,
969                (Domain::Enumeration(_), Domain::Enumeration(_)) => Ordering::Equal,
970                (Domain::Enumeration(_), _) => Ordering::Less,
971                (_, Domain::Enumeration(_)) => Ordering::Greater,
972                _ => Ordering::Equal,
973            });
974
975            if result.is_empty() {
976                Domain::Enumeration(Arc::new(vec![]))
977            } else if result.len() == 1 {
978                result.remove(0)
979            } else {
980                Domain::Union(Arc::new(result))
981            }
982        }
983        Domain::Enumeration(values) => {
984            let mut sorted: Vec<LiteralValue> = values.iter().cloned().collect();
985            sorted.sort_by(|a, b| match lit_cmp(a, b) {
986                -1 => Ordering::Less,
987                0 => Ordering::Equal,
988                _ => Ordering::Greater,
989            });
990            sorted.dedup();
991            Domain::Enumeration(Arc::new(sorted))
992        }
993        other => other,
994    }
995}
996
997fn merge_ranges(domains: Vec<Domain>) -> Vec<Domain> {
998    let mut result = Vec::new();
999    let mut ranges: Vec<(Bound, Bound)> = Vec::new();
1000    let mut others = Vec::new();
1001
1002    for d in domains {
1003        match d {
1004            Domain::Range { min, max } => ranges.push((min, max)),
1005            other => others.push(other),
1006        }
1007    }
1008
1009    if ranges.is_empty() {
1010        return others;
1011    }
1012
1013    ranges.sort_by(|a, b| compare_bounds(&a.0, &b.0));
1014
1015    let mut merged: Vec<(Bound, Bound)> = Vec::new();
1016    let mut current = ranges[0].clone();
1017
1018    for next in ranges.iter().skip(1) {
1019        if ranges_adjacent_or_overlap(&current, next) {
1020            current = (
1021                min_bound(&current.0, &next.0),
1022                max_bound(&current.1, &next.1),
1023            );
1024        } else {
1025            merged.push(current);
1026            current = next.clone();
1027        }
1028    }
1029    merged.push(current);
1030
1031    for (min, max) in merged {
1032        result.push(Domain::Range { min, max });
1033    }
1034    result.extend(others);
1035
1036    result
1037}
1038
1039fn compare_bounds(a: &Bound, b: &Bound) -> Ordering {
1040    match (a, b) {
1041        (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal,
1042        (Bound::Unbounded, _) => Ordering::Less,
1043        (_, Bound::Unbounded) => Ordering::Greater,
1044        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1045        | (Bound::Exclusive(v1), Bound::Exclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1046            -1 => Ordering::Less,
1047            0 => Ordering::Equal,
1048            _ => Ordering::Greater,
1049        },
1050        (Bound::Inclusive(v1), Bound::Exclusive(v2))
1051        | (Bound::Exclusive(v1), Bound::Inclusive(v2)) => match lit_cmp(v1.as_ref(), v2.as_ref()) {
1052            -1 => Ordering::Less,
1053            0 => {
1054                if matches!(a, Bound::Inclusive(_)) {
1055                    Ordering::Less
1056                } else {
1057                    Ordering::Greater
1058                }
1059            }
1060            _ => Ordering::Greater,
1061        },
1062    }
1063}
1064
1065fn ranges_adjacent_or_overlap(r1: &(Bound, Bound), r2: &(Bound, Bound)) -> bool {
1066    match (&r1.1, &r2.0) {
1067        (Bound::Unbounded, _) | (_, Bound::Unbounded) => true,
1068        (Bound::Inclusive(v1), Bound::Inclusive(v2))
1069        | (Bound::Inclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1070        (Bound::Exclusive(v1), Bound::Inclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) >= 0,
1071        (Bound::Exclusive(v1), Bound::Exclusive(v2)) => lit_cmp(v1.as_ref(), v2.as_ref()) > 0,
1072    }
1073}
1074
1075fn min_bound(a: &Bound, b: &Bound) -> Bound {
1076    match (a, b) {
1077        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1078        _ => {
1079            if matches!(compare_bounds(a, b), Ordering::Less | Ordering::Equal) {
1080                a.clone()
1081            } else {
1082                b.clone()
1083            }
1084        }
1085    }
1086}
1087
1088fn max_bound(a: &Bound, b: &Bound) -> Bound {
1089    match (a, b) {
1090        (Bound::Unbounded, _) | (_, Bound::Unbounded) => Bound::Unbounded,
1091        _ => {
1092            if matches!(compare_bounds(a, b), Ordering::Greater) {
1093                a.clone()
1094            } else {
1095                b.clone()
1096            }
1097        }
1098    }
1099}
1100
1101#[cfg(test)]
1102mod tests {
1103    use super::*;
1104    fn num(n: i64) -> LiteralValue {
1105        LiteralValue::number(crate::computation::rational::RationalInteger::new(
1106            n as i128, 1,
1107        ))
1108    }
1109
1110    fn data(name: &str) -> DataPath {
1111        DataPath::new(vec![], name.to_string())
1112    }
1113
1114    #[test]
1115    fn test_normalize_double_complement() {
1116        let inner = Domain::Enumeration(Arc::new(vec![num(5)]));
1117        let double = Domain::Complement(Box::new(Domain::Complement(Box::new(inner.clone()))));
1118        let normalized = normalize_domain(double);
1119        assert_eq!(normalized, inner);
1120    }
1121
1122    #[test]
1123    fn test_normalize_union_absorbs_unconstrained() {
1124        let union = Domain::Union(Arc::new(vec![
1125            Domain::Range {
1126                min: Bound::Inclusive(Arc::new(num(0))),
1127                max: Bound::Inclusive(Arc::new(num(10))),
1128            },
1129            Domain::Unconstrained,
1130        ]));
1131        let normalized = normalize_domain(union);
1132        assert_eq!(normalized, Domain::Unconstrained);
1133    }
1134
1135    #[test]
1136    fn test_domain_display() {
1137        let range = Domain::Range {
1138            min: Bound::Inclusive(Arc::new(num(10))),
1139            max: Bound::Exclusive(Arc::new(num(20))),
1140        };
1141        assert_eq!(format!("{}", range), "[10, 20)");
1142
1143        let enumeration = Domain::Enumeration(Arc::new(vec![num(1), num(2), num(3)]));
1144        assert_eq!(format!("{}", enumeration), "{1, 2, 3}");
1145    }
1146
1147    #[test]
1148    fn test_extract_domain_from_comparison() {
1149        let constraint = Constraint::Comparison {
1150            data: data("age"),
1151            op: ComparisonComputation::GreaterThan,
1152            value: Arc::new(num(18)),
1153        };
1154
1155        let domains = extract_domains_from_constraint(&constraint).unwrap();
1156        let age_domain = domains.get(&data("age")).unwrap();
1157
1158        assert_eq!(
1159            *age_domain,
1160            Domain::Range {
1161                min: Bound::Exclusive(Arc::new(num(18))),
1162                max: Bound::Unbounded,
1163            }
1164        );
1165    }
1166
1167    #[test]
1168    fn test_extract_domain_from_and() {
1169        let constraint = Constraint::And(
1170            Box::new(Constraint::Comparison {
1171                data: data("age"),
1172                op: ComparisonComputation::GreaterThan,
1173                value: Arc::new(num(18)),
1174            }),
1175            Box::new(Constraint::Comparison {
1176                data: data("age"),
1177                op: ComparisonComputation::LessThan,
1178                value: Arc::new(num(65)),
1179            }),
1180        );
1181
1182        let domains = extract_domains_from_constraint(&constraint).unwrap();
1183        let age_domain = domains.get(&data("age")).unwrap();
1184
1185        assert_eq!(
1186            *age_domain,
1187            Domain::Range {
1188                min: Bound::Exclusive(Arc::new(num(18))),
1189                max: Bound::Exclusive(Arc::new(num(65))),
1190            }
1191        );
1192    }
1193
1194    #[test]
1195    fn test_extract_domain_from_equality() {
1196        let constraint = Constraint::Comparison {
1197            data: data("status"),
1198            op: ComparisonComputation::Is,
1199            value: Arc::new(LiteralValue::text("active".to_string())),
1200        };
1201
1202        let domains = extract_domains_from_constraint(&constraint).unwrap();
1203        let status_domain = domains.get(&data("status")).unwrap();
1204
1205        assert_eq!(
1206            *status_domain,
1207            Domain::Enumeration(Arc::new(vec![LiteralValue::text("active".to_string())]))
1208        );
1209    }
1210
1211    #[test]
1212    fn test_extract_domain_from_boolean_data() {
1213        let constraint = Constraint::Data(data("is_active"));
1214
1215        let domains = extract_domains_from_constraint(&constraint).unwrap();
1216        let is_active_domain = domains.get(&data("is_active")).unwrap();
1217
1218        assert_eq!(
1219            *is_active_domain,
1220            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(true)]))
1221        );
1222    }
1223
1224    #[test]
1225    fn test_extract_domain_from_not_boolean_data() {
1226        let constraint = Constraint::Not(Box::new(Constraint::Data(data("is_active"))));
1227
1228        let domains = extract_domains_from_constraint(&constraint).unwrap();
1229        let is_active_domain = domains.get(&data("is_active")).unwrap();
1230
1231        assert_eq!(
1232            *is_active_domain,
1233            Domain::Enumeration(Arc::new(vec![LiteralValue::from_bool(false)]))
1234        );
1235    }
1236
1237    /// Phase 0 — `lit_cmp` for two `Quantity` values whose decompositions match
1238    /// but whose `lemma_type` markers differ (e.g. anonymous compound result vs named type)
1239    /// must compare via `signature_factor`, NOT panic.
1240    ///
1241    /// Today the arm at domain.rs:534-538 panics with `BUG: lit_cmp compared different
1242    /// quantity types`. After implementation (inversion_lit_cmp_signature_path) the same
1243    /// call must succeed and return the correct ordering.
1244    ///
1245    /// Construct both values via the parsing/planning pipeline so the test exercises the
1246    /// real LemmaType identities.
1247    #[test]
1248    fn inversion_lit_cmp_uses_signature_factor_for_compatible_dimensions() {
1249        use crate::engine::Engine;
1250        use crate::parsing::source::SourceType;
1251        use std::collections::HashMap;
1252        use std::path::PathBuf;
1253        use std::sync::Arc as StdArc;
1254
1255        // Build two values that are dimensionally compatible but produced via different code paths.
1256        // 80 eur*hour/minute (compound, anonymous when stored on a compound-signature lemma_type)
1257        // vs 1000 eur (named money). Today both flow through quantity arithmetic, but lit_cmp
1258        // crashes on lemma_type mismatch.
1259        let code = r#"spec lit_cmp_test
1260uses lemma units
1261data money: quantity
1262  -> unit eur 1
1263data rate: quantity
1264  -> unit eur_per_minute eur/minute
1265data r: 40 eur_per_minute
1266data h: 2 hour
1267rule compound_cost: (r * h) as eur
1268rule fixed_cost: 1000 eur
1269"#;
1270        let mut engine = Engine::new();
1271        engine
1272            .load(
1273                code,
1274                SourceType::Path(StdArc::new(PathBuf::from("t.lemma"))),
1275            )
1276            .expect("must load");
1277        let response = engine
1278            .run(None, "lit_cmp_test", None, HashMap::new(), true)
1279            .expect("must eval");
1280        let compound = response
1281            .results
1282            .get("compound_cost")
1283            .unwrap()
1284            .trace
1285            .as_ref()
1286            .expect("trace")
1287            .result
1288            .value()
1289            .unwrap()
1290            .clone();
1291        let fixed = response
1292            .results
1293            .get("fixed_cost")
1294            .unwrap()
1295            .trace
1296            .as_ref()
1297            .expect("trace")
1298            .result
1299            .value()
1300            .unwrap()
1301            .clone();
1302        // 80 eur*hour/minute = 4800 eur > 1000 eur
1303        assert!(
1304            lit_cmp(&compound, &fixed) > 0,
1305            "compound (4800 eur) must be greater than fixed (1000 eur); got {}",
1306            lit_cmp(&compound, &fixed)
1307        );
1308    }
1309}