1use crate::{
2 fun::{Book, Definition, Name, Pattern, Rule, Source, Term},
3 maybe_grow, multi_iterator,
4};
5use std::collections::{BTreeMap, HashSet};
6
7pub const NAME_SEP: &str = "__C";
8
9impl Book {
10 pub fn float_combinators(&mut self, max_size: usize) {
34 let book = self.clone();
35 let mut ctx = FloatCombinatorsCtx::new(&book, max_size);
36
37 for (def_name, def) in self.defs.iter_mut() {
38 if let Some(main) = self.entrypoint.as_ref() {
42 if def_name == main {
43 continue;
44 }
45 }
46
47 let source = def.source.clone();
48 let check = def.check;
49 let body = &mut def.rule_mut().body;
50 ctx.reset();
51 ctx.def_size = body.size();
52 body.float_combinators(&mut ctx, def_name, source, check);
53 }
54
55 self.defs.extend(ctx.combinators.into_iter().map(|(nam, (_, def))| (nam, def)));
56 }
57}
58
59struct FloatCombinatorsCtx<'b> {
60 pub combinators: BTreeMap<Name, (bool, Definition)>,
61 pub name_gen: usize,
62 pub seen: HashSet<Name>,
63 pub book: &'b Book,
64 pub max_size: usize,
65 pub def_size: usize,
66}
67
68impl<'b> FloatCombinatorsCtx<'b> {
69 fn new(book: &'b Book, max_size: usize) -> Self {
70 Self {
71 combinators: Default::default(),
72 name_gen: 0,
73 seen: Default::default(),
74 book,
75 max_size,
76 def_size: 0,
77 }
78 }
79
80 fn reset(&mut self) {
81 self.def_size = 0;
82 self.name_gen = 0;
83 self.seen = Default::default();
84 }
85}
86
87impl Term {
88 fn float_combinators(
89 &mut self,
90 ctx: &mut FloatCombinatorsCtx,
91 def_name: &Name,
92 source: Source,
93 check: bool,
94 ) {
95 maybe_grow(|| {
96 for child in self.float_children_mut() {
98 child.float_combinators(ctx, def_name, source.clone(), check);
99 }
100
101 let mut size = self.size();
102 let is_combinator = self.is_combinator();
103
104 for child in self.float_children_mut() {
106 let child_is_safe = child.is_safe(ctx);
107 let child_size = child.size();
108
109 let extract_for_size = if is_combinator { size > ctx.max_size } else { ctx.def_size > ctx.max_size };
110
111 if child.is_combinator() && child_size > 0 && (!child_is_safe || extract_for_size) {
112 ctx.def_size -= child_size;
113 size -= child_size;
114 child.float(ctx, def_name, source.clone(), check, child_is_safe);
115 }
116 }
117 })
118 }
119
120 fn float(
122 &mut self,
123 ctx: &mut FloatCombinatorsCtx,
124 def_name: &Name,
125 source: Source,
126 check: bool,
127 is_safe: bool,
128 ) {
129 let comb_name = Name::new(format!("{}{}{}", def_name, NAME_SEP, ctx.name_gen));
130 ctx.name_gen += 1;
131
132 let comb_ref = Term::Ref { nam: comb_name.clone() };
133 let extracted_term = std::mem::replace(self, comb_ref);
134
135 let rules = vec![Rule { body: extracted_term, pats: Vec::new() }];
136 let rule = Definition::new_gen(comb_name.clone(), rules, source, check);
137 ctx.combinators.insert(comb_name, (is_safe, rule));
138 }
139}
140
141impl Term {
142 fn is_safe(&self, ctx: &mut FloatCombinatorsCtx) -> bool {
151 maybe_grow(|| match self {
152 Term::Num { .. }
153 | Term::Era
154 | Term::Err
155 | Term::Fan { .. }
156 | Term::App { .. }
157 | Term::Oper { .. }
158 | Term::Swt { .. } => self.children().all(|c| c.is_safe(ctx)),
159 Term::Lam { .. } => self.is_safe_lambda(ctx),
160 Term::Ref { nam } => {
161 if ctx.book.ctrs.contains_key(nam) {
163 return true;
164 }
165 if ctx.seen.contains(nam) {
167 return false;
168 }
169 ctx.seen.insert(nam.clone());
170
171 let safe = if let Some(def) = ctx.book.defs.get(nam) {
173 def.rule().body.is_safe(ctx)
174 } else if let Some((safe, _)) = ctx.combinators.get(nam) {
175 *safe
176 } else {
177 false
178 };
179
180 ctx.seen.remove(nam);
181 safe
182 }
183 _ => false,
186 })
187 }
188
189 fn is_safe_lambda(&self, ctx: &mut FloatCombinatorsCtx) -> bool {
193 let mut current = self;
194 let mut scope = Vec::new();
195
196 while let Term::Lam { pat, bod, .. } = current {
197 scope.extend(pat.binds().filter_map(|x| x.as_ref()));
198 current = bod;
199 }
200
201 match current {
202 Term::Var { nam } if scope.contains(&nam) => true,
203 Term::Ref { .. } => true,
204 term => term.is_safe(ctx),
205 }
206 }
207
208 pub fn has_unscoped_diff(&self) -> bool {
209 let (declared, used) = self.unscoped_vars();
210 declared.difference(&used).count() != 0 || used.difference(&declared).count() != 0
211 }
212
213 fn is_combinator(&self) -> bool {
214 self.free_vars().is_empty() && !self.has_unscoped_diff() && !matches!(self, Term::Ref { .. })
215 }
216
217 fn base_size(&self) -> usize {
218 match self {
219 Term::Let { pat, .. } => pat.size(),
220 Term::Fan { els, .. } => els.len() - 1,
221 Term::Mat { arms, .. } => arms.len(),
222 Term::Swt { arms, .. } => 2 * (arms.len() - 1),
223 Term::Lam { .. } => 1,
224 Term::App { .. } => 1,
225 Term::Oper { .. } => 1,
226 Term::Var { .. } => 0,
227 Term::Link { .. } => 0,
228 Term::Use { .. } => 0,
229 Term::Num { .. } => 0,
230 Term::Ref { .. } => 0,
231 Term::Era => 0,
232 Term::Bend { .. }
233 | Term::Fold { .. }
234 | Term::Nat { .. }
235 | Term::Str { .. }
236 | Term::List { .. }
237 | Term::With { .. }
238 | Term::Ask { .. }
239 | Term::Open { .. }
240 | Term::Def { .. }
241 | Term::Err => unreachable!(),
242 }
243 }
244
245 fn size(&self) -> usize {
246 maybe_grow(|| {
247 let children_size: usize = self.children().map(|c| c.size()).sum();
248 self.base_size() + children_size
249 })
250 }
251
252 pub fn float_children_mut(&mut self) -> impl Iterator<Item = &mut Term> {
253 multi_iterator!(FloatIter { Zero, Two, Vec, Mat, App, Swt });
254 match self {
255 Term::App { .. } => {
256 let mut next = Some(self);
257 FloatIter::App(std::iter::from_fn(move || {
258 let cur = next.take();
259 if let Some(Term::App { fun, arg, .. }) = cur {
260 next = Some(&mut *fun);
261 Some(&mut **arg)
262 } else {
263 cur
264 }
265 }))
266 }
267 Term::Mat { arg, bnd: _, with_bnd: _, with_arg, arms } => FloatIter::Mat(
268 [arg.as_mut()].into_iter().chain(with_arg.iter_mut()).chain(arms.iter_mut().map(|r| &mut r.2)),
269 ),
270 Term::Swt { arg, bnd: _, with_bnd: _, with_arg, pred: _, arms } => {
271 FloatIter::Swt([arg.as_mut()].into_iter().chain(with_arg.iter_mut()).chain(arms.iter_mut()))
272 }
273 Term::Fan { els, .. } | Term::List { els } => FloatIter::Vec(els),
274 Term::Let { val: fst, nxt: snd, .. }
275 | Term::Use { val: fst, nxt: snd, .. }
276 | Term::Oper { fst, snd, .. } => FloatIter::Two([fst.as_mut(), snd.as_mut()]),
277 Term::Lam { bod, .. } => bod.float_children_mut(),
278 Term::Var { .. }
279 | Term::Link { .. }
280 | Term::Num { .. }
281 | Term::Nat { .. }
282 | Term::Str { .. }
283 | Term::Ref { .. }
284 | Term::Era
285 | Term::Err => FloatIter::Zero([]),
286 Term::With { .. }
287 | Term::Ask { .. }
288 | Term::Bend { .. }
289 | Term::Fold { .. }
290 | Term::Open { .. }
291 | Term::Def { .. } => {
292 unreachable!()
293 }
294 }
295 }
296}
297
298impl Pattern {
299 fn size(&self) -> usize {
300 match self {
301 Pattern::Var(_) => 0,
302 Pattern::Chn(_) => 0,
303 Pattern::Fan(_, _, pats) => pats.len() - 1 + pats.iter().map(|p| p.size()).sum::<usize>(),
304
305 Pattern::Num(_) | Pattern::Lst(_) | Pattern::Str(_) | Pattern::Ctr(_, _) => unreachable!(),
306 }
307 }
308}