1use crate::{
2 fun::{Book, Name, Pattern, Term},
3 maybe_grow,
4};
5use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
6
7impl Book {
10 pub fn linearize_match_binds(&mut self) {
27 for def in self.defs.values_mut() {
28 for rule in def.rules.iter_mut() {
29 rule.body.linearize_match_binds();
30 }
31 }
32 }
33}
34
35impl Term {
36 pub fn linearize_match_binds(&mut self) {
39 self.linearize_match_binds_go(vec![]);
40 }
41
42 fn linearize_match_binds_go(&mut self, mut bind_terms: Vec<Term>) {
43 maybe_grow(|| match self {
44 Term::Lam { pat, bod, .. } if !pat.has_unscoped() => {
47 let bod = std::mem::take(bod.as_mut());
48 let term = std::mem::replace(self, bod);
49 bind_terms.push(term);
50 self.linearize_match_binds_go(bind_terms);
51 }
52 Term::Let { val, nxt, .. } | Term::Use { val, nxt, .. } => {
53 val.linearize_match_binds_go(vec![]);
54 if val.has_unscoped() {
55 nxt.linearize_match_binds_go(vec![]);
57 self.wrap_with_bind_terms(bind_terms);
58 } else {
59 let nxt = std::mem::take(nxt.as_mut());
60 let term = std::mem::replace(self, nxt);
61 bind_terms.push(term);
62 self.linearize_match_binds_go(bind_terms);
63 }
64 }
65
66 Term::Mat { .. } | Term::Swt { .. } => {
68 self.linearize_binds_single_match(bind_terms);
69 }
70
71 term => {
74 for child in term.children_mut() {
75 child.linearize_match_binds_go(vec![]);
76 }
77 term.wrap_with_bind_terms(bind_terms);
79 }
80 })
81 }
82
83 fn linearize_binds_single_match(&mut self, mut bind_terms: Vec<Term>) {
84 let (used_vars, with_bnd, with_arg, arms) = match self {
85 Term::Mat { arg, bnd: _, with_bnd, with_arg, arms } => {
86 let vars = arg.free_vars().into_keys().collect::<HashSet<_>>();
87 let arms = arms.iter_mut().map(|arm| &mut arm.2).collect::<Vec<_>>();
88 (vars, with_bnd, with_arg, arms)
89 }
90 Term::Swt { arg, bnd: _, with_bnd, with_arg, pred: _, arms } => {
91 let vars = arg.free_vars().into_keys().collect::<HashSet<_>>();
92 let arms = arms.iter_mut().collect();
93 (vars, with_bnd, with_arg, arms)
94 }
95 _ => unreachable!(),
96 };
97
98 for (bnd, arg) in with_bnd.iter().zip(with_arg.iter()) {
100 let term = Term::Let {
101 pat: Box::new(Pattern::Var(bnd.clone())),
102 val: Box::new(arg.clone()),
103 nxt: Box::new(Term::Err),
104 };
105 bind_terms.push(term)
106 }
107
108 let (mut non_linearized, linearized) = fixed_and_linearized_terms(used_vars, bind_terms);
109
110 for arm in arms {
112 arm.wrap_with_bind_terms(linearized.clone());
113 arm.linearize_match_binds_go(vec![]);
114 }
115
116 let linearized_binds = linearized
118 .iter()
119 .flat_map(|t| match t {
120 Term::Lam { pat, .. } | Term::Let { pat, .. } => pat.binds().flatten().cloned().collect::<Vec<_>>(),
121 Term::Use { nam, .. } => {
122 if let Some(nam) = nam {
123 vec![nam.clone()]
124 } else {
125 vec![]
126 }
127 }
128 _ => unreachable!(),
129 })
130 .collect::<BTreeSet<_>>();
131 update_with_clause(with_bnd, with_arg, &linearized_binds);
132
133 non_linearized.retain(|term| {
136 if let Term::Let { pat, .. } = term {
137 if let Pattern::Var(bnd) = pat.as_ref() {
138 if with_bnd.contains(bnd) {
139 return false;
140 }
141 }
142 }
143 true
144 });
145
146 self.wrap_with_bind_terms(non_linearized);
148 }
149
150 fn wrap_with_bind_terms(
164 &mut self,
165 bind_terms: impl IntoIterator<IntoIter = impl DoubleEndedIterator<Item = Term>>,
166 ) {
167 *self = bind_terms.into_iter().rfold(std::mem::take(self), |acc, mut term| {
168 match &mut term {
169 Term::Lam { bod: nxt, .. } | Term::Let { nxt, .. } | Term::Use { nxt, .. } => {
170 *nxt.as_mut() = acc;
171 }
172 _ => unreachable!(),
173 }
174 term
175 });
176 }
177}
178
179fn fixed_and_linearized_terms(used_in_arg: HashSet<Name>, bind_terms: Vec<Term>) -> (Vec<Term>, Vec<Term>) {
208 let fixed_binds = binds_fixed_by_dependency(used_in_arg, &bind_terms);
209
210 let mut fixed = VecDeque::new();
211 let mut linearized = VecDeque::new();
212 let mut stop = false;
213 for term in bind_terms.into_iter().rev() {
214 let to_linearize = match &term {
215 Term::Use { nam, .. } => nam.as_ref().map_or(true, |nam| !fixed_binds.contains(nam)),
216 Term::Let { pat, .. } => pat.binds().flatten().all(|nam| !fixed_binds.contains(nam)),
217 Term::Lam { pat, .. } => pat.binds().flatten().all(|nam| !fixed_binds.contains(nam)),
218 _ => unreachable!(),
219 };
220 let to_linearize = to_linearize && !stop;
221 if to_linearize {
222 linearized.push_front(term);
223 } else {
224 if matches!(term, Term::Lam { .. }) {
225 stop = true;
226 }
227 fixed.push_front(term);
228 }
229 }
230 (fixed.into_iter().collect(), linearized.into_iter().collect())
231}
232
233fn binds_fixed_by_dependency(used_in_arg: HashSet<Name>, bind_terms: &[Term]) -> HashSet<Name> {
236 let mut fixed_binds = used_in_arg;
237
238 let mut binds = vec![];
240 let mut dependency_digraph = HashMap::new();
241 for term in bind_terms {
242 let (term_binds, term_uses) = match term {
244 Term::Lam { pat, .. } => {
245 let binds = pat.binds().flatten().cloned().collect::<Vec<_>>();
246 (binds, vec![])
247 }
248 Term::Let { pat, val, .. } => {
249 let binds = pat.binds().flatten().cloned().collect::<Vec<_>>();
250 let uses = val.free_vars().into_keys().collect();
251 (binds, uses)
252 }
253 Term::Use { nam, val, .. } => {
254 let binds = if let Some(nam) = nam { vec![nam.clone()] } else { vec![] };
255 let uses = val.free_vars().into_keys().collect();
256 (binds, uses)
257 }
258 _ => unreachable!(),
259 };
260
261 for bind in term_binds {
262 dependency_digraph.insert(bind.clone(), term_uses.clone());
263 binds.push(bind);
264 }
265 }
266
267 for (bind, deps) in dependency_digraph.iter() {
269 if deps.iter().any(|dep| !binds.contains(dep)) {
270 fixed_binds.insert(bind.clone());
271 }
272 }
273
274 let mut dependency_graph: HashMap<Name, HashSet<Name>> =
276 HashMap::from_iter(binds.iter().map(|k| (k.clone(), HashSet::new())));
277 for (bind, deps) in dependency_digraph {
278 for dep in deps {
279 if !binds.contains(&dep) {
280 dependency_graph.insert(dep.clone(), HashSet::new());
281 }
282 dependency_graph.get_mut(&dep).unwrap().insert(bind.clone());
283 dependency_graph.get_mut(&bind).unwrap().insert(dep);
284 }
285 }
286
287 let mut used_component = HashSet::new();
289 let mut visited = HashSet::new();
290 let mut to_visit = fixed_binds.iter().collect::<Vec<_>>();
291 while let Some(node) = to_visit.pop() {
292 if visited.contains(node) {
293 continue;
294 }
295 used_component.insert(node.clone());
296 visited.insert(node);
297
298 if let Some(deps) = dependency_graph.get(node) {
300 to_visit.extend(deps);
301 }
302 }
303
304 let mut fixed_start = false;
306 let mut fixed_lams = HashSet::new();
307 for term in bind_terms.iter().rev() {
308 if let Term::Lam { pat, .. } = term {
309 if pat.binds().flatten().any(|p| used_component.contains(p)) {
310 fixed_start = true;
311 }
312 if fixed_start {
313 for bind in pat.binds().flatten() {
314 fixed_lams.insert(bind.clone());
315 }
316 }
317 }
318 }
319
320 let mut fixed_binds = used_component;
321
322 let mut visited = HashSet::new();
324 let mut to_visit = fixed_lams.iter().collect::<Vec<_>>();
325 while let Some(node) = to_visit.pop() {
326 if visited.contains(node) {
327 continue;
328 }
329 fixed_binds.insert(node.clone());
330 visited.insert(node);
331
332 if let Some(deps) = dependency_graph.get(node) {
334 to_visit.extend(deps);
335 }
336 }
337
338 fixed_binds
339}
340
341fn update_with_clause(
342 with_bnd: &mut Vec<Option<Name>>,
343 with_arg: &mut Vec<Term>,
344 vars_to_lift: &BTreeSet<Name>,
345) {
346 let mut to_remove = Vec::new();
347 for i in 0..with_bnd.len() {
348 if let Some(with_bnd) = &with_bnd[i] {
349 if vars_to_lift.contains(with_bnd) {
350 to_remove.push(i);
351 }
352 }
353 }
354 for (removed, to_remove) in to_remove.into_iter().enumerate() {
355 with_bnd.remove(to_remove - removed);
356 with_arg.remove(to_remove - removed);
357 }
358}
359impl Book {
362 pub fn linearize_matches(&mut self) {
364 for def in self.defs.values_mut() {
365 for rule in def.rules.iter_mut() {
366 rule.body.linearize_matches();
367 }
368 }
369 }
370}
371
372impl Term {
373 fn linearize_matches(&mut self) {
374 maybe_grow(|| {
375 for child in self.children_mut() {
376 child.linearize_matches();
377 }
378
379 if matches!(self, Term::Mat { .. } | Term::Swt { .. }) {
380 lift_match_vars(self);
381 }
382 })
383 }
384}
385
386pub fn lift_match_vars(match_term: &mut Term) -> &mut Term {
394 let (with_bnd, with_arg, arms) = match match_term {
396 Term::Mat { arg: _, bnd: _, with_bnd, with_arg, arms: rules } => {
397 let args =
398 rules.iter().map(|(_, binds, body)| (binds.iter().flatten().cloned().collect(), body)).collect();
399 (with_bnd.clone(), with_arg.clone(), args)
400 }
401 Term::Swt { arg: _, bnd: _, with_bnd, with_arg, pred, arms } => {
402 let (succ, nums) = arms.split_last_mut().unwrap();
403 let mut arms = nums.iter().map(|body| (vec![], body)).collect::<Vec<_>>();
404 arms.push((vec![pred.clone().unwrap()], succ));
405 (with_bnd.clone(), with_arg.clone(), arms)
406 }
407 _ => unreachable!(),
408 };
409
410 let mut free_vars = Vec::<Vec<_>>::new();
412 for (binds, body) in arms {
413 let mut arm_free_vars = body.free_vars();
414 for bind in binds {
415 arm_free_vars.shift_remove(&bind);
416 }
417 free_vars.push(arm_free_vars.into_keys().collect());
418 }
419
420 let vars_to_lift: BTreeSet<Name> = free_vars.into_iter().flatten().collect();
423
424 match match_term {
426 Term::Mat { arg: _, bnd: _, with_bnd, with_arg, arms } => {
427 update_with_clause(with_bnd, with_arg, &vars_to_lift);
428 for arm in arms {
429 let old_body = std::mem::take(&mut arm.2);
430 arm.2 = Term::rfold_lams(old_body, vars_to_lift.iter().cloned().map(Some));
431 }
432 }
433 Term::Swt { arg: _, bnd: _, with_bnd, with_arg, pred: _, arms } => {
434 update_with_clause(with_bnd, with_arg, &vars_to_lift);
435 for arm in arms {
436 let old_body = std::mem::take(arm);
437 *arm = Term::rfold_lams(old_body, vars_to_lift.iter().cloned().map(Some));
438 }
439 }
440 _ => unreachable!(),
441 }
442
443 let args = vars_to_lift
445 .into_iter()
446 .map(|nam| {
447 if let Some(idx) = with_bnd.iter().position(|x| x == &nam) {
448 with_arg[idx].clone()
449 } else {
450 Term::Var { nam }
451 }
452 })
453 .collect::<Vec<_>>();
454 let term = Term::call(std::mem::take(match_term), args);
455 *match_term = term;
456
457 get_match_reference(match_term)
458}
459
460fn get_match_reference(mut match_term: &mut Term) -> &mut Term {
464 loop {
465 match match_term {
466 Term::App { tag: _, fun, arg: _ } => match_term = fun.as_mut(),
467 Term::Swt { .. } | Term::Mat { .. } => return match_term,
468 _ => unreachable!(),
469 }
470 }
471}
472
473impl Book {
476 pub fn linearize_match_with(&mut self) {
478 for def in self.defs.values_mut() {
479 for rule in def.rules.iter_mut() {
480 rule.body.linearize_match_with();
481 }
482 }
483 }
484}
485
486impl Term {
487 fn linearize_match_with(&mut self) {
488 maybe_grow(|| {
489 for child in self.children_mut() {
490 child.linearize_match_with();
491 }
492 });
493 match self {
494 Term::Mat { arg: _, bnd: _, with_bnd, with_arg, arms } => {
495 for rule in arms {
496 rule.2 = Term::rfold_lams(std::mem::take(&mut rule.2), with_bnd.clone().into_iter());
497 }
498 *with_bnd = vec![];
499 let call_args = std::mem::take(with_arg).into_iter();
500 *self = Term::call(std::mem::take(self), call_args);
501 }
502 Term::Swt { arg: _, bnd: _, with_bnd, with_arg, pred: _, arms } => {
503 for rule in arms {
504 *rule = Term::rfold_lams(std::mem::take(rule), with_bnd.clone().into_iter());
505 }
506 *with_bnd = vec![];
507 let call_args = std::mem::take(with_arg).into_iter();
508 *self = Term::call(std::mem::take(self), call_args);
509 }
510 _ => {}
511 }
512 }
513}