1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{replace_subtree, CoreExpr, CoreFrame};
3
4pub struct BetaReduce;
7
8impl Pass for BetaReduce {
9 fn run(&self, expr: &mut CoreExpr) -> Changed {
10 if expr.nodes.is_empty() {
11 return false;
12 }
13 match try_beta_reduce(expr) {
14 Some(new_expr) => {
15 *expr = new_expr;
16 true
17 }
18 None => false,
19 }
20 }
21
22 fn name(&self) -> &str {
23 "BetaReduce"
24 }
25}
26
27fn try_beta_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
28 try_beta_at(expr, expr.nodes.len() - 1)
30}
31
32fn try_beta_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
33 match &expr.nodes[idx] {
34 CoreFrame::App { fun, arg } => {
35 if let CoreFrame::Lam { binder, body } = &expr.nodes[*fun] {
37 let body_tree = expr.extract_subtree(*body);
39 let arg_tree = expr.extract_subtree(*arg);
40 let substituted = tidepool_repr::subst::subst(&body_tree, *binder, &arg_tree);
41 Some(replace_subtree(expr, idx, &substituted))
42 } else {
43 try_beta_at(expr, *fun).or_else(|| try_beta_at(expr, *arg))
45 }
46 }
47 other => {
49 let mut result = None;
50 match other {
54 CoreFrame::Var(_) | CoreFrame::Lit(_) => {}
55 CoreFrame::App { .. } => {
56 return None;
58 }
59 CoreFrame::Lam { body, .. } => {
60 result = try_beta_at(expr, *body);
61 }
62 CoreFrame::LetNonRec { rhs, body, .. } => {
63 result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
64 }
65 CoreFrame::LetRec { bindings, body } => {
66 for (_, rhs) in bindings {
67 result = try_beta_at(expr, *rhs);
68 if result.is_some() {
69 break;
70 }
71 }
72 if result.is_none() {
73 result = try_beta_at(expr, *body);
74 }
75 }
76 CoreFrame::Case {
77 scrutinee, alts, ..
78 } => {
79 result = try_beta_at(expr, *scrutinee);
80 if result.is_none() {
81 for alt in alts {
82 result = try_beta_at(expr, alt.body);
83 if result.is_some() {
84 break;
85 }
86 }
87 }
88 }
89 CoreFrame::Con { fields, .. } => {
90 for field in fields {
91 result = try_beta_at(expr, *field);
92 if result.is_some() {
93 break;
94 }
95 }
96 }
97 CoreFrame::Join { rhs, body, .. } => {
98 result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
99 }
100 CoreFrame::Jump { args, .. } => {
101 for arg in args {
102 result = try_beta_at(expr, *arg);
103 if result.is_some() {
104 break;
105 }
106 }
107 }
108 CoreFrame::PrimOp { args, .. } => {
109 for arg in args {
110 result = try_beta_at(expr, *arg);
111 if result.is_some() {
112 break;
113 }
114 }
115 }
116 }
117 result
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use tidepool_eval::{eval, Env, VecHeap};
126 use tidepool_repr::{Literal, VarId};
127
128 #[test]
129 fn test_beta_identity() {
130 let x = VarId(1);
132 let nodes = vec![
133 CoreFrame::Var(x), CoreFrame::Lam { binder: x, body: 0 }, CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::App { fun: 1, arg: 2 }, ];
138 let mut expr = CoreExpr { nodes };
139 let pass = BetaReduce;
140 let changed = pass.run(&mut expr);
141
142 assert!(changed);
143 assert_eq!(expr.nodes.len(), 1);
144 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
145 }
146
147 #[test]
148 fn test_beta_const() {
149 let x = VarId(1);
151 let y = VarId(2);
152 let nodes = vec![
153 CoreFrame::Var(x), CoreFrame::Lam { binder: y, body: 0 }, CoreFrame::Lam { binder: x, body: 1 }, CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::App { fun: 2, arg: 3 }, ];
159 let mut expr = CoreExpr { nodes };
160 let pass = BetaReduce;
161 let changed = pass.run(&mut expr);
162
163 assert!(changed);
164 let root = expr.nodes.len() - 1;
166 if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
167 assert_eq!(*binder, y);
168 if let CoreFrame::Lit(Literal::LitInt(1)) = &expr.nodes[*body] {
169 } else {
171 panic!("Body should be 1, got {:?}", expr.nodes[*body]);
172 }
173 } else {
174 panic!("Result should be Lam, got {:?}", expr.nodes[root]);
175 }
176 }
177
178 #[test]
179 fn test_beta_no_redex() {
180 let x = VarId(1);
182 let nodes = vec![
183 CoreFrame::Var(x), CoreFrame::Lam { binder: x, body: 0 }, ];
186 let mut expr = CoreExpr { nodes };
187 let pass = BetaReduce;
188 let changed = pass.run(&mut expr);
189 assert!(!changed);
190 }
191
192 #[test]
193 fn test_beta_capture_avoiding() {
194 let x = VarId(1);
196 let y = VarId(2);
197 let nodes = vec![
198 CoreFrame::Var(x), CoreFrame::Lam { binder: y, body: 0 }, CoreFrame::Lam { binder: x, body: 1 }, CoreFrame::Var(y), CoreFrame::App { fun: 2, arg: 3 }, ];
204 let mut expr = CoreExpr { nodes };
205 let pass = BetaReduce;
206 let changed = pass.run(&mut expr);
207
208 assert!(changed);
209 let root = expr.nodes.len() - 1;
210 if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
211 assert_ne!(*binder, y); if let CoreFrame::Var(v) = &expr.nodes[*body] {
213 assert_eq!(*v, y); } else {
215 panic!("Body should be Var(y)");
216 }
217 } else {
218 panic!("Result should be Lam");
219 }
220 }
221
222 #[test]
223 fn test_beta_preserves_eval() {
224 let x = VarId(1);
226 let nodes = vec![
227 CoreFrame::Var(x), CoreFrame::PrimOp {
229 op: tidepool_repr::PrimOpKind::IntAdd,
230 args: vec![0, 0],
231 }, CoreFrame::Lam { binder: x, body: 1 }, CoreFrame::Lit(Literal::LitInt(21)), CoreFrame::App { fun: 2, arg: 3 }, ];
236 let expr_orig = CoreExpr { nodes };
237 let mut expr_reduced = expr_orig.clone();
238 let pass = BetaReduce;
239 pass.run(&mut expr_reduced);
240
241 let mut heap = VecHeap::new();
242 let env = Env::new();
243
244 let val_orig = eval(&expr_orig, &env, &mut heap).expect("Original eval failed");
245 let val_reduced = eval(&expr_reduced, &env, &mut heap).expect("Reduced eval failed");
246
247 if let (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) =
248 (&val_orig, &val_reduced)
249 {
250 assert_eq!(l1, l2);
251 } else {
252 panic!(
253 "Expected literal results, got {:?} and {:?}",
254 val_orig, val_reduced
255 );
256 }
257
258 if let tidepool_eval::Value::Lit(Literal::LitInt(n)) = val_orig {
259 assert_eq!(n, 42);
260 } else {
261 panic!("Expected 42");
262 }
263 }
264
265 #[test]
266 fn test_beta_nested() {
267 let x = VarId(1);
269 let y = VarId(2);
270 let nodes = vec![
271 CoreFrame::Var(y), CoreFrame::Lam { binder: y, body: 0 }, CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::App { fun: 1, arg: 2 }, CoreFrame::Var(x), CoreFrame::Lam { binder: x, body: 4 }, CoreFrame::App { fun: 5, arg: 3 }, ];
279 let mut expr = CoreExpr { nodes };
280 let pass = BetaReduce;
281
282 while pass.run(&mut expr) {}
284
285 assert_eq!(expr.nodes.len(), 1);
286 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
287 }
288}