1use decy_hir::{BinaryOperator, HirExpression, HirFunction, HirStatement};
33#[cfg(test)]
34use decy_hir::HirType;
35
36const MAX_ITERATIONS: usize = 3;
38
39pub fn optimize_function(func: &HirFunction) -> HirFunction {
44 let mut body = func.body().to_vec();
45 let mut changed = true;
46 let mut iterations = 0;
47
48 while changed && iterations < MAX_ITERATIONS {
49 changed = false;
50 let prev = body.clone();
51
52 body = body.into_iter().map(fold_constants_stmt).collect();
53 body = remove_dead_branches(body);
54 body = eliminate_temporaries(body);
55
56 if body != prev {
57 changed = true;
58 }
59 iterations += 1;
60 }
61
62 HirFunction::new_with_body(
63 func.name().to_string(),
64 func.return_type().clone(),
65 func.parameters().to_vec(),
66 body,
67 )
68}
69
70fn fold_constants_stmt(stmt: HirStatement) -> HirStatement {
76 match stmt {
77 HirStatement::VariableDeclaration {
78 name,
79 var_type,
80 initializer,
81 } => HirStatement::VariableDeclaration {
82 name,
83 var_type,
84 initializer: initializer.map(fold_constants_expr),
85 },
86 HirStatement::Return(expr) => HirStatement::Return(expr.map(fold_constants_expr)),
87 HirStatement::Assignment { target, value } => HirStatement::Assignment {
88 target,
89 value: fold_constants_expr(value),
90 },
91 HirStatement::If {
92 condition,
93 then_block,
94 else_block,
95 } => HirStatement::If {
96 condition: fold_constants_expr(condition),
97 then_block: then_block.into_iter().map(fold_constants_stmt).collect(),
98 else_block: else_block
99 .map(|block| block.into_iter().map(fold_constants_stmt).collect()),
100 },
101 HirStatement::While { condition, body } => HirStatement::While {
102 condition: fold_constants_expr(condition),
103 body: body.into_iter().map(fold_constants_stmt).collect(),
104 },
105 HirStatement::For {
106 init,
107 condition,
108 increment,
109 body,
110 } => HirStatement::For {
111 init: init.into_iter().map(fold_constants_stmt).collect(),
112 condition: condition.map(fold_constants_expr),
113 increment: increment.into_iter().map(fold_constants_stmt).collect(),
114 body: body.into_iter().map(fold_constants_stmt).collect(),
115 },
116 HirStatement::Expression(expr) => HirStatement::Expression(fold_constants_expr(expr)),
117 other => other,
119 }
120}
121
122fn fold_constants_expr(expr: HirExpression) -> HirExpression {
124 match expr {
125 HirExpression::BinaryOp { op, left, right } => {
126 let left = fold_constants_expr(*left);
127 let right = fold_constants_expr(*right);
128
129 if let (HirExpression::IntLiteral(l), HirExpression::IntLiteral(r)) = (&left, &right) {
131 if let Some(result) = fold_int_binary(*l, op, *r) {
132 return HirExpression::IntLiteral(result);
133 }
134 }
135
136 HirExpression::BinaryOp {
137 op,
138 left: Box::new(left),
139 right: Box::new(right),
140 }
141 }
142 HirExpression::UnaryOp { op, operand } => {
143 let operand = fold_constants_expr(*operand);
144 if let (decy_hir::UnaryOperator::Minus, HirExpression::IntLiteral(v)) =
145 (op, &operand)
146 {
147 return HirExpression::IntLiteral(-v);
148 }
149 HirExpression::UnaryOp {
150 op,
151 operand: Box::new(operand),
152 }
153 }
154 HirExpression::FunctionCall {
156 function,
157 arguments,
158 } => HirExpression::FunctionCall {
159 function,
160 arguments: arguments.into_iter().map(fold_constants_expr).collect(),
161 },
162 other => other,
163 }
164}
165
166fn fold_int_binary(left: i32, op: BinaryOperator, right: i32) -> Option<i32> {
168 match op {
169 BinaryOperator::Add => left.checked_add(right),
170 BinaryOperator::Subtract => left.checked_sub(right),
171 BinaryOperator::Multiply => left.checked_mul(right),
172 BinaryOperator::Divide => {
173 if right != 0 {
174 left.checked_div(right)
175 } else {
176 None
177 }
178 }
179 BinaryOperator::Modulo => {
180 if right != 0 {
181 left.checked_rem(right)
182 } else {
183 None
184 }
185 }
186 BinaryOperator::LeftShift => {
187 if (0..32).contains(&right) {
188 Some(left << right)
189 } else {
190 None
191 }
192 }
193 BinaryOperator::RightShift => {
194 if (0..32).contains(&right) {
195 Some(left >> right)
196 } else {
197 None
198 }
199 }
200 BinaryOperator::BitwiseAnd => Some(left & right),
201 BinaryOperator::BitwiseOr => Some(left | right),
202 BinaryOperator::BitwiseXor => Some(left ^ right),
203 _ => None,
204 }
205}
206
207fn remove_dead_branches(stmts: Vec<HirStatement>) -> Vec<HirStatement> {
213 let mut result = Vec::new();
214
215 for stmt in stmts {
216 match stmt {
217 HirStatement::If {
218 condition,
219 then_block,
220 else_block,
221 } => {
222 if let Some(always_true) = is_constant_truthy(&condition) {
223 if always_true {
224 result.extend(
226 then_block
227 .into_iter()
228 .map(fold_constants_stmt)
229 .collect::<Vec<_>>(),
230 );
231 } else if let Some(else_body) = else_block {
232 result.extend(
234 else_body
235 .into_iter()
236 .map(fold_constants_stmt)
237 .collect::<Vec<_>>(),
238 );
239 }
240 } else {
242 result.push(HirStatement::If {
244 condition,
245 then_block: remove_dead_branches(then_block),
246 else_block: else_block.map(remove_dead_branches),
247 });
248 }
249 }
250 HirStatement::While { condition, body } => {
251 if let Some(false) = is_constant_truthy(&condition) {
253 } else {
255 result.push(HirStatement::While {
256 condition,
257 body: remove_dead_branches(body),
258 });
259 }
260 }
261 other => result.push(other),
262 }
263 }
264
265 result
266}
267
268fn is_constant_truthy(expr: &HirExpression) -> Option<bool> {
270 match expr {
271 HirExpression::IntLiteral(0) => Some(false),
272 HirExpression::IntLiteral(_) => Some(true),
273 _ => None,
274 }
275}
276
277fn eliminate_temporaries(stmts: Vec<HirStatement>) -> Vec<HirStatement> {
288 if stmts.len() < 2 {
289 return stmts;
290 }
291
292 let mut result = Vec::new();
293 let mut skip_next = false;
294
295 for i in 0..stmts.len() {
296 if skip_next {
297 skip_next = false;
298 continue;
299 }
300
301 if i + 1 < stmts.len() {
303 if let (
304 HirStatement::VariableDeclaration {
305 name,
306 initializer: Some(init_expr),
307 ..
308 },
309 HirStatement::Return(Some(HirExpression::Variable(ret_var))),
310 ) = (&stmts[i], &stmts[i + 1])
311 {
312 if name == ret_var
313 && count_uses(name, &stmts[i + 2..]) == 0
314 && !is_allocation_expr(init_expr)
315 {
316 result.push(HirStatement::Return(Some(init_expr.clone())));
318 skip_next = true;
319 continue;
320 }
321 }
322 }
323
324 result.push(stmts[i].clone());
325 }
326
327 result
328}
329
330fn is_allocation_expr(expr: &HirExpression) -> bool {
333 match expr {
334 HirExpression::Malloc { .. }
335 | HirExpression::Calloc { .. }
336 | HirExpression::Realloc { .. } => true,
337 HirExpression::Cast { expr: inner, .. } => is_allocation_expr(inner),
338 HirExpression::FunctionCall { function, .. } => {
339 matches!(function.as_str(), "malloc" | "calloc" | "realloc")
340 }
341 _ => false,
342 }
343}
344
345fn count_uses(name: &str, stmts: &[HirStatement]) -> usize {
347 let mut count = 0;
348 for stmt in stmts {
349 count += count_uses_in_stmt(name, stmt);
350 }
351 count
352}
353
354fn count_uses_in_stmt(name: &str, stmt: &HirStatement) -> usize {
356 match stmt {
357 HirStatement::Return(Some(expr)) => count_uses_in_expr(name, expr),
358 HirStatement::Assignment { value, .. } => count_uses_in_expr(name, value),
359 HirStatement::Expression(expr) => count_uses_in_expr(name, expr),
360 HirStatement::If {
361 condition,
362 then_block,
363 else_block,
364 } => {
365 let mut c = count_uses_in_expr(name, condition);
366 for s in then_block {
367 c += count_uses_in_stmt(name, s);
368 }
369 if let Some(block) = else_block {
370 for s in block {
371 c += count_uses_in_stmt(name, s);
372 }
373 }
374 c
375 }
376 _ => 0,
377 }
378}
379
380fn count_uses_in_expr(name: &str, expr: &HirExpression) -> usize {
382 match expr {
383 HirExpression::Variable(v) if v == name => 1,
384 HirExpression::BinaryOp { left, right, .. } => {
385 count_uses_in_expr(name, left) + count_uses_in_expr(name, right)
386 }
387 HirExpression::FunctionCall { arguments, .. } => {
388 arguments.iter().map(|a| count_uses_in_expr(name, a)).sum()
389 }
390 HirExpression::UnaryOp { operand, .. } => count_uses_in_expr(name, operand),
391 _ => 0,
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_constant_folding_add() {
401 let expr = HirExpression::BinaryOp {
402 op: BinaryOperator::Add,
403 left: Box::new(HirExpression::IntLiteral(2)),
404 right: Box::new(HirExpression::IntLiteral(3)),
405 };
406 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(5));
407 }
408
409 #[test]
410 fn test_constant_folding_multiply() {
411 let expr = HirExpression::BinaryOp {
412 op: BinaryOperator::Multiply,
413 left: Box::new(HirExpression::IntLiteral(4)),
414 right: Box::new(HirExpression::IntLiteral(5)),
415 };
416 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(20));
417 }
418
419 #[test]
420 fn test_constant_folding_nested() {
421 let expr = HirExpression::BinaryOp {
423 op: BinaryOperator::Multiply,
424 left: Box::new(HirExpression::BinaryOp {
425 op: BinaryOperator::Add,
426 left: Box::new(HirExpression::IntLiteral(2)),
427 right: Box::new(HirExpression::IntLiteral(3)),
428 }),
429 right: Box::new(HirExpression::IntLiteral(4)),
430 };
431 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(20));
432 }
433
434 #[test]
435 fn test_constant_folding_division_by_zero() {
436 let expr = HirExpression::BinaryOp {
437 op: BinaryOperator::Divide,
438 left: Box::new(HirExpression::IntLiteral(10)),
439 right: Box::new(HirExpression::IntLiteral(0)),
440 };
441 match fold_constants_expr(expr) {
443 HirExpression::BinaryOp { .. } => {} other => panic!("Expected BinaryOp, got {:?}", other),
445 }
446 }
447
448 #[test]
449 fn test_constant_folding_non_literal() {
450 let expr = HirExpression::BinaryOp {
451 op: BinaryOperator::Add,
452 left: Box::new(HirExpression::Variable("x".to_string())),
453 right: Box::new(HirExpression::IntLiteral(3)),
454 };
455 match fold_constants_expr(expr) {
457 HirExpression::BinaryOp { .. } => {} other => panic!("Expected BinaryOp, got {:?}", other),
459 }
460 }
461
462 #[test]
463 fn test_dead_branch_removal_true() {
464 let stmts = vec![HirStatement::If {
466 condition: HirExpression::IntLiteral(1),
467 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(42)))],
468 else_block: None,
469 }];
470
471 let result = remove_dead_branches(stmts);
472 assert_eq!(result.len(), 1);
473 assert_eq!(
474 result[0],
475 HirStatement::Return(Some(HirExpression::IntLiteral(42)))
476 );
477 }
478
479 #[test]
480 fn test_dead_branch_removal_false_no_else() {
481 let stmts = vec![HirStatement::If {
483 condition: HirExpression::IntLiteral(0),
484 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(42)))],
485 else_block: None,
486 }];
487
488 let result = remove_dead_branches(stmts);
489 assert!(result.is_empty());
490 }
491
492 #[test]
493 fn test_dead_branch_removal_false_with_else() {
494 let stmts = vec![HirStatement::If {
496 condition: HirExpression::IntLiteral(0),
497 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(42)))],
498 else_block: Some(vec![HirStatement::Return(Some(
499 HirExpression::IntLiteral(99),
500 ))]),
501 }];
502
503 let result = remove_dead_branches(stmts);
504 assert_eq!(result.len(), 1);
505 assert_eq!(
506 result[0],
507 HirStatement::Return(Some(HirExpression::IntLiteral(99)))
508 );
509 }
510
511 #[test]
512 fn test_dead_while_zero() {
513 let stmts = vec![HirStatement::While {
515 condition: HirExpression::IntLiteral(0),
516 body: vec![HirStatement::Break],
517 }];
518
519 let result = remove_dead_branches(stmts);
520 assert!(result.is_empty());
521 }
522
523 #[test]
524 fn test_temp_elimination_return() {
525 let stmts = vec![
527 HirStatement::VariableDeclaration {
528 name: "tmp".to_string(),
529 var_type: HirType::Int,
530 initializer: Some(HirExpression::IntLiteral(42)),
531 },
532 HirStatement::Return(Some(HirExpression::Variable("tmp".to_string()))),
533 ];
534
535 let result = eliminate_temporaries(stmts);
536 assert_eq!(result.len(), 1);
537 assert_eq!(
538 result[0],
539 HirStatement::Return(Some(HirExpression::IntLiteral(42)))
540 );
541 }
542
543 #[test]
544 fn test_optimize_function_combined() {
545 let func = HirFunction::new_with_body(
547 "test".to_string(),
548 HirType::Int,
549 vec![],
550 vec![
551 HirStatement::VariableDeclaration {
552 name: "x".to_string(),
553 var_type: HirType::Int,
554 initializer: Some(HirExpression::BinaryOp {
555 op: BinaryOperator::Add,
556 left: Box::new(HirExpression::IntLiteral(2)),
557 right: Box::new(HirExpression::IntLiteral(3)),
558 }),
559 },
560 HirStatement::If {
561 condition: HirExpression::IntLiteral(1),
562 then_block: vec![HirStatement::Return(Some(HirExpression::Variable(
563 "x".to_string(),
564 )))],
565 else_block: None,
566 },
567 ],
568 );
569
570 let optimized = optimize_function(&func);
571 let body = optimized.body();
574 assert!(
575 body.len() <= 2,
576 "Expected at most 2 statements, got {}",
577 body.len()
578 );
579 }
580
581 #[test]
582 fn test_unary_minus_folding() {
583 let expr = HirExpression::UnaryOp {
584 op: decy_hir::UnaryOperator::Minus,
585 operand: Box::new(HirExpression::IntLiteral(42)),
586 };
587 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(-42));
588 }
589
590 #[test]
591 fn test_bitwise_folding() {
592 let expr = HirExpression::BinaryOp {
593 op: BinaryOperator::BitwiseAnd,
594 left: Box::new(HirExpression::IntLiteral(0xFF)),
595 right: Box::new(HirExpression::IntLiteral(0x0F)),
596 };
597 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(0x0F));
598 }
599
600 #[test]
605 fn test_constant_folding_subtract() {
606 let expr = HirExpression::BinaryOp {
607 op: BinaryOperator::Subtract,
608 left: Box::new(HirExpression::IntLiteral(10)),
609 right: Box::new(HirExpression::IntLiteral(3)),
610 };
611 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(7));
612 }
613
614 #[test]
615 fn test_constant_folding_divide() {
616 let expr = HirExpression::BinaryOp {
617 op: BinaryOperator::Divide,
618 left: Box::new(HirExpression::IntLiteral(20)),
619 right: Box::new(HirExpression::IntLiteral(4)),
620 };
621 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(5));
622 }
623
624 #[test]
625 fn test_constant_folding_modulo() {
626 let expr = HirExpression::BinaryOp {
627 op: BinaryOperator::Modulo,
628 left: Box::new(HirExpression::IntLiteral(17)),
629 right: Box::new(HirExpression::IntLiteral(5)),
630 };
631 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(2));
632 }
633
634 #[test]
635 fn test_constant_folding_modulo_by_zero() {
636 let expr = HirExpression::BinaryOp {
637 op: BinaryOperator::Modulo,
638 left: Box::new(HirExpression::IntLiteral(17)),
639 right: Box::new(HirExpression::IntLiteral(0)),
640 };
641 match fold_constants_expr(expr) {
642 HirExpression::BinaryOp { .. } => {}
643 other => panic!("Expected BinaryOp, got {:?}", other),
644 }
645 }
646
647 #[test]
648 fn test_constant_folding_left_shift() {
649 let expr = HirExpression::BinaryOp {
650 op: BinaryOperator::LeftShift,
651 left: Box::new(HirExpression::IntLiteral(1)),
652 right: Box::new(HirExpression::IntLiteral(4)),
653 };
654 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(16));
655 }
656
657 #[test]
658 fn test_constant_folding_left_shift_overflow() {
659 let expr = HirExpression::BinaryOp {
660 op: BinaryOperator::LeftShift,
661 left: Box::new(HirExpression::IntLiteral(1)),
662 right: Box::new(HirExpression::IntLiteral(32)),
663 };
664 match fold_constants_expr(expr) {
666 HirExpression::BinaryOp { .. } => {}
667 other => panic!("Expected BinaryOp, got {:?}", other),
668 }
669 }
670
671 #[test]
672 fn test_constant_folding_right_shift() {
673 let expr = HirExpression::BinaryOp {
674 op: BinaryOperator::RightShift,
675 left: Box::new(HirExpression::IntLiteral(16)),
676 right: Box::new(HirExpression::IntLiteral(2)),
677 };
678 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(4));
679 }
680
681 #[test]
682 fn test_constant_folding_right_shift_overflow() {
683 let expr = HirExpression::BinaryOp {
684 op: BinaryOperator::RightShift,
685 left: Box::new(HirExpression::IntLiteral(16)),
686 right: Box::new(HirExpression::IntLiteral(-1)),
687 };
688 match fold_constants_expr(expr) {
689 HirExpression::BinaryOp { .. } => {}
690 other => panic!("Expected BinaryOp, got {:?}", other),
691 }
692 }
693
694 #[test]
695 fn test_constant_folding_bitwise_or() {
696 let expr = HirExpression::BinaryOp {
697 op: BinaryOperator::BitwiseOr,
698 left: Box::new(HirExpression::IntLiteral(0xF0)),
699 right: Box::new(HirExpression::IntLiteral(0x0F)),
700 };
701 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(0xFF));
702 }
703
704 #[test]
705 fn test_constant_folding_bitwise_xor() {
706 let expr = HirExpression::BinaryOp {
707 op: BinaryOperator::BitwiseXor,
708 left: Box::new(HirExpression::IntLiteral(0xFF)),
709 right: Box::new(HirExpression::IntLiteral(0x0F)),
710 };
711 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(0xF0));
712 }
713
714 #[test]
715 fn test_constant_folding_unsupported_op() {
716 let expr = HirExpression::BinaryOp {
718 op: BinaryOperator::Equal,
719 left: Box::new(HirExpression::IntLiteral(5)),
720 right: Box::new(HirExpression::IntLiteral(5)),
721 };
722 match fold_constants_expr(expr) {
723 HirExpression::BinaryOp { .. } => {}
724 other => panic!("Expected BinaryOp, got {:?}", other),
725 }
726 }
727
728 #[test]
733 fn test_constant_folding_function_call_args() {
734 let expr = HirExpression::FunctionCall {
736 function: "foo".to_string(),
737 arguments: vec![HirExpression::BinaryOp {
738 op: BinaryOperator::Add,
739 left: Box::new(HirExpression::IntLiteral(2)),
740 right: Box::new(HirExpression::IntLiteral(3)),
741 }],
742 };
743 match fold_constants_expr(expr) {
744 HirExpression::FunctionCall {
745 function,
746 arguments,
747 } => {
748 assert_eq!(function, "foo");
749 assert_eq!(arguments, vec![HirExpression::IntLiteral(5)]);
750 }
751 other => panic!("Expected FunctionCall, got {:?}", other),
752 }
753 }
754
755 #[test]
760 fn test_constant_folding_for_loop() {
761 let stmt = HirStatement::For {
763 init: vec![HirStatement::VariableDeclaration {
764 name: "i".to_string(),
765 var_type: HirType::Int,
766 initializer: Some(HirExpression::IntLiteral(0)),
767 }],
768 condition: Some(HirExpression::BinaryOp {
769 op: BinaryOperator::Add,
770 left: Box::new(HirExpression::IntLiteral(2)),
771 right: Box::new(HirExpression::IntLiteral(3)),
772 }),
773 increment: vec![HirStatement::Expression(HirExpression::Variable(
774 "i".to_string(),
775 ))],
776 body: vec![HirStatement::Return(Some(HirExpression::IntLiteral(1)))],
777 };
778
779 match fold_constants_stmt(stmt) {
780 HirStatement::For {
781 condition, body, ..
782 } => {
783 assert_eq!(condition, Some(HirExpression::IntLiteral(5)));
784 assert!(!body.is_empty());
785 }
786 other => panic!("Expected For, got {:?}", other),
787 }
788 }
789
790 #[test]
795 fn test_constant_folding_expression_stmt() {
796 let stmt = HirStatement::Expression(HirExpression::BinaryOp {
797 op: BinaryOperator::Add,
798 left: Box::new(HirExpression::IntLiteral(1)),
799 right: Box::new(HirExpression::IntLiteral(2)),
800 });
801 match fold_constants_stmt(stmt) {
802 HirStatement::Expression(HirExpression::IntLiteral(3)) => {}
803 other => panic!("Expected Expression(IntLiteral(3)), got {:?}", other),
804 }
805 }
806
807 #[test]
808 fn test_constant_folding_assignment() {
809 let stmt = HirStatement::Assignment {
810 target: "x".to_string(),
811 value: HirExpression::BinaryOp {
812 op: BinaryOperator::Multiply,
813 left: Box::new(HirExpression::IntLiteral(6)),
814 right: Box::new(HirExpression::IntLiteral(7)),
815 },
816 };
817 match fold_constants_stmt(stmt) {
818 HirStatement::Assignment { value, .. } => {
819 assert_eq!(value, HirExpression::IntLiteral(42));
820 }
821 other => panic!("Expected Assignment, got {:?}", other),
822 }
823 }
824
825 #[test]
826 fn test_constant_folding_pass_through() {
827 let stmt = HirStatement::Break;
829 assert_eq!(fold_constants_stmt(stmt), HirStatement::Break);
830 }
831
832 #[test]
837 fn test_unary_not_folding_not_applied() {
838 let expr = HirExpression::UnaryOp {
840 op: decy_hir::UnaryOperator::LogicalNot,
841 operand: Box::new(HirExpression::IntLiteral(1)),
842 };
843 match fold_constants_expr(expr) {
844 HirExpression::UnaryOp { .. } => {}
845 other => panic!("Expected UnaryOp, got {:?}", other),
846 }
847 }
848
849 #[test]
850 fn test_unary_minus_on_variable() {
851 let expr = HirExpression::UnaryOp {
853 op: decy_hir::UnaryOperator::Minus,
854 operand: Box::new(HirExpression::Variable("x".to_string())),
855 };
856 match fold_constants_expr(expr) {
857 HirExpression::UnaryOp { .. } => {}
858 other => panic!("Expected UnaryOp, got {:?}", other),
859 }
860 }
861
862 #[test]
867 fn test_is_allocation_expr_malloc() {
868 assert!(is_allocation_expr(&HirExpression::Malloc {
869 size: Box::new(HirExpression::IntLiteral(4)),
870 }));
871 }
872
873 #[test]
874 fn test_is_allocation_expr_calloc() {
875 assert!(is_allocation_expr(&HirExpression::Calloc {
876 count: Box::new(HirExpression::IntLiteral(10)),
877 element_type: Box::new(HirType::Int),
878 }));
879 }
880
881 #[test]
882 fn test_is_allocation_expr_realloc() {
883 assert!(is_allocation_expr(&HirExpression::Realloc {
884 pointer: Box::new(HirExpression::Variable("p".to_string())),
885 new_size: Box::new(HirExpression::IntLiteral(64)),
886 }));
887 }
888
889 #[test]
890 fn test_is_allocation_expr_cast_wrapping_malloc() {
891 assert!(is_allocation_expr(&HirExpression::Cast {
892 target_type: HirType::Pointer(Box::new(HirType::Int)),
893 expr: Box::new(HirExpression::Malloc {
894 size: Box::new(HirExpression::IntLiteral(4)),
895 }),
896 }));
897 }
898
899 #[test]
900 fn test_is_allocation_expr_function_call_malloc() {
901 assert!(is_allocation_expr(&HirExpression::FunctionCall {
902 function: "malloc".to_string(),
903 arguments: vec![HirExpression::IntLiteral(4)],
904 }));
905 }
906
907 #[test]
908 fn test_is_allocation_expr_function_call_calloc() {
909 assert!(is_allocation_expr(&HirExpression::FunctionCall {
910 function: "calloc".to_string(),
911 arguments: vec![HirExpression::IntLiteral(10), HirExpression::IntLiteral(4)],
912 }));
913 }
914
915 #[test]
916 fn test_is_allocation_expr_regular_call() {
917 assert!(!is_allocation_expr(&HirExpression::FunctionCall {
918 function: "printf".to_string(),
919 arguments: vec![],
920 }));
921 }
922
923 #[test]
924 fn test_is_allocation_expr_literal() {
925 assert!(!is_allocation_expr(&HirExpression::IntLiteral(42)));
926 }
927
928 #[test]
933 fn test_count_uses_empty() {
934 assert_eq!(count_uses("x", &[]), 0);
935 }
936
937 #[test]
938 fn test_count_uses_in_return() {
939 let stmts = vec![HirStatement::Return(Some(HirExpression::Variable(
940 "x".to_string(),
941 )))];
942 assert_eq!(count_uses("x", &stmts), 1);
943 assert_eq!(count_uses("y", &stmts), 0);
944 }
945
946 #[test]
947 fn test_count_uses_in_assignment() {
948 let stmts = vec![HirStatement::Assignment {
949 target: "y".to_string(),
950 value: HirExpression::Variable("x".to_string()),
951 }];
952 assert_eq!(count_uses("x", &stmts), 1);
953 }
954
955 #[test]
956 fn test_count_uses_in_expression_stmt() {
957 let stmts = vec![HirStatement::Expression(HirExpression::FunctionCall {
958 function: "foo".to_string(),
959 arguments: vec![
960 HirExpression::Variable("x".to_string()),
961 HirExpression::Variable("x".to_string()),
962 ],
963 })];
964 assert_eq!(count_uses("x", &stmts), 2);
965 }
966
967 #[test]
968 fn test_count_uses_in_if_with_else() {
969 let stmts = vec![HirStatement::If {
970 condition: HirExpression::Variable("x".to_string()),
971 then_block: vec![HirStatement::Return(Some(HirExpression::Variable(
972 "x".to_string(),
973 )))],
974 else_block: Some(vec![HirStatement::Return(Some(HirExpression::Variable(
975 "x".to_string(),
976 )))]),
977 }];
978 assert_eq!(count_uses("x", &stmts), 3); }
980
981 #[test]
982 fn test_count_uses_in_expr_binary_op() {
983 let expr = HirExpression::BinaryOp {
984 op: BinaryOperator::Add,
985 left: Box::new(HirExpression::Variable("x".to_string())),
986 right: Box::new(HirExpression::Variable("x".to_string())),
987 };
988 assert_eq!(count_uses_in_expr("x", &expr), 2);
989 }
990
991 #[test]
992 fn test_count_uses_in_expr_unary_op() {
993 let expr = HirExpression::UnaryOp {
994 op: decy_hir::UnaryOperator::Minus,
995 operand: Box::new(HirExpression::Variable("x".to_string())),
996 };
997 assert_eq!(count_uses_in_expr("x", &expr), 1);
998 }
999
1000 #[test]
1001 fn test_count_uses_in_expr_non_matching() {
1002 let expr = HirExpression::IntLiteral(42);
1003 assert_eq!(count_uses_in_expr("x", &expr), 0);
1004 }
1005
1006 #[test]
1007 fn test_count_uses_in_stmt_break() {
1008 assert_eq!(count_uses_in_stmt("x", &HirStatement::Break), 0);
1009 }
1010
1011 #[test]
1012 fn test_count_uses_in_stmt_return_none() {
1013 assert_eq!(count_uses_in_stmt("x", &HirStatement::Return(None)), 0);
1014 }
1015
1016 #[test]
1021 fn test_temp_elimination_single_stmt() {
1022 let stmts = vec![HirStatement::Return(Some(HirExpression::IntLiteral(1)))];
1024 let result = eliminate_temporaries(stmts.clone());
1025 assert_eq!(result, stmts);
1026 }
1027
1028 #[test]
1029 fn test_temp_elimination_no_match() {
1030 let stmts = vec![
1032 HirStatement::Expression(HirExpression::IntLiteral(1)),
1033 HirStatement::Return(Some(HirExpression::IntLiteral(2))),
1034 ];
1035 let result = eliminate_temporaries(stmts.clone());
1036 assert_eq!(result, stmts);
1037 }
1038
1039 #[test]
1040 fn test_temp_elimination_allocation_preserved() {
1041 let stmts = vec![
1043 HirStatement::VariableDeclaration {
1044 name: "p".to_string(),
1045 var_type: HirType::Pointer(Box::new(HirType::Int)),
1046 initializer: Some(HirExpression::Malloc {
1047 size: Box::new(HirExpression::IntLiteral(4)),
1048 }),
1049 },
1050 HirStatement::Return(Some(HirExpression::Variable("p".to_string()))),
1051 ];
1052 let result = eliminate_temporaries(stmts);
1053 assert_eq!(result.len(), 2); }
1055
1056 #[test]
1057 fn test_temp_elimination_multi_use_preserved() {
1058 let stmts = vec![
1060 HirStatement::VariableDeclaration {
1061 name: "x".to_string(),
1062 var_type: HirType::Int,
1063 initializer: Some(HirExpression::IntLiteral(42)),
1064 },
1065 HirStatement::Return(Some(HirExpression::Variable("x".to_string()))),
1066 HirStatement::Expression(HirExpression::Variable("x".to_string())),
1067 ];
1068 let result = eliminate_temporaries(stmts);
1069 assert_eq!(result.len(), 3); }
1071
1072 #[test]
1077 fn test_dead_branch_non_constant_if() {
1078 let stmts = vec![HirStatement::If {
1080 condition: HirExpression::Variable("x".to_string()),
1081 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(1)))],
1082 else_block: Some(vec![HirStatement::Return(Some(
1083 HirExpression::IntLiteral(0),
1084 ))]),
1085 }];
1086 let result = remove_dead_branches(stmts);
1087 assert_eq!(result.len(), 1);
1088 match &result[0] {
1089 HirStatement::If { .. } => {}
1090 other => panic!("Expected If, got {:?}", other),
1091 }
1092 }
1093
1094 #[test]
1095 fn test_dead_branch_while_non_constant() {
1096 let stmts = vec![HirStatement::While {
1098 condition: HirExpression::Variable("x".to_string()),
1099 body: vec![HirStatement::Break],
1100 }];
1101 let result = remove_dead_branches(stmts);
1102 assert_eq!(result.len(), 1);
1103 match &result[0] {
1104 HirStatement::While { .. } => {}
1105 other => panic!("Expected While, got {:?}", other),
1106 }
1107 }
1108
1109 #[test]
1110 fn test_dead_branch_while_nonzero_constant() {
1111 let stmts = vec![HirStatement::While {
1113 condition: HirExpression::IntLiteral(1),
1114 body: vec![HirStatement::Break],
1115 }];
1116 let result = remove_dead_branches(stmts);
1117 assert_eq!(result.len(), 1);
1118 }
1119
1120 #[test]
1125 fn test_optimize_no_change_single_iteration() {
1126 let func = HirFunction::new_with_body(
1128 "noop".to_string(),
1129 HirType::Int,
1130 vec![],
1131 vec![HirStatement::Return(Some(HirExpression::IntLiteral(0)))],
1132 );
1133 let optimized = optimize_function(&func);
1134 assert_eq!(optimized.body().len(), 1);
1135 }
1136
1137 #[test]
1138 fn test_optimize_empty_function() {
1139 let func = HirFunction::new_with_body(
1140 "empty".to_string(),
1141 HirType::Void,
1142 vec![],
1143 vec![],
1144 );
1145 let optimized = optimize_function(&func);
1146 assert!(optimized.body().is_empty());
1147 }
1148
1149 #[test]
1154 fn test_fold_constants_stmt_while() {
1155 let stmt = HirStatement::While {
1157 condition: HirExpression::BinaryOp {
1158 op: BinaryOperator::Add,
1159 left: Box::new(HirExpression::IntLiteral(2)),
1160 right: Box::new(HirExpression::IntLiteral(3)),
1161 },
1162 body: vec![HirStatement::Assignment {
1163 target: "x".to_string(),
1164 value: HirExpression::BinaryOp {
1165 op: BinaryOperator::Add,
1166 left: Box::new(HirExpression::IntLiteral(1)),
1167 right: Box::new(HirExpression::IntLiteral(1)),
1168 },
1169 }],
1170 };
1171 match fold_constants_stmt(stmt) {
1172 HirStatement::While { condition, body } => {
1173 assert_eq!(condition, HirExpression::IntLiteral(5));
1174 match &body[0] {
1175 HirStatement::Assignment { value, .. } => {
1176 assert_eq!(*value, HirExpression::IntLiteral(2));
1177 }
1178 other => panic!("Expected Assignment, got {:?}", other),
1179 }
1180 }
1181 other => panic!("Expected While, got {:?}", other),
1182 }
1183 }
1184
1185 #[test]
1186 fn test_fold_constants_stmt_return_some() {
1187 let stmt = HirStatement::Return(Some(HirExpression::BinaryOp {
1188 op: BinaryOperator::Add,
1189 left: Box::new(HirExpression::IntLiteral(10)),
1190 right: Box::new(HirExpression::IntLiteral(20)),
1191 }));
1192 match fold_constants_stmt(stmt) {
1193 HirStatement::Return(Some(HirExpression::IntLiteral(30))) => {}
1194 other => panic!("Expected Return(Some(30)), got {:?}", other),
1195 }
1196 }
1197
1198 #[test]
1199 fn test_fold_constants_stmt_return_none() {
1200 let stmt = HirStatement::Return(None);
1201 assert_eq!(fold_constants_stmt(stmt), HirStatement::Return(None));
1202 }
1203
1204 #[test]
1205 fn test_fold_constants_stmt_var_decl_none_init() {
1206 let stmt = HirStatement::VariableDeclaration {
1207 name: "x".to_string(),
1208 var_type: HirType::Int,
1209 initializer: None,
1210 };
1211 match fold_constants_stmt(stmt) {
1212 HirStatement::VariableDeclaration { initializer: None, .. } => {}
1213 other => panic!("Expected VarDecl with None init, got {:?}", other),
1214 }
1215 }
1216
1217 #[test]
1218 fn test_fold_constants_stmt_if_with_else() {
1219 let stmt = HirStatement::If {
1220 condition: HirExpression::BinaryOp {
1221 op: BinaryOperator::Add,
1222 left: Box::new(HirExpression::IntLiteral(1)),
1223 right: Box::new(HirExpression::IntLiteral(2)),
1224 },
1225 then_block: vec![HirStatement::Return(Some(HirExpression::BinaryOp {
1226 op: BinaryOperator::Multiply,
1227 left: Box::new(HirExpression::IntLiteral(3)),
1228 right: Box::new(HirExpression::IntLiteral(4)),
1229 }))],
1230 else_block: Some(vec![HirStatement::Return(Some(HirExpression::BinaryOp {
1231 op: BinaryOperator::Subtract,
1232 left: Box::new(HirExpression::IntLiteral(10)),
1233 right: Box::new(HirExpression::IntLiteral(5)),
1234 }))]),
1235 };
1236 match fold_constants_stmt(stmt) {
1237 HirStatement::If {
1238 condition,
1239 then_block,
1240 else_block,
1241 } => {
1242 assert_eq!(condition, HirExpression::IntLiteral(3));
1243 match &then_block[0] {
1244 HirStatement::Return(Some(HirExpression::IntLiteral(12))) => {}
1245 other => panic!("Expected Return(12), got {:?}", other),
1246 }
1247 let else_stmts = else_block.unwrap();
1248 match &else_stmts[0] {
1249 HirStatement::Return(Some(HirExpression::IntLiteral(5))) => {}
1250 other => panic!("Expected Return(5), got {:?}", other),
1251 }
1252 }
1253 other => panic!("Expected If, got {:?}", other),
1254 }
1255 }
1256
1257 #[test]
1258 fn test_is_allocation_expr_function_call_realloc() {
1259 assert!(is_allocation_expr(&HirExpression::FunctionCall {
1260 function: "realloc".to_string(),
1261 arguments: vec![
1262 HirExpression::Variable("p".to_string()),
1263 HirExpression::IntLiteral(128),
1264 ],
1265 }));
1266 }
1267
1268 #[test]
1269 fn test_fold_constants_stmt_for_none_condition() {
1270 let stmt = HirStatement::For {
1272 init: vec![],
1273 condition: None,
1274 increment: vec![],
1275 body: vec![HirStatement::Break],
1276 };
1277 match fold_constants_stmt(stmt) {
1278 HirStatement::For { condition: None, .. } => {}
1279 other => panic!("Expected For with None condition, got {:?}", other),
1280 }
1281 }
1282
1283 #[test]
1284 fn test_count_uses_in_if_no_else() {
1285 let stmts = vec![HirStatement::If {
1287 condition: HirExpression::Variable("x".to_string()),
1288 then_block: vec![HirStatement::Return(Some(HirExpression::Variable(
1289 "x".to_string(),
1290 )))],
1291 else_block: None,
1292 }];
1293 assert_eq!(count_uses("x", &stmts), 2); }
1295
1296 #[test]
1297 fn test_count_uses_in_var_decl_stmt() {
1298 let stmts = vec![HirStatement::VariableDeclaration {
1300 name: "y".to_string(),
1301 var_type: HirType::Int,
1302 initializer: Some(HirExpression::Variable("x".to_string())),
1303 }];
1304 assert_eq!(count_uses("x", &stmts), 0);
1306 }
1307
1308 #[test]
1309 fn test_count_uses_in_while_stmt() {
1310 let stmts = vec![HirStatement::While {
1312 condition: HirExpression::Variable("x".to_string()),
1313 body: vec![HirStatement::Expression(HirExpression::Variable(
1314 "x".to_string(),
1315 ))],
1316 }];
1317 assert_eq!(count_uses("x", &stmts), 0);
1319 }
1320
1321 #[test]
1322 fn test_temp_elimination_name_mismatch() {
1323 let stmts = vec![
1325 HirStatement::VariableDeclaration {
1326 name: "tmp".to_string(),
1327 var_type: HirType::Int,
1328 initializer: Some(HirExpression::IntLiteral(42)),
1329 },
1330 HirStatement::Return(Some(HirExpression::Variable("other".to_string()))),
1331 ];
1332 let result = eliminate_temporaries(stmts);
1333 assert_eq!(result.len(), 2); }
1335
1336 #[test]
1337 fn test_temp_elimination_no_initializer() {
1338 let stmts = vec![
1340 HirStatement::VariableDeclaration {
1341 name: "x".to_string(),
1342 var_type: HirType::Int,
1343 initializer: None,
1344 },
1345 HirStatement::Return(Some(HirExpression::Variable("x".to_string()))),
1346 ];
1347 let result = eliminate_temporaries(stmts);
1348 assert_eq!(result.len(), 2); }
1350
1351 #[test]
1352 fn test_dead_branch_nested_if_recurse() {
1353 let stmts = vec![HirStatement::If {
1355 condition: HirExpression::Variable("x".to_string()),
1356 then_block: vec![HirStatement::If {
1357 condition: HirExpression::IntLiteral(0),
1358 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(99)))],
1359 else_block: None,
1360 }],
1361 else_block: None,
1362 }];
1363 let result = remove_dead_branches(stmts);
1364 assert_eq!(result.len(), 1);
1365 match &result[0] {
1366 HirStatement::If { then_block, .. } => {
1367 assert!(then_block.is_empty(), "Nested dead branch should be removed");
1368 }
1369 other => panic!("Expected If, got {:?}", other),
1370 }
1371 }
1372
1373 #[test]
1374 fn test_is_allocation_expr_cast_non_alloc() {
1375 assert!(!is_allocation_expr(&HirExpression::Cast {
1377 target_type: HirType::Int,
1378 expr: Box::new(HirExpression::IntLiteral(42)),
1379 }));
1380 }
1381
1382 #[test]
1387 fn test_fold_constants_stmt_var_decl_with_foldable_init() {
1388 let stmt = HirStatement::VariableDeclaration {
1390 name: "x".to_string(),
1391 var_type: HirType::Int,
1392 initializer: Some(HirExpression::BinaryOp {
1393 op: BinaryOperator::Add,
1394 left: Box::new(HirExpression::IntLiteral(2)),
1395 right: Box::new(HirExpression::IntLiteral(3)),
1396 }),
1397 };
1398 match fold_constants_stmt(stmt) {
1399 HirStatement::VariableDeclaration {
1400 name,
1401 initializer: Some(HirExpression::IntLiteral(5)),
1402 ..
1403 } => {
1404 assert_eq!(name, "x");
1405 }
1406 other => panic!("Expected VarDecl with folded init=5, got {:?}", other),
1407 }
1408 }
1409
1410 #[test]
1411 fn test_fold_constants_stmt_for_with_foldable_increment() {
1412 let stmt = HirStatement::For {
1415 init: vec![HirStatement::VariableDeclaration {
1416 name: "i".to_string(),
1417 var_type: HirType::Int,
1418 initializer: Some(HirExpression::BinaryOp {
1419 op: BinaryOperator::Multiply,
1420 left: Box::new(HirExpression::IntLiteral(2)),
1421 right: Box::new(HirExpression::IntLiteral(0)),
1422 }),
1423 }],
1424 condition: Some(HirExpression::Variable("i".to_string())),
1425 increment: vec![HirStatement::Assignment {
1426 target: "i".to_string(),
1427 value: HirExpression::BinaryOp {
1428 op: BinaryOperator::Add,
1429 left: Box::new(HirExpression::Variable("i".to_string())),
1430 right: Box::new(HirExpression::BinaryOp {
1431 op: BinaryOperator::Add,
1432 left: Box::new(HirExpression::IntLiteral(1)),
1433 right: Box::new(HirExpression::IntLiteral(1)),
1434 }),
1435 },
1436 }],
1437 body: vec![HirStatement::Break],
1438 };
1439 match fold_constants_stmt(stmt) {
1440 HirStatement::For {
1441 init,
1442 increment,
1443 ..
1444 } => {
1445 match &init[0] {
1447 HirStatement::VariableDeclaration {
1448 initializer: Some(HirExpression::IntLiteral(0)),
1449 ..
1450 } => {}
1451 other => panic!("Expected init folded to 0, got {:?}", other),
1452 }
1453 match &increment[0] {
1455 HirStatement::Assignment { value, .. } => {
1456 match value {
1458 HirExpression::BinaryOp { right, .. } => {
1459 assert_eq!(**right, HirExpression::IntLiteral(2));
1460 }
1461 other => panic!("Expected BinaryOp, got {:?}", other),
1462 }
1463 }
1464 other => panic!("Expected Assignment, got {:?}", other),
1465 }
1466 }
1467 other => panic!("Expected For, got {:?}", other),
1468 }
1469 }
1470
1471 #[test]
1472 fn test_count_uses_in_stmt_with_while() {
1473 let stmt = HirStatement::While {
1475 condition: HirExpression::Variable("x".to_string()),
1476 body: vec![HirStatement::Assignment {
1477 target: "y".to_string(),
1478 value: HirExpression::Variable("x".to_string()),
1479 }],
1480 };
1481 assert_eq!(count_uses_in_stmt("x", &stmt), 0);
1483 }
1484
1485 #[test]
1486 fn test_count_uses_in_stmt_if_with_else() {
1487 let stmt = HirStatement::If {
1488 condition: HirExpression::Variable("x".to_string()),
1489 then_block: vec![HirStatement::Return(Some(HirExpression::Variable(
1490 "x".to_string(),
1491 )))],
1492 else_block: Some(vec![HirStatement::Expression(
1493 HirExpression::Variable("x".to_string()),
1494 )]),
1495 };
1496 assert_eq!(count_uses_in_stmt("x", &stmt), 3);
1498 }
1499
1500 #[test]
1501 fn test_count_uses_in_expr_function_call() {
1502 let expr = HirExpression::FunctionCall {
1503 function: "foo".to_string(),
1504 arguments: vec![
1505 HirExpression::Variable("x".to_string()),
1506 HirExpression::Variable("y".to_string()),
1507 HirExpression::Variable("x".to_string()),
1508 ],
1509 };
1510 assert_eq!(count_uses_in_expr("x", &expr), 2);
1511 assert_eq!(count_uses_in_expr("y", &expr), 1);
1512 assert_eq!(count_uses_in_expr("z", &expr), 0);
1513 }
1514
1515 #[test]
1516 fn test_count_uses_in_expr_unary_op_negation() {
1517 let expr = HirExpression::UnaryOp {
1518 op: decy_hir::UnaryOperator::Minus,
1519 operand: Box::new(HirExpression::Variable("x".to_string())),
1520 };
1521 assert_eq!(count_uses_in_expr("x", &expr), 1);
1522 assert_eq!(count_uses_in_expr("y", &expr), 0);
1523 }
1524}