bend/fun/transform/
desugar_fold.rs1use std::collections::HashSet;
2
3use crate::{
4 diagnostics::Diagnostics,
5 fun::{Adts, Constructors, Ctx, Definition, Name, Pattern, Rule, Source, Term},
6 maybe_grow,
7};
8
9impl Ctx<'_> {
10 pub fn desugar_fold(&mut self) -> Result<(), Diagnostics> {
31 let mut new_defs = vec![];
32 for def in self.book.defs.values_mut() {
33 let mut fresh = 0;
34 for rule in def.rules.iter_mut() {
35 let mut ctx = DesugarFoldCtx {
36 def_name: &def.name,
37 fresh: &mut fresh,
38 new_defs: &mut new_defs,
39 ctrs: &self.book.ctrs,
40 adts: &self.book.adts,
41 source: def.source.clone(),
42 check: def.check,
43 };
44
45 let res = rule.body.desugar_fold(&mut ctx);
46 if let Err(e) = res {
47 self.info.add_function_error(e, def.name.clone(), def.source.clone());
48 }
49 }
50 }
51
52 self.book.defs.extend(new_defs.into_iter().map(|def| (def.name.clone(), def)));
53
54 self.info.fatal(())
55 }
56}
57
58struct DesugarFoldCtx<'a> {
59 pub def_name: &'a Name,
60 pub fresh: &'a mut usize,
61 pub new_defs: &'a mut Vec<Definition>,
62 pub ctrs: &'a Constructors,
63 pub adts: &'a Adts,
64 pub source: Source,
65 pub check: bool,
66}
67
68impl Term {
69 fn desugar_fold(&mut self, ctx: &mut DesugarFoldCtx<'_>) -> Result<(), String> {
70 maybe_grow(|| {
71 for child in self.children_mut() {
72 child.desugar_fold(ctx)?;
73 }
74
75 if let Term::Fold { .. } = self {
76 if self.has_unscoped_diff() {
78 return Err("Can't have non self-contained unscoped variables in a 'fold'".into());
79 }
80 let Term::Fold { bnd: _, arg, with_bnd, with_arg, arms } = self else { unreachable!() };
81
82 let mut free_vars = HashSet::new();
84 for arm in arms.iter() {
85 let mut arm_free_vars = arm.2.free_vars().into_keys().collect::<HashSet<_>>();
86 for field in arm.1.iter().flatten() {
87 arm_free_vars.remove(field);
88 }
89 free_vars.extend(arm_free_vars);
90 }
91 for var in with_bnd.iter().flatten() {
92 free_vars.remove(var);
93 }
94 let free_vars = free_vars.into_iter().collect::<Vec<_>>();
95
96 let new_nam = Name::new(format!("{}__fold{}", ctx.def_name, ctx.fresh));
97 *ctx.fresh += 1;
98
99 let ctr = arms[0].0.as_ref().unwrap();
101 let adt_nam = ctx.ctrs.get(ctr).unwrap();
102 let ctrs = &ctx.adts.get(adt_nam).unwrap().ctrs;
103 for arm in arms.iter_mut() {
104 let ctr = arm.0.as_ref().unwrap();
105 let recursive = arm
106 .1
107 .iter()
108 .zip(&ctrs.get(ctr).unwrap().fields)
109 .filter_map(|(var, field)| if field.rec { Some(var.as_ref().unwrap().clone()) } else { None })
110 .collect::<HashSet<_>>();
111 arm.2.call_recursive(&new_nam, &recursive, &free_vars);
112 }
113
114 let x_nam = Name::new("%x");
116 let body = Term::Mat {
117 arg: Box::new(Term::Var { nam: x_nam.clone() }),
118 bnd: None,
119 with_bnd: with_bnd.clone(),
120 with_arg: with_bnd.iter().map(|nam| Term::var_or_era(nam.clone())).collect(),
121 arms: std::mem::take(arms),
122 };
123 let body = Term::rfold_lams(body, with_bnd.iter().cloned());
124 let body = Term::rfold_lams(body, free_vars.iter().map(|nam| Some(nam.clone())));
125 let body = Term::lam(Pattern::Var(Some(x_nam)), body);
126
127 let def = Definition::new_gen(
128 new_nam.clone(),
129 vec![Rule { pats: vec![], body }],
130 ctx.source.clone(),
131 ctx.check,
132 );
133 ctx.new_defs.push(def);
134
135 let call = Term::call(Term::Ref { nam: new_nam.clone() }, [std::mem::take(arg.as_mut())]);
137 let call = Term::call(call, free_vars.iter().cloned().map(|nam| Term::Var { nam }));
138 let call = Term::call(call, with_arg.iter().cloned());
139 *self = call;
140 }
141 Ok(())
142 })
143 }
144
145 fn call_recursive(&mut self, def_name: &Name, recursive: &HashSet<Name>, free_vars: &[Name]) {
146 maybe_grow(|| {
147 for child in self.children_mut() {
148 child.call_recursive(def_name, recursive, free_vars);
149 }
150
151 if let Term::Var { nam } = self {
153 if recursive.contains(nam) {
154 let call = Term::call(Term::Ref { nam: def_name.clone() }, [std::mem::take(self)]);
155 let call = Term::call(call, free_vars.iter().cloned().map(|nam| Term::Var { nam }));
156 *self = call;
157 }
158 }
159 })
160 }
161}