1#![allow(dead_code, clippy::result_unit_err)] #[allow(unused_imports)]
8use crate::prelude::*;
9use oxiz_core::ast::{TermId, TermKind, TermManager};
10
11pub struct NelsonOppenCombiner {
13 shared_terms: FxHashSet<TermId>,
15 equality_classes: UnionFind,
17 pending_equalities: VecDeque<(TermId, TermId)>,
19 propagated_equalities: FxHashSet<(TermId, TermId)>,
22 theory_assignments: FxHashMap<TermId, TheoryId>,
24 stats: NelsonOppenStats,
26 fresh_var_counter: u64,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub struct TheoryId(pub usize);
33
34#[derive(Debug, Clone, Default)]
36pub struct NelsonOppenStats {
37 pub shared_terms_count: usize,
39 pub equalities_propagated: usize,
41 pub theory_conflicts: usize,
43 pub purifications: usize,
45}
46
47impl NelsonOppenCombiner {
48 pub fn new() -> Self {
50 Self {
51 shared_terms: FxHashSet::default(),
52 equality_classes: UnionFind::new(),
53 pending_equalities: VecDeque::new(),
54 propagated_equalities: FxHashSet::default(),
55 theory_assignments: FxHashMap::default(),
56 stats: NelsonOppenStats::default(),
57 fresh_var_counter: 0,
58 }
59 }
60
61 pub fn register_shared_term(&mut self, term_id: TermId, theory1: TheoryId, _theory2: TheoryId) {
63 self.shared_terms.insert(term_id);
64 self.theory_assignments.insert(term_id, theory1);
65 self.equality_classes.make_set(term_id);
66 self.stats.shared_terms_count += 1;
67 }
68
69 fn normalize_pair(lhs: TermId, rhs: TermId) -> (TermId, TermId) {
72 if lhs <= rhs { (lhs, rhs) } else { (rhs, lhs) }
73 }
74
75 pub fn assert_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<(), ()> {
79 if !self.shared_terms.contains(&lhs) || !self.shared_terms.contains(&rhs) {
80 return Err(()); }
82
83 let key = Self::normalize_pair(lhs, rhs);
85 if self.propagated_equalities.contains(&key) {
86 return Ok(());
87 }
88
89 if self.equality_classes.find(lhs) == self.equality_classes.find(rhs) {
91 self.propagated_equalities.insert(key);
92 return Ok(());
93 }
94
95 self.equality_classes.union(lhs, rhs);
97 self.pending_equalities.push_back((lhs, rhs));
98 self.propagated_equalities.insert(key);
99 self.stats.equalities_propagated += 1;
100
101 Ok(())
102 }
103
104 fn fresh_var_name(&mut self) -> String {
106 let name = format!("_no_purify_{}", self.fresh_var_counter);
107 self.fresh_var_counter += 1;
108 name
109 }
110
111 pub fn purify_term(&mut self, term_id: TermId, tm: &mut TermManager) -> Result<TermId, String> {
118 self.stats.purifications += 1;
119
120 let term = tm.get(term_id).ok_or("term not found")?.clone();
122
123 match &term.kind {
124 TermKind::Apply { func, args } => {
125 let func_spur = *func;
126 let original_args: Vec<TermId> = args.iter().copied().collect();
127 let mut purified_args = Vec::new();
128
129 for &arg in &original_args {
130 let purified_arg = self.purify_term(arg, tm)?;
131 purified_args.push(purified_arg);
132 }
133
134 let needs_purification = purified_args.iter().enumerate().any(|(i, &purified)| {
136 self.get_theory(purified) != self.get_theory(original_args[i])
137 });
138
139 if needs_purification {
140 let sort = term.sort;
142 let fresh_name = self.fresh_var_name();
143 let fresh_var = tm.mk_var(&fresh_name, sort);
144
145 self.register_shared_term(fresh_var, TheoryId(0), TheoryId(1));
147
148 let func_name = tm.resolve_str(func_spur).to_string();
152 let purified_app = tm.mk_apply(&func_name, purified_args, sort);
153
154 let _ = self.assert_equality(fresh_var, purified_app);
157
158 Ok(fresh_var)
159 } else {
160 Ok(term_id)
161 }
162 }
163 _ => Ok(term_id),
164 }
165 }
166
167 pub fn get_pending_equalities(&mut self) -> Vec<(TermId, TermId)> {
169 let mut result = Vec::new();
170 while let Some(eq) = self.pending_equalities.pop_front() {
171 result.push(eq);
172 }
173 result
174 }
175
176 pub fn are_equal(&self, lhs: TermId, rhs: TermId) -> bool {
178 self.equality_classes.find(lhs) == self.equality_classes.find(rhs)
179 }
180
181 pub fn get_equivalence_class(&self, term_id: TermId) -> Vec<TermId> {
183 let rep = self.equality_classes.find(term_id);
184 self.shared_terms
185 .iter()
186 .filter(|&&t| self.equality_classes.find(t) == rep)
187 .copied()
188 .collect()
189 }
190
191 fn get_theory(&self, term_id: TermId) -> Option<TheoryId> {
193 self.theory_assignments.get(&term_id).copied()
194 }
195
196 pub fn convexity_closure(&mut self) -> Vec<(TermId, TermId)> {
202 let mut implied_equalities = Vec::new();
203
204 let mut classes: FxHashMap<TermId, Vec<TermId>> = FxHashMap::default();
206 for &term in &self.shared_terms {
207 let rep = self.equality_classes.find(term);
208 classes.entry(rep).or_default().push(term);
209 }
210
211 for (_rep, terms) in classes {
213 if terms.len() > 1 {
214 for i in 0..terms.len() {
216 for j in (i + 1)..terms.len() {
217 let key = Self::normalize_pair(terms[i], terms[j]);
218 if !self.propagated_equalities.contains(&key) {
219 implied_equalities.push((terms[i], terms[j]));
220 }
221 }
222 }
223 }
224 }
225
226 implied_equalities
227 }
228
229 pub fn stats(&self) -> &NelsonOppenStats {
231 &self.stats
232 }
233
234 pub fn reset(&mut self) {
236 self.shared_terms.clear();
237 self.equality_classes = UnionFind::new();
238 self.pending_equalities.clear();
239 self.propagated_equalities.clear();
240 self.theory_assignments.clear();
241 self.stats = NelsonOppenStats::default();
242 self.fresh_var_counter = 0;
243 }
244}
245
246impl Default for NelsonOppenCombiner {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252#[derive(Debug, Clone)]
254struct UnionFind {
255 parent: FxHashMap<TermId, TermId>,
256 rank: FxHashMap<TermId, usize>,
257}
258
259impl UnionFind {
260 fn new() -> Self {
261 Self {
262 parent: FxHashMap::default(),
263 rank: FxHashMap::default(),
264 }
265 }
266
267 fn make_set(&mut self, x: TermId) {
268 self.parent.insert(x, x);
269 self.rank.insert(x, 0);
270 }
271
272 fn find(&self, x: TermId) -> TermId {
273 let mut current = x;
274 while let Some(&parent) = self.parent.get(¤t) {
275 if parent == current {
276 return current;
277 }
278 current = parent;
279 }
280 x }
282
283 fn union(&mut self, x: TermId, y: TermId) {
284 let x_root = self.find(x);
285 let y_root = self.find(y);
286
287 if x_root == y_root {
288 return;
289 }
290
291 let x_rank = *self.rank.get(&x_root).unwrap_or(&0);
292 let y_rank = *self.rank.get(&y_root).unwrap_or(&0);
293
294 if x_rank < y_rank {
295 self.parent.insert(x_root, y_root);
296 } else if x_rank > y_rank {
297 self.parent.insert(y_root, x_root);
298 } else {
299 self.parent.insert(y_root, x_root);
300 self.rank.insert(x_root, x_rank + 1);
301 }
302 }
303}
304
305type SortId = usize;
321type FuncId = usize;
322type ConstId = usize;
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn test_nelson_oppen_creation() {
330 let combiner = NelsonOppenCombiner::new();
331 assert_eq!(combiner.stats.shared_terms_count, 0);
332 }
333
334 #[test]
335 fn test_register_shared_term() {
336 let mut combiner = NelsonOppenCombiner::new();
337 let term_id = TermId(0);
338
339 combiner.register_shared_term(term_id, TheoryId(0), TheoryId(1));
340
341 assert_eq!(combiner.stats.shared_terms_count, 1);
342 assert!(combiner.shared_terms.contains(&term_id));
343 }
344
345 #[test]
346 fn test_assert_equality() {
347 let mut combiner = NelsonOppenCombiner::new();
348 let t1 = TermId(0);
349 let t2 = TermId(1);
350
351 combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
352 combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
353
354 assert!(combiner.assert_equality(t1, t2).is_ok());
355 assert!(combiner.are_equal(t1, t2));
356 assert_eq!(combiner.stats.equalities_propagated, 1);
357 }
358
359 #[test]
360 fn test_convexity_closure() {
361 let mut combiner = NelsonOppenCombiner::new();
362 let t1 = TermId(0);
363 let t2 = TermId(1);
364 let t3 = TermId(2);
365
366 combiner.register_shared_term(t1, TheoryId(0), TheoryId(1));
367 combiner.register_shared_term(t2, TheoryId(0), TheoryId(1));
368 combiner.register_shared_term(t3, TheoryId(0), TheoryId(1));
369
370 combiner
371 .assert_equality(t1, t2)
372 .expect("test operation should succeed");
373 combiner
374 .assert_equality(t2, t3)
375 .expect("test operation should succeed");
376
377 let implied = combiner.convexity_closure();
378 assert!(!implied.is_empty());
379 }
380}