bend/fun/transform/
desugar_fold.rs

1use std::collections::HashSet;
2
3use crate::{
4  diagnostics::Diagnostics,
5  fun::{Adts, Constructors, Ctx, Definition, Name, Pattern, Rule, Source, Term},
6  maybe_grow,
7};
8
9impl Ctx<'_> {
10  /// Desugars `fold` expressions into recursive `match`es.
11  /// ```bend
12  /// foo xs =
13  ///   ...
14  ///   fold bind = init with x1 x2 {
15  ///     Type/Ctr1: (Foo bind.rec_fld bind.fld x1 x2 free_var)
16  ///     Type/Ctr2: (Bar bind.fld x1 x2)
17  ///   }
18  /// ```
19  /// Desugars to:
20  /// ```bend
21  /// foo xs =
22  ///   ...
23  ///   (foo__fold0 init x1 x2 free_var)
24  ///
25  /// foo__fold0 = @bind match bind {
26  ///   Type/Ctr1: (Foo (foo_fold0 bind.rec_fld x1 x2 free_var) bind.fld x1 x2 free_var)
27  ///   Type/Ctr2: (Bar bind.fld x1 x2)
28  /// }
29  /// ```
30  pub fn desugar_fold(&mut self) -> Result<(), Diagnostics> {
31    let mut new_defs = vec![];
32    for def in self.book.defs.values_mut() {
33      let mut fresh = 0;
34      for rule in def.rules.iter_mut() {
35        let mut ctx = DesugarFoldCtx {
36          def_name: &def.name,
37          fresh: &mut fresh,
38          new_defs: &mut new_defs,
39          ctrs: &self.book.ctrs,
40          adts: &self.book.adts,
41          source: def.source.clone(),
42          check: def.check,
43        };
44
45        let res = rule.body.desugar_fold(&mut ctx);
46        if let Err(e) = res {
47          self.info.add_function_error(e, def.name.clone(), def.source.clone());
48        }
49      }
50    }
51
52    self.book.defs.extend(new_defs.into_iter().map(|def| (def.name.clone(), def)));
53
54    self.info.fatal(())
55  }
56}
57
58struct DesugarFoldCtx<'a> {
59  pub def_name: &'a Name,
60  pub fresh: &'a mut usize,
61  pub new_defs: &'a mut Vec<Definition>,
62  pub ctrs: &'a Constructors,
63  pub adts: &'a Adts,
64  pub source: Source,
65  pub check: bool,
66}
67
68impl Term {
69  fn desugar_fold(&mut self, ctx: &mut DesugarFoldCtx<'_>) -> Result<(), String> {
70    maybe_grow(|| {
71      for child in self.children_mut() {
72        child.desugar_fold(ctx)?;
73      }
74
75      if let Term::Fold { .. } = self {
76        // Can't have unmatched unscoped because this'll be extracted
77        if self.has_unscoped_diff() {
78          return Err("Can't have non self-contained unscoped variables in a 'fold'".into());
79        }
80        let Term::Fold { bnd: _, arg, with_bnd, with_arg, arms } = self else { unreachable!() };
81
82        // Gather the free variables
83        let mut free_vars = HashSet::new();
84        for arm in arms.iter() {
85          let mut arm_free_vars = arm.2.free_vars().into_keys().collect::<HashSet<_>>();
86          for field in arm.1.iter().flatten() {
87            arm_free_vars.remove(field);
88          }
89          free_vars.extend(arm_free_vars);
90        }
91        for var in with_bnd.iter().flatten() {
92          free_vars.remove(var);
93        }
94        let free_vars = free_vars.into_iter().collect::<Vec<_>>();
95
96        let new_nam = Name::new(format!("{}__fold{}", ctx.def_name, ctx.fresh));
97        *ctx.fresh += 1;
98
99        // Substitute the implicit recursive calls to call the new function
100        let ctr = arms[0].0.as_ref().unwrap();
101        let adt_nam = ctx.ctrs.get(ctr).unwrap();
102        let ctrs = &ctx.adts.get(adt_nam).unwrap().ctrs;
103        for arm in arms.iter_mut() {
104          let ctr = arm.0.as_ref().unwrap();
105          let recursive = arm
106            .1
107            .iter()
108            .zip(&ctrs.get(ctr).unwrap().fields)
109            .filter_map(|(var, field)| if field.rec { Some(var.as_ref().unwrap().clone()) } else { None })
110            .collect::<HashSet<_>>();
111          arm.2.call_recursive(&new_nam, &recursive, &free_vars);
112        }
113
114        // Create the new function
115        let x_nam = Name::new("%x");
116        let body = Term::Mat {
117          arg: Box::new(Term::Var { nam: x_nam.clone() }),
118          bnd: None,
119          with_bnd: with_bnd.clone(),
120          with_arg: with_bnd.iter().map(|nam| Term::var_or_era(nam.clone())).collect(),
121          arms: std::mem::take(arms),
122        };
123        let body = Term::rfold_lams(body, with_bnd.iter().cloned());
124        let body = Term::rfold_lams(body, free_vars.iter().map(|nam| Some(nam.clone())));
125        let body = Term::lam(Pattern::Var(Some(x_nam)), body);
126
127        let def = Definition::new_gen(
128          new_nam.clone(),
129          vec![Rule { pats: vec![], body }],
130          ctx.source.clone(),
131          ctx.check,
132        );
133        ctx.new_defs.push(def);
134
135        // Call the new function
136        let call = Term::call(Term::Ref { nam: new_nam.clone() }, [std::mem::take(arg.as_mut())]);
137        let call = Term::call(call, free_vars.iter().cloned().map(|nam| Term::Var { nam }));
138        let call = Term::call(call, with_arg.iter().cloned());
139        *self = call;
140      }
141      Ok(())
142    })
143  }
144
145  fn call_recursive(&mut self, def_name: &Name, recursive: &HashSet<Name>, free_vars: &[Name]) {
146    maybe_grow(|| {
147      for child in self.children_mut() {
148        child.call_recursive(def_name, recursive, free_vars);
149      }
150
151      // If we found a recursive field, replace with a call to the new function.
152      if let Term::Var { nam } = self {
153        if recursive.contains(nam) {
154          let call = Term::call(Term::Ref { nam: def_name.clone() }, [std::mem::take(self)]);
155          let call = Term::call(call, free_vars.iter().cloned().map(|nam| Term::Var { nam }));
156          *self = call;
157        }
158      }
159    })
160  }
161}