Skip to main content

oxiz_solver/
nelson_oppen.rs

1//! Nelson-Oppen Theory Combination.
2#![allow(dead_code, clippy::result_unit_err)] // Under development
3//!
4//! Implements the Nelson-Oppen framework for combining decision procedures
5//! of disjoint theories through equality sharing.
6
7use oxiz_core::ast::{TermId, TermKind, TermManager};
8use rustc_hash::{FxHashMap, FxHashSet};
9use std::collections::VecDeque;
10
11/// Nelson-Oppen theory combination engine.
12pub struct NelsonOppenCombiner {
13    /// Shared terms between theories
14    shared_terms: FxHashSet<TermId>,
15    /// Equality classes for shared terms
16    equality_classes: UnionFind,
17    /// Pending equalities to propagate
18    pending_equalities: VecDeque<(TermId, TermId)>,
19    /// Theory assignments for shared terms
20    theory_assignments: FxHashMap<TermId, TheoryId>,
21    /// Statistics
22    stats: NelsonOppenStats,
23}
24
25/// Theory identifier
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub struct TheoryId(pub usize);
28
29/// Nelson-Oppen statistics
30#[derive(Debug, Clone, Default)]
31pub struct NelsonOppenStats {
32    /// Number of shared terms
33    pub shared_terms_count: usize,
34    /// Number of equalities propagated
35    pub equalities_propagated: usize,
36    /// Number of theory conflicts detected
37    pub theory_conflicts: usize,
38    /// Number of purification steps
39    pub purifications: usize,
40}
41
42impl NelsonOppenCombiner {
43    /// Create a new Nelson-Oppen combiner.
44    pub fn new() -> Self {
45        Self {
46            shared_terms: FxHashSet::default(),
47            equality_classes: UnionFind::new(),
48            pending_equalities: VecDeque::new(),
49            theory_assignments: FxHashMap::default(),
50            stats: NelsonOppenStats::default(),
51        }
52    }
53
54    /// Register a shared term between theories.
55    pub fn register_shared_term(&mut self, term_id: TermId, theory1: TheoryId, _theory2: TheoryId) {
56        self.shared_terms.insert(term_id);
57        self.theory_assignments.insert(term_id, theory1);
58        self.equality_classes.make_set(term_id);
59        self.stats.shared_terms_count += 1;
60    }
61
62    /// Assert an equality between shared terms.
63    ///
64    /// Returns Ok(()) if consistent, Err(()) if conflict detected.
65    pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), ()> {
66        if !self.shared_terms.contains(&lhs) || !self.shared_terms.contains(&rhs) {
67            return Err(()); // Only shared terms can be equated
68        }
69
70        // Check if already in same equivalence class
71        if self.equality_classes.find(lhs) == self.equality_classes.find(rhs) {
72            return Ok(());
73        }
74
75        // Merge equivalence classes
76        self.equality_classes.union(lhs, rhs);
77        self.pending_equalities.push_back((lhs, rhs));
78        self.stats.equalities_propagated += 1;
79
80        Ok(())
81    }
82
83    /// Purify a term by introducing fresh variables for sub-terms.
84    ///
85    /// Purification ensures each theory sees only its own symbols.
86    pub fn purify_term(&mut self, term_id: TermId, tm: &mut TermManager) -> Result<TermId, String> {
87        self.stats.purifications += 1;
88
89        // Recursively purify sub-terms
90        let term = tm.get(term_id).ok_or("term not found")?.clone();
91
92        match &term.kind {
93            TermKind::Apply { func: _, args } => {
94                let mut purified_args = Vec::new();
95
96                for &arg in args {
97                    let purified_arg = self.purify_term(arg, tm)?;
98                    purified_args.push(purified_arg);
99                }
100
101                // Check if any argument changed theory
102                let needs_purification = purified_args
103                    .iter()
104                    .enumerate()
105                    .any(|(i, &purified)| self.get_theory(purified) != self.get_theory(args[i]));
106
107                if needs_purification {
108                    // Introduce fresh variable for this sub-term
109                    // TODO: TermManager doesn't have fresh_var - needs implementation
110                    // let fresh_var = tm.fresh_var(term.sort);
111                    // For now, just return the original term
112                    // self.register_shared_term(fresh_var, TheoryId(0), TheoryId(1));
113
114                    // Assert equality: fresh_var = (func purified_args)
115                    // TODO: mk_apply expects SmallVec, not Vec
116                    // let purified_app = tm.mk_apply(*func, purified_args.into());
117                    // self.assert_equality(fresh_var, purified_app)?;
118
119                    // Ok(fresh_var)
120                    Ok(term_id) // Placeholder
121                } else {
122                    Ok(term_id)
123                }
124            }
125            _ => Ok(term_id),
126        }
127    }
128
129    /// Get pending equalities to propagate to theories.
130    pub fn get_pending_equalities(&mut self) -> Vec<(TermId, TermId)> {
131        let mut result = Vec::new();
132        while let Some(eq) = self.pending_equalities.pop_front() {
133            result.push(eq);
134        }
135        result
136    }
137
138    /// Check if two terms are in the same equivalence class.
139    pub fn are_equal(&self, lhs: TermId, rhs: TermId) -> bool {
140        self.equality_classes.find(lhs) == self.equality_classes.find(rhs)
141    }
142
143    /// Get all terms in the equivalence class of a term.
144    pub fn get_equivalence_class(&self, term_id: TermId) -> Vec<TermId> {
145        let rep = self.equality_classes.find(term_id);
146        self.shared_terms
147            .iter()
148            .filter(|&&t| self.equality_classes.find(t) == rep)
149            .copied()
150            .collect()
151    }
152
153    /// Get theory assignment for a term.
154    fn get_theory(&self, term_id: TermId) -> Option<TheoryId> {
155        self.theory_assignments.get(&term_id).copied()
156    }
157
158    /// Convexity closure: generate implied equalities.
159    ///
160    /// For convex theories, if we have equalities in each class,
161    /// we must propagate all pairwise equalities.
162    pub fn convexity_closure(&mut self) -> Vec<(TermId, TermId)> {
163        let mut implied_equalities = Vec::new();
164
165        // Group terms by equivalence class
166        let mut classes: FxHashMap<TermId, Vec<TermId>> = FxHashMap::default();
167        for &term in &self.shared_terms {
168            let rep = self.equality_classes.find(term);
169            classes.entry(rep).or_default().push(term);
170        }
171
172        // For each equivalence class with multiple elements
173        for (_rep, terms) in classes {
174            if terms.len() > 1 {
175                // Generate all pairwise equalities
176                for i in 0..terms.len() {
177                    for j in (i + 1)..terms.len() {
178                        implied_equalities.push((terms[i], terms[j]));
179                    }
180                }
181            }
182        }
183
184        implied_equalities
185    }
186
187    /// Get statistics.
188    pub fn stats(&self) -> &NelsonOppenStats {
189        &self.stats
190    }
191
192    /// Reset for next SMT check.
193    pub fn reset(&mut self) {
194        self.shared_terms.clear();
195        self.equality_classes = UnionFind::new();
196        self.pending_equalities.clear();
197        self.theory_assignments.clear();
198        self.stats = NelsonOppenStats::default();
199    }
200}
201
202impl Default for NelsonOppenCombiner {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208/// Union-Find data structure for equivalence classes.
209#[derive(Debug, Clone)]
210struct UnionFind {
211    parent: FxHashMap<TermId, TermId>,
212    rank: FxHashMap<TermId, usize>,
213}
214
215impl UnionFind {
216    fn new() -> Self {
217        Self {
218            parent: FxHashMap::default(),
219            rank: FxHashMap::default(),
220        }
221    }
222
223    fn make_set(&mut self, x: TermId) {
224        self.parent.insert(x, x);
225        self.rank.insert(x, 0);
226    }
227
228    fn find(&self, x: TermId) -> TermId {
229        let mut current = x;
230        while let Some(&parent) = self.parent.get(&current) {
231            if parent == current {
232                return current;
233            }
234            current = parent;
235        }
236        x // Not found, return itself
237    }
238
239    fn union(&mut self, x: TermId, y: TermId) {
240        let x_root = self.find(x);
241        let y_root = self.find(y);
242
243        if x_root == y_root {
244            return;
245        }
246
247        let x_rank = *self.rank.get(&x_root).unwrap_or(&0);
248        let y_rank = *self.rank.get(&y_root).unwrap_or(&0);
249
250        if x_rank < y_rank {
251            self.parent.insert(x_root, y_root);
252        } else if x_rank > y_rank {
253            self.parent.insert(y_root, x_root);
254        } else {
255            self.parent.insert(y_root, x_root);
256            self.rank.insert(x_root, x_rank + 1);
257        }
258    }
259}
260
261// Placeholder types (these would be defined elsewhere in the codebase)
262// Note: Using types from oxiz_core::ast instead
263// #[derive(Debug, Clone)]
264// struct Term {
265//     kind: TermKind,
266//     sort: SortId,
267// }
268//
269// #[derive(Debug, Clone)]
270// enum TermKind {
271//     Var(String),
272//     App(FuncId, Vec<TermId>),
273//     Const(ConstId),
274// }
275
276type SortId = usize;
277type FuncId = usize;
278type ConstId = usize;
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_nelson_oppen_creation() {
286        let combiner = NelsonOppenCombiner::new();
287        assert_eq!(combiner.stats.shared_terms_count, 0);
288    }
289
290    #[test]
291    fn test_register_shared_term() {
292        let mut combiner = NelsonOppenCombiner::new();
293        let term_id = TermId(0);
294
295        combiner.register_shared_term(term_id, TheoryId(0), TheoryId(1));
296
297        assert_eq!(combiner.stats.shared_terms_count, 1);
298        assert!(combiner.shared_terms.contains(&term_id));
299    }
300
301    #[test]
302    fn test_assert_equality() {
303        let mut combiner = NelsonOppenCombiner::new();
304        let t1 = TermId(0);
305        let t2 = TermId(1);
306
307        combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
308        combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
309
310        assert!(combiner.assert_equality(t1, t2).is_ok());
311        assert!(combiner.are_equal(t1, t2));
312        assert_eq!(combiner.stats.equalities_propagated, 1);
313    }
314
315    #[test]
316    fn test_convexity_closure() {
317        let mut combiner = NelsonOppenCombiner::new();
318        let t1 = TermId(0);
319        let t2 = TermId(1);
320        let t3 = TermId(2);
321
322        combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
323        combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
324        combiner.register_shared_term(t3, TheoryId(0), TheoryId(1));
325
326        combiner.assert_equality(t1, t2).unwrap();
327        combiner.assert_equality(t2, t3).unwrap();
328
329        let implied = combiner.convexity_closure();
330        assert!(!implied.is_empty());
331    }
332}