bend/fun/transform/
encode_adts.rs

1use crate::{
2  fun::{Book, Definition, Name, Num, Pattern, Rule, Source, Term},
3  AdtEncoding,
4};
5
6impl Book {
7  /// Defines a function for each constructor in each ADT in the book.
8  pub fn encode_adts(&mut self, adt_encoding: AdtEncoding) {
9    let mut defs = vec![];
10
11    for (_, adt) in self.adts.iter() {
12      for (ctr_idx, (ctr_name, ctr)) in adt.ctrs.iter().enumerate() {
13        let ctrs: Vec<_> = adt.ctrs.keys().cloned().collect();
14
15        let body = match adt_encoding {
16          AdtEncoding::Scott => encode_ctr_scott(ctr.fields.iter().map(|f| &f.nam), ctrs, ctr_name),
17          AdtEncoding::NumScott => {
18            let tag = encode_num_scott_tag(ctr_idx as u32, ctr_name, adt.source.clone());
19            defs.push((tag.name.clone(), tag.clone()));
20            encode_ctr_num_scott(ctr.fields.iter().map(|f| &f.nam), &tag.name)
21          }
22        };
23
24        let rules = vec![Rule { pats: vec![], body }];
25        let def = Definition {
26          name: ctr_name.clone(),
27          typ: ctr.typ.clone(),
28          check: true,
29          rules,
30          source: adt.source.clone(),
31        };
32        defs.push((ctr_name.clone(), def));
33      }
34    }
35    self.defs.extend(defs);
36  }
37}
38
39fn encode_ctr_scott<'a>(
40  ctr_args: impl DoubleEndedIterator<Item = &'a Name> + Clone,
41  ctrs: Vec<Name>,
42  ctr_name: &Name,
43) -> Term {
44  let ctr = Term::Var { nam: ctr_name.clone() };
45  let app = Term::call(ctr, ctr_args.clone().cloned().map(|nam| Term::Var { nam }));
46  let lam = Term::rfold_lams(app, ctrs.into_iter().map(Some));
47  Term::rfold_lams(lam, ctr_args.cloned().map(Some))
48}
49
50fn encode_ctr_num_scott<'a>(ctr_args: impl DoubleEndedIterator<Item = &'a Name> + Clone, tag: &str) -> Term {
51  let nam = Name::new("%x");
52  // λa1 .. λan λx (x TAG a1 .. an)
53  let term = Term::Var { nam: nam.clone() };
54  let tag = Term::r#ref(tag);
55  let term = Term::app(term, tag);
56  let term = Term::call(term, ctr_args.clone().cloned().map(|nam| Term::Var { nam }));
57  let term = Term::lam(Pattern::Var(Some(nam)), term);
58  Term::rfold_lams(term, ctr_args.cloned().map(Some))
59}
60
61fn encode_num_scott_tag(tag: u32, ctr_name: &Name, source: Source) -> Definition {
62  let tag_nam = Name::new(format!("{ctr_name}/tag"));
63  let rules = vec![Rule { pats: vec![], body: Term::Num { val: Num::U24(tag) } }];
64  Definition::new_gen(tag_nam.clone(), rules, source, true)
65}