egglog_bridge/
proof_format.rs

1//! A proof format for egglog programs, based on the Rocq format and checker from Tia Vu, Ryan
2//! Doegens, and Oliver Flatt.
3use std::{hash::Hash, io, rc::Rc};
4
5use crate::core_relations::Value;
6use crate::numeric_id::{DenseIdMap, NumericId, define_id};
7use indexmap::IndexSet;
8
9use crate::{FunctionId, rule::Variable};
10
11define_id!(pub TermProofId, u32, "an id identifying proofs of terms within a [`ProofStore`]");
12define_id!(pub EqProofId, u32, "an id identifying proofs of equality between two terms within a [`ProofStore`]");
13define_id!(pub TermId, u32, "an id identifying terms within a [`TermDag`]");
14
15#[derive(Clone, Debug)]
16struct HashCons<K, T> {
17    data: IndexSet<T>,
18    _marker: std::marker::PhantomData<K>,
19}
20
21impl<K, T> Default for HashCons<K, T> {
22    fn default() -> Self {
23        HashCons {
24            data: IndexSet::new(),
25            _marker: std::marker::PhantomData,
26        }
27    }
28}
29
30impl<K: NumericId, T: Clone + Eq + Hash> HashCons<K, T> {
31    fn get_or_insert(&mut self, value: &T) -> K {
32        if let Some((index, _)) = self.data.get_full(value) {
33            K::from_usize(index)
34        } else {
35            let id = K::from_usize(self.data.len());
36            self.data.insert(value.clone());
37            id
38        }
39    }
40
41    fn lookup(&self, id: K) -> Option<&T> {
42        self.data.get_index(id.index())
43    }
44}
45
46#[derive(Default, Clone)]
47pub struct TermDag {
48    store: HashCons<TermId, Term>,
49}
50
51impl TermDag {
52    /// Print the term in a human-readable format to the given writer.
53    pub fn print_term(&self, term: TermId, writer: &mut impl io::Write) -> io::Result<()> {
54        self.print_term_pretty(term, &PrettyPrintConfig::default(), writer)
55    }
56
57    /// Print the term with pretty-printing configuration.
58    pub fn print_term_pretty(
59        &self,
60        term: TermId,
61        config: &PrettyPrintConfig,
62        writer: &mut impl io::Write,
63    ) -> io::Result<()> {
64        let mut printer = PrettyPrinter::new(writer, config);
65        self.print_term_with_printer(term, &mut printer)
66    }
67
68    fn print_term_with_printer<W: io::Write>(
69        &self,
70        term: TermId,
71        printer: &mut PrettyPrinter<W>,
72    ) -> io::Result<()> {
73        let term = self.store.lookup(term).unwrap();
74        match term {
75            Term::Constant { id, rendered } => {
76                if let Some(rendered) = rendered {
77                    printer.write_str(rendered)?;
78                } else {
79                    printer.write_str(&format!("c{}", id.index()))?;
80                }
81            }
82            Term::Func { id, args } => {
83                printer.write_str(&format!("({id:?}"))?;
84                if !args.is_empty() {
85                    printer.increase_indent();
86                    for (i, arg) in args.iter().enumerate() {
87                        if i > 0 {
88                            printer.write_str(",")?;
89                        }
90                        printer.write_with_break(" ")?;
91                        self.print_term_with_printer(*arg, printer)?;
92                    }
93                    printer.decrease_indent();
94                }
95                printer.write_str(")")?;
96            }
97        }
98        Ok(())
99    }
100
101    /// Add the given [`Term`] to the store, returning its [`TermId`].
102    ///
103    /// The [`TermId`]s in this term should point into this same [`TermDag`].
104    pub fn get_or_insert(&mut self, term: &Term) -> TermId {
105        self.store.get_or_insert(term)
106    }
107
108    pub(crate) fn proj(&self, term: TermId, arg_idx: usize) -> TermId {
109        let term = self.store.lookup(term).unwrap();
110        match term {
111            Term::Func { args, .. } => {
112                if arg_idx < args.len() {
113                    args[arg_idx]
114                } else {
115                    panic!("Index out of bounds for function arguments")
116                }
117            }
118            _ => panic!("Cannot project a non-function term"),
119        }
120    }
121}
122
123#[derive(Clone, PartialEq, Eq, Hash, Debug)]
124pub enum Term {
125    Constant {
126        id: Value,
127        rendered: Option<Rc<str>>,
128    },
129    Func {
130        id: FunctionId,
131        args: Vec<TermId>,
132    },
133}
134
135/// A hash-cons store for proofs and terms related to an egglog program.
136#[derive(Clone, Default)]
137pub struct ProofStore {
138    eq_memo: HashCons<EqProofId, EqProof>,
139    term_memo: HashCons<TermProofId, TermProof>,
140    pub(crate) termdag: TermDag,
141}
142
143impl ProofStore {
144    /// Print a term proof with pretty-printing configuration.
145    pub fn print_term_proof_pretty(
146        &self,
147        term_pf: TermProofId,
148        config: &PrettyPrintConfig,
149        writer: &mut impl io::Write,
150    ) -> io::Result<()> {
151        let mut printer = PrettyPrinter::new(writer, config);
152        self.print_term_proof_with_printer(term_pf, &mut printer)
153    }
154
155    /// Print an equality proof with pretty-printing configuration.
156    pub fn print_eq_proof_pretty(
157        &self,
158        eq_pf: EqProofId,
159        config: &PrettyPrintConfig,
160        writer: &mut impl io::Write,
161    ) -> io::Result<()> {
162        let mut printer = PrettyPrinter::new(writer, config);
163        self.print_eq_proof_with_printer(eq_pf, &mut printer)
164    }
165
166    fn print_cong_with_printer<W: io::Write>(
167        &self,
168        cong_pf: &CongProof,
169        printer: &mut PrettyPrinter<W>,
170    ) -> io::Result<()> {
171        let CongProof {
172            pf_args_eq,
173            pf_f_args_ok,
174            old_term,
175            new_term,
176            func,
177        } = cong_pf;
178        printer.write_str(&format!("Cong({func:?}, "))?;
179        self.termdag.print_term_with_printer(*old_term, printer)?;
180        printer.write_str(" => ")?;
181        self.termdag.print_term_with_printer(*new_term, printer)?;
182        printer.write_str(" by (")?;
183        printer.increase_indent();
184        for (i, pf) in pf_args_eq.iter().enumerate() {
185            if i > 0 {
186                printer.write_str(",")?;
187            }
188            printer.write_with_break(" ")?;
189            self.print_eq_proof_with_printer(*pf, printer)?;
190        }
191        printer.write_with_break(") , old term exists by: ")?;
192        printer.increase_indent();
193        printer.newline()?;
194        self.print_term_proof_with_printer(*pf_f_args_ok, printer)?;
195        printer.decrease_indent();
196        printer.write_str(")")?;
197        printer.decrease_indent();
198        Ok(())
199    }
200
201    /// Print the equality proof in a human-readable format to the given writer.
202    pub fn print_eq_proof(&self, eq_pf: EqProofId, writer: &mut impl io::Write) -> io::Result<()> {
203        self.print_eq_proof_pretty(eq_pf, &PrettyPrintConfig::default(), writer)
204    }
205
206    fn print_eq_proof_with_printer<W: io::Write>(
207        &self,
208        eq_pf: EqProofId,
209        printer: &mut PrettyPrinter<W>,
210    ) -> io::Result<()> {
211        let eq_pf = self.eq_memo.lookup(eq_pf).unwrap();
212        match eq_pf {
213            EqProof::PRule {
214                rule_name,
215                subst,
216                body_pfs,
217                result_lhs,
218                result_rhs,
219            } => {
220                printer.write_str(&format!("PRule[Equality]({rule_name:?}, Subst {{"))?;
221                printer.increase_indent();
222                printer.newline()?;
223                for (i, (var, term)) in subst.iter().enumerate() {
224                    if i > 0 {
225                        printer.write_str(",")?;
226                    }
227                    printer.write_with_break(" ")?;
228                    printer.write_str(&format!("{var:?} => "))?;
229                    self.termdag.print_term_with_printer(*term, printer)?;
230                    printer.newline()?;
231                }
232                printer.newline()?;
233                printer.write_with_break("},")?;
234                printer.newline()?;
235                printer.write_with_break("Body Pfs: [")?;
236                printer.increase_indent();
237                for (i, pf) in body_pfs.iter().enumerate() {
238                    if i > 0 {
239                        printer.write_str(",")?;
240                    }
241                    printer.write_with_break(" ")?;
242                    match pf {
243                        Premise::TermOk(term_pf) => {
244                            printer.write_str("TermOk(")?;
245                            self.print_term_proof_with_printer(*term_pf, printer)?;
246                            printer.write_str(")")?;
247                        }
248                        Premise::Eq(eq_pf) => {
249                            printer.write_str("Eq(")?;
250                            self.print_eq_proof_with_printer(*eq_pf, printer)?;
251                            printer.write_str(")")?;
252                        }
253                    }
254                }
255                printer.decrease_indent();
256                printer.write_with_break("], ")?;
257                printer.newline()?;
258                printer.write_with_break(" Result: ")?;
259                self.termdag.print_term_with_printer(*result_lhs, printer)?;
260                printer.write_str(" = ")?;
261                self.termdag.print_term_with_printer(*result_rhs, printer)?;
262                printer.write_str(")")?;
263                printer.decrease_indent();
264            }
265            EqProof::PRefl { t_ok_pf, t } => {
266                printer.write_str("PRefl(")?;
267                self.print_term_proof_with_printer(*t_ok_pf, printer)?;
268                printer.write_str(", (term= ")?;
269                self.termdag.print_term_with_printer(*t, printer)?;
270                printer.write_str("))")?
271            }
272            EqProof::PSym { eq_pf } => {
273                printer.write_str("PSym(")?;
274                self.print_eq_proof_with_printer(*eq_pf, printer)?;
275                printer.write_str(")")?
276            }
277            EqProof::PTrans { pfxy, pfyz } => {
278                printer.write_str("PTrans(")?;
279                printer.increase_indent();
280                printer.increase_indent();
281                printer.newline()?;
282                self.print_eq_proof_with_printer(*pfxy, printer)?;
283                printer.decrease_indent();
284                printer.newline()?;
285                printer.write_with_break(" ... and then ... ")?;
286                printer.increase_indent();
287                printer.newline()?;
288                self.print_eq_proof_with_printer(*pfyz, printer)?;
289                printer.decrease_indent();
290                printer.decrease_indent();
291                printer.newline()?;
292                printer.write_str(")")?
293            }
294            EqProof::PCong(cong_pf) => {
295                printer.write_str("PCong[Equality](")?;
296                self.print_cong_with_printer(cong_pf, printer)?;
297                printer.write_str(")")?
298            }
299        }
300        printer.newline()
301    }
302
303    /// Print the term proof in a human-readable format to the given writer.
304    pub fn print_term_proof(
305        &self,
306        term_pf: TermProofId,
307        writer: &mut impl io::Write,
308    ) -> io::Result<()> {
309        self.print_term_proof_pretty(term_pf, &PrettyPrintConfig::default(), writer)
310    }
311
312    fn print_term_proof_with_printer<W: io::Write>(
313        &self,
314        term_pf: TermProofId,
315        printer: &mut PrettyPrinter<W>,
316    ) -> io::Result<()> {
317        let term_pf = self.term_memo.lookup(term_pf).unwrap();
318        match term_pf {
319            TermProof::PRule {
320                rule_name,
321                subst,
322                body_pfs,
323                result,
324            } => {
325                printer.write_str(&format!("PRule[Existence]({rule_name:?}, Subst {{"))?;
326                printer.increase_indent();
327                printer.newline()?;
328                for (i, (var, term)) in subst.iter().enumerate() {
329                    if i > 0 {
330                        printer.write_str(",")?;
331                    }
332                    printer.write_with_break(" ")?;
333                    printer.write_str(&format!("{var:?} => "))?;
334                    self.termdag.print_term_with_printer(*term, printer)?;
335                    printer.newline()?;
336                }
337                printer.newline()?;
338                printer.write_with_break("},")?;
339                printer.newline()?;
340                printer.write_with_break("Body Pfs: [")?;
341                printer.increase_indent();
342                for (i, pf) in body_pfs.iter().enumerate() {
343                    if i > 0 {
344                        printer.write_str(",")?;
345                    }
346                    printer.write_with_break(" ")?;
347                    match pf {
348                        Premise::TermOk(term_pf) => {
349                            printer.write_str("TermOk(")?;
350                            self.print_term_proof_with_printer(*term_pf, printer)?;
351                            printer.write_str(")")?;
352                        }
353                        Premise::Eq(eq_pf) => {
354                            printer.write_str("Eq(")?;
355                            self.print_eq_proof_with_printer(*eq_pf, printer)?;
356                            printer.write_str(")")?;
357                        }
358                    }
359                }
360                printer.decrease_indent();
361                printer.write_with_break("], Result: ")?;
362                self.termdag.print_term_with_printer(*result, printer)?;
363                printer.write_str(")")
364            }
365            TermProof::PProj {
366                pf_f_args_ok,
367                arg_idx,
368            } => {
369                printer.write_str("PProj(")?;
370                self.print_term_proof_with_printer(*pf_f_args_ok, printer)?;
371                printer.write_str(&format!(", {arg_idx})"))
372            }
373            TermProof::PCong(cong_pf) => {
374                printer.write_str("PCong[Exists](")?;
375                self.print_cong_with_printer(cong_pf, printer)?;
376                printer.write_str(")")
377            }
378            TermProof::PFiat { desc, term } => {
379                printer.write_str(&format!("PFiat({desc:?}"))?;
380                printer.write_str(", ")?;
381                self.termdag.print_term_with_printer(*term, printer)?;
382                printer.write_str(")")
383            }
384        }
385    }
386    pub(crate) fn intern_term(&mut self, prf: &TermProof) -> TermProofId {
387        self.term_memo.get_or_insert(prf)
388    }
389    pub(crate) fn intern_eq(&mut self, prf: &EqProof) -> EqProofId {
390        self.eq_memo.get_or_insert(prf)
391    }
392
393    pub(crate) fn refl(&mut self, proof: TermProofId, term: TermId) -> EqProofId {
394        self.intern_eq(&EqProof::PRefl {
395            t_ok_pf: proof,
396            t: term,
397        })
398    }
399
400    pub(crate) fn sym(&mut self, proof: EqProofId) -> EqProofId {
401        self.intern_eq(&EqProof::PSym { eq_pf: proof })
402    }
403
404    pub(crate) fn trans(&mut self, pfxy: EqProofId, pfyz: EqProofId) -> EqProofId {
405        self.intern_eq(&EqProof::PTrans { pfxy, pfyz })
406    }
407
408    pub(crate) fn sequence_proofs(&mut self, pfs: &[EqProofId]) -> EqProofId {
409        match pfs {
410            [] => panic!("Cannot sequence an empty list of proofs"),
411            [pf] => *pf,
412            [pf1, rest @ ..] => {
413                let mut cur = *pf1;
414                for pf in rest {
415                    cur = self.trans(cur, *pf);
416                }
417                cur
418            }
419        }
420    }
421}
422
423#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
424pub enum Premise {
425    TermOk(TermProofId),
426    Eq(EqProofId),
427}
428
429#[derive(Clone, PartialEq, Eq, Hash, Debug)]
430pub struct CongProof {
431    pub pf_args_eq: Vec<EqProofId>,
432    pub pf_f_args_ok: TermProofId,
433    pub old_term: TermId,
434    pub new_term: TermId,
435    pub func: FunctionId,
436}
437
438#[allow(clippy::enum_variant_names)]
439#[derive(Clone, PartialEq, Eq, Hash, Debug)]
440pub enum TermProof {
441    /// proves a Proposition based on a rule application
442    /// the subsitution gives the mapping from variables to terms
443    /// the body_pfs gives proofs for each of the conditions in the query of the rule
444    /// the act_pf gives a location in the action of the proposition
445    PRule {
446        rule_name: Rc<str>,
447        subst: DenseIdMap<Variable, TermId>,
448        body_pfs: Vec<Premise>,
449        result: TermId,
450    },
451    /// get a proof for the child of a term given a proof of a term
452    PProj {
453        pf_f_args_ok: TermProofId,
454        arg_idx: usize,
455    },
456    PCong(CongProof),
457    PFiat {
458        desc: Rc<str>,
459        term: TermId,
460    },
461}
462
463#[allow(clippy::enum_variant_names)]
464#[derive(Clone, PartialEq, Eq, Hash, Debug)]
465pub enum EqProof {
466    /// proves a Proposition based on a rule application
467    /// the subsitution gives the mapping from variables to terms
468    /// the body_pfs gives proofs for each of the conditions in the query of the rule
469    /// the act_pf gives a location in the action of the proposition
470    PRule {
471        rule_name: Rc<str>,
472        subst: DenseIdMap<Variable, TermId>,
473        body_pfs: Vec<Premise>,
474        result_lhs: TermId,
475        result_rhs: TermId,
476    },
477    /// A term is equal to itself- proves the proposition t = t
478    PRefl {
479        t_ok_pf: TermProofId,
480        t: TermId,
481    },
482    /// The symmetric equality of eq_pf
483    PSym {
484        eq_pf: EqProofId,
485    },
486    PTrans {
487        pfxy: EqProofId,
488        pfyz: EqProofId,
489    },
490    /// Proves f(x1, y1, ...) = f(x2, y2, ...) where f is fun_sym
491    /// A proof via congruence- one proof for each child of the term
492    /// pf_f_args_ok is a proof that the term with the lhs children is valid
493    PCong(CongProof),
494}
495
496#[derive(Clone, Debug)]
497pub struct PrettyPrintConfig {
498    pub line_width: usize,
499    pub indent_size: usize,
500}
501
502impl Default for PrettyPrintConfig {
503    fn default() -> Self {
504        Self {
505            line_width: 512,
506            indent_size: 4,
507        }
508    }
509}
510
511struct PrettyPrinter<'w, W: io::Write> {
512    writer: &'w mut W,
513    config: &'w PrettyPrintConfig,
514    current_indent: usize,
515    current_line_pos: usize,
516}
517
518impl<'w, W: io::Write> PrettyPrinter<'w, W> {
519    fn new(writer: &'w mut W, config: &'w PrettyPrintConfig) -> Self {
520        Self {
521            writer,
522            config,
523            current_indent: 0,
524            current_line_pos: 0,
525        }
526    }
527
528    fn write_str(&mut self, s: &str) -> io::Result<()> {
529        write!(self.writer, "{s}")?;
530        self.current_line_pos += s.len();
531        Ok(())
532    }
533
534    fn newline(&mut self) -> io::Result<()> {
535        writeln!(self.writer)?;
536        self.current_line_pos = 0;
537        self.write_indent()?;
538        Ok(())
539    }
540
541    fn write_indent(&mut self) -> io::Result<()> {
542        for _ in 0..self.current_indent {
543            write!(self.writer, " ")?;
544        }
545        self.current_line_pos = self.current_indent;
546        Ok(())
547    }
548
549    fn increase_indent(&mut self) {
550        self.current_indent += self.config.indent_size;
551    }
552
553    fn decrease_indent(&mut self) {
554        self.current_indent = self.current_indent.saturating_sub(self.config.indent_size);
555    }
556
557    fn should_break(&self, additional_chars: usize) -> bool {
558        self.current_line_pos + additional_chars > self.config.line_width
559    }
560
561    fn write_with_break(&mut self, s: &str) -> io::Result<()> {
562        if self.should_break(s.len()) && self.current_line_pos > self.current_indent {
563            self.newline()?;
564            self.write_indent()?;
565        }
566        self.write_str(s)
567    }
568}