bend/fun/transform/
desugar_with_blocks.rs

1use crate::{
2  diagnostics::Diagnostics,
3  fun::{Ctx, Name, Pattern, Term},
4  maybe_grow,
5};
6use std::collections::HashSet;
7
8impl Ctx<'_> {
9  /// Converts `ask` terms inside `with` blocks into calls to a monadic bind operation.
10  pub fn desugar_with_blocks(&mut self) -> Result<(), Diagnostics> {
11    let def_names = self.book.defs.keys().cloned().collect::<HashSet<_>>();
12
13    for def in self.book.defs.values_mut() {
14      for rule in def.rules.iter_mut() {
15        if let Err(e) = rule.body.desugar_with_blocks(None, &def_names) {
16          self.info.add_function_error(e, def.name.clone(), def.source.clone());
17        }
18      }
19    }
20
21    self.info.fatal(())
22  }
23}
24
25impl Term {
26  pub fn desugar_with_blocks(
27    &mut self,
28    cur_block: Option<&Name>,
29    def_names: &HashSet<Name>,
30  ) -> Result<(), String> {
31    maybe_grow(|| {
32      if let Term::With { typ, bod } = self {
33        bod.desugar_with_blocks(Some(typ), def_names)?;
34        let wrap_ref = Term::r#ref(&format!("{typ}/wrap"));
35        *self = Term::Use { nam: Some(Name::new("wrap")), val: Box::new(wrap_ref), nxt: std::mem::take(bod) };
36      }
37
38      if let Term::Ask { pat, val, nxt } = self {
39        if let Some(typ) = cur_block {
40          let bind_nam = Name::new(format!("{typ}/bind"));
41
42          if def_names.contains(&bind_nam) {
43            let nxt = Term::lam(*pat.clone(), std::mem::take(nxt));
44            let nxt = nxt.defer();
45
46            *self = Term::call(Term::Ref { nam: bind_nam }, [*val.clone(), nxt]);
47          } else {
48            return Err(format!("Could not find definition {bind_nam} for type {typ}."));
49          }
50        } else {
51          return Err(format!("Monadic bind operation '{pat} <- ...' used outside of a `do` block."));
52        }
53      }
54
55      for children in self.children_mut() {
56        children.desugar_with_blocks(cur_block, def_names)?;
57      }
58
59      Ok(())
60    })
61  }
62
63  /// Converts a term with free vars `(f x1 .. xn)` into a deferred
64  /// call that passes those vars to the term.
65  ///
66  /// Ex: `(f x1 .. xn)` becomes `@x (x @x1 .. @xn (f x1 .. x2) x1 .. x2)`.
67  ///
68  /// The user must call this lazy thunk by calling the builtin
69  /// `undefer` function, or by applying `@x x` to the term.
70  fn defer(self) -> Term {
71    let free_vars = self.free_vars().into_keys().collect::<Vec<_>>();
72    let term = Term::rfold_lams(self, free_vars.iter().cloned().map(Some));
73    let term = Term::call(Term::Var { nam: Name::new("%x") }, [term]);
74    let term = Term::call(term, free_vars.iter().cloned().map(|nam| Term::Var { nam }));
75    Term::lam(Pattern::Var(Some(Name::new("%x"))), term)
76  }
77}