1use std::{hash::Hash, io, rc::Rc};
4
5use crate::core_relations::Value;
6use crate::numeric_id::{DenseIdMap, NumericId, define_id};
7use indexmap::IndexSet;
8
9use crate::{FunctionId, rule::Variable};
10
11define_id!(pub TermProofId, u32, "an id identifying proofs of terms within a [`ProofStore`]");
12define_id!(pub EqProofId, u32, "an id identifying proofs of equality between two terms within a [`ProofStore`]");
13define_id!(pub TermId, u32, "an id identifying terms within a [`TermDag`]");
14
15#[derive(Clone, Debug)]
16struct HashCons<K, T> {
17 data: IndexSet<T>,
18 _marker: std::marker::PhantomData<K>,
19}
20
21impl<K, T> Default for HashCons<K, T> {
22 fn default() -> Self {
23 HashCons {
24 data: IndexSet::new(),
25 _marker: std::marker::PhantomData,
26 }
27 }
28}
29
30impl<K: NumericId, T: Clone + Eq + Hash> HashCons<K, T> {
31 fn get_or_insert(&mut self, value: &T) -> K {
32 if let Some((index, _)) = self.data.get_full(value) {
33 K::from_usize(index)
34 } else {
35 let id = K::from_usize(self.data.len());
36 self.data.insert(value.clone());
37 id
38 }
39 }
40
41 fn lookup(&self, id: K) -> Option<&T> {
42 self.data.get_index(id.index())
43 }
44}
45
46#[derive(Default, Clone)]
47pub struct TermDag {
48 store: HashCons<TermId, Term>,
49}
50
51impl TermDag {
52 pub fn print_term(&self, term: TermId, writer: &mut impl io::Write) -> io::Result<()> {
54 self.print_term_pretty(term, &PrettyPrintConfig::default(), writer)
55 }
56
57 pub fn print_term_pretty(
59 &self,
60 term: TermId,
61 config: &PrettyPrintConfig,
62 writer: &mut impl io::Write,
63 ) -> io::Result<()> {
64 let mut printer = PrettyPrinter::new(writer, config);
65 self.print_term_with_printer(term, &mut printer)
66 }
67
68 fn print_term_with_printer<W: io::Write>(
69 &self,
70 term: TermId,
71 printer: &mut PrettyPrinter<W>,
72 ) -> io::Result<()> {
73 let term = self.store.lookup(term).unwrap();
74 match term {
75 Term::Constant { id, rendered } => {
76 if let Some(rendered) = rendered {
77 printer.write_str(rendered)?;
78 } else {
79 printer.write_str(&format!("c{}", id.index()))?;
80 }
81 }
82 Term::Func { id, args } => {
83 printer.write_str(&format!("({id:?}"))?;
84 if !args.is_empty() {
85 printer.increase_indent();
86 for (i, arg) in args.iter().enumerate() {
87 if i > 0 {
88 printer.write_str(",")?;
89 }
90 printer.write_with_break(" ")?;
91 self.print_term_with_printer(*arg, printer)?;
92 }
93 printer.decrease_indent();
94 }
95 printer.write_str(")")?;
96 }
97 }
98 Ok(())
99 }
100
101 pub fn get_or_insert(&mut self, term: &Term) -> TermId {
105 self.store.get_or_insert(term)
106 }
107
108 pub(crate) fn proj(&self, term: TermId, arg_idx: usize) -> TermId {
109 let term = self.store.lookup(term).unwrap();
110 match term {
111 Term::Func { args, .. } => {
112 if arg_idx < args.len() {
113 args[arg_idx]
114 } else {
115 panic!("Index out of bounds for function arguments")
116 }
117 }
118 _ => panic!("Cannot project a non-function term"),
119 }
120 }
121}
122
123#[derive(Clone, PartialEq, Eq, Hash, Debug)]
124pub enum Term {
125 Constant {
126 id: Value,
127 rendered: Option<Rc<str>>,
128 },
129 Func {
130 id: FunctionId,
131 args: Vec<TermId>,
132 },
133}
134
135#[derive(Clone, Default)]
137pub struct ProofStore {
138 eq_memo: HashCons<EqProofId, EqProof>,
139 term_memo: HashCons<TermProofId, TermProof>,
140 pub(crate) termdag: TermDag,
141}
142
143impl ProofStore {
144 pub fn print_term_proof_pretty(
146 &self,
147 term_pf: TermProofId,
148 config: &PrettyPrintConfig,
149 writer: &mut impl io::Write,
150 ) -> io::Result<()> {
151 let mut printer = PrettyPrinter::new(writer, config);
152 self.print_term_proof_with_printer(term_pf, &mut printer)
153 }
154
155 pub fn print_eq_proof_pretty(
157 &self,
158 eq_pf: EqProofId,
159 config: &PrettyPrintConfig,
160 writer: &mut impl io::Write,
161 ) -> io::Result<()> {
162 let mut printer = PrettyPrinter::new(writer, config);
163 self.print_eq_proof_with_printer(eq_pf, &mut printer)
164 }
165
166 fn print_cong_with_printer<W: io::Write>(
167 &self,
168 cong_pf: &CongProof,
169 printer: &mut PrettyPrinter<W>,
170 ) -> io::Result<()> {
171 let CongProof {
172 pf_args_eq,
173 pf_f_args_ok,
174 old_term,
175 new_term,
176 func,
177 } = cong_pf;
178 printer.write_str(&format!("Cong({func:?}, "))?;
179 self.termdag.print_term_with_printer(*old_term, printer)?;
180 printer.write_str(" => ")?;
181 self.termdag.print_term_with_printer(*new_term, printer)?;
182 printer.write_str(" by (")?;
183 printer.increase_indent();
184 for (i, pf) in pf_args_eq.iter().enumerate() {
185 if i > 0 {
186 printer.write_str(",")?;
187 }
188 printer.write_with_break(" ")?;
189 self.print_eq_proof_with_printer(*pf, printer)?;
190 }
191 printer.write_with_break(") , old term exists by: ")?;
192 printer.increase_indent();
193 printer.newline()?;
194 self.print_term_proof_with_printer(*pf_f_args_ok, printer)?;
195 printer.decrease_indent();
196 printer.write_str(")")?;
197 printer.decrease_indent();
198 Ok(())
199 }
200
201 pub fn print_eq_proof(&self, eq_pf: EqProofId, writer: &mut impl io::Write) -> io::Result<()> {
203 self.print_eq_proof_pretty(eq_pf, &PrettyPrintConfig::default(), writer)
204 }
205
206 fn print_eq_proof_with_printer<W: io::Write>(
207 &self,
208 eq_pf: EqProofId,
209 printer: &mut PrettyPrinter<W>,
210 ) -> io::Result<()> {
211 let eq_pf = self.eq_memo.lookup(eq_pf).unwrap();
212 match eq_pf {
213 EqProof::PRule {
214 rule_name,
215 subst,
216 body_pfs,
217 result_lhs,
218 result_rhs,
219 } => {
220 printer.write_str(&format!("PRule[Equality]({rule_name:?}, Subst {{"))?;
221 printer.increase_indent();
222 printer.newline()?;
223 for (i, (var, term)) in subst.iter().enumerate() {
224 if i > 0 {
225 printer.write_str(",")?;
226 }
227 printer.write_with_break(" ")?;
228 printer.write_str(&format!("{var:?} => "))?;
229 self.termdag.print_term_with_printer(*term, printer)?;
230 printer.newline()?;
231 }
232 printer.newline()?;
233 printer.write_with_break("},")?;
234 printer.newline()?;
235 printer.write_with_break("Body Pfs: [")?;
236 printer.increase_indent();
237 for (i, pf) in body_pfs.iter().enumerate() {
238 if i > 0 {
239 printer.write_str(",")?;
240 }
241 printer.write_with_break(" ")?;
242 match pf {
243 Premise::TermOk(term_pf) => {
244 printer.write_str("TermOk(")?;
245 self.print_term_proof_with_printer(*term_pf, printer)?;
246 printer.write_str(")")?;
247 }
248 Premise::Eq(eq_pf) => {
249 printer.write_str("Eq(")?;
250 self.print_eq_proof_with_printer(*eq_pf, printer)?;
251 printer.write_str(")")?;
252 }
253 }
254 }
255 printer.decrease_indent();
256 printer.write_with_break("], ")?;
257 printer.newline()?;
258 printer.write_with_break(" Result: ")?;
259 self.termdag.print_term_with_printer(*result_lhs, printer)?;
260 printer.write_str(" = ")?;
261 self.termdag.print_term_with_printer(*result_rhs, printer)?;
262 printer.write_str(")")?;
263 printer.decrease_indent();
264 }
265 EqProof::PRefl { t_ok_pf, t } => {
266 printer.write_str("PRefl(")?;
267 self.print_term_proof_with_printer(*t_ok_pf, printer)?;
268 printer.write_str(", (term= ")?;
269 self.termdag.print_term_with_printer(*t, printer)?;
270 printer.write_str("))")?
271 }
272 EqProof::PSym { eq_pf } => {
273 printer.write_str("PSym(")?;
274 self.print_eq_proof_with_printer(*eq_pf, printer)?;
275 printer.write_str(")")?
276 }
277 EqProof::PTrans { pfxy, pfyz } => {
278 printer.write_str("PTrans(")?;
279 printer.increase_indent();
280 printer.increase_indent();
281 printer.newline()?;
282 self.print_eq_proof_with_printer(*pfxy, printer)?;
283 printer.decrease_indent();
284 printer.newline()?;
285 printer.write_with_break(" ... and then ... ")?;
286 printer.increase_indent();
287 printer.newline()?;
288 self.print_eq_proof_with_printer(*pfyz, printer)?;
289 printer.decrease_indent();
290 printer.decrease_indent();
291 printer.newline()?;
292 printer.write_str(")")?
293 }
294 EqProof::PCong(cong_pf) => {
295 printer.write_str("PCong[Equality](")?;
296 self.print_cong_with_printer(cong_pf, printer)?;
297 printer.write_str(")")?
298 }
299 }
300 printer.newline()
301 }
302
303 pub fn print_term_proof(
305 &self,
306 term_pf: TermProofId,
307 writer: &mut impl io::Write,
308 ) -> io::Result<()> {
309 self.print_term_proof_pretty(term_pf, &PrettyPrintConfig::default(), writer)
310 }
311
312 fn print_term_proof_with_printer<W: io::Write>(
313 &self,
314 term_pf: TermProofId,
315 printer: &mut PrettyPrinter<W>,
316 ) -> io::Result<()> {
317 let term_pf = self.term_memo.lookup(term_pf).unwrap();
318 match term_pf {
319 TermProof::PRule {
320 rule_name,
321 subst,
322 body_pfs,
323 result,
324 } => {
325 printer.write_str(&format!("PRule[Existence]({rule_name:?}, Subst {{"))?;
326 printer.increase_indent();
327 printer.newline()?;
328 for (i, (var, term)) in subst.iter().enumerate() {
329 if i > 0 {
330 printer.write_str(",")?;
331 }
332 printer.write_with_break(" ")?;
333 printer.write_str(&format!("{var:?} => "))?;
334 self.termdag.print_term_with_printer(*term, printer)?;
335 printer.newline()?;
336 }
337 printer.newline()?;
338 printer.write_with_break("},")?;
339 printer.newline()?;
340 printer.write_with_break("Body Pfs: [")?;
341 printer.increase_indent();
342 for (i, pf) in body_pfs.iter().enumerate() {
343 if i > 0 {
344 printer.write_str(",")?;
345 }
346 printer.write_with_break(" ")?;
347 match pf {
348 Premise::TermOk(term_pf) => {
349 printer.write_str("TermOk(")?;
350 self.print_term_proof_with_printer(*term_pf, printer)?;
351 printer.write_str(")")?;
352 }
353 Premise::Eq(eq_pf) => {
354 printer.write_str("Eq(")?;
355 self.print_eq_proof_with_printer(*eq_pf, printer)?;
356 printer.write_str(")")?;
357 }
358 }
359 }
360 printer.decrease_indent();
361 printer.write_with_break("], Result: ")?;
362 self.termdag.print_term_with_printer(*result, printer)?;
363 printer.write_str(")")
364 }
365 TermProof::PProj {
366 pf_f_args_ok,
367 arg_idx,
368 } => {
369 printer.write_str("PProj(")?;
370 self.print_term_proof_with_printer(*pf_f_args_ok, printer)?;
371 printer.write_str(&format!(", {arg_idx})"))
372 }
373 TermProof::PCong(cong_pf) => {
374 printer.write_str("PCong[Exists](")?;
375 self.print_cong_with_printer(cong_pf, printer)?;
376 printer.write_str(")")
377 }
378 TermProof::PFiat { desc, term } => {
379 printer.write_str(&format!("PFiat({desc:?}"))?;
380 printer.write_str(", ")?;
381 self.termdag.print_term_with_printer(*term, printer)?;
382 printer.write_str(")")
383 }
384 }
385 }
386 pub(crate) fn intern_term(&mut self, prf: &TermProof) -> TermProofId {
387 self.term_memo.get_or_insert(prf)
388 }
389 pub(crate) fn intern_eq(&mut self, prf: &EqProof) -> EqProofId {
390 self.eq_memo.get_or_insert(prf)
391 }
392
393 pub(crate) fn refl(&mut self, proof: TermProofId, term: TermId) -> EqProofId {
394 self.intern_eq(&EqProof::PRefl {
395 t_ok_pf: proof,
396 t: term,
397 })
398 }
399
400 pub(crate) fn sym(&mut self, proof: EqProofId) -> EqProofId {
401 self.intern_eq(&EqProof::PSym { eq_pf: proof })
402 }
403
404 pub(crate) fn trans(&mut self, pfxy: EqProofId, pfyz: EqProofId) -> EqProofId {
405 self.intern_eq(&EqProof::PTrans { pfxy, pfyz })
406 }
407
408 pub(crate) fn sequence_proofs(&mut self, pfs: &[EqProofId]) -> EqProofId {
409 match pfs {
410 [] => panic!("Cannot sequence an empty list of proofs"),
411 [pf] => *pf,
412 [pf1, rest @ ..] => {
413 let mut cur = *pf1;
414 for pf in rest {
415 cur = self.trans(cur, *pf);
416 }
417 cur
418 }
419 }
420 }
421}
422
423#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
424pub enum Premise {
425 TermOk(TermProofId),
426 Eq(EqProofId),
427}
428
429#[derive(Clone, PartialEq, Eq, Hash, Debug)]
430pub struct CongProof {
431 pub pf_args_eq: Vec<EqProofId>,
432 pub pf_f_args_ok: TermProofId,
433 pub old_term: TermId,
434 pub new_term: TermId,
435 pub func: FunctionId,
436}
437
438#[allow(clippy::enum_variant_names)]
439#[derive(Clone, PartialEq, Eq, Hash, Debug)]
440pub enum TermProof {
441 PRule {
446 rule_name: Rc<str>,
447 subst: DenseIdMap<Variable, TermId>,
448 body_pfs: Vec<Premise>,
449 result: TermId,
450 },
451 PProj {
453 pf_f_args_ok: TermProofId,
454 arg_idx: usize,
455 },
456 PCong(CongProof),
457 PFiat {
458 desc: Rc<str>,
459 term: TermId,
460 },
461}
462
463#[allow(clippy::enum_variant_names)]
464#[derive(Clone, PartialEq, Eq, Hash, Debug)]
465pub enum EqProof {
466 PRule {
471 rule_name: Rc<str>,
472 subst: DenseIdMap<Variable, TermId>,
473 body_pfs: Vec<Premise>,
474 result_lhs: TermId,
475 result_rhs: TermId,
476 },
477 PRefl {
479 t_ok_pf: TermProofId,
480 t: TermId,
481 },
482 PSym {
484 eq_pf: EqProofId,
485 },
486 PTrans {
487 pfxy: EqProofId,
488 pfyz: EqProofId,
489 },
490 PCong(CongProof),
494}
495
496#[derive(Clone, Debug)]
497pub struct PrettyPrintConfig {
498 pub line_width: usize,
499 pub indent_size: usize,
500}
501
502impl Default for PrettyPrintConfig {
503 fn default() -> Self {
504 Self {
505 line_width: 512,
506 indent_size: 4,
507 }
508 }
509}
510
511struct PrettyPrinter<'w, W: io::Write> {
512 writer: &'w mut W,
513 config: &'w PrettyPrintConfig,
514 current_indent: usize,
515 current_line_pos: usize,
516}
517
518impl<'w, W: io::Write> PrettyPrinter<'w, W> {
519 fn new(writer: &'w mut W, config: &'w PrettyPrintConfig) -> Self {
520 Self {
521 writer,
522 config,
523 current_indent: 0,
524 current_line_pos: 0,
525 }
526 }
527
528 fn write_str(&mut self, s: &str) -> io::Result<()> {
529 write!(self.writer, "{s}")?;
530 self.current_line_pos += s.len();
531 Ok(())
532 }
533
534 fn newline(&mut self) -> io::Result<()> {
535 writeln!(self.writer)?;
536 self.current_line_pos = 0;
537 self.write_indent()?;
538 Ok(())
539 }
540
541 fn write_indent(&mut self) -> io::Result<()> {
542 for _ in 0..self.current_indent {
543 write!(self.writer, " ")?;
544 }
545 self.current_line_pos = self.current_indent;
546 Ok(())
547 }
548
549 fn increase_indent(&mut self) {
550 self.current_indent += self.config.indent_size;
551 }
552
553 fn decrease_indent(&mut self) {
554 self.current_indent = self.current_indent.saturating_sub(self.config.indent_size);
555 }
556
557 fn should_break(&self, additional_chars: usize) -> bool {
558 self.current_line_pos + additional_chars > self.config.line_width
559 }
560
561 fn write_with_break(&mut self, s: &str) -> io::Result<()> {
562 if self.should_break(s.len()) && self.current_line_pos > self.current_indent {
563 self.newline()?;
564 self.write_indent()?;
565 }
566 self.write_str(s)
567 }
568}