1use crate::{
2 fun::{Book, FanKind, Name, Pattern, Tag, Term},
3 maybe_grow, multi_iterator,
4};
5use std::collections::HashMap;
6
7impl Book {
20 pub fn linearize_vars(&mut self) {
21 for def in self.defs.values_mut() {
22 def.rule_mut().body.linearize_vars();
23 }
24 }
25}
26
27impl Term {
28 pub fn linearize_vars(&mut self) {
29 term_to_linear(self, &mut HashMap::new());
30 }
31}
32
33fn term_to_linear(term: &mut Term, var_uses: &mut HashMap<Name, u64>) {
34 maybe_grow(|| {
35 if let Term::Let { pat, val, nxt } = term {
36 if let Pattern::Var(Some(nam)) = pat.as_ref() {
37 term_to_linear(nxt, var_uses);
41
42 let uses = get_var_uses(Some(nam), var_uses);
43 term_to_linear(val, var_uses);
44 match uses {
45 0 => {
46 let Term::Let { pat, .. } = term else { unreachable!() };
47 **pat = Pattern::Var(None);
48 }
49 1 => {
50 nxt.subst(nam, val.as_ref());
51 *term = std::mem::take(nxt.as_mut());
52 }
53 _ => {
54 let new_pat = duplicate_pat(nam, uses);
55 let Term::Let { pat, .. } = term else { unreachable!() };
56 *pat = new_pat;
57 }
58 }
59 return;
60 }
61 }
62 if let Term::Var { nam } = term {
63 let instantiated_count = var_uses.entry(nam.clone()).or_default();
64 *instantiated_count += 1;
65 *nam = dup_name(nam, *instantiated_count);
66 return;
67 }
68
69 for (child, binds) in term.children_mut_with_binds_mut() {
70 term_to_linear(child, var_uses);
71
72 for bind in binds {
73 let uses = get_var_uses(bind.as_ref(), var_uses);
74 match uses {
75 0 => *bind = None,
77 1 => (),
79 uses => {
81 debug_assert!(uses > 1);
82 let nam = bind.as_ref().unwrap();
83 *child = Term::Let {
84 pat: duplicate_pat(nam, uses),
85 val: Box::new(Term::Var { nam: nam.clone() }),
86 nxt: Box::new(std::mem::take(child)),
87 }
88 }
89 }
90 }
91 }
92 })
93}
94
95fn get_var_uses(nam: Option<&Name>, var_uses: &HashMap<Name, u64>) -> u64 {
96 nam.and_then(|nam| var_uses.get(nam).copied()).unwrap_or_default()
97}
98
99fn duplicate_pat(nam: &Name, uses: u64) -> Box<Pattern> {
100 Box::new(Pattern::Fan(
101 FanKind::Dup,
102 Tag::Auto,
103 (1..uses + 1).map(|i| Pattern::Var(Some(dup_name(nam, i)))).collect(),
104 ))
105}
106
107fn dup_name(nam: &Name, uses: u64) -> Name {
108 if uses == 1 {
109 nam.clone()
110 } else {
111 Name::new(format!("{nam}_{uses}"))
112 }
113}
114
115impl Term {
116 pub fn children_mut_with_binds_mut(
119 &mut self,
120 ) -> impl DoubleEndedIterator<Item = (&mut Term, impl DoubleEndedIterator<Item = &mut Option<Name>>)> {
121 multi_iterator!(ChildrenIter { Zero, One, Two, Vec, Swt });
122 multi_iterator!(BindsIter { Zero, One, Pat });
123 match self {
124 Term::Swt { arg, bnd, with_bnd, with_arg, pred, arms } => {
125 debug_assert!(bnd.is_none());
126 debug_assert!(with_bnd.is_empty());
127 debug_assert!(with_arg.is_empty());
128 debug_assert!(pred.is_none());
129 ChildrenIter::Swt(
130 [(arg.as_mut(), BindsIter::Zero([]))]
131 .into_iter()
132 .chain(arms.iter_mut().map(|x| (x, BindsIter::Zero([])))),
133 )
134 }
135 Term::Fan { els, .. } | Term::List { els } => {
136 ChildrenIter::Vec(els.iter_mut().map(|el| (el, BindsIter::Zero([]))))
137 }
138 Term::Use { nam, val, nxt } => {
139 ChildrenIter::Two([(val.as_mut(), BindsIter::Zero([])), (nxt.as_mut(), BindsIter::One([nam]))])
140 }
141 Term::Let { pat, val, nxt, .. } | Term::Ask { pat, val, nxt, .. } => ChildrenIter::Two([
142 (val.as_mut(), BindsIter::Zero([])),
143 (nxt.as_mut(), BindsIter::Pat(pat.binds_mut())),
144 ]),
145 Term::App { fun: fst, arg: snd, .. } | Term::Oper { fst, snd, .. } => {
146 ChildrenIter::Two([(fst.as_mut(), BindsIter::Zero([])), (snd.as_mut(), BindsIter::Zero([]))])
147 }
148 Term::Lam { pat, bod, .. } => ChildrenIter::One([(bod.as_mut(), BindsIter::Pat(pat.binds_mut()))]),
149 Term::With { bod, .. } => ChildrenIter::One([(bod.as_mut(), BindsIter::Zero([]))]),
150 Term::Var { .. }
151 | Term::Link { .. }
152 | Term::Num { .. }
153 | Term::Nat { .. }
154 | Term::Str { .. }
155 | Term::Ref { .. }
156 | Term::Era
157 | Term::Err => ChildrenIter::Zero([]),
158 Term::Mat { .. } => unreachable!("'match' should be removed in earlier pass"),
159 Term::Fold { .. } => unreachable!("'fold' should be removed in earlier pass"),
160 Term::Bend { .. } => unreachable!("'bend' should be removed in earlier pass"),
161 Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"),
162 Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"),
163 }
164 }
165}