bend/fun/transform/
encode_match_terms.rs

1use crate::{
2  fun::{Book, MatchRule, Name, Pattern, Term},
3  maybe_grow, AdtEncoding,
4};
5
6impl Book {
7  /// Encodes pattern matching expressions in the book into their
8  /// core form. Must be run after [`Ctr::fix_match_terms`].
9  ///
10  /// ADT matches are encoded based on `adt_encoding`.
11  ///
12  /// Num matches are encoded as a sequence of native num matches (on 0 and 1+).
13  pub fn encode_matches(&mut self, adt_encoding: AdtEncoding) {
14    for def in self.defs.values_mut() {
15      for rule in &mut def.rules {
16        rule.body.encode_matches(adt_encoding);
17      }
18    }
19  }
20}
21
22impl Term {
23  pub fn encode_matches(&mut self, adt_encoding: AdtEncoding) {
24    maybe_grow(|| {
25      for child in self.children_mut() {
26        child.encode_matches(adt_encoding)
27      }
28
29      if let Term::Mat { arg, bnd: _, with_bnd, with_arg, arms } = self {
30        assert!(with_bnd.is_empty());
31        assert!(with_arg.is_empty());
32        let arg = std::mem::take(arg.as_mut());
33        let rules = std::mem::take(arms);
34        *self = encode_match(arg, rules, adt_encoding);
35      } else if let Term::Swt { arg, bnd: _, with_bnd, with_arg, pred, arms } = self {
36        assert!(with_bnd.is_empty());
37        assert!(with_arg.is_empty());
38        let arg = std::mem::take(arg.as_mut());
39        let pred = std::mem::take(pred);
40        let rules = std::mem::take(arms);
41        *self = encode_switch(arg, pred, rules);
42      }
43    })
44  }
45}
46
47fn encode_match(arg: Term, rules: Vec<MatchRule>, adt_encoding: AdtEncoding) -> Term {
48  match adt_encoding {
49    AdtEncoding::Scott => {
50      let arms = rules.into_iter().map(|rule| Term::rfold_lams(rule.2, rule.1.into_iter()));
51      Term::call(arg, arms)
52    }
53    AdtEncoding::NumScott => {
54      fn make_switches(arms: &mut [Term]) -> Term {
55        maybe_grow(|| match arms {
56          [] => Term::Err,
57          [arm] => Term::lam(Pattern::Var(None), std::mem::take(arm)),
58          [arm, rest @ ..] => Term::lam(
59            Pattern::Var(Some(Name::new("%tag"))),
60            Term::Swt {
61              arg: Box::new(Term::Var { nam: Name::new("%tag") }),
62              bnd: None,
63              with_bnd: vec![],
64              with_arg: vec![],
65              pred: None,
66              arms: vec![std::mem::take(arm), make_switches(rest)],
67            },
68          ),
69        })
70      }
71      let mut arms =
72        rules.into_iter().map(|rule| Term::rfold_lams(rule.2, rule.1.into_iter())).collect::<Vec<_>>();
73      let term = if arms.len() == 1 {
74        // λx (x λtag switch tag {0: Ctr0; _: * })
75        let arm = arms.pop().unwrap();
76        let term = Term::Swt {
77          arg: Box::new(Term::Var { nam: Name::new("%tag") }),
78          bnd: None,
79          with_bnd: vec![],
80          with_arg: vec![],
81          pred: None,
82          arms: vec![arm, Term::Era],
83        };
84        Term::lam(Pattern::Var(Some(Name::new("%tag"))), term)
85      } else {
86        // λx (x λtag switch tag {0: Ctr0; _: switch tag-1 { ... } })
87        make_switches(arms.as_mut_slice())
88      };
89      Term::call(arg, [term])
90    }
91  }
92}
93
94/// Convert into a sequence of native switches, decrementing by 1 each switch.
95/// switch n {0: A; 1: B; _: (C n-2)} converted to
96/// switch n {0: A; _: @%x match %x {0: B; _: @n-2 (C n-2)}}
97fn encode_switch(arg: Term, pred: Option<Name>, mut rules: Vec<Term>) -> Term {
98  // Create the cascade of switches
99  let match_var = Name::new("%x");
100  let (succ, nums) = rules.split_last_mut().unwrap();
101  let last_arm = Term::lam(Pattern::Var(pred), std::mem::take(succ));
102  nums.iter_mut().enumerate().rfold(last_arm, |term, (i, rule)| {
103    let arms = vec![std::mem::take(rule), term];
104    if i == 0 {
105      Term::Swt {
106        arg: Box::new(arg.clone()),
107        bnd: None,
108        with_bnd: vec![],
109        with_arg: vec![],
110        pred: None,
111        arms,
112      }
113    } else {
114      let swt = Term::Swt {
115        arg: Box::new(Term::Var { nam: match_var.clone() }),
116        bnd: None,
117        with_bnd: vec![],
118        with_arg: vec![],
119        pred: None,
120        arms,
121      };
122      Term::lam(Pattern::Var(Some(match_var.clone())), swt)
123    }
124  })
125}