oxiz-solver 0.2.0

Main CDCL(T) Solver API for OxiZ
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
//! Equality Propagation Engine for Theory Combination.
#![allow(dead_code)] // Under development
//!
//! Implements efficient equality propagation between theories using:
//! - Congruence closure with union-find
//! - E-graph for term rewriting
//! - Equality explanation generation
//! - Watched equalities for lazy propagation

#[allow(unused_imports)]
use crate::prelude::*;
use oxiz_core::ast::{TermId, TermKind, TermManager};

/// Equality propagation engine.
pub struct EqualityPropagator {
    /// Union-find for equality classes
    union_find: UnionFind,
    /// Congruence closure data structures
    congruence: CongruenceData,
    /// Pending equalities to propagate
    pending: VecDeque<(TermId, TermId, Explanation)>,
    /// Watched equalities: term → watchers
    watched: FxHashMap<TermId, Vec<EqualityWatch>>,
    /// E-graph for term canonicalization
    egraph: EGraph,
    /// Statistics
    stats: EqualityPropStats,
}

/// Union-find data structure for equivalence classes.
#[derive(Debug, Clone)]
pub struct UnionFind {
    /// Parent pointers
    parent: FxHashMap<TermId, TermId>,
    /// Rank for union-by-rank
    rank: FxHashMap<TermId, usize>,
    /// Size of equivalence class
    size: FxHashMap<TermId, usize>,
}

/// Congruence closure data.
#[derive(Debug, Clone)]
pub struct CongruenceData {
    /// Use list: term → terms that use it
    use_list: FxHashMap<TermId, Vec<TermId>>,
    /// Lookup table: (function, args) → term
    lookup: FxHashMap<CongruenceKey, TermId>,
    /// Pending congruence checks
    pending_congruences: VecDeque<(TermId, TermId)>,
}

/// Key for congruence lookup.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CongruenceKey {
    /// Function/operator
    pub function: TermKind,
    /// Canonical arguments (equivalence class representatives)
    pub args: Vec<TermId>,
}

/// E-graph for term canonicalization.
#[derive(Debug, Clone)]
pub struct EGraph {
    /// E-class membership: term → e-class
    eclass: FxHashMap<TermId, EClassId>,
    /// E-class contents: e-class → terms
    nodes: FxHashMap<EClassId, Vec<TermId>>,
    /// E-class data
    data: FxHashMap<EClassId, EClassData>,
    /// Next available e-class ID
    next_id: EClassId,
}

/// E-class identifier.
pub type EClassId = usize;

/// Data associated with an e-class.
#[derive(Debug, Clone)]
pub struct EClassData {
    /// Representative term
    pub representative: TermId,
    /// Size of e-class
    pub size: usize,
    /// Parent e-classes (for congruence)
    pub parents: Vec<EClassId>,
}

/// Explanation for an equality.
#[derive(Debug, Clone)]
pub enum Explanation {
    /// Given equality (axiom)
    Given,
    /// Equality by reflexivity
    Reflexivity,
    /// Equality by transitivity
    Transitivity(TermId, Box<Explanation>, Box<Explanation>),
    /// Equality by congruence
    Congruence(Vec<(TermId, TermId, Box<Explanation>)>),
    /// Theory propagation
    TheoryPropagation(TheoryExplanation),
}

/// Theory-specific explanation.
#[derive(Debug, Clone)]
pub struct TheoryExplanation {
    /// Theory ID
    pub theory_id: usize,
    /// Antecedent equalities
    pub antecedents: Vec<(TermId, TermId)>,
}

/// Watched equality for lazy propagation.
#[derive(Debug, Clone)]
pub struct EqualityWatch {
    /// Left-hand side
    pub lhs: TermId,
    /// Right-hand side
    pub rhs: TermId,
    /// Callback ID
    pub callback: usize,
}

/// Equality propagation statistics.
#[derive(Debug, Clone, Default)]
pub struct EqualityPropStats {
    /// Equalities propagated
    pub equalities_propagated: usize,
    /// Congruences found
    pub congruences_found: usize,
    /// E-graph merges
    pub egraph_merges: usize,
    /// Explanations generated
    pub explanations_generated: usize,
    /// Watched equality triggers
    pub watch_triggers: usize,
}

impl UnionFind {
    /// Create a new union-find structure.
    pub fn new() -> Self {
        Self {
            parent: FxHashMap::default(),
            rank: FxHashMap::default(),
            size: FxHashMap::default(),
        }
    }

    /// Find the representative of a set.
    pub fn find(&mut self, x: TermId) -> TermId {
        if let crate::prelude::hash_map::Entry::Vacant(e) = self.parent.entry(x) {
            e.insert(x);
            self.rank.insert(x, 0);
            self.size.insert(x, 1);
            return x;
        }

        let parent = self.parent[&x];
        if parent != x {
            // Path compression
            let root = self.find(parent);
            self.parent.insert(x, root);
            root
        } else {
            x
        }
    }

    /// Union two sets.
    pub fn union(&mut self, x: TermId, y: TermId) -> bool {
        let root_x = self.find(x);
        let root_y = self.find(y);

        if root_x == root_y {
            return false; // Already in same set
        }

        let rank_x = self.rank.get(&root_x).copied().unwrap_or(0);
        let rank_y = self.rank.get(&root_y).copied().unwrap_or(0);

        // Union by rank
        if rank_x < rank_y {
            self.parent.insert(root_x, root_y);
            let size_x = self.size.get(&root_x).copied().unwrap_or(1);
            *self.size.entry(root_y).or_insert(1) += size_x;
        } else if rank_x > rank_y {
            self.parent.insert(root_y, root_x);
            let size_y = self.size.get(&root_y).copied().unwrap_or(1);
            *self.size.entry(root_x).or_insert(1) += size_y;
        } else {
            self.parent.insert(root_y, root_x);
            *self.rank.entry(root_x).or_insert(0) += 1;
            let size_y = self.size.get(&root_y).copied().unwrap_or(1);
            *self.size.entry(root_x).or_insert(1) += size_y;
        }

        true
    }

    /// Check if two elements are in the same set.
    pub fn connected(&mut self, x: TermId, y: TermId) -> bool {
        self.find(x) == self.find(y)
    }

    /// Get size of the set containing x.
    pub fn set_size(&mut self, x: TermId) -> usize {
        let root = self.find(x);
        self.size[&root]
    }
}

impl EqualityPropagator {
    /// Create a new equality propagator.
    pub fn new() -> Self {
        Self {
            union_find: UnionFind::new(),
            congruence: CongruenceData::new(),
            pending: VecDeque::new(),
            watched: FxHashMap::default(),
            egraph: EGraph::new(),
            stats: EqualityPropStats::default(),
        }
    }

    /// Assert an equality.
    pub fn assert_equality(
        &mut self,
        lhs: TermId,
        rhs: TermId,
        explanation: Explanation,
        tm: &TermManager,
    ) -> Result<(), String> {
        // Check if already equal
        if self.union_find.connected(lhs, rhs) {
            return Ok(());
        }

        // Add to pending queue
        self.pending.push_back((lhs, rhs, explanation));

        // Propagate all pending equalities
        self.propagate(tm)?;

        Ok(())
    }

    /// Propagate all pending equalities.
    fn propagate(&mut self, tm: &TermManager) -> Result<(), String> {
        while let Some((lhs, rhs, explanation)) = self.pending.pop_front() {
            self.propagate_equality(lhs, rhs, explanation, tm)?;
        }

        // Check for new congruences
        self.check_congruences(tm)?;

        Ok(())
    }

    /// Propagate a single equality.
    fn propagate_equality(
        &mut self,
        lhs: TermId,
        rhs: TermId,
        _explanation: Explanation,
        _tm: &TermManager,
    ) -> Result<(), String> {
        // Union in union-find
        if !self.union_find.union(lhs, rhs) {
            return Ok(()); // Already merged
        }

        self.stats.equalities_propagated += 1;

        // Merge in e-graph
        self.egraph.merge(lhs, rhs);
        self.stats.egraph_merges += 1;

        // Update use lists
        self.congruence.merge_use_lists(lhs, rhs);

        // Trigger watches
        self.trigger_watches(lhs, rhs)?;

        // Add parents to pending congruence checks
        let lhs_parents = self.congruence.get_parents(lhs);
        let rhs_parents = self.congruence.get_parents(rhs);

        for lhs_parent in lhs_parents {
            for &rhs_parent in &rhs_parents {
                self.congruence
                    .pending_congruences
                    .push_back((lhs_parent, rhs_parent));
            }
        }

        Ok(())
    }

    /// Check for new congruences.
    fn check_congruences(&mut self, tm: &TermManager) -> Result<(), String> {
        while let Some((t1, t2)) = self.congruence.pending_congruences.pop_front() {
            // Check if they have congruent arguments
            if self.are_congruent(t1, t2, tm)? {
                self.stats.congruences_found += 1;

                // Generate congruence explanation
                let explanation = self.generate_congruence_explanation(t1, t2, tm)?;

                // Assert equality
                self.pending.push_back((t1, t2, explanation));
            }
        }

        Ok(())
    }

    /// Check if two terms are congruent.
    fn are_congruent(&mut self, t1: TermId, t2: TermId, tm: &TermManager) -> Result<bool, String> {
        let term1 = tm.get(t1).ok_or("term not found")?;
        let term2 = tm.get(t2).ok_or("term not found")?;

        // Must have same kind
        if core::mem::discriminant(&term1.kind) != core::mem::discriminant(&term2.kind) {
            return Ok(false);
        }

        // Get arguments
        let args1 = self.get_args(&term1.kind);
        let args2 = self.get_args(&term2.kind);

        if args1.len() != args2.len() {
            return Ok(false);
        }

        // Check if all arguments are equal
        for (arg1, arg2) in args1.iter().zip(args2.iter()) {
            if !self.union_find.connected(*arg1, *arg2) {
                return Ok(false);
            }
        }

        Ok(true)
    }

    /// Generate explanation for congruence.
    fn generate_congruence_explanation(
        &mut self,
        t1: TermId,
        t2: TermId,
        tm: &TermManager,
    ) -> Result<Explanation, String> {
        let term1 = tm.get(t1).ok_or("term not found")?;
        let term2 = tm.get(t2).ok_or("term not found")?;

        let args1 = self.get_args(&term1.kind);
        let args2 = self.get_args(&term2.kind);

        let mut arg_explanations = Vec::new();

        for (arg1, arg2) in args1.iter().zip(args2.iter()) {
            let expl = self.explain_equality(*arg1, *arg2)?;
            arg_explanations.push((*arg1, *arg2, Box::new(expl)));
        }

        self.stats.explanations_generated += 1;

        Ok(Explanation::Congruence(arg_explanations))
    }

    /// Explain why two terms are equal.
    pub fn explain_equality(&mut self, lhs: TermId, rhs: TermId) -> Result<Explanation, String> {
        if lhs == rhs {
            return Ok(Explanation::Reflexivity);
        }

        if !self.union_find.connected(lhs, rhs) {
            return Err("Terms are not equal".to_string());
        }

        // Simplified: return a generic explanation
        // Full implementation would trace union-find path
        Ok(Explanation::Given)
    }

    /// Watch an equality.
    pub fn watch_equality(&mut self, lhs: TermId, rhs: TermId, callback: usize) {
        let watch = EqualityWatch { lhs, rhs, callback };

        self.watched.entry(lhs).or_default().push(watch.clone());
        self.watched.entry(rhs).or_default().push(watch);
    }

    /// Trigger watches when an equality is established.
    fn trigger_watches(&mut self, lhs: TermId, rhs: TermId) -> Result<(), String> {
        let mut triggered = Vec::new();

        // Check watches on lhs
        if let Some(watches) = self.watched.get(&lhs) {
            for watch in watches {
                if self.union_find.connected(watch.lhs, watch.rhs) {
                    triggered.push(watch.callback);
                }
            }
        }

        // Check watches on rhs
        if let Some(watches) = self.watched.get(&rhs) {
            for watch in watches {
                if self.union_find.connected(watch.lhs, watch.rhs) {
                    triggered.push(watch.callback);
                }
            }
        }

        self.stats.watch_triggers += triggered.len();

        Ok(())
    }

    /// Get arguments of a term.
    fn get_args(&self, kind: &TermKind) -> Vec<TermId> {
        match kind {
            TermKind::And(args) | TermKind::Or(args) => args.to_vec(),
            TermKind::Not(arg) => vec![*arg],
            TermKind::Eq(l, r) | TermKind::Le(l, r) | TermKind::Lt(l, r) => vec![*l, *r],
            TermKind::Add(args) | TermKind::Mul(args) => args.to_vec(),
            _ => vec![],
        }
    }

    /// Get statistics.
    pub fn stats(&self) -> &EqualityPropStats {
        &self.stats
    }
}

impl CongruenceData {
    /// Create new congruence data.
    pub fn new() -> Self {
        Self {
            use_list: FxHashMap::default(),
            lookup: FxHashMap::default(),
            pending_congruences: VecDeque::new(),
        }
    }

    /// Merge use lists when two terms become equal.
    pub fn merge_use_lists(&mut self, t1: TermId, t2: TermId) {
        // Simplified implementation
        let t1_uses = self.use_list.get(&t1).cloned().unwrap_or_default();
        let t2_uses = self.use_list.get(&t2).cloned().unwrap_or_default();

        let mut merged = t1_uses;
        merged.extend(t2_uses);

        self.use_list.insert(t1, merged.clone());
        self.use_list.insert(t2, merged);
    }

    /// Get parent terms.
    pub fn get_parents(&self, t: TermId) -> Vec<TermId> {
        self.use_list.get(&t).cloned().unwrap_or_default()
    }
}

impl EGraph {
    /// Create a new e-graph.
    pub fn new() -> Self {
        Self {
            eclass: FxHashMap::default(),
            nodes: FxHashMap::default(),
            data: FxHashMap::default(),
            next_id: 0,
        }
    }

    /// Get e-class for a term.
    pub fn get_eclass(&mut self, term: TermId) -> EClassId {
        if let Some(&id) = self.eclass.get(&term) {
            id
        } else {
            let id = self.next_id;
            self.next_id += 1;

            self.eclass.insert(term, id);
            self.nodes.insert(id, vec![term]);
            self.data.insert(
                id,
                EClassData {
                    representative: term,
                    size: 1,
                    parents: Vec::new(),
                },
            );

            id
        }
    }

    /// Merge two terms in the e-graph.
    pub fn merge(&mut self, t1: TermId, t2: TermId) {
        let id1 = self.get_eclass(t1);
        let id2 = self.get_eclass(t2);

        if id1 == id2 {
            return;
        }

        // Merge smaller into larger
        let size1 = self.data[&id1].size;
        let size2 = self.data[&id2].size;

        let (smaller, larger) = if size1 < size2 {
            (id1, id2)
        } else {
            (id2, id1)
        };

        // Update e-class membership
        let smaller_nodes = self.nodes[&smaller].clone();
        for &node in &smaller_nodes {
            self.eclass.insert(node, larger);
        }

        // Merge node lists
        if let Some(larger_nodes) = self.nodes.get_mut(&larger) {
            larger_nodes.extend(smaller_nodes);
        }
        self.nodes.remove(&smaller);

        // Update data
        let smaller_size = self.data.get(&smaller).map(|d| d.size).unwrap_or(0);
        if let Some(larger_data) = self.data.get_mut(&larger) {
            larger_data.size += smaller_size;
        }
        self.data.remove(&smaller);
    }
}

impl Default for EqualityPropagator {
    fn default() -> Self {
        Self::new()
    }
}

impl Default for UnionFind {
    fn default() -> Self {
        Self::new()
    }
}

impl Default for CongruenceData {
    fn default() -> Self {
        Self::new()
    }
}

impl Default for EGraph {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_union_find() {
        let mut uf = UnionFind::new();

        let t1 = TermId::from(1);
        let t2 = TermId::from(2);
        let t3 = TermId::from(3);

        assert!(!uf.connected(t1, t2));

        uf.union(t1, t2);
        assert!(uf.connected(t1, t2));

        uf.union(t2, t3);
        assert!(uf.connected(t1, t3));
    }

    #[test]
    fn test_equality_propagator() {
        let prop = EqualityPropagator::new();
        assert_eq!(prop.stats.equalities_propagated, 0);
    }

    #[test]
    fn test_egraph() {
        let mut eg = EGraph::new();

        let t1 = TermId::from(1);
        let t2 = TermId::from(2);

        let id1 = eg.get_eclass(t1);
        let id2 = eg.get_eclass(t2);

        assert_ne!(id1, id2);

        eg.merge(t1, t2);

        let id1_after = eg.get_eclass(t1);
        let id2_after = eg.get_eclass(t2);

        assert_eq!(id1_after, id2_after);
    }
}