bend/imp/
to_fun.rs

1use super::{AssignPattern, Definition, Expr, InPlaceOp, Stmt};
2use crate::{
3  diagnostics::Diagnostics,
4  fun::{
5    self,
6    builtins::{LCONS, LNIL},
7    parser::ParseBook,
8    Book, Name,
9  },
10};
11
12impl ParseBook {
13  // TODO: Change all functions to return diagnostics
14  pub fn to_fun(mut self) -> Result<Book, Diagnostics> {
15    for (name, mut def) in std::mem::take(&mut self.imp_defs) {
16      def.order_kwargs(&self)?;
17      def.gen_map_get();
18
19      if self.fun_defs.contains_key(&name) {
20        panic!("Def names collision should be checked at parse time")
21      }
22
23      self.fun_defs.insert(name, def.to_fun()?);
24    }
25
26    let ParseBook { fun_defs: defs, hvm_defs, adts, ctrs, import_ctx, .. } = self;
27    Ok(Book { defs, hvm_defs, adts, ctrs, entrypoint: None, imports: import_ctx.to_imports() })
28  }
29}
30
31impl Definition {
32  pub fn to_fun(self) -> Result<fun::Definition, Diagnostics> {
33    let body = self.body.into_fun().map_err(|e| {
34      let mut diags = Diagnostics::default();
35      diags.add_function_error(e, self.name.clone(), self.source.clone());
36      diags
37    })?;
38
39    let body = match body {
40      StmtToFun::Return(term) => term,
41      StmtToFun::Assign(..) => {
42        let mut diags = Diagnostics::default();
43        diags.add_function_error(
44          "Function doesn't end with a return statement",
45          self.name,
46          self.source.clone(),
47        );
48        return Err(diags);
49      }
50    };
51
52    let rule =
53      fun::Rule { pats: self.args.into_iter().map(|param| fun::Pattern::Var(Some(param))).collect(), body };
54
55    let def = fun::Definition {
56      name: self.name,
57      typ: self.typ,
58      check: self.check,
59      rules: vec![rule],
60      source: self.source,
61    };
62    Ok(def)
63  }
64}
65
66impl AssignPattern {
67  pub fn into_fun(self) -> fun::Pattern {
68    match self {
69      AssignPattern::Eraser => fun::Pattern::Var(None),
70      AssignPattern::Var(name) => fun::Pattern::Var(Some(name)),
71      AssignPattern::Chn(name) => fun::Pattern::Chn(name),
72      AssignPattern::Tup(names) => fun::Pattern::Fan(
73        fun::FanKind::Tup,
74        fun::Tag::Static,
75        names.into_iter().map(Self::into_fun).collect(),
76      ),
77      AssignPattern::Sup(names) => {
78        fun::Pattern::Fan(fun::FanKind::Dup, fun::Tag::Auto, names.into_iter().map(Self::into_fun).collect())
79      }
80      AssignPattern::MapSet(..) => unreachable!(),
81    }
82  }
83}
84
85#[derive(Debug)]
86enum StmtToFun {
87  Return(fun::Term),
88  Assign(bool, fun::Pattern, fun::Term),
89}
90
91fn take(t: Stmt) -> Result<(bool, Option<fun::Pattern>, fun::Term), String> {
92  match t.into_fun()? {
93    StmtToFun::Return(ret) => Ok((false, None, ret)),
94    StmtToFun::Assign(x, pat, val) => Ok((x, Some(pat), val)),
95  }
96}
97
98fn wrap(nxt: Option<fun::Pattern>, term: fun::Term, ask: bool) -> StmtToFun {
99  if let Some(pat) = nxt {
100    StmtToFun::Assign(ask, pat, term)
101  } else {
102    StmtToFun::Return(term)
103  }
104}
105
106impl Stmt {
107  fn into_fun(self) -> Result<StmtToFun, String> {
108    // TODO: Refactor this to not repeat everything.
109    // TODO: When we have an error with an assignment, we should show the offending assignment (eg. "{pat} = ...").
110    let stmt_to_fun = match self {
111      Stmt::Assign { pat: AssignPattern::MapSet(map, key), val, nxt: Some(nxt) } => {
112        let (ask, nxt_pat, nxt) = take(*nxt)?;
113        let term = fun::Term::Let {
114          pat: Box::new(fun::Pattern::Var(Some(map.clone()))),
115          val: Box::new(fun::Term::call(
116            fun::Term::Ref { nam: fun::Name::new("Map/set") },
117            [fun::Term::Var { nam: map }, key.to_fun(), val.to_fun()],
118          )),
119          nxt: Box::new(nxt),
120        };
121        wrap(nxt_pat, term, ask)
122      }
123      Stmt::Assign { pat: AssignPattern::MapSet(..), val: _, nxt: None } => {
124        return Err("Branch ends with map assignment.".to_string())?;
125      }
126      Stmt::Assign { pat, val, nxt: Some(nxt) } => {
127        let pat = pat.into_fun();
128        let val = val.to_fun();
129        let (ask, nxt_pat, nxt) = take(*nxt)?;
130        let term = fun::Term::Let { pat: Box::new(pat), val: Box::new(val), nxt: Box::new(nxt) };
131        wrap(nxt_pat, term, ask)
132      }
133      Stmt::Assign { pat, val, nxt: None } => {
134        let pat = pat.into_fun();
135        let val = val.to_fun();
136        StmtToFun::Assign(false, pat, val)
137      }
138      Stmt::InPlace { op, pat, val, nxt } => {
139        let (ask, nxt_pat, nxt) = take(*nxt)?;
140        // if it is a mapper operation
141        if let InPlaceOp::Map = op {
142          let term = match &*pat {
143            AssignPattern::MapSet(map, key) => {
144              let rhs = fun::Term::call(
145                fun::Term::r#ref("Map/map"),
146                [fun::Term::Var { nam: map.clone() }, key.clone().to_fun(), val.clone().to_fun()],
147              );
148              fun::Term::Let {
149                pat: Box::new(fun::Pattern::Var(Some(map.clone()))),
150                val: Box::new(rhs),
151                nxt: Box::new(nxt),
152              }
153            }
154            _ => {
155              let rhs = fun::Term::call(val.to_fun(), [pat.clone().into_fun().to_term()]);
156              fun::Term::Let { pat: Box::new(pat.into_fun()), val: Box::new(rhs), nxt: Box::new(nxt) }
157            }
158          };
159
160          return Ok(wrap(nxt_pat, term, ask));
161        }
162
163        // otherwise
164        match *pat {
165          AssignPattern::Var(var) => {
166            let term = fun::Term::Let {
167              pat: Box::new(fun::Pattern::Var(Some(var.clone()))),
168              val: Box::new(fun::Term::Oper {
169                opr: op.to_lang_op(),
170                fst: Box::new(fun::Term::Var { nam: var }),
171                snd: Box::new(val.to_fun()),
172              }),
173              nxt: Box::new(nxt),
174            };
175            wrap(nxt_pat, term, ask)
176          }
177          AssignPattern::MapSet(map, key) => {
178            let temp = Name::new("%0");
179            let partial =
180              Expr::Opr { op: op.to_lang_op(), lhs: Box::new(Expr::Var { nam: temp.clone() }), rhs: val };
181            let map_fn = Expr::Lam { names: vec![(temp, false)], bod: Box::new(partial) };
182            let map_term = fun::Term::call(
183              fun::Term::r#ref("Map/map"),
184              [fun::Term::Var { nam: map.clone() }, key.to_fun(), map_fn.to_fun()],
185            );
186            let term = fun::Term::Let {
187              pat: Box::new(fun::Pattern::Var(Some(map))),
188              val: Box::new(map_term),
189              nxt: Box::new(nxt),
190            };
191            wrap(nxt_pat, term, ask)
192          }
193          _ => unreachable!(),
194        }
195      }
196      Stmt::If { cond, then, otherwise, nxt } => {
197        let (ask, pat, then, else_) = match (then.into_fun()?, otherwise.into_fun()?) {
198          (StmtToFun::Return(t), StmtToFun::Return(e)) => (false, None, t, e),
199          (StmtToFun::Assign(ask, tp, t), StmtToFun::Assign(ask_, ep, e)) if tp == ep => {
200            (ask && ask_, Some(tp), t, e)
201          }
202          (StmtToFun::Assign(..), StmtToFun::Assign(..)) => {
203            return Err("'if' branches end with different assignments.".to_string())?;
204          }
205          (StmtToFun::Return(..), StmtToFun::Assign(..)) => {
206            return Err(
207              "Expected 'else' branch from 'if' to return, but it ends with assignment.".to_string(),
208            )?;
209          }
210          (StmtToFun::Assign(..), StmtToFun::Return(..)) => {
211            return Err(
212              "Expected 'else' branch from 'if' to end with assignment, but it returns.".to_string(),
213            )?;
214          }
215        };
216        let arms = vec![else_, then];
217        let term = fun::Term::Swt {
218          arg: Box::new(cond.to_fun()),
219          bnd: Some(Name::new("%pred")),
220          with_bnd: vec![],
221          with_arg: vec![],
222          pred: Some(Name::new("%pred-1")),
223          arms,
224        };
225        wrap_nxt_assign_stmt(term, nxt, pat, ask)?
226      }
227      Stmt::Match { arg, bnd, with_bnd, with_arg, arms, nxt } => {
228        let arg = arg.to_fun();
229        let mut fun_arms = vec![];
230        let mut arms = arms.into_iter();
231        let fst = arms.next().unwrap();
232        let (fst_ask, fst_pat, fst_rgt) = take(fst.rgt)?;
233        let with_arg = with_arg.into_iter().map(Expr::to_fun).collect();
234        fun_arms.push((fst.lft, vec![], fst_rgt));
235        for arm in arms {
236          let (arm_ask, arm_pat, arm_rgt) = take(arm.rgt)?;
237          match (&arm_pat, &fst_pat) {
238            (Some(arm_pat), Some(fst_pat)) if arm_pat != fst_pat || arm_ask != fst_ask => {
239              return Err("'match' arms end with different assignments.".to_string())?;
240            }
241            (Some(_), None) => {
242              return Err("Expected 'match' arms to end with assignment, but it returns.".to_string())?;
243            }
244            (None, Some(_)) => {
245              return Err("Expected 'match' arms to return, but it ends with assignment.".to_string())?;
246            }
247            (Some(_), Some(_)) => fun_arms.push((arm.lft, vec![], arm_rgt)),
248            (None, None) => fun_arms.push((arm.lft, vec![], arm_rgt)),
249          }
250        }
251        let term = fun::Term::Mat { arg: Box::new(arg), bnd, with_bnd, with_arg, arms: fun_arms };
252        wrap_nxt_assign_stmt(term, nxt, fst_pat, fst_ask)?
253      }
254      Stmt::Switch { arg, bnd, with_bnd, with_arg, arms, nxt } => {
255        let arg = arg.to_fun();
256        let mut fun_arms = vec![];
257        let mut arms = arms.into_iter();
258        let fst = arms.next().unwrap();
259        let (fst_ask, fst_pat, fst) = take(fst)?;
260        let with_arg = with_arg.into_iter().map(Expr::to_fun).collect();
261        fun_arms.push(fst);
262        for arm in arms {
263          let (arm_ask, arm_pat, arm) = take(arm)?;
264          match (&arm_pat, &fst_pat) {
265            (Some(arm_pat), Some(fst_pat)) if arm_pat != fst_pat || arm_ask != fst_ask => {
266              return Err("'switch' arms end with different assignments.".to_string())?;
267            }
268            (Some(_), None) => {
269              return Err("Expected 'switch' arms to end with assignment, but it returns.".to_string())?;
270            }
271            (None, Some(_)) => {
272              return Err("Expected 'switch' arms to return, but it ends with assignment.".to_string())?;
273            }
274            (Some(_), Some(_)) => fun_arms.push(arm),
275            (None, None) => fun_arms.push(arm),
276          }
277        }
278        let pred = Some(Name::new(format!("{}-{}", bnd.clone().unwrap(), fun_arms.len() - 1)));
279        let term = fun::Term::Swt { arg: Box::new(arg), bnd, with_bnd, with_arg, pred, arms: fun_arms };
280        wrap_nxt_assign_stmt(term, nxt, fst_pat, fst_ask)?
281      }
282      Stmt::Fold { arg, bnd, with_bnd, with_arg, arms, nxt } => {
283        let arg = arg.to_fun();
284        let mut fun_arms = vec![];
285        let mut arms = arms.into_iter();
286        let fst = arms.next().unwrap();
287        let (fst_ask, fst_pat, fst_rgt) = take(fst.rgt)?;
288        fun_arms.push((fst.lft, vec![], fst_rgt));
289        let with_arg = with_arg.into_iter().map(Expr::to_fun).collect();
290        for arm in arms {
291          let (arm_ask, arm_pat, arm_rgt) = take(arm.rgt)?;
292          match (&arm_pat, &fst_pat) {
293            (Some(arm_pat), Some(fst_pat)) if arm_pat != fst_pat || arm_ask != fst_ask => {
294              return Err("'fold' arms end with different assignments.".to_string())?;
295            }
296            (Some(_), None) => {
297              return Err("Expected 'fold' arms to end with assignment, but it returns.".to_string())?;
298            }
299            (None, Some(_)) => {
300              return Err("Expected 'fold' arms to return, but it ends with assignment.".to_string())?;
301            }
302            (Some(_), Some(_)) => fun_arms.push((arm.lft, vec![], arm_rgt)),
303            (None, None) => fun_arms.push((arm.lft, vec![], arm_rgt)),
304          }
305        }
306        let term = fun::Term::Fold { arg: Box::new(arg), bnd, with_bnd, with_arg, arms: fun_arms };
307        wrap_nxt_assign_stmt(term, nxt, fst_pat, fst_ask)?
308      }
309      Stmt::Bend { bnd, arg, cond, step, base, nxt } => {
310        let arg = arg.into_iter().map(Expr::to_fun).collect();
311        let cond = cond.to_fun();
312        let (ask, pat, step, base) = match (step.into_fun()?, base.into_fun()?) {
313          (StmtToFun::Return(s), StmtToFun::Return(b)) => (false, None, s, b),
314          (StmtToFun::Assign(aa, sp, s), StmtToFun::Assign(ba, bp, b)) if sp == bp => {
315            (aa && ba, Some(sp), s, b)
316          }
317          (StmtToFun::Assign(..), StmtToFun::Assign(..)) => {
318            return Err("'bend' branches end with different assignments.".to_string())?;
319          }
320          (StmtToFun::Return(..), StmtToFun::Assign(..)) => {
321            return Err(
322              "Expected 'else' branch from 'bend' to return, but it ends with assignment.".to_string(),
323            )?;
324          }
325          (StmtToFun::Assign(..), StmtToFun::Return(..)) => {
326            return Err(
327              "Expected 'else' branch from 'bend' to end with assignment, but it returns.".to_string(),
328            )?;
329          }
330        };
331        let term =
332          fun::Term::Bend { bnd, arg, cond: Box::new(cond), step: Box::new(step), base: Box::new(base) };
333        wrap_nxt_assign_stmt(term, nxt, pat, ask)?
334      }
335      Stmt::With { typ, bod, nxt } => {
336        let (ask, pat, bod) = take(*bod)?;
337        let term = fun::Term::With { typ, bod: Box::new(bod) };
338        wrap_nxt_assign_stmt(term, nxt, pat, ask)?
339      }
340      Stmt::Ask { pat, val, nxt: Some(nxt) } => {
341        let (ask, nxt_pat, nxt) = take(*nxt)?;
342        let term =
343          fun::Term::Ask { pat: Box::new(pat.into_fun()), val: Box::new(val.to_fun()), nxt: Box::new(nxt) };
344        wrap(nxt_pat, term, ask)
345      }
346      Stmt::Ask { pat, val, nxt: None } => {
347        let pat = pat.into_fun();
348        let val = val.to_fun();
349        StmtToFun::Assign(true, pat, val)
350      }
351      Stmt::Open { typ, var, nxt } => {
352        let (ask, nxt_pat, nxt) = take(*nxt)?;
353        let term = fun::Term::Open { typ, var, bod: Box::new(nxt) };
354        wrap(nxt_pat, term, ask)
355      }
356      Stmt::Use { nam, val, nxt } => {
357        let (ask, nxt_pat, nxt) = take(*nxt)?;
358        let term = fun::Term::Use { nam: Some(nam), val: Box::new(val.to_fun()), nxt: Box::new(nxt) };
359        wrap(nxt_pat, term, ask)
360      }
361      Stmt::Return { term } => StmtToFun::Return(term.to_fun()),
362      Stmt::LocalDef { def, nxt } => {
363        let (ask, nxt_pat, nxt) = take(*nxt)?;
364        let def = def.to_fun().map_err(|e| e.display_only_messages().to_string())?;
365        let term = fun::Term::Def { def, nxt: Box::new(nxt) };
366        wrap(nxt_pat, term, ask)
367      }
368      Stmt::Err => unreachable!(),
369    };
370    Ok(stmt_to_fun)
371  }
372}
373
374impl Expr {
375  pub fn to_fun(self) -> fun::Term {
376    match self {
377      Expr::Era => fun::Term::Era,
378      Expr::Var { nam } => fun::Term::Var { nam },
379      Expr::Chn { nam } => fun::Term::Link { nam },
380      Expr::Num { val } => fun::Term::Num { val },
381      Expr::Call { fun, args, kwargs } => {
382        assert!(kwargs.is_empty());
383        let args = args.into_iter().map(Self::to_fun);
384        fun::Term::call(fun.to_fun(), args)
385      }
386      Expr::Lam { names, bod } => names.into_iter().rfold(bod.to_fun(), |acc, (name, link)| fun::Term::Lam {
387        tag: fun::Tag::Static,
388        pat: Box::new(if link { fun::Pattern::Chn(name) } else { fun::Pattern::Var(Some(name)) }),
389        bod: Box::new(acc),
390      }),
391      Expr::Opr { op, lhs, rhs } => {
392        fun::Term::Oper { opr: op, fst: Box::new(lhs.to_fun()), snd: Box::new(rhs.to_fun()) }
393      }
394      Expr::Str { val } => fun::Term::Str { val },
395      Expr::Lst { els } => fun::Term::List { els: els.into_iter().map(Self::to_fun).collect() },
396      Expr::Tup { els } => fun::Term::Fan {
397        fan: fun::FanKind::Tup,
398        tag: fun::Tag::Static,
399        els: els.into_iter().map(Self::to_fun).collect(),
400      },
401      Expr::Sup { els } => fun::Term::Fan {
402        fan: fun::FanKind::Dup,
403        tag: fun::Tag::Auto,
404        els: els.into_iter().map(Self::to_fun).collect(),
405      },
406      Expr::Ctr { name, args, kwargs } => {
407        assert!(kwargs.is_empty());
408        let args = args.into_iter().map(Self::to_fun);
409        fun::Term::call(fun::Term::Var { nam: name }, args)
410      }
411      Expr::LstMap { term, bind, iter, cond } => {
412        const ITER_TAIL: &str = "%iter.tail";
413        const ITER_HEAD: &str = "%iter.head";
414
415        let cons_branch = fun::Term::call(
416          fun::Term::r#ref(LCONS),
417          [term.to_fun(), fun::Term::Var { nam: Name::new(ITER_TAIL) }],
418        );
419        let cons_branch = if let Some(cond) = cond {
420          fun::Term::Swt {
421            arg: Box::new(cond.to_fun()),
422            bnd: Some(Name::new("%comprehension")),
423            with_bnd: vec![],
424            with_arg: vec![],
425            pred: Some(Name::new("%comprehension-1")),
426            arms: vec![fun::Term::Var { nam: Name::new(ITER_TAIL) }, cons_branch],
427          }
428        } else {
429          cons_branch
430        };
431        let cons_branch = fun::Term::Let {
432          pat: Box::new(fun::Pattern::Var(Some(bind))),
433          val: Box::new(fun::Term::Var { nam: Name::new(ITER_HEAD) }),
434          nxt: Box::new(cons_branch),
435        };
436
437        fun::Term::Fold {
438          bnd: Some(Name::new("%iter")),
439          arg: Box::new(iter.to_fun()),
440          with_bnd: vec![],
441          with_arg: vec![],
442          arms: vec![
443            (Some(Name::new(LNIL)), vec![], fun::Term::r#ref(LNIL)),
444            (Some(Name::new(LCONS)), vec![], cons_branch),
445          ],
446        }
447      }
448      Expr::Map { entries } => map_init(entries),
449      Expr::MapGet { .. } => unreachable!(),
450      Expr::TreeNode { left, right } => {
451        let left = left.to_fun();
452        let right = right.to_fun();
453        fun::Term::call(fun::Term::r#ref("Tree/Node"), [left, right])
454      }
455      Expr::TreeLeaf { val } => {
456        let val = val.to_fun();
457        fun::Term::app(fun::Term::r#ref("Tree/Leaf"), val)
458      }
459    }
460  }
461}
462
463fn map_init(entries: Vec<(Expr, Expr)>) -> fun::Term {
464  let mut map = fun::Term::Ref { nam: fun::Name::new("Map/empty") };
465  for (key, value) in entries {
466    map =
467      fun::Term::call(fun::Term::Ref { nam: fun::Name::new("Map/set") }, [map, key.to_fun(), value.to_fun()]);
468  }
469  map
470}
471
472/// If the statement was a return, returns it, erroring if there is another after it.
473/// Otherwise, turns it into a 'let' and returns the next statement.
474fn wrap_nxt_assign_stmt(
475  term: fun::Term,
476  nxt: Option<Box<Stmt>>,
477  pat: Option<fun::Pattern>,
478  ask: bool,
479) -> Result<StmtToFun, String> {
480  if let Some(nxt) = nxt {
481    if let Some(pat) = pat {
482      let (ask_nxt, nxt_pat, nxt) = take(*nxt)?;
483      let term = if ask {
484        fun::Term::Ask { pat: Box::new(pat), val: Box::new(term), nxt: Box::new(nxt) }
485      } else {
486        fun::Term::Let { pat: Box::new(pat), val: Box::new(term), nxt: Box::new(nxt) }
487      };
488      Ok(wrap(nxt_pat, term, ask_nxt))
489    } else {
490      Err("Statement ends with return but is not at end of function.".to_string())?
491    }
492  } else if let Some(pat) = pat {
493    Ok(StmtToFun::Assign(ask, pat, term))
494  } else {
495    Ok(StmtToFun::Return(term))
496  }
497}