1#[cfg(test)]
33use decy_hir::HirType;
34use decy_hir::{BinaryOperator, HirExpression, HirFunction, HirStatement};
35
36const MAX_ITERATIONS: usize = 3;
38
39pub fn optimize_function(func: &HirFunction) -> HirFunction {
44 let mut body = func.body().to_vec();
45 let mut changed = true;
46 let mut iterations = 0;
47
48 while changed && iterations < MAX_ITERATIONS {
49 changed = false;
50 let prev = body.clone();
51
52 body = body.into_iter().map(fold_constants_stmt).collect();
53 body = remove_dead_branches(body);
54 body = eliminate_temporaries(body);
55
56 if body != prev {
57 changed = true;
58 }
59 iterations += 1;
60 }
61
62 let mut result = HirFunction::new_with_body(
63 func.name().to_string(),
64 func.return_type().clone(),
65 func.parameters().to_vec(),
66 body,
67 );
68 result.set_cuda_qualifier(func.cuda_qualifier());
70 result
71}
72
73fn fold_constants_stmt(stmt: HirStatement) -> HirStatement {
79 match stmt {
80 HirStatement::VariableDeclaration { name, var_type, initializer } => {
81 HirStatement::VariableDeclaration {
82 name,
83 var_type,
84 initializer: initializer.map(fold_constants_expr),
85 }
86 }
87 HirStatement::Return(expr) => HirStatement::Return(expr.map(fold_constants_expr)),
88 HirStatement::Assignment { target, value } => {
89 HirStatement::Assignment { target, value: fold_constants_expr(value) }
90 }
91 HirStatement::If { condition, then_block, else_block } => HirStatement::If {
92 condition: fold_constants_expr(condition),
93 then_block: then_block.into_iter().map(fold_constants_stmt).collect(),
94 else_block: else_block
95 .map(|block| block.into_iter().map(fold_constants_stmt).collect()),
96 },
97 HirStatement::While { condition, body } => HirStatement::While {
98 condition: fold_constants_expr(condition),
99 body: body.into_iter().map(fold_constants_stmt).collect(),
100 },
101 HirStatement::For { init, condition, increment, body } => HirStatement::For {
102 init: init.into_iter().map(fold_constants_stmt).collect(),
103 condition: condition.map(fold_constants_expr),
104 increment: increment.into_iter().map(fold_constants_stmt).collect(),
105 body: body.into_iter().map(fold_constants_stmt).collect(),
106 },
107 HirStatement::Expression(expr) => HirStatement::Expression(fold_constants_expr(expr)),
108 other => other,
110 }
111}
112
113fn fold_constants_expr(expr: HirExpression) -> HirExpression {
115 match expr {
116 HirExpression::BinaryOp { op, left, right } => {
117 let left = fold_constants_expr(*left);
118 let right = fold_constants_expr(*right);
119
120 if let (HirExpression::IntLiteral(l), HirExpression::IntLiteral(r)) = (&left, &right) {
122 if let Some(result) = fold_int_binary(*l, op, *r) {
123 return HirExpression::IntLiteral(result);
124 }
125 }
126
127 HirExpression::BinaryOp { op, left: Box::new(left), right: Box::new(right) }
128 }
129 HirExpression::UnaryOp { op, operand } => {
130 let operand = fold_constants_expr(*operand);
131 if let (decy_hir::UnaryOperator::Minus, HirExpression::IntLiteral(v)) = (op, &operand) {
132 return HirExpression::IntLiteral(-v);
133 }
134 HirExpression::UnaryOp { op, operand: Box::new(operand) }
135 }
136 HirExpression::FunctionCall { function, arguments } => HirExpression::FunctionCall {
138 function,
139 arguments: arguments.into_iter().map(fold_constants_expr).collect(),
140 },
141 other => other,
142 }
143}
144
145fn fold_int_binary(left: i32, op: BinaryOperator, right: i32) -> Option<i32> {
147 match op {
148 BinaryOperator::Add => left.checked_add(right),
149 BinaryOperator::Subtract => left.checked_sub(right),
150 BinaryOperator::Multiply => left.checked_mul(right),
151 BinaryOperator::Divide => {
152 if right != 0 {
153 left.checked_div(right)
154 } else {
155 None
156 }
157 }
158 BinaryOperator::Modulo => {
159 if right != 0 {
160 left.checked_rem(right)
161 } else {
162 None
163 }
164 }
165 BinaryOperator::LeftShift => {
166 if (0..32).contains(&right) {
167 Some(left << right)
168 } else {
169 None
170 }
171 }
172 BinaryOperator::RightShift => {
173 if (0..32).contains(&right) {
174 Some(left >> right)
175 } else {
176 None
177 }
178 }
179 BinaryOperator::BitwiseAnd => Some(left & right),
180 BinaryOperator::BitwiseOr => Some(left | right),
181 BinaryOperator::BitwiseXor => Some(left ^ right),
182 _ => None,
183 }
184}
185
186fn remove_dead_branches(stmts: Vec<HirStatement>) -> Vec<HirStatement> {
192 let mut result = Vec::new();
193
194 for stmt in stmts {
195 match stmt {
196 HirStatement::If { condition, then_block, else_block } => {
197 if let Some(always_true) = is_constant_truthy(&condition) {
198 if always_true {
199 result.extend(
201 then_block.into_iter().map(fold_constants_stmt).collect::<Vec<_>>(),
202 );
203 } else if let Some(else_body) = else_block {
204 result.extend(
206 else_body.into_iter().map(fold_constants_stmt).collect::<Vec<_>>(),
207 );
208 }
209 } else {
211 result.push(HirStatement::If {
213 condition,
214 then_block: remove_dead_branches(then_block),
215 else_block: else_block.map(remove_dead_branches),
216 });
217 }
218 }
219 HirStatement::While { condition, body } => {
220 if let Some(false) = is_constant_truthy(&condition) {
222 } else {
224 result
225 .push(HirStatement::While { condition, body: remove_dead_branches(body) });
226 }
227 }
228 other => result.push(other),
229 }
230 }
231
232 result
233}
234
235fn is_constant_truthy(expr: &HirExpression) -> Option<bool> {
237 match expr {
238 HirExpression::IntLiteral(0) => Some(false),
239 HirExpression::IntLiteral(_) => Some(true),
240 _ => None,
241 }
242}
243
244fn eliminate_temporaries(stmts: Vec<HirStatement>) -> Vec<HirStatement> {
255 if stmts.len() < 2 {
256 return stmts;
257 }
258
259 let mut result = Vec::new();
260 let mut skip_next = false;
261
262 for i in 0..stmts.len() {
263 if skip_next {
264 skip_next = false;
265 continue;
266 }
267
268 if i + 1 < stmts.len() {
270 if let (
271 HirStatement::VariableDeclaration { name, initializer: Some(init_expr), .. },
272 HirStatement::Return(Some(HirExpression::Variable(ret_var))),
273 ) = (&stmts[i], &stmts[i + 1])
274 {
275 if name == ret_var
276 && count_uses(name, &stmts[i + 2..]) == 0
277 && !is_allocation_expr(init_expr)
278 {
279 result.push(HirStatement::Return(Some(init_expr.clone())));
281 skip_next = true;
282 continue;
283 }
284 }
285 }
286
287 result.push(stmts[i].clone());
288 }
289
290 result
291}
292
293fn is_allocation_expr(expr: &HirExpression) -> bool {
296 match expr {
297 HirExpression::Malloc { .. }
298 | HirExpression::Calloc { .. }
299 | HirExpression::Realloc { .. } => true,
300 HirExpression::Cast { expr: inner, .. } => is_allocation_expr(inner),
301 HirExpression::FunctionCall { function, .. } => {
302 matches!(function.as_str(), "malloc" | "calloc" | "realloc")
303 }
304 _ => false,
305 }
306}
307
308fn count_uses(name: &str, stmts: &[HirStatement]) -> usize {
310 let mut count = 0;
311 for stmt in stmts {
312 count += count_uses_in_stmt(name, stmt);
313 }
314 count
315}
316
317fn count_uses_in_stmt(name: &str, stmt: &HirStatement) -> usize {
319 match stmt {
320 HirStatement::Return(Some(expr)) => count_uses_in_expr(name, expr),
321 HirStatement::Assignment { value, .. } => count_uses_in_expr(name, value),
322 HirStatement::Expression(expr) => count_uses_in_expr(name, expr),
323 HirStatement::If { condition, then_block, else_block } => {
324 let mut c = count_uses_in_expr(name, condition);
325 for s in then_block {
326 c += count_uses_in_stmt(name, s);
327 }
328 if let Some(block) = else_block {
329 for s in block {
330 c += count_uses_in_stmt(name, s);
331 }
332 }
333 c
334 }
335 _ => 0,
336 }
337}
338
339fn count_uses_in_expr(name: &str, expr: &HirExpression) -> usize {
341 match expr {
342 HirExpression::Variable(v) if v == name => 1,
343 HirExpression::BinaryOp { left, right, .. } => {
344 count_uses_in_expr(name, left) + count_uses_in_expr(name, right)
345 }
346 HirExpression::FunctionCall { arguments, .. } => {
347 arguments.iter().map(|a| count_uses_in_expr(name, a)).sum()
348 }
349 HirExpression::UnaryOp { operand, .. } => count_uses_in_expr(name, operand),
350 _ => 0,
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_constant_folding_add() {
360 let expr = HirExpression::BinaryOp {
361 op: BinaryOperator::Add,
362 left: Box::new(HirExpression::IntLiteral(2)),
363 right: Box::new(HirExpression::IntLiteral(3)),
364 };
365 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(5));
366 }
367
368 #[test]
369 fn test_constant_folding_multiply() {
370 let expr = HirExpression::BinaryOp {
371 op: BinaryOperator::Multiply,
372 left: Box::new(HirExpression::IntLiteral(4)),
373 right: Box::new(HirExpression::IntLiteral(5)),
374 };
375 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(20));
376 }
377
378 #[test]
379 fn test_constant_folding_nested() {
380 let expr = HirExpression::BinaryOp {
382 op: BinaryOperator::Multiply,
383 left: Box::new(HirExpression::BinaryOp {
384 op: BinaryOperator::Add,
385 left: Box::new(HirExpression::IntLiteral(2)),
386 right: Box::new(HirExpression::IntLiteral(3)),
387 }),
388 right: Box::new(HirExpression::IntLiteral(4)),
389 };
390 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(20));
391 }
392
393 #[test]
394 fn test_constant_folding_division_by_zero() {
395 let expr = HirExpression::BinaryOp {
396 op: BinaryOperator::Divide,
397 left: Box::new(HirExpression::IntLiteral(10)),
398 right: Box::new(HirExpression::IntLiteral(0)),
399 };
400 match fold_constants_expr(expr) {
402 HirExpression::BinaryOp { .. } => {} other => panic!("Expected BinaryOp, got {:?}", other),
404 }
405 }
406
407 #[test]
408 fn test_constant_folding_non_literal() {
409 let expr = HirExpression::BinaryOp {
410 op: BinaryOperator::Add,
411 left: Box::new(HirExpression::Variable("x".to_string())),
412 right: Box::new(HirExpression::IntLiteral(3)),
413 };
414 match fold_constants_expr(expr) {
416 HirExpression::BinaryOp { .. } => {} other => panic!("Expected BinaryOp, got {:?}", other),
418 }
419 }
420
421 #[test]
422 fn test_dead_branch_removal_true() {
423 let stmts = vec![HirStatement::If {
425 condition: HirExpression::IntLiteral(1),
426 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(42)))],
427 else_block: None,
428 }];
429
430 let result = remove_dead_branches(stmts);
431 assert_eq!(result.len(), 1);
432 assert_eq!(result[0], HirStatement::Return(Some(HirExpression::IntLiteral(42))));
433 }
434
435 #[test]
436 fn test_dead_branch_removal_false_no_else() {
437 let stmts = vec![HirStatement::If {
439 condition: HirExpression::IntLiteral(0),
440 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(42)))],
441 else_block: None,
442 }];
443
444 let result = remove_dead_branches(stmts);
445 assert!(result.is_empty());
446 }
447
448 #[test]
449 fn test_dead_branch_removal_false_with_else() {
450 let stmts = vec![HirStatement::If {
452 condition: HirExpression::IntLiteral(0),
453 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(42)))],
454 else_block: Some(vec![HirStatement::Return(Some(HirExpression::IntLiteral(99)))]),
455 }];
456
457 let result = remove_dead_branches(stmts);
458 assert_eq!(result.len(), 1);
459 assert_eq!(result[0], HirStatement::Return(Some(HirExpression::IntLiteral(99))));
460 }
461
462 #[test]
463 fn test_dead_while_zero() {
464 let stmts = vec![HirStatement::While {
466 condition: HirExpression::IntLiteral(0),
467 body: vec![HirStatement::Break],
468 }];
469
470 let result = remove_dead_branches(stmts);
471 assert!(result.is_empty());
472 }
473
474 #[test]
475 fn test_temp_elimination_return() {
476 let stmts = vec![
478 HirStatement::VariableDeclaration {
479 name: "tmp".to_string(),
480 var_type: HirType::Int,
481 initializer: Some(HirExpression::IntLiteral(42)),
482 },
483 HirStatement::Return(Some(HirExpression::Variable("tmp".to_string()))),
484 ];
485
486 let result = eliminate_temporaries(stmts);
487 assert_eq!(result.len(), 1);
488 assert_eq!(result[0], HirStatement::Return(Some(HirExpression::IntLiteral(42))));
489 }
490
491 #[test]
492 fn test_optimize_function_combined() {
493 let func = HirFunction::new_with_body(
495 "test".to_string(),
496 HirType::Int,
497 vec![],
498 vec![
499 HirStatement::VariableDeclaration {
500 name: "x".to_string(),
501 var_type: HirType::Int,
502 initializer: Some(HirExpression::BinaryOp {
503 op: BinaryOperator::Add,
504 left: Box::new(HirExpression::IntLiteral(2)),
505 right: Box::new(HirExpression::IntLiteral(3)),
506 }),
507 },
508 HirStatement::If {
509 condition: HirExpression::IntLiteral(1),
510 then_block: vec![HirStatement::Return(Some(HirExpression::Variable(
511 "x".to_string(),
512 )))],
513 else_block: None,
514 },
515 ],
516 );
517
518 let optimized = optimize_function(&func);
519 let body = optimized.body();
522 assert!(body.len() <= 2, "Expected at most 2 statements, got {}", body.len());
523 }
524
525 #[test]
526 fn test_unary_minus_folding() {
527 let expr = HirExpression::UnaryOp {
528 op: decy_hir::UnaryOperator::Minus,
529 operand: Box::new(HirExpression::IntLiteral(42)),
530 };
531 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(-42));
532 }
533
534 #[test]
535 fn test_bitwise_folding() {
536 let expr = HirExpression::BinaryOp {
537 op: BinaryOperator::BitwiseAnd,
538 left: Box::new(HirExpression::IntLiteral(0xFF)),
539 right: Box::new(HirExpression::IntLiteral(0x0F)),
540 };
541 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(0x0F));
542 }
543
544 #[test]
549 fn test_constant_folding_subtract() {
550 let expr = HirExpression::BinaryOp {
551 op: BinaryOperator::Subtract,
552 left: Box::new(HirExpression::IntLiteral(10)),
553 right: Box::new(HirExpression::IntLiteral(3)),
554 };
555 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(7));
556 }
557
558 #[test]
559 fn test_constant_folding_divide() {
560 let expr = HirExpression::BinaryOp {
561 op: BinaryOperator::Divide,
562 left: Box::new(HirExpression::IntLiteral(20)),
563 right: Box::new(HirExpression::IntLiteral(4)),
564 };
565 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(5));
566 }
567
568 #[test]
569 fn test_constant_folding_modulo() {
570 let expr = HirExpression::BinaryOp {
571 op: BinaryOperator::Modulo,
572 left: Box::new(HirExpression::IntLiteral(17)),
573 right: Box::new(HirExpression::IntLiteral(5)),
574 };
575 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(2));
576 }
577
578 #[test]
579 fn test_constant_folding_modulo_by_zero() {
580 let expr = HirExpression::BinaryOp {
581 op: BinaryOperator::Modulo,
582 left: Box::new(HirExpression::IntLiteral(17)),
583 right: Box::new(HirExpression::IntLiteral(0)),
584 };
585 match fold_constants_expr(expr) {
586 HirExpression::BinaryOp { .. } => {}
587 other => panic!("Expected BinaryOp, got {:?}", other),
588 }
589 }
590
591 #[test]
592 fn test_constant_folding_left_shift() {
593 let expr = HirExpression::BinaryOp {
594 op: BinaryOperator::LeftShift,
595 left: Box::new(HirExpression::IntLiteral(1)),
596 right: Box::new(HirExpression::IntLiteral(4)),
597 };
598 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(16));
599 }
600
601 #[test]
602 fn test_constant_folding_left_shift_overflow() {
603 let expr = HirExpression::BinaryOp {
604 op: BinaryOperator::LeftShift,
605 left: Box::new(HirExpression::IntLiteral(1)),
606 right: Box::new(HirExpression::IntLiteral(32)),
607 };
608 match fold_constants_expr(expr) {
610 HirExpression::BinaryOp { .. } => {}
611 other => panic!("Expected BinaryOp, got {:?}", other),
612 }
613 }
614
615 #[test]
616 fn test_constant_folding_right_shift() {
617 let expr = HirExpression::BinaryOp {
618 op: BinaryOperator::RightShift,
619 left: Box::new(HirExpression::IntLiteral(16)),
620 right: Box::new(HirExpression::IntLiteral(2)),
621 };
622 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(4));
623 }
624
625 #[test]
626 fn test_constant_folding_right_shift_overflow() {
627 let expr = HirExpression::BinaryOp {
628 op: BinaryOperator::RightShift,
629 left: Box::new(HirExpression::IntLiteral(16)),
630 right: Box::new(HirExpression::IntLiteral(-1)),
631 };
632 match fold_constants_expr(expr) {
633 HirExpression::BinaryOp { .. } => {}
634 other => panic!("Expected BinaryOp, got {:?}", other),
635 }
636 }
637
638 #[test]
639 fn test_constant_folding_bitwise_or() {
640 let expr = HirExpression::BinaryOp {
641 op: BinaryOperator::BitwiseOr,
642 left: Box::new(HirExpression::IntLiteral(0xF0)),
643 right: Box::new(HirExpression::IntLiteral(0x0F)),
644 };
645 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(0xFF));
646 }
647
648 #[test]
649 fn test_constant_folding_bitwise_xor() {
650 let expr = HirExpression::BinaryOp {
651 op: BinaryOperator::BitwiseXor,
652 left: Box::new(HirExpression::IntLiteral(0xFF)),
653 right: Box::new(HirExpression::IntLiteral(0x0F)),
654 };
655 assert_eq!(fold_constants_expr(expr), HirExpression::IntLiteral(0xF0));
656 }
657
658 #[test]
659 fn test_constant_folding_unsupported_op() {
660 let expr = HirExpression::BinaryOp {
662 op: BinaryOperator::Equal,
663 left: Box::new(HirExpression::IntLiteral(5)),
664 right: Box::new(HirExpression::IntLiteral(5)),
665 };
666 match fold_constants_expr(expr) {
667 HirExpression::BinaryOp { .. } => {}
668 other => panic!("Expected BinaryOp, got {:?}", other),
669 }
670 }
671
672 #[test]
677 fn test_constant_folding_function_call_args() {
678 let expr = HirExpression::FunctionCall {
680 function: "foo".to_string(),
681 arguments: vec![HirExpression::BinaryOp {
682 op: BinaryOperator::Add,
683 left: Box::new(HirExpression::IntLiteral(2)),
684 right: Box::new(HirExpression::IntLiteral(3)),
685 }],
686 };
687 match fold_constants_expr(expr) {
688 HirExpression::FunctionCall { function, arguments } => {
689 assert_eq!(function, "foo");
690 assert_eq!(arguments, vec![HirExpression::IntLiteral(5)]);
691 }
692 other => panic!("Expected FunctionCall, got {:?}", other),
693 }
694 }
695
696 #[test]
701 fn test_constant_folding_for_loop() {
702 let stmt = HirStatement::For {
704 init: vec![HirStatement::VariableDeclaration {
705 name: "i".to_string(),
706 var_type: HirType::Int,
707 initializer: Some(HirExpression::IntLiteral(0)),
708 }],
709 condition: Some(HirExpression::BinaryOp {
710 op: BinaryOperator::Add,
711 left: Box::new(HirExpression::IntLiteral(2)),
712 right: Box::new(HirExpression::IntLiteral(3)),
713 }),
714 increment: vec![HirStatement::Expression(HirExpression::Variable("i".to_string()))],
715 body: vec![HirStatement::Return(Some(HirExpression::IntLiteral(1)))],
716 };
717
718 match fold_constants_stmt(stmt) {
719 HirStatement::For { condition, body, .. } => {
720 assert_eq!(condition, Some(HirExpression::IntLiteral(5)));
721 assert!(!body.is_empty());
722 }
723 other => panic!("Expected For, got {:?}", other),
724 }
725 }
726
727 #[test]
732 fn test_constant_folding_expression_stmt() {
733 let stmt = HirStatement::Expression(HirExpression::BinaryOp {
734 op: BinaryOperator::Add,
735 left: Box::new(HirExpression::IntLiteral(1)),
736 right: Box::new(HirExpression::IntLiteral(2)),
737 });
738 match fold_constants_stmt(stmt) {
739 HirStatement::Expression(HirExpression::IntLiteral(3)) => {}
740 other => panic!("Expected Expression(IntLiteral(3)), got {:?}", other),
741 }
742 }
743
744 #[test]
745 fn test_constant_folding_assignment() {
746 let stmt = HirStatement::Assignment {
747 target: "x".to_string(),
748 value: HirExpression::BinaryOp {
749 op: BinaryOperator::Multiply,
750 left: Box::new(HirExpression::IntLiteral(6)),
751 right: Box::new(HirExpression::IntLiteral(7)),
752 },
753 };
754 match fold_constants_stmt(stmt) {
755 HirStatement::Assignment { value, .. } => {
756 assert_eq!(value, HirExpression::IntLiteral(42));
757 }
758 other => panic!("Expected Assignment, got {:?}", other),
759 }
760 }
761
762 #[test]
763 fn test_constant_folding_pass_through() {
764 let stmt = HirStatement::Break;
766 assert_eq!(fold_constants_stmt(stmt), HirStatement::Break);
767 }
768
769 #[test]
774 fn test_unary_not_folding_not_applied() {
775 let expr = HirExpression::UnaryOp {
777 op: decy_hir::UnaryOperator::LogicalNot,
778 operand: Box::new(HirExpression::IntLiteral(1)),
779 };
780 match fold_constants_expr(expr) {
781 HirExpression::UnaryOp { .. } => {}
782 other => panic!("Expected UnaryOp, got {:?}", other),
783 }
784 }
785
786 #[test]
787 fn test_unary_minus_on_variable() {
788 let expr = HirExpression::UnaryOp {
790 op: decy_hir::UnaryOperator::Minus,
791 operand: Box::new(HirExpression::Variable("x".to_string())),
792 };
793 match fold_constants_expr(expr) {
794 HirExpression::UnaryOp { .. } => {}
795 other => panic!("Expected UnaryOp, got {:?}", other),
796 }
797 }
798
799 #[test]
804 fn test_is_allocation_expr_malloc() {
805 assert!(is_allocation_expr(&HirExpression::Malloc {
806 size: Box::new(HirExpression::IntLiteral(4)),
807 }));
808 }
809
810 #[test]
811 fn test_is_allocation_expr_calloc() {
812 assert!(is_allocation_expr(&HirExpression::Calloc {
813 count: Box::new(HirExpression::IntLiteral(10)),
814 element_type: Box::new(HirType::Int),
815 }));
816 }
817
818 #[test]
819 fn test_is_allocation_expr_realloc() {
820 assert!(is_allocation_expr(&HirExpression::Realloc {
821 pointer: Box::new(HirExpression::Variable("p".to_string())),
822 new_size: Box::new(HirExpression::IntLiteral(64)),
823 }));
824 }
825
826 #[test]
827 fn test_is_allocation_expr_cast_wrapping_malloc() {
828 assert!(is_allocation_expr(&HirExpression::Cast {
829 target_type: HirType::Pointer(Box::new(HirType::Int)),
830 expr: Box::new(HirExpression::Malloc { size: Box::new(HirExpression::IntLiteral(4)) }),
831 }));
832 }
833
834 #[test]
835 fn test_is_allocation_expr_function_call_malloc() {
836 assert!(is_allocation_expr(&HirExpression::FunctionCall {
837 function: "malloc".to_string(),
838 arguments: vec![HirExpression::IntLiteral(4)],
839 }));
840 }
841
842 #[test]
843 fn test_is_allocation_expr_function_call_calloc() {
844 assert!(is_allocation_expr(&HirExpression::FunctionCall {
845 function: "calloc".to_string(),
846 arguments: vec![HirExpression::IntLiteral(10), HirExpression::IntLiteral(4)],
847 }));
848 }
849
850 #[test]
851 fn test_is_allocation_expr_regular_call() {
852 assert!(!is_allocation_expr(&HirExpression::FunctionCall {
853 function: "printf".to_string(),
854 arguments: vec![],
855 }));
856 }
857
858 #[test]
859 fn test_is_allocation_expr_literal() {
860 assert!(!is_allocation_expr(&HirExpression::IntLiteral(42)));
861 }
862
863 #[test]
868 fn test_count_uses_empty() {
869 assert_eq!(count_uses("x", &[]), 0);
870 }
871
872 #[test]
873 fn test_count_uses_in_return() {
874 let stmts = vec![HirStatement::Return(Some(HirExpression::Variable("x".to_string())))];
875 assert_eq!(count_uses("x", &stmts), 1);
876 assert_eq!(count_uses("y", &stmts), 0);
877 }
878
879 #[test]
880 fn test_count_uses_in_assignment() {
881 let stmts = vec![HirStatement::Assignment {
882 target: "y".to_string(),
883 value: HirExpression::Variable("x".to_string()),
884 }];
885 assert_eq!(count_uses("x", &stmts), 1);
886 }
887
888 #[test]
889 fn test_count_uses_in_expression_stmt() {
890 let stmts = vec![HirStatement::Expression(HirExpression::FunctionCall {
891 function: "foo".to_string(),
892 arguments: vec![
893 HirExpression::Variable("x".to_string()),
894 HirExpression::Variable("x".to_string()),
895 ],
896 })];
897 assert_eq!(count_uses("x", &stmts), 2);
898 }
899
900 #[test]
901 fn test_count_uses_in_if_with_else() {
902 let stmts = vec![HirStatement::If {
903 condition: HirExpression::Variable("x".to_string()),
904 then_block: vec![HirStatement::Return(Some(HirExpression::Variable("x".to_string())))],
905 else_block: Some(vec![HirStatement::Return(Some(HirExpression::Variable(
906 "x".to_string(),
907 )))]),
908 }];
909 assert_eq!(count_uses("x", &stmts), 3); }
911
912 #[test]
913 fn test_count_uses_in_expr_binary_op() {
914 let expr = HirExpression::BinaryOp {
915 op: BinaryOperator::Add,
916 left: Box::new(HirExpression::Variable("x".to_string())),
917 right: Box::new(HirExpression::Variable("x".to_string())),
918 };
919 assert_eq!(count_uses_in_expr("x", &expr), 2);
920 }
921
922 #[test]
923 fn test_count_uses_in_expr_unary_op() {
924 let expr = HirExpression::UnaryOp {
925 op: decy_hir::UnaryOperator::Minus,
926 operand: Box::new(HirExpression::Variable("x".to_string())),
927 };
928 assert_eq!(count_uses_in_expr("x", &expr), 1);
929 }
930
931 #[test]
932 fn test_count_uses_in_expr_non_matching() {
933 let expr = HirExpression::IntLiteral(42);
934 assert_eq!(count_uses_in_expr("x", &expr), 0);
935 }
936
937 #[test]
938 fn test_count_uses_in_stmt_break() {
939 assert_eq!(count_uses_in_stmt("x", &HirStatement::Break), 0);
940 }
941
942 #[test]
943 fn test_count_uses_in_stmt_return_none() {
944 assert_eq!(count_uses_in_stmt("x", &HirStatement::Return(None)), 0);
945 }
946
947 #[test]
952 fn test_temp_elimination_single_stmt() {
953 let stmts = vec![HirStatement::Return(Some(HirExpression::IntLiteral(1)))];
955 let result = eliminate_temporaries(stmts.clone());
956 assert_eq!(result, stmts);
957 }
958
959 #[test]
960 fn test_temp_elimination_no_match() {
961 let stmts = vec![
963 HirStatement::Expression(HirExpression::IntLiteral(1)),
964 HirStatement::Return(Some(HirExpression::IntLiteral(2))),
965 ];
966 let result = eliminate_temporaries(stmts.clone());
967 assert_eq!(result, stmts);
968 }
969
970 #[test]
971 fn test_temp_elimination_allocation_preserved() {
972 let stmts = vec![
974 HirStatement::VariableDeclaration {
975 name: "p".to_string(),
976 var_type: HirType::Pointer(Box::new(HirType::Int)),
977 initializer: Some(HirExpression::Malloc {
978 size: Box::new(HirExpression::IntLiteral(4)),
979 }),
980 },
981 HirStatement::Return(Some(HirExpression::Variable("p".to_string()))),
982 ];
983 let result = eliminate_temporaries(stmts);
984 assert_eq!(result.len(), 2); }
986
987 #[test]
988 fn test_temp_elimination_multi_use_preserved() {
989 let stmts = vec![
991 HirStatement::VariableDeclaration {
992 name: "x".to_string(),
993 var_type: HirType::Int,
994 initializer: Some(HirExpression::IntLiteral(42)),
995 },
996 HirStatement::Return(Some(HirExpression::Variable("x".to_string()))),
997 HirStatement::Expression(HirExpression::Variable("x".to_string())),
998 ];
999 let result = eliminate_temporaries(stmts);
1000 assert_eq!(result.len(), 3); }
1002
1003 #[test]
1008 fn test_dead_branch_non_constant_if() {
1009 let stmts = vec![HirStatement::If {
1011 condition: HirExpression::Variable("x".to_string()),
1012 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(1)))],
1013 else_block: Some(vec![HirStatement::Return(Some(HirExpression::IntLiteral(0)))]),
1014 }];
1015 let result = remove_dead_branches(stmts);
1016 assert_eq!(result.len(), 1);
1017 match &result[0] {
1018 HirStatement::If { .. } => {}
1019 other => panic!("Expected If, got {:?}", other),
1020 }
1021 }
1022
1023 #[test]
1024 fn test_dead_branch_while_non_constant() {
1025 let stmts = vec![HirStatement::While {
1027 condition: HirExpression::Variable("x".to_string()),
1028 body: vec![HirStatement::Break],
1029 }];
1030 let result = remove_dead_branches(stmts);
1031 assert_eq!(result.len(), 1);
1032 match &result[0] {
1033 HirStatement::While { .. } => {}
1034 other => panic!("Expected While, got {:?}", other),
1035 }
1036 }
1037
1038 #[test]
1039 fn test_dead_branch_while_nonzero_constant() {
1040 let stmts = vec![HirStatement::While {
1042 condition: HirExpression::IntLiteral(1),
1043 body: vec![HirStatement::Break],
1044 }];
1045 let result = remove_dead_branches(stmts);
1046 assert_eq!(result.len(), 1);
1047 }
1048
1049 #[test]
1054 fn test_optimize_no_change_single_iteration() {
1055 let func = HirFunction::new_with_body(
1057 "noop".to_string(),
1058 HirType::Int,
1059 vec![],
1060 vec![HirStatement::Return(Some(HirExpression::IntLiteral(0)))],
1061 );
1062 let optimized = optimize_function(&func);
1063 assert_eq!(optimized.body().len(), 1);
1064 }
1065
1066 #[test]
1067 fn test_optimize_empty_function() {
1068 let func = HirFunction::new_with_body("empty".to_string(), HirType::Void, vec![], vec![]);
1069 let optimized = optimize_function(&func);
1070 assert!(optimized.body().is_empty());
1071 }
1072
1073 #[test]
1078 fn test_fold_constants_stmt_while() {
1079 let stmt = HirStatement::While {
1081 condition: HirExpression::BinaryOp {
1082 op: BinaryOperator::Add,
1083 left: Box::new(HirExpression::IntLiteral(2)),
1084 right: Box::new(HirExpression::IntLiteral(3)),
1085 },
1086 body: vec![HirStatement::Assignment {
1087 target: "x".to_string(),
1088 value: HirExpression::BinaryOp {
1089 op: BinaryOperator::Add,
1090 left: Box::new(HirExpression::IntLiteral(1)),
1091 right: Box::new(HirExpression::IntLiteral(1)),
1092 },
1093 }],
1094 };
1095 match fold_constants_stmt(stmt) {
1096 HirStatement::While { condition, body } => {
1097 assert_eq!(condition, HirExpression::IntLiteral(5));
1098 match &body[0] {
1099 HirStatement::Assignment { value, .. } => {
1100 assert_eq!(*value, HirExpression::IntLiteral(2));
1101 }
1102 other => panic!("Expected Assignment, got {:?}", other),
1103 }
1104 }
1105 other => panic!("Expected While, got {:?}", other),
1106 }
1107 }
1108
1109 #[test]
1110 fn test_fold_constants_stmt_return_some() {
1111 let stmt = HirStatement::Return(Some(HirExpression::BinaryOp {
1112 op: BinaryOperator::Add,
1113 left: Box::new(HirExpression::IntLiteral(10)),
1114 right: Box::new(HirExpression::IntLiteral(20)),
1115 }));
1116 match fold_constants_stmt(stmt) {
1117 HirStatement::Return(Some(HirExpression::IntLiteral(30))) => {}
1118 other => panic!("Expected Return(Some(30)), got {:?}", other),
1119 }
1120 }
1121
1122 #[test]
1123 fn test_fold_constants_stmt_return_none() {
1124 let stmt = HirStatement::Return(None);
1125 assert_eq!(fold_constants_stmt(stmt), HirStatement::Return(None));
1126 }
1127
1128 #[test]
1129 fn test_fold_constants_stmt_var_decl_none_init() {
1130 let stmt = HirStatement::VariableDeclaration {
1131 name: "x".to_string(),
1132 var_type: HirType::Int,
1133 initializer: None,
1134 };
1135 match fold_constants_stmt(stmt) {
1136 HirStatement::VariableDeclaration { initializer: None, .. } => {}
1137 other => panic!("Expected VarDecl with None init, got {:?}", other),
1138 }
1139 }
1140
1141 #[test]
1142 fn test_fold_constants_stmt_if_with_else() {
1143 let stmt = HirStatement::If {
1144 condition: HirExpression::BinaryOp {
1145 op: BinaryOperator::Add,
1146 left: Box::new(HirExpression::IntLiteral(1)),
1147 right: Box::new(HirExpression::IntLiteral(2)),
1148 },
1149 then_block: vec![HirStatement::Return(Some(HirExpression::BinaryOp {
1150 op: BinaryOperator::Multiply,
1151 left: Box::new(HirExpression::IntLiteral(3)),
1152 right: Box::new(HirExpression::IntLiteral(4)),
1153 }))],
1154 else_block: Some(vec![HirStatement::Return(Some(HirExpression::BinaryOp {
1155 op: BinaryOperator::Subtract,
1156 left: Box::new(HirExpression::IntLiteral(10)),
1157 right: Box::new(HirExpression::IntLiteral(5)),
1158 }))]),
1159 };
1160 match fold_constants_stmt(stmt) {
1161 HirStatement::If { condition, then_block, else_block } => {
1162 assert_eq!(condition, HirExpression::IntLiteral(3));
1163 match &then_block[0] {
1164 HirStatement::Return(Some(HirExpression::IntLiteral(12))) => {}
1165 other => panic!("Expected Return(12), got {:?}", other),
1166 }
1167 let else_stmts = else_block.unwrap();
1168 match &else_stmts[0] {
1169 HirStatement::Return(Some(HirExpression::IntLiteral(5))) => {}
1170 other => panic!("Expected Return(5), got {:?}", other),
1171 }
1172 }
1173 other => panic!("Expected If, got {:?}", other),
1174 }
1175 }
1176
1177 #[test]
1178 fn test_is_allocation_expr_function_call_realloc() {
1179 assert!(is_allocation_expr(&HirExpression::FunctionCall {
1180 function: "realloc".to_string(),
1181 arguments: vec![
1182 HirExpression::Variable("p".to_string()),
1183 HirExpression::IntLiteral(128),
1184 ],
1185 }));
1186 }
1187
1188 #[test]
1189 fn test_fold_constants_stmt_for_none_condition() {
1190 let stmt = HirStatement::For {
1192 init: vec![],
1193 condition: None,
1194 increment: vec![],
1195 body: vec![HirStatement::Break],
1196 };
1197 match fold_constants_stmt(stmt) {
1198 HirStatement::For { condition: None, .. } => {}
1199 other => panic!("Expected For with None condition, got {:?}", other),
1200 }
1201 }
1202
1203 #[test]
1204 fn test_count_uses_in_if_no_else() {
1205 let stmts = vec![HirStatement::If {
1207 condition: HirExpression::Variable("x".to_string()),
1208 then_block: vec![HirStatement::Return(Some(HirExpression::Variable("x".to_string())))],
1209 else_block: None,
1210 }];
1211 assert_eq!(count_uses("x", &stmts), 2); }
1213
1214 #[test]
1215 fn test_count_uses_in_var_decl_stmt() {
1216 let stmts = vec![HirStatement::VariableDeclaration {
1218 name: "y".to_string(),
1219 var_type: HirType::Int,
1220 initializer: Some(HirExpression::Variable("x".to_string())),
1221 }];
1222 assert_eq!(count_uses("x", &stmts), 0);
1224 }
1225
1226 #[test]
1227 fn test_count_uses_in_while_stmt() {
1228 let stmts = vec![HirStatement::While {
1230 condition: HirExpression::Variable("x".to_string()),
1231 body: vec![HirStatement::Expression(HirExpression::Variable("x".to_string()))],
1232 }];
1233 assert_eq!(count_uses("x", &stmts), 0);
1235 }
1236
1237 #[test]
1238 fn test_temp_elimination_name_mismatch() {
1239 let stmts = vec![
1241 HirStatement::VariableDeclaration {
1242 name: "tmp".to_string(),
1243 var_type: HirType::Int,
1244 initializer: Some(HirExpression::IntLiteral(42)),
1245 },
1246 HirStatement::Return(Some(HirExpression::Variable("other".to_string()))),
1247 ];
1248 let result = eliminate_temporaries(stmts);
1249 assert_eq!(result.len(), 2); }
1251
1252 #[test]
1253 fn test_temp_elimination_no_initializer() {
1254 let stmts = vec![
1256 HirStatement::VariableDeclaration {
1257 name: "x".to_string(),
1258 var_type: HirType::Int,
1259 initializer: None,
1260 },
1261 HirStatement::Return(Some(HirExpression::Variable("x".to_string()))),
1262 ];
1263 let result = eliminate_temporaries(stmts);
1264 assert_eq!(result.len(), 2); }
1266
1267 #[test]
1268 fn test_dead_branch_nested_if_recurse() {
1269 let stmts = vec![HirStatement::If {
1271 condition: HirExpression::Variable("x".to_string()),
1272 then_block: vec![HirStatement::If {
1273 condition: HirExpression::IntLiteral(0),
1274 then_block: vec![HirStatement::Return(Some(HirExpression::IntLiteral(99)))],
1275 else_block: None,
1276 }],
1277 else_block: None,
1278 }];
1279 let result = remove_dead_branches(stmts);
1280 assert_eq!(result.len(), 1);
1281 match &result[0] {
1282 HirStatement::If { then_block, .. } => {
1283 assert!(then_block.is_empty(), "Nested dead branch should be removed");
1284 }
1285 other => panic!("Expected If, got {:?}", other),
1286 }
1287 }
1288
1289 #[test]
1290 fn test_is_allocation_expr_cast_non_alloc() {
1291 assert!(!is_allocation_expr(&HirExpression::Cast {
1293 target_type: HirType::Int,
1294 expr: Box::new(HirExpression::IntLiteral(42)),
1295 }));
1296 }
1297
1298 #[test]
1303 fn test_fold_constants_stmt_var_decl_with_foldable_init() {
1304 let stmt = HirStatement::VariableDeclaration {
1306 name: "x".to_string(),
1307 var_type: HirType::Int,
1308 initializer: Some(HirExpression::BinaryOp {
1309 op: BinaryOperator::Add,
1310 left: Box::new(HirExpression::IntLiteral(2)),
1311 right: Box::new(HirExpression::IntLiteral(3)),
1312 }),
1313 };
1314 match fold_constants_stmt(stmt) {
1315 HirStatement::VariableDeclaration {
1316 name,
1317 initializer: Some(HirExpression::IntLiteral(5)),
1318 ..
1319 } => {
1320 assert_eq!(name, "x");
1321 }
1322 other => panic!("Expected VarDecl with folded init=5, got {:?}", other),
1323 }
1324 }
1325
1326 #[test]
1327 fn test_fold_constants_stmt_for_with_foldable_increment() {
1328 let stmt = HirStatement::For {
1331 init: vec![HirStatement::VariableDeclaration {
1332 name: "i".to_string(),
1333 var_type: HirType::Int,
1334 initializer: Some(HirExpression::BinaryOp {
1335 op: BinaryOperator::Multiply,
1336 left: Box::new(HirExpression::IntLiteral(2)),
1337 right: Box::new(HirExpression::IntLiteral(0)),
1338 }),
1339 }],
1340 condition: Some(HirExpression::Variable("i".to_string())),
1341 increment: vec![HirStatement::Assignment {
1342 target: "i".to_string(),
1343 value: HirExpression::BinaryOp {
1344 op: BinaryOperator::Add,
1345 left: Box::new(HirExpression::Variable("i".to_string())),
1346 right: Box::new(HirExpression::BinaryOp {
1347 op: BinaryOperator::Add,
1348 left: Box::new(HirExpression::IntLiteral(1)),
1349 right: Box::new(HirExpression::IntLiteral(1)),
1350 }),
1351 },
1352 }],
1353 body: vec![HirStatement::Break],
1354 };
1355 match fold_constants_stmt(stmt) {
1356 HirStatement::For { init, increment, .. } => {
1357 match &init[0] {
1359 HirStatement::VariableDeclaration {
1360 initializer: Some(HirExpression::IntLiteral(0)),
1361 ..
1362 } => {}
1363 other => panic!("Expected init folded to 0, got {:?}", other),
1364 }
1365 match &increment[0] {
1367 HirStatement::Assignment { value, .. } => {
1368 match value {
1370 HirExpression::BinaryOp { right, .. } => {
1371 assert_eq!(**right, HirExpression::IntLiteral(2));
1372 }
1373 other => panic!("Expected BinaryOp, got {:?}", other),
1374 }
1375 }
1376 other => panic!("Expected Assignment, got {:?}", other),
1377 }
1378 }
1379 other => panic!("Expected For, got {:?}", other),
1380 }
1381 }
1382
1383 #[test]
1384 fn test_count_uses_in_stmt_with_while() {
1385 let stmt = HirStatement::While {
1387 condition: HirExpression::Variable("x".to_string()),
1388 body: vec![HirStatement::Assignment {
1389 target: "y".to_string(),
1390 value: HirExpression::Variable("x".to_string()),
1391 }],
1392 };
1393 assert_eq!(count_uses_in_stmt("x", &stmt), 0);
1395 }
1396
1397 #[test]
1398 fn test_count_uses_in_stmt_if_with_else() {
1399 let stmt = HirStatement::If {
1400 condition: HirExpression::Variable("x".to_string()),
1401 then_block: vec![HirStatement::Return(Some(HirExpression::Variable("x".to_string())))],
1402 else_block: Some(vec![HirStatement::Expression(HirExpression::Variable(
1403 "x".to_string(),
1404 ))]),
1405 };
1406 assert_eq!(count_uses_in_stmt("x", &stmt), 3);
1408 }
1409
1410 #[test]
1411 fn test_count_uses_in_expr_function_call() {
1412 let expr = HirExpression::FunctionCall {
1413 function: "foo".to_string(),
1414 arguments: vec![
1415 HirExpression::Variable("x".to_string()),
1416 HirExpression::Variable("y".to_string()),
1417 HirExpression::Variable("x".to_string()),
1418 ],
1419 };
1420 assert_eq!(count_uses_in_expr("x", &expr), 2);
1421 assert_eq!(count_uses_in_expr("y", &expr), 1);
1422 assert_eq!(count_uses_in_expr("z", &expr), 0);
1423 }
1424
1425 #[test]
1426 fn test_count_uses_in_expr_unary_op_negation() {
1427 let expr = HirExpression::UnaryOp {
1428 op: decy_hir::UnaryOperator::Minus,
1429 operand: Box::new(HirExpression::Variable("x".to_string())),
1430 };
1431 assert_eq!(count_uses_in_expr("x", &expr), 1);
1432 assert_eq!(count_uses_in_expr("y", &expr), 0);
1433 }
1434}