1#![allow(dead_code)] use oxiz_core::ast::{TermId, TermKind, TermManager};
11use rustc_hash::FxHashMap;
12use std::collections::VecDeque;
13
14pub struct EqualityPropagator {
16 union_find: UnionFind,
18 congruence: CongruenceData,
20 pending: VecDeque<(TermId, TermId, Explanation)>,
22 watched: FxHashMap<TermId, Vec<EqualityWatch>>,
24 egraph: EGraph,
26 stats: EqualityPropStats,
28}
29
30#[derive(Debug, Clone)]
32pub struct UnionFind {
33 parent: FxHashMap<TermId, TermId>,
35 rank: FxHashMap<TermId, usize>,
37 size: FxHashMap<TermId, usize>,
39}
40
41#[derive(Debug, Clone)]
43pub struct CongruenceData {
44 use_list: FxHashMap<TermId, Vec<TermId>>,
46 lookup: FxHashMap<CongruenceKey, TermId>,
48 pending_congruences: VecDeque<(TermId, TermId)>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Hash)]
54pub struct CongruenceKey {
55 pub function: TermKind,
57 pub args: Vec<TermId>,
59}
60
61#[derive(Debug, Clone)]
63pub struct EGraph {
64 eclass: FxHashMap<TermId, EClassId>,
66 nodes: FxHashMap<EClassId, Vec<TermId>>,
68 data: FxHashMap<EClassId, EClassData>,
70 next_id: EClassId,
72}
73
74pub type EClassId = usize;
76
77#[derive(Debug, Clone)]
79pub struct EClassData {
80 pub representative: TermId,
82 pub size: usize,
84 pub parents: Vec<EClassId>,
86}
87
88#[derive(Debug, Clone)]
90pub enum Explanation {
91 Given,
93 Reflexivity,
95 Transitivity(TermId, Box<Explanation>, Box<Explanation>),
97 Congruence(Vec<(TermId, TermId, Box<Explanation>)>),
99 TheoryPropagation(TheoryExplanation),
101}
102
103#[derive(Debug, Clone)]
105pub struct TheoryExplanation {
106 pub theory_id: usize,
108 pub antecedents: Vec<(TermId, TermId)>,
110}
111
112#[derive(Debug, Clone)]
114pub struct EqualityWatch {
115 pub lhs: TermId,
117 pub rhs: TermId,
119 pub callback: usize,
121}
122
123#[derive(Debug, Clone, Default)]
125pub struct EqualityPropStats {
126 pub equalities_propagated: usize,
128 pub congruences_found: usize,
130 pub egraph_merges: usize,
132 pub explanations_generated: usize,
134 pub watch_triggers: usize,
136}
137
138impl UnionFind {
139 pub fn new() -> Self {
141 Self {
142 parent: FxHashMap::default(),
143 rank: FxHashMap::default(),
144 size: FxHashMap::default(),
145 }
146 }
147
148 pub fn find(&mut self, x: TermId) -> TermId {
150 if let std::collections::hash_map::Entry::Vacant(e) = self.parent.entry(x) {
151 e.insert(x);
152 self.rank.insert(x, 0);
153 self.size.insert(x, 1);
154 return x;
155 }
156
157 let parent = self.parent[&x];
158 if parent != x {
159 let root = self.find(parent);
161 self.parent.insert(x, root);
162 root
163 } else {
164 x
165 }
166 }
167
168 pub fn union(&mut self, x: TermId, y: TermId) -> bool {
170 let root_x = self.find(x);
171 let root_y = self.find(y);
172
173 if root_x == root_y {
174 return false; }
176
177 let rank_x = self.rank.get(&root_x).copied().unwrap_or(0);
178 let rank_y = self.rank.get(&root_y).copied().unwrap_or(0);
179
180 if rank_x < rank_y {
182 self.parent.insert(root_x, root_y);
183 let size_x = self.size.get(&root_x).copied().unwrap_or(1);
184 *self.size.entry(root_y).or_insert(1) += size_x;
185 } else if rank_x > rank_y {
186 self.parent.insert(root_y, root_x);
187 let size_y = self.size.get(&root_y).copied().unwrap_or(1);
188 *self.size.entry(root_x).or_insert(1) += size_y;
189 } else {
190 self.parent.insert(root_y, root_x);
191 *self.rank.entry(root_x).or_insert(0) += 1;
192 let size_y = self.size.get(&root_y).copied().unwrap_or(1);
193 *self.size.entry(root_x).or_insert(1) += size_y;
194 }
195
196 true
197 }
198
199 pub fn connected(&mut self, x: TermId, y: TermId) -> bool {
201 self.find(x) == self.find(y)
202 }
203
204 pub fn set_size(&mut self, x: TermId) -> usize {
206 let root = self.find(x);
207 self.size[&root]
208 }
209}
210
211impl EqualityPropagator {
212 pub fn new() -> Self {
214 Self {
215 union_find: UnionFind::new(),
216 congruence: CongruenceData::new(),
217 pending: VecDeque::new(),
218 watched: FxHashMap::default(),
219 egraph: EGraph::new(),
220 stats: EqualityPropStats::default(),
221 }
222 }
223
224 pub fn assert_equality(
226 &mut self,
227 lhs: TermId,
228 rhs: TermId,
229 explanation: Explanation,
230 tm: &TermManager,
231 ) -> Result<(), String> {
232 if self.union_find.connected(lhs, rhs) {
234 return Ok(());
235 }
236
237 self.pending.push_back((lhs, rhs, explanation));
239
240 self.propagate(tm)?;
242
243 Ok(())
244 }
245
246 fn propagate(&mut self, tm: &TermManager) -> Result<(), String> {
248 while let Some((lhs, rhs, explanation)) = self.pending.pop_front() {
249 self.propagate_equality(lhs, rhs, explanation, tm)?;
250 }
251
252 self.check_congruences(tm)?;
254
255 Ok(())
256 }
257
258 fn propagate_equality(
260 &mut self,
261 lhs: TermId,
262 rhs: TermId,
263 _explanation: Explanation,
264 _tm: &TermManager,
265 ) -> Result<(), String> {
266 if !self.union_find.union(lhs, rhs) {
268 return Ok(()); }
270
271 self.stats.equalities_propagated += 1;
272
273 self.egraph.merge(lhs, rhs);
275 self.stats.egraph_merges += 1;
276
277 self.congruence.merge_use_lists(lhs, rhs);
279
280 self.trigger_watches(lhs, rhs)?;
282
283 let lhs_parents = self.congruence.get_parents(lhs);
285 let rhs_parents = self.congruence.get_parents(rhs);
286
287 for lhs_parent in lhs_parents {
288 for &rhs_parent in &rhs_parents {
289 self.congruence
290 .pending_congruences
291 .push_back((lhs_parent, rhs_parent));
292 }
293 }
294
295 Ok(())
296 }
297
298 fn check_congruences(&mut self, tm: &TermManager) -> Result<(), String> {
300 while let Some((t1, t2)) = self.congruence.pending_congruences.pop_front() {
301 if self.are_congruent(t1, t2, tm)? {
303 self.stats.congruences_found += 1;
304
305 let explanation = self.generate_congruence_explanation(t1, t2, tm)?;
307
308 self.pending.push_back((t1, t2, explanation));
310 }
311 }
312
313 Ok(())
314 }
315
316 fn are_congruent(&mut self, t1: TermId, t2: TermId, tm: &TermManager) -> Result<bool, String> {
318 let term1 = tm.get(t1).ok_or("term not found")?;
319 let term2 = tm.get(t2).ok_or("term not found")?;
320
321 if std::mem::discriminant(&term1.kind) != std::mem::discriminant(&term2.kind) {
323 return Ok(false);
324 }
325
326 let args1 = self.get_args(&term1.kind);
328 let args2 = self.get_args(&term2.kind);
329
330 if args1.len() != args2.len() {
331 return Ok(false);
332 }
333
334 for (arg1, arg2) in args1.iter().zip(args2.iter()) {
336 if !self.union_find.connected(*arg1, *arg2) {
337 return Ok(false);
338 }
339 }
340
341 Ok(true)
342 }
343
344 fn generate_congruence_explanation(
346 &mut self,
347 t1: TermId,
348 t2: TermId,
349 tm: &TermManager,
350 ) -> Result<Explanation, String> {
351 let term1 = tm.get(t1).ok_or("term not found")?;
352 let term2 = tm.get(t2).ok_or("term not found")?;
353
354 let args1 = self.get_args(&term1.kind);
355 let args2 = self.get_args(&term2.kind);
356
357 let mut arg_explanations = Vec::new();
358
359 for (arg1, arg2) in args1.iter().zip(args2.iter()) {
360 let expl = self.explain_equality(*arg1, *arg2)?;
361 arg_explanations.push((*arg1, *arg2, Box::new(expl)));
362 }
363
364 self.stats.explanations_generated += 1;
365
366 Ok(Explanation::Congruence(arg_explanations))
367 }
368
369 pub fn explain_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<Explanation, String> {
371 if lhs == rhs {
372 return Ok(Explanation::Reflexivity);
373 }
374
375 if !self.union_find.connected(lhs, rhs) {
376 return Err("Terms are not equal".to_string());
377 }
378
379 Ok(Explanation::Given)
382 }
383
384 pub fn watch_equality(&mut self, lhs: TermId, rhs: TermId, callback: usize) {
386 let watch = EqualityWatch { lhs, rhs, callback };
387
388 self.watched.entry(lhs).or_default().push(watch.clone());
389 self.watched.entry(rhs).or_default().push(watch);
390 }
391
392 fn trigger_watches(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
394 let mut triggered = Vec::new();
395
396 if let Some(watches) = self.watched.get(&lhs) {
398 for watch in watches {
399 if self.union_find.connected(watch.lhs, watch.rhs) {
400 triggered.push(watch.callback);
401 }
402 }
403 }
404
405 if let Some(watches) = self.watched.get(&rhs) {
407 for watch in watches {
408 if self.union_find.connected(watch.lhs, watch.rhs) {
409 triggered.push(watch.callback);
410 }
411 }
412 }
413
414 self.stats.watch_triggers += triggered.len();
415
416 Ok(())
417 }
418
419 fn get_args(&self, kind: &TermKind) -> Vec<TermId> {
421 match kind {
422 TermKind::And(args) | TermKind::Or(args) => args.to_vec(),
423 TermKind::Not(arg) => vec![*arg],
424 TermKind::Eq(l, r) | TermKind::Le(l, r) | TermKind::Lt(l, r) => vec![*l, *r],
425 TermKind::Add(args) | TermKind::Mul(args) => args.to_vec(),
426 _ => vec![],
427 }
428 }
429
430 pub fn stats(&self) -> &EqualityPropStats {
432 &self.stats
433 }
434}
435
436impl CongruenceData {
437 pub fn new() -> Self {
439 Self {
440 use_list: FxHashMap::default(),
441 lookup: FxHashMap::default(),
442 pending_congruences: VecDeque::new(),
443 }
444 }
445
446 pub fn merge_use_lists(&mut self, t1: TermId, t2: TermId) {
448 let t1_uses = self.use_list.get(&t1).cloned().unwrap_or_default();
450 let t2_uses = self.use_list.get(&t2).cloned().unwrap_or_default();
451
452 let mut merged = t1_uses;
453 merged.extend(t2_uses);
454
455 self.use_list.insert(t1, merged.clone());
456 self.use_list.insert(t2, merged);
457 }
458
459 pub fn get_parents(&self, t: TermId) -> Vec<TermId> {
461 self.use_list.get(&t).cloned().unwrap_or_default()
462 }
463}
464
465impl EGraph {
466 pub fn new() -> Self {
468 Self {
469 eclass: FxHashMap::default(),
470 nodes: FxHashMap::default(),
471 data: FxHashMap::default(),
472 next_id: 0,
473 }
474 }
475
476 pub fn get_eclass(&mut self, term: TermId) -> EClassId {
478 if let Some(&id) = self.eclass.get(&term) {
479 id
480 } else {
481 let id = self.next_id;
482 self.next_id += 1;
483
484 self.eclass.insert(term, id);
485 self.nodes.insert(id, vec![term]);
486 self.data.insert(
487 id,
488 EClassData {
489 representative: term,
490 size: 1,
491 parents: Vec::new(),
492 },
493 );
494
495 id
496 }
497 }
498
499 pub fn merge(&mut self, t1: TermId, t2: TermId) {
501 let id1 = self.get_eclass(t1);
502 let id2 = self.get_eclass(t2);
503
504 if id1 == id2 {
505 return;
506 }
507
508 let size1 = self.data[&id1].size;
510 let size2 = self.data[&id2].size;
511
512 let (smaller, larger) = if size1 < size2 {
513 (id1, id2)
514 } else {
515 (id2, id1)
516 };
517
518 let smaller_nodes = self.nodes[&smaller].clone();
520 for &node in &smaller_nodes {
521 self.eclass.insert(node, larger);
522 }
523
524 if let Some(larger_nodes) = self.nodes.get_mut(&larger) {
526 larger_nodes.extend(smaller_nodes);
527 }
528 self.nodes.remove(&smaller);
529
530 let smaller_size = self.data.get(&smaller).map(|d| d.size).unwrap_or(0);
532 if let Some(larger_data) = self.data.get_mut(&larger) {
533 larger_data.size += smaller_size;
534 }
535 self.data.remove(&smaller);
536 }
537}
538
539impl Default for EqualityPropagator {
540 fn default() -> Self {
541 Self::new()
542 }
543}
544
545impl Default for UnionFind {
546 fn default() -> Self {
547 Self::new()
548 }
549}
550
551impl Default for CongruenceData {
552 fn default() -> Self {
553 Self::new()
554 }
555}
556
557impl Default for EGraph {
558 fn default() -> Self {
559 Self::new()
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_union_find() {
569 let mut uf = UnionFind::new();
570
571 let t1 = TermId::from(1);
572 let t2 = TermId::from(2);
573 let t3 = TermId::from(3);
574
575 assert!(!uf.connected(t1, t2));
576
577 uf.union(t1, t2);
578 assert!(uf.connected(t1, t2));
579
580 uf.union(t2, t3);
581 assert!(uf.connected(t1, t3));
582 }
583
584 #[test]
585 fn test_equality_propagator() {
586 let prop = EqualityPropagator::new();
587 assert_eq!(prop.stats.equalities_propagated, 0);
588 }
589
590 #[test]
591 fn test_egraph() {
592 let mut eg = EGraph::new();
593
594 let t1 = TermId::from(1);
595 let t2 = TermId::from(2);
596
597 let id1 = eg.get_eclass(t1);
598 let id2 = eg.get_eclass(t2);
599
600 assert_ne!(id1, id2);
601
602 eg.merge(t1, t2);
603
604 let id1_after = eg.get_eclass(t1);
605 let id2_after = eg.get_eclass(t2);
606
607 assert_eq!(id1_after, id2_after);
608 }
609}