Skip to main content

oxiz_theories/set/
solver.rs

1//! Set Constraint Solver
2//!
3//! Core implementation of the set theory solver using:
4//! - BDD-based set representation for symbolic sets
5//! - Constraint propagation for membership and subset
6//! - Conflict-driven reasoning for unsatisfiability
7
8#![allow(missing_docs)]
9
10use super::{
11    CardConstraint, CardConstraintKind, CardPropagator, MemberConstraint, MemberPropagator,
12    SetConflict, SetLiteral, SetProofStep, SetSort, SubsetConstraint, SubsetPropagator,
13};
14use crate::theory::{
15    EqualityNotification, Theory, TheoryCombination, TheoryId, TheoryResult as TR,
16};
17use oxiz_core::ast::TermId;
18use oxiz_core::error::Result;
19use rustc_hash::{FxHashMap, FxHashSet};
20use smallvec::SmallVec;
21use std::collections::VecDeque;
22
23/// Set variable identifier
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
25pub struct SetVarId(pub u32);
26
27impl SetVarId {
28    /// Create a new set variable ID
29    pub fn new(id: u32) -> Self {
30        Self(id)
31    }
32
33    /// Get the underlying ID
34    pub fn id(&self) -> u32 {
35        self.0
36    }
37}
38
39/// Set variable representation
40#[derive(Debug, Clone)]
41pub struct SetVar {
42    /// Variable ID
43    pub id: SetVarId,
44    /// Variable name (for debugging)
45    pub name: String,
46    /// Sort of this set
47    pub sort: SetSort,
48    /// Current domain: known members
49    pub must_members: FxHashSet<u32>,
50    /// Current domain: known non-members
51    pub must_not_members: FxHashSet<u32>,
52    /// Possible members (for finite domain reasoning)
53    pub may_members: Option<FxHashSet<u32>>,
54    /// Cardinality bounds [lower, upper]
55    pub card_bounds: (Option<i64>, Option<i64>),
56    /// Is this set known to be empty?
57    pub is_empty: bool,
58    /// Is this set the universal set?
59    pub is_universal: bool,
60    /// Decision level when this variable was created
61    pub level: usize,
62}
63
64impl SetVar {
65    /// Create a new set variable
66    pub fn new(id: SetVarId, name: String, sort: SetSort, level: usize) -> Self {
67        Self {
68            id,
69            name,
70            sort,
71            must_members: FxHashSet::default(),
72            must_not_members: FxHashSet::default(),
73            may_members: None,
74            card_bounds: (Some(0), None),
75            is_empty: false,
76            is_universal: false,
77            level,
78        }
79    }
80
81    /// Add a must-member element
82    pub fn add_must_member(&mut self, elem: u32) -> bool {
83        if self.must_not_members.contains(&elem) {
84            return false; // Conflict
85        }
86        self.must_members.insert(elem);
87        true // Success (either newly inserted or already there)
88    }
89
90    /// Add a must-not-member element
91    pub fn add_must_not_member(&mut self, elem: u32) -> bool {
92        if self.must_members.contains(&elem) {
93            return false; // Conflict
94        }
95        self.must_not_members.insert(elem);
96        true // Success (either newly inserted or already there)
97    }
98
99    /// Check if element is definitely in the set
100    pub fn contains(&self, elem: u32) -> Option<bool> {
101        if self.must_members.contains(&elem) {
102            Some(true)
103        } else if self.must_not_members.contains(&elem) {
104            Some(false)
105        } else {
106            None
107        }
108    }
109
110    /// Get current cardinality bounds
111    pub fn cardinality_bounds(&self) -> (i64, Option<i64>) {
112        let lower = self.must_members.len() as i64;
113        let upper = if let Some(may) = &self.may_members {
114            Some(may.len() as i64)
115        } else {
116            self.card_bounds.1
117        };
118        (
119            lower.max(self.card_bounds.0.unwrap_or(0)),
120            match (upper, self.card_bounds.1) {
121                (Some(a), Some(b)) => Some(a.min(b)),
122                (Some(a), None) => Some(a),
123                (None, b) => b,
124            },
125        )
126    }
127
128    /// Check if cardinality is determined
129    pub fn cardinality_determined(&self) -> Option<i64> {
130        let (lower, upper) = self.cardinality_bounds();
131        if let Some(u) = upper
132            && lower == u
133        {
134            return Some(lower);
135        }
136        None
137    }
138
139    /// Mark set as empty
140    pub fn set_empty(&mut self) -> bool {
141        if !self.must_members.is_empty() {
142            return false; // Conflict
143        }
144        self.is_empty = true;
145        self.card_bounds.1 = Some(0);
146        true
147    }
148
149    /// Check if set is definitely empty
150    pub fn is_definitely_empty(&self) -> bool {
151        self.is_empty
152            || self.card_bounds.1 == Some(0)
153            || (self.may_members.as_ref().is_some_and(|m| m.is_empty()))
154    }
155
156    /// Tighten upper cardinality bound
157    pub fn tighten_upper_card(&mut self, bound: i64) -> bool {
158        match self.card_bounds.1 {
159            Some(current) if bound >= current => true,
160            _ => {
161                if bound < self.must_members.len() as i64 {
162                    return false; // Conflict
163                }
164                // Check conflict with lower bound
165                if let Some(lower) = self.card_bounds.0
166                    && bound < lower
167                {
168                    return false; // Conflict: upper < lower
169                }
170                self.card_bounds.1 = Some(bound);
171                true
172            }
173        }
174    }
175
176    /// Tighten lower cardinality bound
177    pub fn tighten_lower_card(&mut self, bound: i64) -> bool {
178        match self.card_bounds.0 {
179            Some(current) if bound <= current => true,
180            _ => {
181                if let Some(upper) = self.card_bounds.1
182                    && bound > upper
183                {
184                    return false; // Conflict
185                }
186                self.card_bounds.0 = Some(bound);
187                true
188            }
189        }
190    }
191}
192
193/// Set expression for constraints
194#[derive(Debug, Clone)]
195pub enum SetExpr {
196    /// Variable reference
197    Var(SetVarId),
198    /// Empty set
199    Empty,
200    /// Universal set
201    Universal,
202    /// Singleton set {elem}
203    Singleton(u32),
204    /// Union S1 ∪ S2
205    Union(Box<SetExpr>, Box<SetExpr>),
206    /// Intersection S1 ∩ S2
207    Intersection(Box<SetExpr>, Box<SetExpr>),
208    /// Difference S1 \ S2
209    Difference(Box<SetExpr>, Box<SetExpr>),
210    /// Complement ¬S
211    Complement(Box<SetExpr>),
212    /// Set comprehension {x | φ(x)}
213    Comprehension { var: u32, formula: Box<SetExpr> },
214}
215
216impl SetExpr {
217    /// Create a union expression
218    pub fn union(left: SetExpr, right: SetExpr) -> Self {
219        SetExpr::Union(Box::new(left), Box::new(right))
220    }
221
222    /// Create an intersection expression
223    pub fn intersection(left: SetExpr, right: SetExpr) -> Self {
224        SetExpr::Intersection(Box::new(left), Box::new(right))
225    }
226
227    /// Create a difference expression
228    pub fn difference(left: SetExpr, right: SetExpr) -> Self {
229        SetExpr::Difference(Box::new(left), Box::new(right))
230    }
231
232    /// Create a complement expression
233    pub fn complement(expr: SetExpr) -> Self {
234        SetExpr::Complement(Box::new(expr))
235    }
236
237    /// Get all set variables referenced in this expression
238    pub fn get_vars(&self) -> FxHashSet<SetVarId> {
239        let mut vars = FxHashSet::default();
240        self.collect_vars(&mut vars);
241        vars
242    }
243
244    fn collect_vars(&self, vars: &mut FxHashSet<SetVarId>) {
245        match self {
246            SetExpr::Var(v) => {
247                vars.insert(*v);
248            }
249            SetExpr::Union(l, r) | SetExpr::Intersection(l, r) | SetExpr::Difference(l, r) => {
250                l.collect_vars(vars);
251                r.collect_vars(vars);
252            }
253            SetExpr::Complement(e) => e.collect_vars(vars),
254            SetExpr::Comprehension { formula, .. } => formula.collect_vars(vars),
255            _ => {}
256        }
257    }
258}
259
260/// Set constraint
261#[derive(Debug, Clone)]
262pub enum SetConstraint {
263    /// x ∈ S
264    Member {
265        element: u32,
266        set: SetExpr,
267        sign: bool,
268    },
269    /// S1 ⊆ S2
270    Subset {
271        lhs: SetExpr,
272        rhs: SetExpr,
273        sign: bool,
274    },
275    /// S1 = S2
276    Equal { lhs: SetExpr, rhs: SetExpr },
277    /// |S| op k
278    Cardinality {
279        set: SetExpr,
280        op: CardConstraintKind,
281        bound: i64,
282    },
283    /// S1 ∩ S2 = ∅ (disjoint)
284    Disjoint { lhs: SetExpr, rhs: SetExpr },
285}
286
287/// Set solver configuration
288#[derive(Debug, Clone)]
289pub struct SetConfig {
290    /// Enable aggressive propagation
291    pub aggressive_propagation: bool,
292    /// Maximum cardinality for finite domain reasoning
293    pub max_finite_card: Option<usize>,
294    /// Enable BDD-based set representation
295    pub use_bdd: bool,
296    /// Conflict clause minimization
297    pub minimize_conflicts: bool,
298    /// Enable subset closure computation
299    pub compute_subset_closure: bool,
300}
301
302impl Default for SetConfig {
303    fn default() -> Self {
304        Self {
305            aggressive_propagation: true,
306            max_finite_card: Some(1000),
307            use_bdd: true,
308            minimize_conflicts: true,
309            compute_subset_closure: true,
310        }
311    }
312}
313
314/// Set solver statistics
315#[derive(Debug, Clone, Default)]
316pub struct SetStats {
317    /// Number of set variables
318    pub num_vars: usize,
319    /// Number of constraints
320    pub num_constraints: usize,
321    /// Number of membership constraints
322    pub num_member_constraints: usize,
323    /// Number of subset constraints
324    pub num_subset_constraints: usize,
325    /// Number of cardinality constraints
326    pub num_card_constraints: usize,
327    /// Number of propagations
328    pub num_propagations: usize,
329    /// Number of conflicts
330    pub num_conflicts: usize,
331    /// Number of backtracks
332    pub num_backtracks: usize,
333}
334
335/// Set solver result
336pub type SetResult<T> = std::result::Result<T, SetConflict>;
337
338/// Set solver state for push/pop
339#[derive(Debug, Clone)]
340struct SolverState {
341    num_vars: usize,
342    num_constraints: usize,
343    num_member_constraints: usize,
344    num_subset_constraints: usize,
345    num_card_constraints: usize,
346}
347
348/// Main set theory solver
349pub struct SetSolver {
350    /// Configuration
351    #[allow(dead_code)]
352    config: SetConfig,
353    /// Set variables
354    vars: Vec<SetVar>,
355    /// Variable name to ID mapping
356    var_names: FxHashMap<String, SetVarId>,
357    /// Membership constraints
358    member_constraints: Vec<MemberConstraint>,
359    /// Subset constraints
360    subset_constraints: Vec<SubsetConstraint>,
361    /// Cardinality constraints
362    card_constraints: Vec<CardConstraint>,
363    /// General constraints
364    constraints: Vec<SetConstraint>,
365    /// Propagation queue
366    propagation_queue: VecDeque<SetVarId>,
367    /// Membership propagator
368    member_prop: MemberPropagator,
369    /// Subset propagator
370    subset_prop: SubsetPropagator,
371    /// Cardinality propagator
372    card_prop: CardPropagator,
373    /// Current decision level
374    level: usize,
375    /// Trail of assignments (for backtracking)
376    trail: Vec<TrailEntry>,
377    /// Decision level boundaries in trail
378    level_boundaries: Vec<usize>,
379    /// Statistics
380    stats: SetStats,
381    /// Context stack for push/pop
382    context_stack: Vec<SolverState>,
383    /// Conflict clause (if UNSAT)
384    conflict: Option<SetConflict>,
385    /// Term to set variable mapping (for theory integration)
386    term_to_var: FxHashMap<TermId, SetVarId>,
387    /// Set variable to term mapping
388    var_to_term: FxHashMap<SetVarId, TermId>,
389}
390
391/// Trail entry for backtracking
392#[derive(Debug, Clone)]
393enum TrailEntry {
394    /// Variable assignment
395    VarAssign {
396        var: SetVarId,
397        snapshot: SetVarSnapshot,
398    },
399    /// Decision level marker
400    #[allow(dead_code)]
401    DecisionLevel(usize),
402}
403
404/// Snapshot of a set variable for backtracking
405#[derive(Debug, Clone)]
406struct SetVarSnapshot {
407    must_members: FxHashSet<u32>,
408    must_not_members: FxHashSet<u32>,
409    may_members: Option<FxHashSet<u32>>,
410    card_bounds: (Option<i64>, Option<i64>),
411    is_empty: bool,
412    is_universal: bool,
413}
414
415impl From<&SetVar> for SetVarSnapshot {
416    fn from(var: &SetVar) -> Self {
417        Self {
418            must_members: var.must_members.clone(),
419            must_not_members: var.must_not_members.clone(),
420            may_members: var.may_members.clone(),
421            card_bounds: var.card_bounds,
422            is_empty: var.is_empty,
423            is_universal: var.is_universal,
424        }
425    }
426}
427
428impl SetSolver {
429    /// Create a new set solver
430    pub fn new() -> Self {
431        Self::with_config(SetConfig::default())
432    }
433
434    /// Create a new set solver with configuration
435    pub fn with_config(config: SetConfig) -> Self {
436        Self {
437            config,
438            vars: Vec::new(),
439            var_names: FxHashMap::default(),
440            member_constraints: Vec::new(),
441            subset_constraints: Vec::new(),
442            card_constraints: Vec::new(),
443            constraints: Vec::new(),
444            propagation_queue: VecDeque::new(),
445            member_prop: MemberPropagator::new(),
446            subset_prop: SubsetPropagator::new(),
447            card_prop: CardPropagator::new(),
448            level: 0,
449            trail: Vec::new(),
450            level_boundaries: Vec::new(),
451            stats: SetStats::default(),
452            context_stack: Vec::new(),
453            conflict: None,
454            term_to_var: FxHashMap::default(),
455            var_to_term: FxHashMap::default(),
456        }
457    }
458
459    /// Create a new set variable
460    pub fn new_set_var(&mut self, name: &str, sort: SetSort) -> SetVarId {
461        let id = SetVarId(self.vars.len() as u32);
462        let var = SetVar::new(id, name.to_string(), sort, self.level);
463        self.vars.push(var);
464        self.var_names.insert(name.to_string(), id);
465        self.stats.num_vars += 1;
466        id
467    }
468
469    /// Get a variable by ID
470    pub fn get_var(&self, id: SetVarId) -> Option<&SetVar> {
471        self.vars.get(id.0 as usize)
472    }
473
474    /// Get a mutable variable by ID
475    pub fn get_var_mut(&mut self, id: SetVarId) -> Option<&mut SetVar> {
476        self.vars.get_mut(id.0 as usize)
477    }
478
479    /// Get a variable by name
480    pub fn get_var_by_name(&self, name: &str) -> Option<&SetVar> {
481        self.var_names.get(name).and_then(|id| self.get_var(*id))
482    }
483
484    /// Add a constraint
485    pub fn add_constraint(&mut self, constraint: SetConstraint) -> SetResult<()> {
486        self.stats.num_constraints += 1;
487
488        match &constraint {
489            SetConstraint::Member { element, set, sign } => {
490                self.add_member_constraint(*element, set, *sign)?;
491            }
492            SetConstraint::Subset { lhs, rhs, sign } => {
493                self.add_subset_constraint(lhs, rhs, *sign)?;
494            }
495            SetConstraint::Equal { lhs, rhs } => {
496                self.add_equal_constraint(lhs, rhs)?;
497            }
498            SetConstraint::Cardinality { set, op, bound } => {
499                self.add_cardinality_constraint(set, *op, *bound)?;
500            }
501            SetConstraint::Disjoint { lhs, rhs } => {
502                self.add_disjoint_constraint(lhs, rhs)?;
503            }
504        }
505
506        self.constraints.push(constraint);
507        Ok(())
508    }
509
510    /// Add a membership constraint: elem ∈ set or elem ∉ set
511    fn add_member_constraint(&mut self, element: u32, set: &SetExpr, sign: bool) -> SetResult<()> {
512        self.stats.num_member_constraints += 1;
513
514        // Extract the set variable
515        let set_var = match set {
516            SetExpr::Var(v) => *v,
517            _ => {
518                // For complex expressions, create an auxiliary variable
519                let aux_var =
520                    self.new_set_var(&format!("aux_member_{}", self.vars.len()), SetSort::IntSet);
521                self.add_equal_constraint(&SetExpr::Var(aux_var), set)?;
522                aux_var
523            }
524        };
525
526        // Save snapshot
527        if let Some(var) = self.get_var(set_var) {
528            self.trail.push(TrailEntry::VarAssign {
529                var: set_var,
530                snapshot: SetVarSnapshot::from(var),
531            });
532        }
533
534        // Apply the constraint
535        if let Some(var) = self.get_var_mut(set_var) {
536            let success = if sign {
537                var.add_must_member(element)
538            } else {
539                var.add_must_not_member(element)
540            };
541
542            if !success {
543                return Err(SetConflict {
544                    literals: vec![SetLiteral::Member {
545                        element,
546                        set: set_var,
547                        sign,
548                    }],
549                    reason: format!(
550                        "Conflict: element {} is both in and not in set {}",
551                        element, var.name
552                    ),
553                    proof_steps: vec![SetProofStep::Assume(SetLiteral::Member {
554                        element,
555                        set: set_var,
556                        sign,
557                    })],
558                });
559            }
560
561            // Queue for propagation
562            self.propagation_queue.push_back(set_var);
563        }
564
565        Ok(())
566    }
567
568    /// Add a subset constraint: lhs ⊆ rhs or lhs ⊈ rhs
569    fn add_subset_constraint(&mut self, lhs: &SetExpr, rhs: &SetExpr, sign: bool) -> SetResult<()> {
570        self.stats.num_subset_constraints += 1;
571
572        // Extract variables
573        let lhs_var = self.extract_var(lhs)?;
574        let rhs_var = self.extract_var(rhs)?;
575
576        if sign {
577            // lhs ⊆ rhs: all elements in lhs must be in rhs
578            self.propagate_subset(lhs_var, rhs_var)?;
579        } else {
580            // lhs ⊈ rhs: there exists an element in lhs but not in rhs
581            // This is handled lazily during conflict analysis
582        }
583
584        Ok(())
585    }
586
587    /// Add an equality constraint: lhs = rhs
588    fn add_equal_constraint(&mut self, lhs: &SetExpr, rhs: &SetExpr) -> SetResult<()> {
589        // S1 = S2 is equivalent to S1 ⊆ S2 ∧ S2 ⊆ S1
590        self.add_subset_constraint(lhs, rhs, true)?;
591        self.add_subset_constraint(rhs, lhs, true)?;
592        Ok(())
593    }
594
595    /// Add a cardinality constraint: |set| op bound
596    fn add_cardinality_constraint(
597        &mut self,
598        set: &SetExpr,
599        op: CardConstraintKind,
600        bound: i64,
601    ) -> SetResult<()> {
602        self.stats.num_card_constraints += 1;
603
604        let set_var = self.extract_var(set)?;
605
606        // Save snapshot
607        if let Some(var) = self.get_var(set_var) {
608            self.trail.push(TrailEntry::VarAssign {
609                var: set_var,
610                snapshot: SetVarSnapshot::from(var),
611            });
612        }
613
614        // Apply cardinality bounds
615        if let Some(var) = self.get_var_mut(set_var) {
616            let success = match op {
617                CardConstraintKind::Equal => {
618                    var.tighten_lower_card(bound) && var.tighten_upper_card(bound)
619                }
620                CardConstraintKind::Le => var.tighten_upper_card(bound),
621                CardConstraintKind::Lt => var.tighten_upper_card(bound - 1),
622                CardConstraintKind::Ge => var.tighten_lower_card(bound),
623                CardConstraintKind::Gt => var.tighten_lower_card(bound + 1),
624            };
625
626            if !success {
627                return Err(SetConflict {
628                    literals: vec![SetLiteral::Cardinality {
629                        set: set_var,
630                        op,
631                        bound,
632                    }],
633                    reason: format!(
634                        "Conflict: cardinality constraint |{}| {:?} {} is unsatisfiable",
635                        var.name, op, bound
636                    ),
637                    proof_steps: vec![SetProofStep::Assume(SetLiteral::Cardinality {
638                        set: set_var,
639                        op,
640                        bound,
641                    })],
642                });
643            }
644
645            self.propagation_queue.push_back(set_var);
646        }
647
648        Ok(())
649    }
650
651    /// Add a disjoint constraint: lhs ∩ rhs = ∅
652    fn add_disjoint_constraint(&mut self, lhs: &SetExpr, rhs: &SetExpr) -> SetResult<()> {
653        let lhs_var = self.extract_var(lhs)?;
654        let rhs_var = self.extract_var(rhs)?;
655
656        // Propagate: if x ∈ lhs, then x ∉ rhs - collect members first to avoid borrow checker issues
657        let lhs_members: Vec<u32> = self
658            .get_var(lhs_var)
659            .map(|s| s.must_members.iter().copied().collect())
660            .unwrap_or_default();
661
662        for elem in lhs_members {
663            self.add_member_constraint(elem, &SetExpr::Var(rhs_var), false)?;
664        }
665
666        // Similarly for rhs
667        let rhs_members: Vec<u32> = self
668            .get_var(rhs_var)
669            .map(|s| s.must_members.iter().copied().collect())
670            .unwrap_or_default();
671
672        for elem in rhs_members {
673            self.add_member_constraint(elem, &SetExpr::Var(lhs_var), false)?;
674        }
675
676        Ok(())
677    }
678
679    /// Extract a set variable from an expression (creating auxiliary if needed)
680    fn extract_var(&mut self, expr: &SetExpr) -> SetResult<SetVarId> {
681        match expr {
682            SetExpr::Var(v) => Ok(*v),
683            SetExpr::Empty => {
684                let var = self.new_set_var(&format!("empty_{}", self.vars.len()), SetSort::IntSet);
685                if let Some(v) = self.get_var_mut(var) {
686                    v.set_empty();
687                }
688                Ok(var)
689            }
690            _ => {
691                // Create auxiliary variable for complex expressions
692                let var = self.new_set_var(&format!("aux_{}", self.vars.len()), SetSort::IntSet);
693                // TODO: Add constraints to define the auxiliary variable
694                Ok(var)
695            }
696        }
697    }
698
699    /// Propagate subset constraint: lhs ⊆ rhs
700    fn propagate_subset(&mut self, lhs: SetVarId, rhs: SetVarId) -> SetResult<()> {
701        // Get members that must be in lhs
702        let lhs_must_members: SmallVec<[u32; 16]> = if let Some(lhs_var) = self.get_var(lhs) {
703            lhs_var.must_members.iter().copied().collect()
704        } else {
705            return Ok(());
706        };
707
708        // They must all be in rhs
709        for elem in lhs_must_members {
710            if let Some(rhs_var) = self.get_var(rhs)
711                && rhs_var.must_not_members.contains(&elem)
712            {
713                return Err(SetConflict {
714                    literals: vec![
715                        SetLiteral::Subset {
716                            lhs,
717                            rhs,
718                            sign: true,
719                        },
720                        SetLiteral::Member {
721                            element: elem,
722                            set: lhs,
723                            sign: true,
724                        },
725                        SetLiteral::Member {
726                            element: elem,
727                            set: rhs,
728                            sign: false,
729                        },
730                    ],
731                    reason: format!(
732                        "Conflict: element {} is in lhs but not in rhs for subset constraint",
733                        elem
734                    ),
735                    proof_steps: vec![SetProofStep::SubsetProp {
736                        from: lhs,
737                        mid: lhs,
738                        to: rhs,
739                    }],
740                });
741            }
742
743            self.add_member_constraint(elem, &SetExpr::Var(rhs), true)?;
744        }
745
746        // Get elements that must not be in rhs
747        let rhs_must_not_members: SmallVec<[u32; 16]> = if let Some(rhs_var) = self.get_var(rhs) {
748            rhs_var.must_not_members.iter().copied().collect()
749        } else {
750            return Ok(());
751        };
752
753        // They must all not be in lhs
754        for elem in rhs_must_not_members {
755            self.add_member_constraint(elem, &SetExpr::Var(lhs), false)?;
756        }
757
758        Ok(())
759    }
760
761    /// Run constraint propagation
762    pub fn propagate(&mut self) -> SetResult<()> {
763        while let Some(var_id) = self.propagation_queue.pop_front() {
764            self.stats.num_propagations += 1;
765
766            // Propagate membership
767            self.member_prop.propagate(var_id, &mut self.vars)?;
768
769            // Propagate subset
770            self.subset_prop
771                .propagate(var_id, &mut self.vars, &self.subset_constraints)?;
772
773            // Propagate cardinality
774            self.card_prop
775                .propagate(var_id, &mut self.vars, &self.card_constraints)?;
776
777            // Check for conflicts
778            if let Some(var) = self.get_var(var_id) {
779                // Check cardinality conflict
780                let (lower, upper) = var.cardinality_bounds();
781                let var_name = var.name.clone();
782                let is_empty = var.is_definitely_empty();
783                let has_must_members = !var.must_members.is_empty();
784
785                if let Some(u) = upper
786                    && lower > u
787                {
788                    self.stats.num_conflicts += 1;
789                    return Err(SetConflict {
790                        literals: vec![],
791                        reason: format!(
792                            "Cardinality conflict: |{}| must be in [{}, {}] which is empty",
793                            var_name, lower, u
794                        ),
795                        proof_steps: vec![SetProofStep::CardConflict {
796                            set: var_id,
797                            lower,
798                            upper: u,
799                        }],
800                    });
801                }
802
803                // Check empty set conflict
804                if is_empty && has_must_members {
805                    self.stats.num_conflicts += 1;
806                    return Err(SetConflict {
807                        literals: vec![],
808                        reason: format!("Empty set conflict: {} cannot be empty", var_name),
809                        proof_steps: vec![SetProofStep::EmptyConflict { set: var_id }],
810                    });
811                }
812            }
813        }
814
815        Ok(())
816    }
817
818    /// Check satisfiability
819    pub fn check(&mut self) -> SetResult<bool> {
820        // Run propagation
821        self.propagate()?;
822
823        // Check if all variables are determined
824        let all_determined = self
825            .vars
826            .iter()
827            .all(|v| v.cardinality_determined().is_some());
828
829        Ok(all_determined)
830    }
831
832    /// Push a new decision level
833    pub fn push(&mut self) {
834        let state = SolverState {
835            num_vars: self.vars.len(),
836            num_constraints: self.constraints.len(),
837            num_member_constraints: self.member_constraints.len(),
838            num_subset_constraints: self.subset_constraints.len(),
839            num_card_constraints: self.card_constraints.len(),
840        };
841        self.context_stack.push(state);
842        self.level += 1;
843        self.level_boundaries.push(self.trail.len());
844        self.trail.push(TrailEntry::DecisionLevel(self.level));
845    }
846
847    /// Pop a decision level
848    pub fn pop(&mut self) {
849        if let Some(state) = self.context_stack.pop() {
850            self.stats.num_backtracks += 1;
851            self.level = self.level.saturating_sub(1);
852
853            // Restore state
854            self.vars.truncate(state.num_vars);
855            self.constraints.truncate(state.num_constraints);
856            self.member_constraints
857                .truncate(state.num_member_constraints);
858            self.subset_constraints
859                .truncate(state.num_subset_constraints);
860            self.card_constraints.truncate(state.num_card_constraints);
861
862            // Restore trail
863            if let Some(boundary) = self.level_boundaries.pop() {
864                while self.trail.len() > boundary {
865                    if let Some(entry) = self.trail.pop()
866                        && let TrailEntry::VarAssign { var, snapshot } = entry
867                        && let Some(v) = self.get_var_mut(var)
868                    {
869                        v.must_members = snapshot.must_members;
870                        v.must_not_members = snapshot.must_not_members;
871                        v.may_members = snapshot.may_members;
872                        v.card_bounds = snapshot.card_bounds;
873                        v.is_empty = snapshot.is_empty;
874                        v.is_universal = snapshot.is_universal;
875                    }
876                }
877            }
878
879            // Clear propagation queue
880            self.propagation_queue.clear();
881        }
882    }
883
884    /// Get statistics
885    pub fn stats(&self) -> &SetStats {
886        &self.stats
887    }
888
889    /// Reset the solver
890    pub fn reset(&mut self) {
891        self.vars.clear();
892        self.var_names.clear();
893        self.member_constraints.clear();
894        self.subset_constraints.clear();
895        self.card_constraints.clear();
896        self.constraints.clear();
897        self.propagation_queue.clear();
898        self.level = 0;
899        self.trail.clear();
900        self.level_boundaries.clear();
901        self.stats = SetStats::default();
902        self.context_stack.clear();
903        self.conflict = None;
904        self.term_to_var.clear();
905        self.var_to_term.clear();
906    }
907
908    /// Register a term-to-variable mapping
909    pub fn register_term(&mut self, term: TermId, var: SetVarId) {
910        self.term_to_var.insert(term, var);
911        self.var_to_term.insert(var, term);
912    }
913
914    /// Get the set variable for a term
915    pub fn get_var_for_term(&self, term: TermId) -> Option<SetVarId> {
916        self.term_to_var.get(&term).copied()
917    }
918
919    /// Get model for a variable
920    pub fn get_model(&self, var: SetVarId) -> Option<FxHashSet<u32>> {
921        self.get_var(var).map(|v| v.must_members.clone())
922    }
923}
924
925impl Default for SetSolver {
926    fn default() -> Self {
927        Self::new()
928    }
929}
930
931impl Theory for SetSolver {
932    fn id(&self) -> TheoryId {
933        TheoryId::Bool // Use Bool for now, should add SetTheory variant
934    }
935
936    fn name(&self) -> &str {
937        "Set Theory"
938    }
939
940    fn can_handle(&self, _term: TermId) -> bool {
941        // TODO: Implement proper term type checking
942        true
943    }
944
945    fn assert_true(&mut self, _term: TermId) -> Result<TR> {
946        // TODO: Convert term to set constraint and add it
947        // For now, just push the term
948        self.push();
949        Ok(TR::Sat)
950    }
951
952    fn assert_false(&mut self, _term: TermId) -> Result<TR> {
953        // TODO: Convert term to negated set constraint and add it
954        self.push();
955        Ok(TR::Sat)
956    }
957
958    fn check(&mut self) -> Result<TR> {
959        match self.check() {
960            Ok(true) => Ok(TR::Sat),
961            Ok(false) => Ok(TR::Unknown),
962            Err(conflict) => {
963                self.conflict = Some(conflict.clone());
964                Ok(TR::Unsat(vec![]))
965            }
966        }
967    }
968
969    fn push(&mut self) {
970        self.push();
971    }
972
973    fn pop(&mut self) {
974        self.pop();
975    }
976
977    fn reset(&mut self) {
978        self.reset();
979    }
980
981    fn get_model(&self) -> Vec<(TermId, TermId)> {
982        // TODO: Convert set model to term pairs
983        Vec::new()
984    }
985}
986
987impl TheoryCombination for SetSolver {
988    fn notify_equality(&mut self, _eq: EqualityNotification) -> bool {
989        // TODO: Handle equality notifications from other theories
990        false
991    }
992
993    fn get_shared_equalities(&self) -> Vec<EqualityNotification> {
994        // TODO: Export shared equalities
995        Vec::new()
996    }
997
998    fn is_relevant(&self, term: TermId) -> bool {
999        self.term_to_var.contains_key(&term)
1000    }
1001}
1002
1003#[cfg(test)]
1004mod tests {
1005    use super::*;
1006
1007    #[test]
1008    fn test_set_var_creation() {
1009        let mut solver = SetSolver::new();
1010        let s1 = solver.new_set_var("S1", SetSort::IntSet);
1011        let s2 = solver.new_set_var("S2", SetSort::IntSet);
1012
1013        assert_eq!(s1.id(), 0);
1014        assert_eq!(s2.id(), 1);
1015        assert_eq!(solver.stats.num_vars, 2);
1016    }
1017
1018    #[test]
1019    fn test_membership_constraint() {
1020        let mut solver = SetSolver::new();
1021        let s = solver.new_set_var("S", SetSort::IntSet);
1022
1023        // Assert: 42 ∈ S
1024        let result = solver.add_member_constraint(42, &SetExpr::Var(s), true);
1025        assert!(result.is_ok());
1026
1027        // Verify the element is in must_members
1028        let var = solver.get_var(s).unwrap();
1029        assert!(var.must_members.contains(&42));
1030    }
1031
1032    #[test]
1033    fn test_membership_conflict() {
1034        let mut solver = SetSolver::new();
1035        let s = solver.new_set_var("S", SetSort::IntSet);
1036
1037        // Assert: 42 ∈ S
1038        solver
1039            .add_member_constraint(42, &SetExpr::Var(s), true)
1040            .unwrap();
1041
1042        // Assert: 42 ∉ S (conflict)
1043        let result = solver.add_member_constraint(42, &SetExpr::Var(s), false);
1044        assert!(result.is_err());
1045    }
1046
1047    #[test]
1048    fn test_cardinality_bounds() {
1049        let mut solver = SetSolver::new();
1050        let s = solver.new_set_var("S", SetSort::IntSet);
1051
1052        // Assert: |S| ≤ 5
1053        solver
1054            .add_cardinality_constraint(&SetExpr::Var(s), CardConstraintKind::Le, 5)
1055            .unwrap();
1056
1057        let var = solver.get_var(s).unwrap();
1058        let (lower, upper) = var.cardinality_bounds();
1059        assert_eq!(lower, 0);
1060        assert_eq!(upper, Some(5));
1061    }
1062
1063    #[test]
1064    fn test_cardinality_conflict() {
1065        let mut solver = SetSolver::new();
1066        let s = solver.new_set_var("S", SetSort::IntSet);
1067
1068        // Assert: |S| ≥ 10
1069        solver
1070            .add_cardinality_constraint(&SetExpr::Var(s), CardConstraintKind::Ge, 10)
1071            .unwrap();
1072
1073        // Assert: |S| ≤ 5 (conflict)
1074        let result = solver.add_cardinality_constraint(&SetExpr::Var(s), CardConstraintKind::Le, 5);
1075        assert!(result.is_err());
1076    }
1077
1078    #[test]
1079    fn test_subset_propagation() {
1080        let mut solver = SetSolver::new();
1081        let s1 = solver.new_set_var("S1", SetSort::IntSet);
1082        let s2 = solver.new_set_var("S2", SetSort::IntSet);
1083
1084        // Assert: 42 ∈ S1
1085        solver
1086            .add_member_constraint(42, &SetExpr::Var(s1), true)
1087            .unwrap();
1088
1089        // Assert: S1 ⊆ S2
1090        solver
1091            .add_subset_constraint(&SetExpr::Var(s1), &SetExpr::Var(s2), true)
1092            .unwrap();
1093
1094        // Propagate: 42 should now be in S2
1095        solver.propagate().unwrap();
1096
1097        let var2 = solver.get_var(s2).unwrap();
1098        assert!(var2.must_members.contains(&42));
1099    }
1100
1101    #[test]
1102    fn test_empty_set() {
1103        let mut solver = SetSolver::new();
1104        let s = solver.new_set_var("S", SetSort::IntSet);
1105
1106        // Assert: |S| = 0
1107        solver
1108            .add_cardinality_constraint(&SetExpr::Var(s), CardConstraintKind::Equal, 0)
1109            .unwrap();
1110
1111        let var = solver.get_var(s).unwrap();
1112        assert!(var.is_definitely_empty());
1113    }
1114
1115    #[test]
1116    fn test_disjoint_sets() {
1117        let mut solver = SetSolver::new();
1118        let s1 = solver.new_set_var("S1", SetSort::IntSet);
1119        let s2 = solver.new_set_var("S2", SetSort::IntSet);
1120
1121        // Assert: 42 ∈ S1
1122        solver
1123            .add_member_constraint(42, &SetExpr::Var(s1), true)
1124            .unwrap();
1125
1126        // Assert: S1 ∩ S2 = ∅
1127        solver
1128            .add_disjoint_constraint(&SetExpr::Var(s1), &SetExpr::Var(s2))
1129            .unwrap();
1130
1131        // 42 should not be in S2
1132        let var2 = solver.get_var(s2).unwrap();
1133        assert!(var2.must_not_members.contains(&42));
1134    }
1135
1136    #[test]
1137    fn test_push_pop() {
1138        let mut solver = SetSolver::new();
1139        let s = solver.new_set_var("S", SetSort::IntSet);
1140
1141        solver.push();
1142
1143        // Assert: 42 ∈ S
1144        solver
1145            .add_member_constraint(42, &SetExpr::Var(s), true)
1146            .unwrap();
1147
1148        assert!(solver.get_var(s).unwrap().must_members.contains(&42));
1149
1150        solver.pop();
1151
1152        // After pop, 42 should not be in must_members
1153        assert!(!solver.get_var(s).unwrap().must_members.contains(&42));
1154    }
1155
1156    #[test]
1157    fn test_set_expr_vars() {
1158        let expr = SetExpr::union(
1159            SetExpr::Var(SetVarId(0)),
1160            SetExpr::intersection(SetExpr::Var(SetVarId(1)), SetExpr::Var(SetVarId(2))),
1161        );
1162
1163        let vars = expr.get_vars();
1164        assert_eq!(vars.len(), 3);
1165        assert!(vars.contains(&SetVarId(0)));
1166        assert!(vars.contains(&SetVarId(1)));
1167        assert!(vars.contains(&SetVarId(2)));
1168    }
1169}