bend/imp/
gen_map_get.rs

1use crate::fun::Name;
2
3use super::{AssignPattern, Definition, Expr, Stmt};
4
5impl Definition {
6  /// Generates a map from `Stmt` to `Substitutions` for each definition in the program.
7  /// Iterates over all definitions in the program and applies `gen_map_get` to their bodies.
8  /// It replaces `Expr::MapGet` expressions with variable accesses, introducing
9  /// new variables as necessary to hold intermediate results from map accesses.
10  pub fn gen_map_get(&mut self) {
11    self.body.gen_map_get(&mut 0);
12  }
13}
14
15impl Stmt {
16  fn gen_map_get(&mut self, id: &mut usize) {
17    match self {
18      Stmt::LocalDef { def, nxt } => {
19        nxt.gen_map_get(id);
20        def.gen_map_get()
21      }
22      Stmt::Assign { pat, val, nxt } => {
23        let key_substitutions =
24          if let AssignPattern::MapSet(_, key) = pat { key.substitute_map_gets(id) } else { Vec::new() };
25
26        if let Some(nxt) = nxt {
27          nxt.gen_map_get(id);
28        }
29
30        let substitutions = val.substitute_map_gets(id);
31        if !substitutions.is_empty() {
32          *self = gen_get(self, substitutions);
33        }
34
35        if !key_substitutions.is_empty() {
36          *self = gen_get(self, key_substitutions);
37        }
38      }
39      Stmt::Ask { pat: _, val, nxt } => {
40        if let Some(nxt) = nxt {
41          nxt.gen_map_get(id);
42        }
43        let substitutions = val.substitute_map_gets(id);
44        if !substitutions.is_empty() {
45          *self = gen_get(self, substitutions);
46        }
47      }
48      Stmt::InPlace { op: _, pat, val, nxt } => {
49        let key_substitutions = if let AssignPattern::MapSet(_, key) = &mut **pat {
50          key.substitute_map_gets(id)
51        } else {
52          Vec::new()
53        };
54
55        nxt.gen_map_get(id);
56
57        let substitutions = val.substitute_map_gets(id);
58        if !substitutions.is_empty() {
59          *self = gen_get(self, substitutions);
60        }
61
62        if !key_substitutions.is_empty() {
63          *self = gen_get(self, key_substitutions);
64        }
65      }
66      Stmt::If { cond, then, otherwise, nxt } => {
67        then.gen_map_get(id);
68        otherwise.gen_map_get(id);
69        if let Some(nxt) = nxt {
70          nxt.gen_map_get(id);
71        }
72        let substitutions = cond.substitute_map_gets(id);
73        if !substitutions.is_empty() {
74          *self = gen_get(self, substitutions);
75        }
76      }
77      Stmt::Match { bnd: _, arg, with_bnd: _, with_arg, arms, nxt }
78      | Stmt::Fold { bnd: _, arg, arms, with_bnd: _, with_arg, nxt } => {
79        for arm in arms.iter_mut() {
80          arm.rgt.gen_map_get(id);
81        }
82        if let Some(nxt) = nxt {
83          nxt.gen_map_get(id);
84        }
85        let mut substitutions = arg.substitute_map_gets(id);
86        for arg in with_arg {
87          substitutions.extend(arg.substitute_map_gets(id));
88        }
89        if !substitutions.is_empty() {
90          *self = gen_get(self, substitutions);
91        }
92      }
93      Stmt::Switch { bnd: _, arg, with_bnd: _, with_arg, arms, nxt } => {
94        for arm in arms.iter_mut() {
95          arm.gen_map_get(id);
96        }
97        if let Some(nxt) = nxt {
98          nxt.gen_map_get(id);
99        }
100        let mut substitutions = arg.substitute_map_gets(id);
101        for arg in with_arg {
102          substitutions.extend(arg.substitute_map_gets(id));
103        }
104        if !substitutions.is_empty() {
105          *self = gen_get(self, substitutions);
106        }
107      }
108      Stmt::Bend { bnd: _, arg: init, cond, step, base, nxt } => {
109        step.gen_map_get(id);
110        base.gen_map_get(id);
111        if let Some(nxt) = nxt {
112          nxt.gen_map_get(id);
113        }
114        let mut substitutions = cond.substitute_map_gets(id);
115        for init in init {
116          substitutions.extend(init.substitute_map_gets(id));
117        }
118        if !substitutions.is_empty() {
119          *self = gen_get(self, substitutions);
120        }
121      }
122      Stmt::With { typ: _, bod, nxt } => {
123        bod.gen_map_get(id);
124        if let Some(nxt) = nxt {
125          nxt.gen_map_get(id);
126        }
127      }
128      Stmt::Return { term } => {
129        let substitutions = term.substitute_map_gets(id);
130        if !substitutions.is_empty() {
131          *self = gen_get(self, substitutions);
132        }
133      }
134      Stmt::Open { typ: _, var: _, nxt } => {
135        nxt.gen_map_get(id);
136      }
137      Stmt::Use { nam: _, val: bod, nxt } => {
138        nxt.gen_map_get(id);
139        let substitutions = bod.substitute_map_gets(id);
140        if !substitutions.is_empty() {
141          *self = gen_get(self, substitutions);
142        }
143      }
144      Stmt::Err => {}
145    }
146  }
147}
148
149type Substitutions = Vec<(Name, Name, Box<Expr>)>;
150
151impl Expr {
152  fn substitute_map_gets(&mut self, id: &mut usize) -> Substitutions {
153    fn go(e: &mut Expr, substitutions: &mut Substitutions, id: &mut usize) {
154      match e {
155        Expr::MapGet { nam, key } => {
156          go(key, substitutions, id);
157          let new_var = gen_map_var(id);
158          substitutions.push((new_var.clone(), nam.clone(), key.clone()));
159          *e = Expr::Var { nam: new_var };
160        }
161        Expr::Call { fun, args, kwargs } => {
162          go(fun, substitutions, id);
163          for arg in args {
164            go(arg, substitutions, id);
165          }
166          for (_, arg) in kwargs {
167            go(arg, substitutions, id);
168          }
169        }
170        Expr::Lam { bod, .. } => {
171          go(bod, substitutions, id);
172        }
173        Expr::Opr { lhs, rhs, .. } => {
174          go(lhs, substitutions, id);
175          go(rhs, substitutions, id);
176        }
177        Expr::Lst { els } | Expr::Tup { els } | Expr::Sup { els } => {
178          for el in els {
179            go(el, substitutions, id);
180          }
181        }
182        Expr::Ctr { kwargs, .. } => {
183          for (_, arg) in kwargs.iter_mut() {
184            go(arg, substitutions, id);
185          }
186        }
187        Expr::LstMap { term, iter, cond, .. } => {
188          go(term, substitutions, id);
189          go(iter, substitutions, id);
190          if let Some(cond) = cond {
191            go(cond, substitutions, id);
192          }
193        }
194        Expr::Map { entries } => {
195          for (_, entry) in entries {
196            go(entry, substitutions, id);
197          }
198        }
199        Expr::TreeNode { left, right } => {
200          go(left, substitutions, id);
201          go(right, substitutions, id);
202        }
203        Expr::TreeLeaf { val } => {
204          go(val, substitutions, id);
205        }
206        Expr::Era | Expr::Str { .. } | Expr::Var { .. } | Expr::Chn { .. } | Expr::Num { .. } => {}
207      }
208    }
209    let mut substitutions = Substitutions::new();
210    go(self, &mut substitutions, id);
211    substitutions
212  }
213}
214
215fn gen_get(current: &mut Stmt, substitutions: Substitutions) -> Stmt {
216  substitutions.into_iter().rfold(std::mem::take(current), |acc, next| {
217    let (var, map_var, key) = next;
218    let map_get_call = Expr::Var { nam: Name::new("Map/get") };
219    let map_get_call = Expr::Call {
220      fun: Box::new(map_get_call),
221      args: vec![Expr::Var { nam: map_var.clone() }, *key],
222      kwargs: Vec::new(),
223    };
224    let pat = AssignPattern::Tup(vec![AssignPattern::Var(var), AssignPattern::Var(map_var)]);
225
226    Stmt::Assign { pat, val: Box::new(map_get_call), nxt: Some(Box::new(acc)) }
227  })
228}
229
230fn gen_map_var(id: &mut usize) -> Name {
231  let name = Name::new(format!("map/get%{}", id));
232  *id += 1;
233  name
234}