1use crate::fun::Name;
2
3use super::{AssignPattern, Definition, Expr, Stmt};
4
5impl Definition {
6 pub fn gen_map_get(&mut self) {
11 self.body.gen_map_get(&mut 0);
12 }
13}
14
15impl Stmt {
16 fn gen_map_get(&mut self, id: &mut usize) {
17 match self {
18 Stmt::LocalDef { def, nxt } => {
19 nxt.gen_map_get(id);
20 def.gen_map_get()
21 }
22 Stmt::Assign { pat, val, nxt } => {
23 let key_substitutions =
24 if let AssignPattern::MapSet(_, key) = pat { key.substitute_map_gets(id) } else { Vec::new() };
25
26 if let Some(nxt) = nxt {
27 nxt.gen_map_get(id);
28 }
29
30 let substitutions = val.substitute_map_gets(id);
31 if !substitutions.is_empty() {
32 *self = gen_get(self, substitutions);
33 }
34
35 if !key_substitutions.is_empty() {
36 *self = gen_get(self, key_substitutions);
37 }
38 }
39 Stmt::Ask { pat: _, val, nxt } => {
40 if let Some(nxt) = nxt {
41 nxt.gen_map_get(id);
42 }
43 let substitutions = val.substitute_map_gets(id);
44 if !substitutions.is_empty() {
45 *self = gen_get(self, substitutions);
46 }
47 }
48 Stmt::InPlace { op: _, pat, val, nxt } => {
49 let key_substitutions = if let AssignPattern::MapSet(_, key) = &mut **pat {
50 key.substitute_map_gets(id)
51 } else {
52 Vec::new()
53 };
54
55 nxt.gen_map_get(id);
56
57 let substitutions = val.substitute_map_gets(id);
58 if !substitutions.is_empty() {
59 *self = gen_get(self, substitutions);
60 }
61
62 if !key_substitutions.is_empty() {
63 *self = gen_get(self, key_substitutions);
64 }
65 }
66 Stmt::If { cond, then, otherwise, nxt } => {
67 then.gen_map_get(id);
68 otherwise.gen_map_get(id);
69 if let Some(nxt) = nxt {
70 nxt.gen_map_get(id);
71 }
72 let substitutions = cond.substitute_map_gets(id);
73 if !substitutions.is_empty() {
74 *self = gen_get(self, substitutions);
75 }
76 }
77 Stmt::Match { bnd: _, arg, with_bnd: _, with_arg, arms, nxt }
78 | Stmt::Fold { bnd: _, arg, arms, with_bnd: _, with_arg, nxt } => {
79 for arm in arms.iter_mut() {
80 arm.rgt.gen_map_get(id);
81 }
82 if let Some(nxt) = nxt {
83 nxt.gen_map_get(id);
84 }
85 let mut substitutions = arg.substitute_map_gets(id);
86 for arg in with_arg {
87 substitutions.extend(arg.substitute_map_gets(id));
88 }
89 if !substitutions.is_empty() {
90 *self = gen_get(self, substitutions);
91 }
92 }
93 Stmt::Switch { bnd: _, arg, with_bnd: _, with_arg, arms, nxt } => {
94 for arm in arms.iter_mut() {
95 arm.gen_map_get(id);
96 }
97 if let Some(nxt) = nxt {
98 nxt.gen_map_get(id);
99 }
100 let mut substitutions = arg.substitute_map_gets(id);
101 for arg in with_arg {
102 substitutions.extend(arg.substitute_map_gets(id));
103 }
104 if !substitutions.is_empty() {
105 *self = gen_get(self, substitutions);
106 }
107 }
108 Stmt::Bend { bnd: _, arg: init, cond, step, base, nxt } => {
109 step.gen_map_get(id);
110 base.gen_map_get(id);
111 if let Some(nxt) = nxt {
112 nxt.gen_map_get(id);
113 }
114 let mut substitutions = cond.substitute_map_gets(id);
115 for init in init {
116 substitutions.extend(init.substitute_map_gets(id));
117 }
118 if !substitutions.is_empty() {
119 *self = gen_get(self, substitutions);
120 }
121 }
122 Stmt::With { typ: _, bod, nxt } => {
123 bod.gen_map_get(id);
124 if let Some(nxt) = nxt {
125 nxt.gen_map_get(id);
126 }
127 }
128 Stmt::Return { term } => {
129 let substitutions = term.substitute_map_gets(id);
130 if !substitutions.is_empty() {
131 *self = gen_get(self, substitutions);
132 }
133 }
134 Stmt::Open { typ: _, var: _, nxt } => {
135 nxt.gen_map_get(id);
136 }
137 Stmt::Use { nam: _, val: bod, nxt } => {
138 nxt.gen_map_get(id);
139 let substitutions = bod.substitute_map_gets(id);
140 if !substitutions.is_empty() {
141 *self = gen_get(self, substitutions);
142 }
143 }
144 Stmt::Err => {}
145 }
146 }
147}
148
149type Substitutions = Vec<(Name, Name, Box<Expr>)>;
150
151impl Expr {
152 fn substitute_map_gets(&mut self, id: &mut usize) -> Substitutions {
153 fn go(e: &mut Expr, substitutions: &mut Substitutions, id: &mut usize) {
154 match e {
155 Expr::MapGet { nam, key } => {
156 go(key, substitutions, id);
157 let new_var = gen_map_var(id);
158 substitutions.push((new_var.clone(), nam.clone(), key.clone()));
159 *e = Expr::Var { nam: new_var };
160 }
161 Expr::Call { fun, args, kwargs } => {
162 go(fun, substitutions, id);
163 for arg in args {
164 go(arg, substitutions, id);
165 }
166 for (_, arg) in kwargs {
167 go(arg, substitutions, id);
168 }
169 }
170 Expr::Lam { bod, .. } => {
171 go(bod, substitutions, id);
172 }
173 Expr::Opr { lhs, rhs, .. } => {
174 go(lhs, substitutions, id);
175 go(rhs, substitutions, id);
176 }
177 Expr::Lst { els } | Expr::Tup { els } | Expr::Sup { els } => {
178 for el in els {
179 go(el, substitutions, id);
180 }
181 }
182 Expr::Ctr { kwargs, .. } => {
183 for (_, arg) in kwargs.iter_mut() {
184 go(arg, substitutions, id);
185 }
186 }
187 Expr::LstMap { term, iter, cond, .. } => {
188 go(term, substitutions, id);
189 go(iter, substitutions, id);
190 if let Some(cond) = cond {
191 go(cond, substitutions, id);
192 }
193 }
194 Expr::Map { entries } => {
195 for (_, entry) in entries {
196 go(entry, substitutions, id);
197 }
198 }
199 Expr::TreeNode { left, right } => {
200 go(left, substitutions, id);
201 go(right, substitutions, id);
202 }
203 Expr::TreeLeaf { val } => {
204 go(val, substitutions, id);
205 }
206 Expr::Era | Expr::Str { .. } | Expr::Var { .. } | Expr::Chn { .. } | Expr::Num { .. } => {}
207 }
208 }
209 let mut substitutions = Substitutions::new();
210 go(self, &mut substitutions, id);
211 substitutions
212 }
213}
214
215fn gen_get(current: &mut Stmt, substitutions: Substitutions) -> Stmt {
216 substitutions.into_iter().rfold(std::mem::take(current), |acc, next| {
217 let (var, map_var, key) = next;
218 let map_get_call = Expr::Var { nam: Name::new("Map/get") };
219 let map_get_call = Expr::Call {
220 fun: Box::new(map_get_call),
221 args: vec![Expr::Var { nam: map_var.clone() }, *key],
222 kwargs: Vec::new(),
223 };
224 let pat = AssignPattern::Tup(vec![AssignPattern::Var(var), AssignPattern::Var(map_var)]);
225
226 Stmt::Assign { pat, val: Box::new(map_get_call), nxt: Some(Box::new(acc)) }
227 })
228}
229
230fn gen_map_var(id: &mut usize) -> Name {
231 let name = Name::new(format!("map/get%{}", id));
232 *id += 1;
233 name
234}