1use crate::ast::{BinaryOperator, Expr, Literal, UnaryOperator};
7
8#[derive(Debug, Clone)]
10pub struct Optimizer;
11
12impl Optimizer {
13 pub fn new() -> Self {
15 Self
16 }
17
18 pub fn optimize(expr: &Expr) -> Expr {
20 Self::fold_constants(expr)
21 }
22
23 pub fn fold_constants(expr: &Expr) -> Expr {
25 match expr {
26 Expr::BinaryOp { op, left, right } => {
28 let left_folded = Self::fold_constants(left);
29 let right_folded = Self::fold_constants(right);
30
31 if let (Expr::Literal(left_lit), Expr::Literal(right_lit)) =
33 (&left_folded, &right_folded)
34 {
35 if let Some(folded) = Self::fold_binary_op(*op, left_lit, right_lit) {
36 return folded;
37 }
38 }
39
40 Expr::BinaryOp {
41 op: *op,
42 left: Box::new(left_folded),
43 right: Box::new(right_folded),
44 }
45 }
46
47 Expr::UnaryOp { op, operand } => {
49 let operand_folded = Self::fold_constants(operand);
50
51 if let Expr::Literal(lit) = &operand_folded {
52 if let Some(folded) = Self::fold_unary_op(*op, lit) {
53 return folded;
54 }
55 }
56
57 Expr::UnaryOp {
58 op: *op,
59 operand: Box::new(operand_folded),
60 }
61 }
62
63 Expr::Array(elements) => {
65 let folded: Vec<Expr> = elements.iter().map(Self::fold_constants).collect();
66 Expr::Array(folded)
67 }
68
69 Expr::Object(fields) => {
70 let folded: Vec<(String, Expr)> = fields
71 .iter()
72 .map(|(k, v)| (k.clone(), Self::fold_constants(v)))
73 .collect();
74 Expr::Object(folded)
75 }
76
77 Expr::FieldAccess { receiver, field } => Expr::FieldAccess {
78 receiver: Box::new(Self::fold_constants(receiver)),
79 field: field.clone(),
80 },
81
82 Expr::FunctionCall { name, args } => Expr::FunctionCall {
83 name: name.clone(),
84 args: args.iter().map(Self::fold_constants).collect(),
85 },
86
87 Expr::Lambda { param, body } => Expr::Lambda {
88 param: param.clone(),
89 body: Box::new(Self::fold_constants(body)),
90 },
91
92 Expr::Let { name, value, body } => Expr::Let {
93 name: name.clone(),
94 value: Box::new(Self::fold_constants(value)),
95 body: Box::new(Self::fold_constants(body)),
96 },
97
98 Expr::If {
99 condition,
100 then_branch,
101 else_branch,
102 } => Expr::If {
103 condition: Box::new(Self::fold_constants(condition)),
104 then_branch: Box::new(Self::fold_constants(then_branch)),
105 else_branch: Box::new(Self::fold_constants(else_branch)),
106 },
107
108 Expr::Pipe { value, functions } => Expr::Pipe {
109 value: Box::new(Self::fold_constants(value)),
110 functions: functions.iter().map(Self::fold_constants).collect(),
111 },
112
113 Expr::Alternative {
114 primary,
115 alternative,
116 } => Expr::Alternative {
117 primary: Box::new(Self::fold_constants(primary)),
118 alternative: Box::new(Self::fold_constants(alternative)),
119 },
120
121 Expr::Guard { condition, body } => Expr::Guard {
122 condition: Box::new(Self::fold_constants(condition)),
123 body: Box::new(Self::fold_constants(body)),
124 },
125
126 expr => expr.clone(),
128 }
129 }
130
131 fn fold_binary_op(op: BinaryOperator, left: &Literal, right: &Literal) -> Option<Expr> {
133 match (left, right) {
134 (Literal::Integer(l), Literal::Integer(r)) => {
135 match op {
136 BinaryOperator::Add => {
137 let result = l.checked_add(*r)?;
138 Some(Expr::Literal(Literal::Integer(result)))
139 }
140 BinaryOperator::Sub => {
141 let result = l.checked_sub(*r)?;
142 Some(Expr::Literal(Literal::Integer(result)))
143 }
144 BinaryOperator::Mul => {
145 let result = l.checked_mul(*r)?;
146 Some(Expr::Literal(Literal::Integer(result)))
147 }
148 BinaryOperator::Div if *r != 0 => {
149 let result = l.checked_div(*r)?;
150 Some(Expr::Literal(Literal::Integer(result)))
151 }
152 BinaryOperator::Mod if *r != 0 => Some(Expr::Literal(Literal::Integer(l % r))),
153 BinaryOperator::Pow => {
154 if *r < 0 || *r > 31 {
155 return None; }
157 let result = l.checked_pow(*r as u32)?;
158 Some(Expr::Literal(Literal::Integer(result)))
159 }
160 BinaryOperator::Eq => Some(Expr::Literal(Literal::Boolean(l == r))),
161 BinaryOperator::Neq => Some(Expr::Literal(Literal::Boolean(l != r))),
162 BinaryOperator::Lt => Some(Expr::Literal(Literal::Boolean(l < r))),
163 BinaryOperator::Lte => Some(Expr::Literal(Literal::Boolean(l <= r))),
164 BinaryOperator::Gt => Some(Expr::Literal(Literal::Boolean(l > r))),
165 BinaryOperator::Gte => Some(Expr::Literal(Literal::Boolean(l >= r))),
166 #[allow(clippy::needless_return)]
167 _ => return None,
168 }
169 }
170
171 (Literal::Float(l), Literal::Float(r)) => {
172 let result = match op {
173 BinaryOperator::Add => l + r,
174 BinaryOperator::Sub => l - r,
175 BinaryOperator::Mul => l * r,
176 BinaryOperator::Div if *r != 0.0 => l / r,
177 BinaryOperator::Mod if *r != 0.0 => l % r,
178 BinaryOperator::Pow => l.powf(*r),
179 BinaryOperator::Eq => {
180 return Some(Expr::Literal(Literal::Boolean(
181 (l - r).abs() < f64::EPSILON,
182 )))
183 }
184 BinaryOperator::Neq => {
185 return Some(Expr::Literal(Literal::Boolean(
186 (l - r).abs() >= f64::EPSILON,
187 )))
188 }
189 BinaryOperator::Lt => return Some(Expr::Literal(Literal::Boolean(l < r))),
190 BinaryOperator::Lte => return Some(Expr::Literal(Literal::Boolean(l <= r))),
191 BinaryOperator::Gt => return Some(Expr::Literal(Literal::Boolean(l > r))),
192 BinaryOperator::Gte => return Some(Expr::Literal(Literal::Boolean(l >= r))),
193 _ => return None,
194 };
195 Some(Expr::Literal(Literal::Float(result)))
196 }
197
198 (Literal::Boolean(l), Literal::Boolean(r)) => match op {
199 BinaryOperator::And => Some(Expr::Literal(Literal::Boolean(*l && *r))),
200 BinaryOperator::Or => Some(Expr::Literal(Literal::Boolean(*l || *r))),
201 BinaryOperator::Eq => Some(Expr::Literal(Literal::Boolean(l == r))),
202 BinaryOperator::Neq => Some(Expr::Literal(Literal::Boolean(l != r))),
203 _ => None,
204 },
205
206 _ => None,
207 }
208 }
209
210 fn fold_unary_op(op: UnaryOperator, lit: &Literal) -> Option<Expr> {
212 match op {
213 UnaryOperator::Not => {
214 if let Literal::Boolean(b) = lit {
215 Some(Expr::Literal(Literal::Boolean(!b)))
216 } else {
217 None
218 }
219 }
220 UnaryOperator::Neg => match lit {
221 Literal::Integer(n) => Some(Expr::Literal(Literal::Integer(-n))),
222 Literal::Float(f) => Some(Expr::Literal(Literal::Float(-f))),
223 _ => None,
224 },
225 UnaryOperator::Plus => match lit {
226 Literal::Integer(n) => Some(Expr::Literal(Literal::Integer(*n))),
227 Literal::Float(f) => Some(Expr::Literal(Literal::Float(*f))),
228 _ => None,
229 },
230 }
231 }
232}
233
234impl Default for Optimizer {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_optimizer_creation() {
246 let _opt = Optimizer::new();
247 }
248
249 #[test]
250 fn test_fold_integer_addition() {
251 let expr = Expr::BinaryOp {
252 op: BinaryOperator::Add,
253 left: Box::new(Expr::Literal(Literal::Integer(5))),
254 right: Box::new(Expr::Literal(Literal::Integer(3))),
255 };
256
257 let folded = Optimizer::optimize(&expr);
258 match folded {
259 Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 8),
260 _ => panic!("Expected folded integer literal"),
261 }
262 }
263
264 #[test]
265 fn test_fold_integer_subtraction() {
266 let expr = Expr::BinaryOp {
267 op: BinaryOperator::Sub,
268 left: Box::new(Expr::Literal(Literal::Integer(10))),
269 right: Box::new(Expr::Literal(Literal::Integer(3))),
270 };
271
272 let folded = Optimizer::optimize(&expr);
273 match folded {
274 Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 7),
275 _ => panic!("Expected folded integer literal"),
276 }
277 }
278
279 #[test]
280 fn test_fold_integer_multiplication() {
281 let expr = Expr::BinaryOp {
282 op: BinaryOperator::Mul,
283 left: Box::new(Expr::Literal(Literal::Integer(4))),
284 right: Box::new(Expr::Literal(Literal::Integer(3))),
285 };
286
287 let folded = Optimizer::optimize(&expr);
288 match folded {
289 Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 12),
290 _ => panic!("Expected folded integer literal"),
291 }
292 }
293
294 #[test]
295 fn test_fold_float_addition() {
296 let expr = Expr::BinaryOp {
297 op: BinaryOperator::Add,
298 left: Box::new(Expr::Literal(Literal::Float(2.5))),
299 right: Box::new(Expr::Literal(Literal::Float(1.5))),
300 };
301
302 let folded = Optimizer::optimize(&expr);
303 match folded {
304 Expr::Literal(Literal::Float(f)) => assert!((f - 4.0).abs() < f64::EPSILON),
305 _ => panic!("Expected folded float literal"),
306 }
307 }
308
309 #[test]
310 fn test_fold_boolean_and() {
311 let expr = Expr::BinaryOp {
312 op: BinaryOperator::And,
313 left: Box::new(Expr::Literal(Literal::Boolean(true))),
314 right: Box::new(Expr::Literal(Literal::Boolean(false))),
315 };
316
317 let folded = Optimizer::optimize(&expr);
318 match folded {
319 Expr::Literal(Literal::Boolean(b)) => assert!(!b),
320 _ => panic!("Expected folded boolean literal"),
321 }
322 }
323
324 #[test]
325 fn test_fold_boolean_or() {
326 let expr = Expr::BinaryOp {
327 op: BinaryOperator::Or,
328 left: Box::new(Expr::Literal(Literal::Boolean(false))),
329 right: Box::new(Expr::Literal(Literal::Boolean(true))),
330 };
331
332 let folded = Optimizer::optimize(&expr);
333 match folded {
334 Expr::Literal(Literal::Boolean(b)) => assert!(b),
335 _ => panic!("Expected folded boolean literal"),
336 }
337 }
338
339 #[test]
340 fn test_fold_integer_comparison() {
341 let expr = Expr::BinaryOp {
342 op: BinaryOperator::Gt,
343 left: Box::new(Expr::Literal(Literal::Integer(10))),
344 right: Box::new(Expr::Literal(Literal::Integer(5))),
345 };
346
347 let folded = Optimizer::optimize(&expr);
348 match folded {
349 Expr::Literal(Literal::Boolean(b)) => assert!(b),
350 _ => panic!("Expected folded boolean literal"),
351 }
352 }
353
354 #[test]
355 fn test_fold_unary_not() {
356 let expr = Expr::UnaryOp {
357 op: UnaryOperator::Not,
358 operand: Box::new(Expr::Literal(Literal::Boolean(true))),
359 };
360
361 let folded = Optimizer::optimize(&expr);
362 match folded {
363 Expr::Literal(Literal::Boolean(b)) => assert!(!b),
364 _ => panic!("Expected folded boolean literal"),
365 }
366 }
367
368 #[test]
369 fn test_fold_unary_negate() {
370 let expr = Expr::UnaryOp {
371 op: UnaryOperator::Neg,
372 operand: Box::new(Expr::Literal(Literal::Integer(42))),
373 };
374
375 let folded = Optimizer::optimize(&expr);
376 match folded {
377 Expr::Literal(Literal::Integer(n)) => assert_eq!(n, -42),
378 _ => panic!("Expected folded integer literal"),
379 }
380 }
381
382 #[test]
383 fn test_no_fold_identifier() {
384 let expr = Expr::BinaryOp {
385 op: BinaryOperator::Add,
386 left: Box::new(Expr::Identifier("x".to_string())),
387 right: Box::new(Expr::Literal(Literal::Integer(5))),
388 };
389
390 let folded = Optimizer::optimize(&expr);
391 matches!(folded, Expr::BinaryOp { .. });
393 }
394
395 #[test]
396 fn test_fold_nested_constants() {
397 let expr = Expr::BinaryOp {
398 op: BinaryOperator::Add,
399 left: Box::new(Expr::BinaryOp {
400 op: BinaryOperator::Mul,
401 left: Box::new(Expr::Literal(Literal::Integer(2))),
402 right: Box::new(Expr::Literal(Literal::Integer(3))),
403 }),
404 right: Box::new(Expr::Literal(Literal::Integer(4))),
405 };
406
407 let folded = Optimizer::optimize(&expr);
408 match folded {
409 Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 10), _ => panic!("Expected folded integer literal"),
411 }
412 }
413
414 #[test]
415 fn test_fold_array_constants() {
416 let expr = Expr::Array(vec![
417 Expr::Literal(Literal::Integer(1)),
418 Expr::BinaryOp {
419 op: BinaryOperator::Add,
420 left: Box::new(Expr::Literal(Literal::Integer(2))),
421 right: Box::new(Expr::Literal(Literal::Integer(3))),
422 },
423 ]);
424
425 let folded = Optimizer::optimize(&expr);
426 match folded {
427 Expr::Array(elements) => {
428 assert_eq!(elements.len(), 2);
429 match &elements[1] {
430 Expr::Literal(Literal::Integer(n)) => assert_eq!(*n, 5),
431 _ => panic!("Expected folded constant in array"),
432 }
433 }
434 _ => panic!("Expected array expression"),
435 }
436 }
437
438 #[test]
439 fn test_fold_division_by_zero() {
440 let expr = Expr::BinaryOp {
441 op: BinaryOperator::Div,
442 left: Box::new(Expr::Literal(Literal::Integer(10))),
443 right: Box::new(Expr::Literal(Literal::Integer(0))),
444 };
445
446 let folded = Optimizer::optimize(&expr);
447 matches!(folded, Expr::BinaryOp { .. });
449 }
450}