1#![allow(dead_code, clippy::result_unit_err)] use oxiz_core::ast::{TermId, TermKind, TermManager};
8use rustc_hash::{FxHashMap, FxHashSet};
9use std::collections::VecDeque;
10
11pub struct NelsonOppenCombiner {
13 shared_terms: FxHashSet<TermId>,
15 equality_classes: UnionFind,
17 pending_equalities: VecDeque<(TermId, TermId)>,
19 theory_assignments: FxHashMap<TermId, TheoryId>,
21 stats: NelsonOppenStats,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub struct TheoryId(pub usize);
28
29#[derive(Debug, Clone, Default)]
31pub struct NelsonOppenStats {
32 pub shared_terms_count: usize,
34 pub equalities_propagated: usize,
36 pub theory_conflicts: usize,
38 pub purifications: usize,
40}
41
42impl NelsonOppenCombiner {
43 pub fn new() -> Self {
45 Self {
46 shared_terms: FxHashSet::default(),
47 equality_classes: UnionFind::new(),
48 pending_equalities: VecDeque::new(),
49 theory_assignments: FxHashMap::default(),
50 stats: NelsonOppenStats::default(),
51 }
52 }
53
54 pub fn register_shared_term(&mut self, term_id: TermId, theory1: TheoryId, _theory2: TheoryId) {
56 self.shared_terms.insert(term_id);
57 self.theory_assignments.insert(term_id, theory1);
58 self.equality_classes.make_set(term_id);
59 self.stats.shared_terms_count += 1;
60 }
61
62 pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), ()> {
66 if !self.shared_terms.contains(&lhs) || !self.shared_terms.contains(&rhs) {
67 return Err(()); }
69
70 if self.equality_classes.find(lhs) == self.equality_classes.find(rhs) {
72 return Ok(());
73 }
74
75 self.equality_classes.union(lhs, rhs);
77 self.pending_equalities.push_back((lhs, rhs));
78 self.stats.equalities_propagated += 1;
79
80 Ok(())
81 }
82
83 pub fn purify_term(&mut self, term_id: TermId, tm: &mut TermManager) -> Result<TermId, String> {
87 self.stats.purifications += 1;
88
89 let term = tm.get(term_id).ok_or("term not found")?.clone();
91
92 match &term.kind {
93 TermKind::Apply { func: _, args } => {
94 let mut purified_args = Vec::new();
95
96 for &arg in args {
97 let purified_arg = self.purify_term(arg, tm)?;
98 purified_args.push(purified_arg);
99 }
100
101 let needs_purification = purified_args
103 .iter()
104 .enumerate()
105 .any(|(i, &purified)| self.get_theory(purified) != self.get_theory(args[i]));
106
107 if needs_purification {
108 Ok(term_id) } else {
122 Ok(term_id)
123 }
124 }
125 _ => Ok(term_id),
126 }
127 }
128
129 pub fn get_pending_equalities(&mut self) -> Vec<(TermId, TermId)> {
131 let mut result = Vec::new();
132 while let Some(eq) = self.pending_equalities.pop_front() {
133 result.push(eq);
134 }
135 result
136 }
137
138 pub fn are_equal(&self, lhs: TermId, rhs: TermId) -> bool {
140 self.equality_classes.find(lhs) == self.equality_classes.find(rhs)
141 }
142
143 pub fn get_equivalence_class(&self, term_id: TermId) -> Vec<TermId> {
145 let rep = self.equality_classes.find(term_id);
146 self.shared_terms
147 .iter()
148 .filter(|&&t| self.equality_classes.find(t) == rep)
149 .copied()
150 .collect()
151 }
152
153 fn get_theory(&self, term_id: TermId) -> Option<TheoryId> {
155 self.theory_assignments.get(&term_id).copied()
156 }
157
158 pub fn convexity_closure(&mut self) -> Vec<(TermId, TermId)> {
163 let mut implied_equalities = Vec::new();
164
165 let mut classes: FxHashMap<TermId, Vec<TermId>> = FxHashMap::default();
167 for &term in &self.shared_terms {
168 let rep = self.equality_classes.find(term);
169 classes.entry(rep).or_default().push(term);
170 }
171
172 for (_rep, terms) in classes {
174 if terms.len() > 1 {
175 for i in 0..terms.len() {
177 for j in (i + 1)..terms.len() {
178 implied_equalities.push((terms[i], terms[j]));
179 }
180 }
181 }
182 }
183
184 implied_equalities
185 }
186
187 pub fn stats(&self) -> &NelsonOppenStats {
189 &self.stats
190 }
191
192 pub fn reset(&mut self) {
194 self.shared_terms.clear();
195 self.equality_classes = UnionFind::new();
196 self.pending_equalities.clear();
197 self.theory_assignments.clear();
198 self.stats = NelsonOppenStats::default();
199 }
200}
201
202impl Default for NelsonOppenCombiner {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[derive(Debug, Clone)]
210struct UnionFind {
211 parent: FxHashMap<TermId, TermId>,
212 rank: FxHashMap<TermId, usize>,
213}
214
215impl UnionFind {
216 fn new() -> Self {
217 Self {
218 parent: FxHashMap::default(),
219 rank: FxHashMap::default(),
220 }
221 }
222
223 fn make_set(&mut self, x: TermId) {
224 self.parent.insert(x, x);
225 self.rank.insert(x, 0);
226 }
227
228 fn find(&self, x: TermId) -> TermId {
229 let mut current = x;
230 while let Some(&parent) = self.parent.get(¤t) {
231 if parent == current {
232 return current;
233 }
234 current = parent;
235 }
236 x }
238
239 fn union(&mut self, x: TermId, y: TermId) {
240 let x_root = self.find(x);
241 let y_root = self.find(y);
242
243 if x_root == y_root {
244 return;
245 }
246
247 let x_rank = *self.rank.get(&x_root).unwrap_or(&0);
248 let y_rank = *self.rank.get(&y_root).unwrap_or(&0);
249
250 if x_rank < y_rank {
251 self.parent.insert(x_root, y_root);
252 } else if x_rank > y_rank {
253 self.parent.insert(y_root, x_root);
254 } else {
255 self.parent.insert(y_root, x_root);
256 self.rank.insert(x_root, x_rank + 1);
257 }
258 }
259}
260
261type SortId = usize;
277type FuncId = usize;
278type ConstId = usize;
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_nelson_oppen_creation() {
286 let combiner = NelsonOppenCombiner::new();
287 assert_eq!(combiner.stats.shared_terms_count, 0);
288 }
289
290 #[test]
291 fn test_register_shared_term() {
292 let mut combiner = NelsonOppenCombiner::new();
293 let term_id = TermId(0);
294
295 combiner.register_shared_term(term_id, TheoryId(0), TheoryId(1));
296
297 assert_eq!(combiner.stats.shared_terms_count, 1);
298 assert!(combiner.shared_terms.contains(&term_id));
299 }
300
301 #[test]
302 fn test_assert_equality() {
303 let mut combiner = NelsonOppenCombiner::new();
304 let t1 = TermId(0);
305 let t2 = TermId(1);
306
307 combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
308 combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
309
310 assert!(combiner.assert_equality(t1, t2).is_ok());
311 assert!(combiner.are_equal(t1, t2));
312 assert_eq!(combiner.stats.equalities_propagated, 1);
313 }
314
315 #[test]
316 fn test_convexity_closure() {
317 let mut combiner = NelsonOppenCombiner::new();
318 let t1 = TermId(0);
319 let t2 = TermId(1);
320 let t3 = TermId(2);
321
322 combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
323 combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
324 combiner.register_shared_term(t3, TheoryId(0), TheoryId(1));
325
326 combiner.assert_equality(t1, t2).unwrap();
327 combiner.assert_equality(t2, t3).unwrap();
328
329 let implied = combiner.convexity_closure();
330 assert!(!implied.is_empty());
331 }
332}