bend/fun/transform/
desugar_bend.rs

1use crate::{
2  diagnostics::Diagnostics,
3  fun::{Ctx, Definition, Name, Rule, Source, Term},
4  maybe_grow,
5};
6use indexmap::IndexMap;
7
8pub const RECURSIVE_KW: &str = "fork";
9const NEW_FN_SEP: &str = "__bend";
10
11impl Ctx<'_> {
12  pub fn desugar_bend(&mut self) -> Result<(), Diagnostics> {
13    let mut new_defs = IndexMap::new();
14    for def in self.book.defs.values_mut() {
15      let mut fresh = 0;
16      for rule in def.rules.iter_mut() {
17        if let Err(err) =
18          rule.body.desugar_bend(&def.name, &mut fresh, &mut new_defs, def.source.clone(), def.check)
19        {
20          self.info.add_function_error(err, def.name.clone(), def.source.clone());
21          break;
22        }
23      }
24    }
25
26    self.book.defs.extend(new_defs);
27
28    self.info.fatal(())
29  }
30}
31
32impl Term {
33  fn desugar_bend(
34    &mut self,
35    def_name: &Name,
36    fresh: &mut usize,
37    new_defs: &mut IndexMap<Name, Definition>,
38    source: Source,
39    check: bool,
40  ) -> Result<(), String> {
41    maybe_grow(|| {
42      // Recursively encode bends in the children
43      for child in self.children_mut() {
44        child.desugar_bend(def_name, fresh, new_defs, source.clone(), check)?;
45      }
46
47      // Convert a bend into a new recursive function and call it.
48      if let Term::Bend { .. } = self {
49        // Can't have unmatched unscoped because this'll be extracted
50        if self.has_unscoped_diff() {
51          return Err("Can't have non self-contained unscoped variables in a 'bend'".into());
52        }
53        let Term::Bend { bnd, arg, cond, step, base } = self else { unreachable!() };
54
55        let new_nam = Name::new(format!("{}{}{}", def_name, NEW_FN_SEP, fresh));
56        *fresh += 1;
57
58        // Gather the free variables
59        // They will be implicitly captured by the new function
60        let mut free_vars = step.free_vars();
61        free_vars.shift_remove(&Name::new(RECURSIVE_KW));
62        free_vars.extend(base.free_vars());
63        free_vars.extend(cond.free_vars());
64        for bnd in bnd.iter().flatten() {
65          free_vars.shift_remove(bnd);
66        }
67        let free_vars = free_vars.into_keys().collect::<Vec<_>>();
68
69        // Add a substitution of `fork`, a use term with a partially applied recursive call
70        let step = Term::Use {
71          nam: Some(Name::new(RECURSIVE_KW)),
72          val: Box::new(Term::call(
73            Term::Ref { nam: new_nam.clone() },
74            free_vars.iter().cloned().map(|nam| Term::Var { nam }),
75          )),
76          nxt: Box::new(std::mem::take(step.as_mut())),
77        };
78
79        // Create the function body for the bend.
80        let body = Term::Swt {
81          arg: Box::new(std::mem::take(cond)),
82          bnd: Some(Name::new("_")),
83          with_bnd: vec![],
84          with_arg: vec![],
85          pred: Some(Name::new("_-1")),
86          arms: vec![std::mem::take(base.as_mut()), step],
87        };
88        let body = Term::rfold_lams(body, std::mem::take(bnd).into_iter());
89        let body = Term::rfold_lams(body, free_vars.iter().cloned().map(Some));
90
91        // Make a definition from the new function
92        let def = Definition::new_gen(new_nam.clone(), vec![Rule { pats: vec![], body }], source, check);
93        new_defs.insert(new_nam.clone(), def);
94
95        // Call the new function in the original term.
96        let call =
97          Term::call(Term::Ref { nam: new_nam }, free_vars.iter().map(|v| Term::Var { nam: v.clone() }));
98        *self = Term::call(call, arg.drain(..));
99      }
100
101      Ok(())
102    })
103  }
104}