Skip to main content

slotted_egraphs/explain/
proof.rs

1use crate::*;
2
3use std::hash::{Hash, Hasher};
4
5#[derive(Clone, PartialEq, Eq, Hash)]
6pub struct Equation {
7    pub l: AppliedId,
8    pub r: AppliedId,
9}
10
11#[derive(Clone, Debug)]
12pub struct ExplicitProof(pub Option<String>);
13#[derive(Clone, Debug)]
14pub struct ReflexivityProof;
15#[derive(Clone, Debug)]
16pub struct SymmetryProof(pub ProvenEq);
17#[derive(Clone, Debug)]
18pub struct TransitivityProof(pub ProvenEq, pub ProvenEq);
19#[derive(Clone, Debug)]
20pub struct CongruenceProof(pub Vec<ProvenEq>);
21
22#[derive(Debug, Clone)]
23pub enum Proof {
24    Explicit(ExplicitProof),
25    Reflexivity(ReflexivityProof),
26    Symmetry(SymmetryProof),
27    Transitivity(TransitivityProof),
28    Congruence(CongruenceProof),
29    // Both global renaming within equations and alpha-equivalence will be handled in the other rules too.
30    // All equations will be understood as an arbitrary representative from its global renaming equivalence class.
31    // So f(x, y) = g(x, y) is conceptually the same equation as f(a, b) = g(a, b).
32    // In other words, whenever you use an equation, you always do it using "match_app_id".
33}
34
35pub type ProvenEq = Arc<ProvenEqRaw>;
36
37#[derive(Debug, Clone)]
38pub struct ProvenEqRaw {
39    // fields are intentionally private so that only this module can construct instances for it.
40    // These equations should always be fully "syn", i.e. they should not have any missing slot arguments, even redundant slots have to be passed explicitly.
41    eq: Equation,
42    proof: Proof,
43}
44
45impl ProvenEqRaw {
46    pub fn equ(&self) -> Equation {
47        (**self).clone()
48    }
49
50    pub(crate) fn check<L: Language, N: Analysis<L>>(&self, eg: &EGraph<L, N>) {
51        let Equation { l, r } = self.equ();
52        eg.check_syn_applied_id(&l);
53        eg.check_syn_applied_id(&r);
54    }
55}
56
57impl PartialEq for ProvenEqRaw {
58    // TODO normalize slotnames before this?
59    fn eq(&self, other: &Self) -> bool {
60        self.eq == other.eq
61    }
62}
63
64impl Eq for ProvenEqRaw {}
65
66impl Hash for ProvenEqRaw {
67    fn hash<H: Hasher>(&self, hasher: &mut H) {
68        // TODO normalize slotnames before this?
69        self.eq.hash(hasher);
70    }
71}
72
73impl ExplicitProof {
74    pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
75        let eq = eq.clone();
76        let proof = Proof::Explicit(self.clone());
77        reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
78    }
79}
80
81impl ReflexivityProof {
82    pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
83        assert_eq!(eq.l, eq.r);
84
85        let eq = eq.clone();
86        let proof = Proof::Reflexivity(self.clone());
87        reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
88    }
89}
90
91impl SymmetryProof {
92    pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
93        let SymmetryProof(x) = self;
94
95        let flipped = Equation {
96            l: x.r.clone(),
97            r: x.l.clone(),
98        };
99        assert_match_equation(eq, &flipped);
100
101        let eq = eq.clone();
102        let proof = Proof::Symmetry(self.clone());
103        reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
104    }
105}
106
107impl TransitivityProof {
108    pub(crate) fn check(&self, eq: &Equation, reg: &ProofRegistry) -> ProvenEq {
109        let TransitivityProof(eq1, eq2) = self;
110
111        let mut theta1 = {
112            // eq1.l*theta1 == eq.l
113            // -> theta1 == eq1.l^-1 * eq.l
114            eq1.l.m.inverse().compose_partial(&eq.l.m)
115        };
116        let mut theta2 = {
117            // eq2.r*theta2 == eq.r
118            // -> theta2 == eq2.r^-1 * eq.r
119            eq2.r.m.inverse().compose_partial(&eq.r.m)
120        };
121
122        let recompute_theta1 = |theta1: &mut SlotMap, theta2: &SlotMap| {
123            // eq1.r*theta1 == eq2.l*theta2
124            // -> theta1 == eq1.r^-1 * eq2.l * theta2
125            *theta1 = theta1
126                .try_union(
127                    &eq1.r
128                        .m
129                        .inverse()
130                        .compose_partial(&eq2.l.m)
131                        .compose_partial(theta2),
132                )
133                .unwrap();
134        };
135
136        let recompute_theta2 = |theta1: &SlotMap, theta2: &mut SlotMap| {
137            // eq1.r*theta1 == eq2.l*theta2
138            // -> theta2 == eq2.l^-1 * eq1.r * theta2
139            *theta2 = theta2
140                .try_union(
141                    &eq2.l
142                        .m
143                        .inverse()
144                        .compose_partial(&eq1.r.m)
145                        .compose_partial(theta1),
146                )
147                .unwrap();
148        };
149
150        recompute_theta1(&mut theta1, &theta2);
151        recompute_theta2(&theta1, &mut theta2);
152
153        for x in eq1.slots() {
154            if !theta1.contains_key(x) {
155                theta1.insert(x, Slot::fresh());
156            }
157        }
158        recompute_theta2(&theta1, &mut theta2);
159        for x in eq2.slots() {
160            if !theta2.contains_key(x) {
161                theta2.insert(x, Slot::fresh());
162            }
163        }
164
165        let renamed_eq1 = eq1.apply_slotmap(&theta1);
166        let renamed_eq2 = eq2.apply_slotmap(&theta2);
167
168        assert_eq!(renamed_eq1.l, eq.l);
169        assert_eq!(renamed_eq2.r, eq.r);
170        assert_eq!(renamed_eq1.r, renamed_eq2.l);
171
172        let eq = eq.clone();
173        let proof = Proof::Transitivity(self.clone());
174        reg.insert(Arc::new(ProvenEqRaw { eq, proof }))
175    }
176}
177
178// replaces 'private' slots with enumerated slot-names, like a shape.
179pub(crate) fn alpha_normalize<L: Language>(n: &L) -> L {
180    let (sh, bij) = n.weak_shape();
181    if CHECKS {
182        let all_slots: SmallHashSet<_> = sh.all_slot_occurrences().into_iter().collect();
183        assert!(&bij.values().is_disjoint(&all_slots));
184    }
185    sh.apply_slotmap(&bij)
186}
187
188impl CongruenceProof {
189    pub fn check<L: Language, N: Analysis<L>>(&self, eq: &Equation, eg: &EGraph<L, N>) -> ProvenEq {
190        let CongruenceProof(child_proofs) = self;
191
192        let l = alpha_normalize(&eg.get_syn_node(&eq.l));
193        let r = alpha_normalize(&eg.get_syn_node(&eq.r));
194
195        let null_l = nullify_app_ids(&l);
196        let null_r = nullify_app_ids(&r);
197        assert_eq!(null_l, null_r);
198
199        let l_v = l.applied_id_occurrences();
200        let r_v = r.applied_id_occurrences();
201
202        assert_eq!(l_v.len(), child_proofs.len());
203        assert_eq!(r_v.len(), child_proofs.len());
204
205        let l_v = l_v.into_iter().cloned();
206        let r_v = r_v.into_iter().cloned();
207
208        let c_v = child_proofs.into_iter();
209        for ((ll, rr), prf) in l_v.zip(r_v).zip(c_v) {
210            let eq1 = &Equation { l: ll, r: rr };
211            let eq2 = prf.deref();
212            assert_match_equation(eq1, eq2);
213        }
214
215        let eq = eq.clone();
216        let proof = Proof::Congruence(self.clone());
217        eg.proof_registry
218            .insert(Arc::new(ProvenEqRaw { eq, proof }))
219    }
220}
221
222impl Equation {
223    pub fn slots(&self) -> SmallHashSet<Slot> {
224        &self.l.slots() | &self.r.slots()
225    }
226
227    #[track_caller]
228    pub fn apply_slotmap(&self, m: &SlotMap) -> Self {
229        Equation {
230            l: self.l.apply_slotmap(&m),
231            r: self.r.apply_slotmap(&m),
232        }
233    }
234
235    pub fn apply_slotmap_fresh(&self, m: &SlotMap) -> Self {
236        let mut m = m.clone();
237        for s in &self.l.slots() | &self.r.slots() {
238            if !m.contains_key(s) {
239                m.insert(s, Slot::fresh());
240            }
241        }
242        Equation {
243            l: self.l.apply_slotmap(&m),
244            r: self.r.apply_slotmap(&m),
245        }
246    }
247}
248
249impl Deref for ProvenEqRaw {
250    type Target = Equation;
251
252    fn deref(&self) -> &Equation {
253        &self.eq
254    }
255}
256
257impl ProvenEqRaw {
258    pub fn proof(&self) -> &Proof {
259        &self.proof
260    }
261}
262
263// returns the global renaming theta, s.t. a.apply_slotmap(theta) = b, if it exists.
264#[track_caller]
265pub(crate) fn match_app_id(a: &AppliedId, b: &AppliedId) -> SlotMap {
266    if CHECKS {
267        assert_eq!(a.id, b.id);
268        assert_eq!(
269            a.m.keys(),
270            b.m.keys(),
271            "match_app_id failed: different set of arguments"
272        );
273    }
274
275    // a.m :: slots(i) -> A
276    // b.m :: slots(i) -> B
277    // theta :: A -> B
278    let theta = a.m.inverse().compose(&b.m);
279
280    if CHECKS {
281        assert_eq!(&a.apply_slotmap(&theta), b);
282    }
283
284    theta
285}
286
287// returns the bijective renaming theta, s.t. a.apply_slotmap(theta) = b, if it exists.
288pub(crate) fn assert_match_equation(a: &Equation, b: &Equation) -> SlotMap {
289    let theta_l = match_app_id(&a.l, &b.l);
290    let theta_r = match_app_id(&a.r, &b.r);
291
292    let theta = theta_l.try_union(&theta_r).unwrap_or_else(|| panic!("trying to union {theta_l:?} with {theta_r:?} while trying to match '{a:?}' against '{b:?}'"));
293
294    if CHECKS {
295        assert!(theta.is_bijection(), "trying to unify {theta_l:?} with {theta_r:?}, in assert_match_equation(\n  {a:?},\n  {b:?}\n)");
296
297        assert_eq!(&a.apply_slotmap(&theta), b);
298    }
299
300    theta
301}
302
303pub(crate) fn assert_proves_equation(peq: &ProvenEq, eq: &Equation) {
304    let mut e: Equation = (***peq).clone();
305
306    for s in e.l.m.keys() {
307        if !eq.l.m.contains_key(s) {
308            e.l.m.remove(s);
309        }
310    }
311
312    for s in e.r.m.keys() {
313        if !eq.r.m.contains_key(s) {
314            e.r.m.remove(s);
315        }
316    }
317
318    assert_match_equation(&e, eq);
319}