Skip to main content

oxiz_solver/
nelson_oppen.rs

1//! Nelson-Oppen Theory Combination.
2#![allow(dead_code, clippy::result_unit_err)] // Under development
3//!
4//! Implements the Nelson-Oppen framework for combining decision procedures
5//! of disjoint theories through equality sharing.
6
7#[allow(unused_imports)]
8use crate::prelude::*;
9use oxiz_core::ast::{TermId, TermKind, TermManager};
10
11/// Nelson-Oppen theory combination engine.
12pub struct NelsonOppenCombiner {
13    /// Shared terms between theories
14    shared_terms: FxHashSet<TermId>,
15    /// Equality classes for shared terms
16    equality_classes: UnionFind,
17    /// Pending equalities to propagate
18    pending_equalities: VecDeque<(TermId, TermId)>,
19    /// Already-propagated equalities (normalized so lhs <= rhs).
20    /// Prevents the fixed-point loop from re-discovering known equalities.
21    propagated_equalities: FxHashSet<(TermId, TermId)>,
22    /// Theory assignments for shared terms
23    theory_assignments: FxHashMap<TermId, TheoryId>,
24    /// Statistics
25    stats: NelsonOppenStats,
26    /// Counter for generating fresh variable names during purification
27    fresh_var_counter: u64,
28}
29
30/// Theory identifier
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub struct TheoryId(pub usize);
33
34/// Nelson-Oppen statistics
35#[derive(Debug, Clone, Default)]
36pub struct NelsonOppenStats {
37    /// Number of shared terms
38    pub shared_terms_count: usize,
39    /// Number of equalities propagated
40    pub equalities_propagated: usize,
41    /// Number of theory conflicts detected
42    pub theory_conflicts: usize,
43    /// Number of purification steps
44    pub purifications: usize,
45}
46
47impl NelsonOppenCombiner {
48    /// Create a new Nelson-Oppen combiner.
49    pub fn new() -> Self {
50        Self {
51            shared_terms: FxHashSet::default(),
52            equality_classes: UnionFind::new(),
53            pending_equalities: VecDeque::new(),
54            propagated_equalities: FxHashSet::default(),
55            theory_assignments: FxHashMap::default(),
56            stats: NelsonOppenStats::default(),
57            fresh_var_counter: 0,
58        }
59    }
60
61    /// Register a shared term between theories.
62    pub fn register_shared_term(&mut self, term_id: TermId, theory1: TheoryId, _theory2: TheoryId) {
63        self.shared_terms.insert(term_id);
64        self.theory_assignments.insert(term_id, theory1);
65        self.equality_classes.make_set(term_id);
66        self.stats.shared_terms_count += 1;
67    }
68
69    /// Normalize an equality pair so that the smaller TermId comes first.
70    /// This ensures (a,b) and (b,a) are treated as the same equality.
71    fn normalize_pair(lhs: TermId, rhs: TermId) -> (TermId, TermId) {
72        if lhs <= rhs { (lhs, rhs) } else { (rhs, lhs) }
73    }
74
75    /// Assert an equality between shared terms.
76    ///
77    /// Returns Ok(()) if consistent, Err(()) if conflict detected.
78    pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), ()> {
79        if !self.shared_terms.contains(&lhs) || !self.shared_terms.contains(&rhs) {
80            return Err(()); // Only shared terms can be equated
81        }
82
83        // Normalize and check if this equality was already propagated
84        let key = Self::normalize_pair(lhs, rhs);
85        if self.propagated_equalities.contains(&key) {
86            return Ok(());
87        }
88
89        // Check if already in same equivalence class
90        if self.equality_classes.find(lhs) == self.equality_classes.find(rhs) {
91            self.propagated_equalities.insert(key);
92            return Ok(());
93        }
94
95        // Merge equivalence classes
96        self.equality_classes.union(lhs, rhs);
97        self.pending_equalities.push_back((lhs, rhs));
98        self.propagated_equalities.insert(key);
99        self.stats.equalities_propagated += 1;
100
101        Ok(())
102    }
103
104    /// Generate a fresh variable name for purification.
105    fn fresh_var_name(&mut self) -> String {
106        let name = format!("_no_purify_{}", self.fresh_var_counter);
107        self.fresh_var_counter += 1;
108        name
109    }
110
111    /// Purify a term by introducing fresh variables for sub-terms.
112    ///
113    /// Purification ensures each theory sees only its own symbols.
114    /// When a subterm belongs to a different theory than the parent application,
115    /// it is replaced by a fresh shared variable, and an equality constraint
116    /// is recorded between the fresh variable and the original subterm.
117    pub fn purify_term(&mut self, term_id: TermId, tm: &mut TermManager) -> Result<TermId, String> {
118        self.stats.purifications += 1;
119
120        // Recursively purify sub-terms
121        let term = tm.get(term_id).ok_or("term not found")?.clone();
122
123        match &term.kind {
124            TermKind::Apply { func, args } => {
125                let func_spur = *func;
126                let original_args: Vec<TermId> = args.iter().copied().collect();
127                let mut purified_args = Vec::new();
128
129                for &arg in &original_args {
130                    let purified_arg = self.purify_term(arg, tm)?;
131                    purified_args.push(purified_arg);
132                }
133
134                // Check if any argument changed theory
135                let needs_purification = purified_args.iter().enumerate().any(|(i, &purified)| {
136                    self.get_theory(purified) != self.get_theory(original_args[i])
137                });
138
139                if needs_purification {
140                    // Create a fresh variable with the same sort as this term
141                    let sort = term.sort;
142                    let fresh_name = self.fresh_var_name();
143                    let fresh_var = tm.mk_var(&fresh_name, sort);
144
145                    // Register the fresh variable as shared between the relevant theories
146                    self.register_shared_term(fresh_var, TheoryId(0), TheoryId(1));
147
148                    // Build the purified application term using the func spur
149                    // mk_apply expects &str but we have a Spur. Use the original term's
150                    // function name from the interner.
151                    let func_name = tm.resolve_str(func_spur).to_string();
152                    let purified_app = tm.mk_apply(&func_name, purified_args, sort);
153
154                    // Record equality: fresh_var = purified_app
155                    // This equality will be propagated through pending_equalities
156                    let _ = self.assert_equality(fresh_var, purified_app);
157
158                    Ok(fresh_var)
159                } else {
160                    Ok(term_id)
161                }
162            }
163            _ => Ok(term_id),
164        }
165    }
166
167    /// Get pending equalities to propagate to theories.
168    pub fn get_pending_equalities(&mut self) -> Vec<(TermId, TermId)> {
169        let mut result = Vec::new();
170        while let Some(eq) = self.pending_equalities.pop_front() {
171            result.push(eq);
172        }
173        result
174    }
175
176    /// Check if two terms are in the same equivalence class.
177    pub fn are_equal(&self, lhs: TermId, rhs: TermId) -> bool {
178        self.equality_classes.find(lhs) == self.equality_classes.find(rhs)
179    }
180
181    /// Get all terms in the equivalence class of a term.
182    pub fn get_equivalence_class(&self, term_id: TermId) -> Vec<TermId> {
183        let rep = self.equality_classes.find(term_id);
184        self.shared_terms
185            .iter()
186            .filter(|&&t| self.equality_classes.find(t) == rep)
187            .copied()
188            .collect()
189    }
190
191    /// Get theory assignment for a term.
192    fn get_theory(&self, term_id: TermId) -> Option<TheoryId> {
193        self.theory_assignments.get(&term_id).copied()
194    }
195
196    /// Convexity closure: generate implied equalities.
197    ///
198    /// For convex theories, if we have equalities in each class,
199    /// we must propagate all pairwise equalities.
200    /// Only returns equalities that have NOT already been propagated.
201    pub fn convexity_closure(&mut self) -> Vec<(TermId, TermId)> {
202        let mut implied_equalities = Vec::new();
203
204        // Group terms by equivalence class
205        let mut classes: FxHashMap<TermId, Vec<TermId>> = FxHashMap::default();
206        for &term in &self.shared_terms {
207            let rep = self.equality_classes.find(term);
208            classes.entry(rep).or_default().push(term);
209        }
210
211        // For each equivalence class with multiple elements
212        for (_rep, terms) in classes {
213            if terms.len() > 1 {
214                // Generate all pairwise equalities, skipping already-propagated ones
215                for i in 0..terms.len() {
216                    for j in (i + 1)..terms.len() {
217                        let key = Self::normalize_pair(terms[i], terms[j]);
218                        if !self.propagated_equalities.contains(&key) {
219                            implied_equalities.push((terms[i], terms[j]));
220                        }
221                    }
222                }
223            }
224        }
225
226        implied_equalities
227    }
228
229    /// Get statistics.
230    pub fn stats(&self) -> &NelsonOppenStats {
231        &self.stats
232    }
233
234    /// Reset for next SMT check.
235    pub fn reset(&mut self) {
236        self.shared_terms.clear();
237        self.equality_classes = UnionFind::new();
238        self.pending_equalities.clear();
239        self.propagated_equalities.clear();
240        self.theory_assignments.clear();
241        self.stats = NelsonOppenStats::default();
242        self.fresh_var_counter = 0;
243    }
244}
245
246impl Default for NelsonOppenCombiner {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252/// Union-Find data structure for equivalence classes.
253#[derive(Debug, Clone)]
254struct UnionFind {
255    parent: FxHashMap<TermId, TermId>,
256    rank: FxHashMap<TermId, usize>,
257}
258
259impl UnionFind {
260    fn new() -> Self {
261        Self {
262            parent: FxHashMap::default(),
263            rank: FxHashMap::default(),
264        }
265    }
266
267    fn make_set(&mut self, x: TermId) {
268        self.parent.insert(x, x);
269        self.rank.insert(x, 0);
270    }
271
272    fn find(&self, x: TermId) -> TermId {
273        let mut current = x;
274        while let Some(&parent) = self.parent.get(&current) {
275            if parent == current {
276                return current;
277            }
278            current = parent;
279        }
280        x // Not found, return itself
281    }
282
283    fn union(&mut self, x: TermId, y: TermId) {
284        let x_root = self.find(x);
285        let y_root = self.find(y);
286
287        if x_root == y_root {
288            return;
289        }
290
291        let x_rank = *self.rank.get(&x_root).unwrap_or(&0);
292        let y_rank = *self.rank.get(&y_root).unwrap_or(&0);
293
294        if x_rank < y_rank {
295            self.parent.insert(x_root, y_root);
296        } else if x_rank > y_rank {
297            self.parent.insert(y_root, x_root);
298        } else {
299            self.parent.insert(y_root, x_root);
300            self.rank.insert(x_root, x_rank + 1);
301        }
302    }
303}
304
305// Placeholder types (these would be defined elsewhere in the codebase)
306// Note: Using types from oxiz_core::ast instead
307// #[derive(Debug, Clone)]
308// struct Term {
309//     kind: TermKind,
310//     sort: SortId,
311// }
312//
313// #[derive(Debug, Clone)]
314// enum TermKind {
315//     Var(String),
316//     App(FuncId, Vec<TermId>),
317//     Const(ConstId),
318// }
319
320type SortId = usize;
321type FuncId = usize;
322type ConstId = usize;
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn test_nelson_oppen_creation() {
330        let combiner = NelsonOppenCombiner::new();
331        assert_eq!(combiner.stats.shared_terms_count, 0);
332    }
333
334    #[test]
335    fn test_register_shared_term() {
336        let mut combiner = NelsonOppenCombiner::new();
337        let term_id = TermId(0);
338
339        combiner.register_shared_term(term_id, TheoryId(0), TheoryId(1));
340
341        assert_eq!(combiner.stats.shared_terms_count, 1);
342        assert!(combiner.shared_terms.contains(&term_id));
343    }
344
345    #[test]
346    fn test_assert_equality() {
347        let mut combiner = NelsonOppenCombiner::new();
348        let t1 = TermId(0);
349        let t2 = TermId(1);
350
351        combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
352        combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
353
354        assert!(combiner.assert_equality(t1, t2).is_ok());
355        assert!(combiner.are_equal(t1, t2));
356        assert_eq!(combiner.stats.equalities_propagated, 1);
357    }
358
359    #[test]
360    fn test_convexity_closure() {
361        let mut combiner = NelsonOppenCombiner::new();
362        let t1 = TermId(0);
363        let t2 = TermId(1);
364        let t3 = TermId(2);
365
366        combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
367        combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
368        combiner.register_shared_term(t3, TheoryId(0), TheoryId(1));
369
370        combiner
371            .assert_equality(t1, t2)
372            .expect("test operation should succeed");
373        combiner
374            .assert_equality(t2, t3)
375            .expect("test operation should succeed");
376
377        let implied = combiner.convexity_closure();
378        assert!(!implied.is_empty());
379    }
380}