Skip to main content

ringkernel_ir/
optimize.rs

1//! Optimization passes for the IR.
2//!
3//! This module provides optimization passes that transform IR modules:
4//!
5//! - **Dead Code Elimination (DCE)**: Remove instructions whose results are never used
6//! - **Constant Folding**: Evaluate operations on constants at compile time
7//! - **Constant Propagation**: Replace uses of constants with their values
8//!
9//! # Example
10//!
11//! ```ignore
12//! use ringkernel_ir::{IrModule, optimize};
13//!
14//! let module = build_ir();
15//! let optimized = optimize::run_all_passes(&module);
16//! ```
17
18use std::collections::{HashMap, HashSet};
19
20use crate::{
21    BinaryOp, BlockId, CompareOp, ConstantValue, IrModule, IrNode, IrType, ScalarType, Terminator,
22    UnaryOp, ValueId,
23};
24
25// ============================================================================
26// OPTIMIZATION PASS INTERFACE
27// ============================================================================
28
29/// An optimization pass that transforms an IR module.
30pub trait OptimizationPass {
31    /// Run the optimization pass on the module.
32    fn run(&self, module: &mut IrModule) -> OptimizationResult;
33
34    /// Get the name of this pass.
35    fn name(&self) -> &'static str;
36}
37
38/// Result of running an optimization pass.
39#[derive(Debug, Clone, Default)]
40pub struct OptimizationResult {
41    /// Whether the module was changed.
42    pub changed: bool,
43    /// Number of instructions removed.
44    pub instructions_removed: usize,
45    /// Number of instructions modified.
46    pub instructions_modified: usize,
47    /// Number of blocks removed.
48    pub blocks_removed: usize,
49}
50
51impl OptimizationResult {
52    /// Create a result indicating no changes.
53    pub fn unchanged() -> Self {
54        Self::default()
55    }
56
57    /// Create a result indicating changes were made.
58    pub fn changed() -> Self {
59        Self {
60            changed: true,
61            ..Default::default()
62        }
63    }
64
65    /// Merge with another result.
66    pub fn merge(&mut self, other: OptimizationResult) {
67        self.changed |= other.changed;
68        self.instructions_removed += other.instructions_removed;
69        self.instructions_modified += other.instructions_modified;
70        self.blocks_removed += other.blocks_removed;
71    }
72}
73
74// ============================================================================
75// DEAD CODE ELIMINATION
76// ============================================================================
77
78/// Dead Code Elimination pass.
79///
80/// Removes instructions whose results are never used.
81pub struct DeadCodeElimination;
82
83impl DeadCodeElimination {
84    /// Create a new DCE pass.
85    pub fn new() -> Self {
86        Self
87    }
88
89    /// Find all values that are used in the module.
90    fn find_used_values(&self, module: &IrModule) -> HashSet<ValueId> {
91        let mut used = HashSet::new();
92
93        // Parameters are always used
94        for param in &module.parameters {
95            used.insert(param.value_id);
96        }
97
98        // Traverse all blocks
99        for block in module.blocks.values() {
100            // Collect uses from instructions
101            for inst in &block.instructions {
102                self.collect_uses(&inst.node, &mut used);
103            }
104
105            // Collect uses from terminator
106            if let Some(ref term) = block.terminator {
107                self.collect_terminator_uses(term, &mut used);
108            }
109        }
110
111        used
112    }
113
114    /// Collect all value uses from a node.
115    fn collect_uses(&self, node: &IrNode, used: &mut HashSet<ValueId>) {
116        match node {
117            IrNode::BinaryOp(_, lhs, rhs) => {
118                used.insert(*lhs);
119                used.insert(*rhs);
120            }
121            IrNode::UnaryOp(_, operand) => {
122                used.insert(*operand);
123            }
124            IrNode::Compare(_, lhs, rhs) => {
125                used.insert(*lhs);
126                used.insert(*rhs);
127            }
128            IrNode::Cast(_, value, _) => {
129                used.insert(*value);
130            }
131            IrNode::Load(ptr) => {
132                used.insert(*ptr);
133            }
134            IrNode::Store(ptr, value) => {
135                used.insert(*ptr);
136                used.insert(*value);
137            }
138            IrNode::GetElementPtr(base, indices) => {
139                used.insert(*base);
140                for idx in indices {
141                    used.insert(*idx);
142                }
143            }
144            IrNode::Select(cond, then_val, else_val) => {
145                used.insert(*cond);
146                used.insert(*then_val);
147                used.insert(*else_val);
148            }
149            IrNode::Phi(incoming) => {
150                for (_, value) in incoming {
151                    used.insert(*value);
152                }
153            }
154            IrNode::Atomic(_, ptr, value) => {
155                used.insert(*ptr);
156                used.insert(*value);
157            }
158            IrNode::AtomicCas(ptr, expected, desired) => {
159                used.insert(*ptr);
160                used.insert(*expected);
161                used.insert(*desired);
162            }
163            IrNode::WarpVote(_, pred) => {
164                used.insert(*pred);
165            }
166            IrNode::WarpShuffle(_, value, lane) => {
167                used.insert(*value);
168                used.insert(*lane);
169            }
170            IrNode::WarpReduce(_, value) => {
171                used.insert(*value);
172            }
173            IrNode::Math(_, args) => {
174                for arg in args {
175                    used.insert(*arg);
176                }
177            }
178            IrNode::Call(_, args) => {
179                for arg in args {
180                    used.insert(*arg);
181                }
182            }
183            IrNode::K2HEnqueue(value) => {
184                used.insert(*value);
185            }
186            IrNode::K2KSend(dest, msg) => {
187                used.insert(*dest);
188                used.insert(*msg);
189            }
190            IrNode::HlcUpdate(ts) => {
191                used.insert(*ts);
192            }
193            IrNode::ExtractField(value, _) => {
194                used.insert(*value);
195            }
196            IrNode::InsertField(base, _, value) => {
197                used.insert(*base);
198                used.insert(*value);
199            }
200            // No uses for these nodes
201            IrNode::Constant(_)
202            | IrNode::Parameter(_)
203            | IrNode::Undef
204            | IrNode::ThreadId(_)
205            | IrNode::BlockId(_)
206            | IrNode::BlockDim(_)
207            | IrNode::GridDim(_)
208            | IrNode::GlobalThreadId(_)
209            | IrNode::WarpId
210            | IrNode::LaneId
211            | IrNode::Barrier
212            | IrNode::MemoryFence(_)
213            | IrNode::GridSync
214            | IrNode::Alloca(_)
215            | IrNode::SharedAlloc(_, _)
216            | IrNode::H2KDequeue
217            | IrNode::H2KIsEmpty
218            | IrNode::K2KRecv
219            | IrNode::K2KTryRecv
220            | IrNode::HlcNow
221            | IrNode::HlcTick => {}
222        }
223    }
224
225    /// Collect uses from a terminator.
226    fn collect_terminator_uses(&self, term: &Terminator, used: &mut HashSet<ValueId>) {
227        match term {
228            Terminator::Return(Some(value)) => {
229                used.insert(*value);
230            }
231            Terminator::CondBranch(cond, _, _) => {
232                used.insert(*cond);
233            }
234            Terminator::Switch(value, _, _) => {
235                used.insert(*value);
236            }
237            Terminator::Return(None) | Terminator::Branch(_) | Terminator::Unreachable => {}
238        }
239    }
240
241    /// Check if an instruction has side effects and cannot be removed.
242    fn has_side_effects(&self, node: &IrNode) -> bool {
243        matches!(
244            node,
245            IrNode::Store(_, _)
246                | IrNode::Atomic(_, _, _)
247                | IrNode::AtomicCas(_, _, _)
248                | IrNode::Barrier
249                | IrNode::MemoryFence(_)
250                | IrNode::GridSync
251                | IrNode::Call(_, _)
252                | IrNode::K2HEnqueue(_)
253                | IrNode::K2KSend(_, _)
254                | IrNode::HlcTick
255                | IrNode::HlcUpdate(_)
256        )
257    }
258}
259
260impl Default for DeadCodeElimination {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266impl OptimizationPass for DeadCodeElimination {
267    fn run(&self, module: &mut IrModule) -> OptimizationResult {
268        let used = self.find_used_values(module);
269        let mut result = OptimizationResult::unchanged();
270
271        // Remove unused instructions from each block
272        for block in module.blocks.values_mut() {
273            let original_len = block.instructions.len();
274
275            block.instructions.retain(|inst| {
276                // Keep if result is used OR has side effects
277                used.contains(&inst.result) || self.has_side_effects(&inst.node)
278            });
279
280            let removed = original_len - block.instructions.len();
281            if removed > 0 {
282                result.changed = true;
283                result.instructions_removed += removed;
284            }
285        }
286
287        result
288    }
289
290    fn name(&self) -> &'static str {
291        "dead-code-elimination"
292    }
293}
294
295// ============================================================================
296// CONSTANT FOLDING
297// ============================================================================
298
299/// Constant Folding pass.
300///
301/// Evaluates operations on constants at compile time.
302pub struct ConstantFolding {
303    /// Map from value IDs to their constant values (used for incremental folding).
304    #[allow(dead_code)]
305    constants: HashMap<ValueId, ConstantValue>,
306}
307
308impl ConstantFolding {
309    /// Create a new constant folding pass.
310    pub fn new() -> Self {
311        Self {
312            constants: HashMap::new(),
313        }
314    }
315
316    /// Try to fold a binary operation.
317    fn fold_binary_op(
318        &self,
319        op: BinaryOp,
320        lhs: &ConstantValue,
321        rhs: &ConstantValue,
322    ) -> Option<ConstantValue> {
323        match (lhs, rhs) {
324            (ConstantValue::I32(l), ConstantValue::I32(r)) => {
325                Some(ConstantValue::I32(Self::fold_binary_i32(op, *l, *r)?))
326            }
327            (ConstantValue::U32(l), ConstantValue::U32(r)) => {
328                Some(ConstantValue::U32(Self::fold_binary_u32(op, *l, *r)?))
329            }
330            (ConstantValue::F32(l), ConstantValue::F32(r)) => {
331                Some(ConstantValue::F32(Self::fold_binary_f32(op, *l, *r)?))
332            }
333            (ConstantValue::I64(l), ConstantValue::I64(r)) => {
334                Some(ConstantValue::I64(Self::fold_binary_i64(op, *l, *r)?))
335            }
336            (ConstantValue::U64(l), ConstantValue::U64(r)) => {
337                Some(ConstantValue::U64(Self::fold_binary_u64(op, *l, *r)?))
338            }
339            (ConstantValue::F64(l), ConstantValue::F64(r)) => {
340                Some(ConstantValue::F64(Self::fold_binary_f64(op, *l, *r)?))
341            }
342            _ => None,
343        }
344    }
345
346    fn fold_binary_i32(op: BinaryOp, l: i32, r: i32) -> Option<i32> {
347        Some(match op {
348            BinaryOp::Add => l.wrapping_add(r),
349            BinaryOp::Sub => l.wrapping_sub(r),
350            BinaryOp::Mul => l.wrapping_mul(r),
351            BinaryOp::Div => l.checked_div(r)?,
352            BinaryOp::Rem => l.checked_rem(r)?,
353            BinaryOp::And => l & r,
354            BinaryOp::Or => l | r,
355            BinaryOp::Xor => l ^ r,
356            BinaryOp::Shl => l.wrapping_shl(r as u32),
357            BinaryOp::Shr => l.wrapping_shr(r as u32),
358            BinaryOp::Sar => l >> (r as u32),
359            BinaryOp::Min => l.min(r),
360            BinaryOp::Max => l.max(r),
361            _ => return None,
362        })
363    }
364
365    fn fold_binary_u32(op: BinaryOp, l: u32, r: u32) -> Option<u32> {
366        Some(match op {
367            BinaryOp::Add => l.wrapping_add(r),
368            BinaryOp::Sub => l.wrapping_sub(r),
369            BinaryOp::Mul => l.wrapping_mul(r),
370            BinaryOp::Div => l.checked_div(r)?,
371            BinaryOp::Rem => l.checked_rem(r)?,
372            BinaryOp::And => l & r,
373            BinaryOp::Or => l | r,
374            BinaryOp::Xor => l ^ r,
375            BinaryOp::Shl => l.wrapping_shl(r),
376            BinaryOp::Shr => l.wrapping_shr(r),
377            BinaryOp::Sar => l >> r,
378            BinaryOp::Min => l.min(r),
379            BinaryOp::Max => l.max(r),
380            _ => return None,
381        })
382    }
383
384    fn fold_binary_i64(op: BinaryOp, l: i64, r: i64) -> Option<i64> {
385        Some(match op {
386            BinaryOp::Add => l.wrapping_add(r),
387            BinaryOp::Sub => l.wrapping_sub(r),
388            BinaryOp::Mul => l.wrapping_mul(r),
389            BinaryOp::Div => l.checked_div(r)?,
390            BinaryOp::Rem => l.checked_rem(r)?,
391            BinaryOp::And => l & r,
392            BinaryOp::Or => l | r,
393            BinaryOp::Xor => l ^ r,
394            BinaryOp::Shl => l.wrapping_shl(r as u32),
395            BinaryOp::Shr => l.wrapping_shr(r as u32),
396            BinaryOp::Sar => l >> (r as u32),
397            BinaryOp::Min => l.min(r),
398            BinaryOp::Max => l.max(r),
399            _ => return None,
400        })
401    }
402
403    fn fold_binary_u64(op: BinaryOp, l: u64, r: u64) -> Option<u64> {
404        Some(match op {
405            BinaryOp::Add => l.wrapping_add(r),
406            BinaryOp::Sub => l.wrapping_sub(r),
407            BinaryOp::Mul => l.wrapping_mul(r),
408            BinaryOp::Div => l.checked_div(r)?,
409            BinaryOp::Rem => l.checked_rem(r)?,
410            BinaryOp::And => l & r,
411            BinaryOp::Or => l | r,
412            BinaryOp::Xor => l ^ r,
413            BinaryOp::Shl => l.wrapping_shl(r as u32),
414            BinaryOp::Shr => l.wrapping_shr(r as u32),
415            BinaryOp::Sar => l >> (r as u32),
416            BinaryOp::Min => l.min(r),
417            BinaryOp::Max => l.max(r),
418            _ => return None,
419        })
420    }
421
422    fn fold_binary_f32(op: BinaryOp, l: f32, r: f32) -> Option<f32> {
423        Some(match op {
424            BinaryOp::Add => l + r,
425            BinaryOp::Sub => l - r,
426            BinaryOp::Mul => l * r,
427            BinaryOp::Div => l / r,
428            BinaryOp::Rem => l % r,
429            BinaryOp::Min => l.min(r),
430            BinaryOp::Max => l.max(r),
431            BinaryOp::Pow => l.powf(r),
432            _ => return None,
433        })
434    }
435
436    fn fold_binary_f64(op: BinaryOp, l: f64, r: f64) -> Option<f64> {
437        Some(match op {
438            BinaryOp::Add => l + r,
439            BinaryOp::Sub => l - r,
440            BinaryOp::Mul => l * r,
441            BinaryOp::Div => l / r,
442            BinaryOp::Rem => l % r,
443            BinaryOp::Min => l.min(r),
444            BinaryOp::Max => l.max(r),
445            BinaryOp::Pow => l.powf(r),
446            _ => return None,
447        })
448    }
449
450    /// Try to fold a unary operation.
451    fn fold_unary_op(&self, op: UnaryOp, operand: &ConstantValue) -> Option<ConstantValue> {
452        match operand {
453            ConstantValue::I32(v) => Some(ConstantValue::I32(Self::fold_unary_i32(op, *v)?)),
454            ConstantValue::U32(v) => Some(ConstantValue::U32(Self::fold_unary_u32(op, *v)?)),
455            ConstantValue::F32(v) => Some(ConstantValue::F32(Self::fold_unary_f32(op, *v)?)),
456            ConstantValue::F64(v) => Some(ConstantValue::F64(Self::fold_unary_f64(op, *v)?)),
457            ConstantValue::Bool(v) => {
458                if op == UnaryOp::LogicalNot {
459                    Some(ConstantValue::Bool(!v))
460                } else {
461                    None
462                }
463            }
464            _ => None,
465        }
466    }
467
468    fn fold_unary_i32(op: UnaryOp, v: i32) -> Option<i32> {
469        Some(match op {
470            UnaryOp::Neg => -v,
471            UnaryOp::Not => !v,
472            UnaryOp::Abs => v.abs(),
473            UnaryOp::Sign => v.signum(),
474            _ => return None,
475        })
476    }
477
478    fn fold_unary_u32(op: UnaryOp, v: u32) -> Option<u32> {
479        Some(match op {
480            UnaryOp::Not => !v,
481            _ => return None,
482        })
483    }
484
485    fn fold_unary_f32(op: UnaryOp, v: f32) -> Option<f32> {
486        Some(match op {
487            UnaryOp::Neg => -v,
488            UnaryOp::Abs => v.abs(),
489            UnaryOp::Sqrt => v.sqrt(),
490            UnaryOp::Rsqrt => 1.0 / v.sqrt(),
491            UnaryOp::Floor => v.floor(),
492            UnaryOp::Ceil => v.ceil(),
493            UnaryOp::Round => v.round(),
494            UnaryOp::Trunc => v.trunc(),
495            UnaryOp::Sign => v.signum(),
496            _ => return None,
497        })
498    }
499
500    fn fold_unary_f64(op: UnaryOp, v: f64) -> Option<f64> {
501        Some(match op {
502            UnaryOp::Neg => -v,
503            UnaryOp::Abs => v.abs(),
504            UnaryOp::Sqrt => v.sqrt(),
505            UnaryOp::Rsqrt => 1.0 / v.sqrt(),
506            UnaryOp::Floor => v.floor(),
507            UnaryOp::Ceil => v.ceil(),
508            UnaryOp::Round => v.round(),
509            UnaryOp::Trunc => v.trunc(),
510            UnaryOp::Sign => v.signum(),
511            _ => return None,
512        })
513    }
514
515    /// Try to fold a comparison.
516    fn fold_compare(
517        &self,
518        op: CompareOp,
519        lhs: &ConstantValue,
520        rhs: &ConstantValue,
521    ) -> Option<ConstantValue> {
522        let result = match (lhs, rhs) {
523            (ConstantValue::I32(l), ConstantValue::I32(r)) => Self::compare_i32(op, *l, *r),
524            (ConstantValue::U32(l), ConstantValue::U32(r)) => Self::compare_u32(op, *l, *r),
525            (ConstantValue::F32(l), ConstantValue::F32(r)) => Self::compare_f32(op, *l, *r),
526            (ConstantValue::Bool(l), ConstantValue::Bool(r)) => match op {
527                CompareOp::Eq => *l == *r,
528                CompareOp::Ne => *l != *r,
529                _ => return None,
530            },
531            _ => return None,
532        };
533        Some(ConstantValue::Bool(result))
534    }
535
536    fn compare_i32(op: CompareOp, l: i32, r: i32) -> bool {
537        match op {
538            CompareOp::Eq => l == r,
539            CompareOp::Ne => l != r,
540            CompareOp::Lt => l < r,
541            CompareOp::Le => l <= r,
542            CompareOp::Gt => l > r,
543            CompareOp::Ge => l >= r,
544        }
545    }
546
547    fn compare_u32(op: CompareOp, l: u32, r: u32) -> bool {
548        match op {
549            CompareOp::Eq => l == r,
550            CompareOp::Ne => l != r,
551            CompareOp::Lt => l < r,
552            CompareOp::Le => l <= r,
553            CompareOp::Gt => l > r,
554            CompareOp::Ge => l >= r,
555        }
556    }
557
558    fn compare_f32(op: CompareOp, l: f32, r: f32) -> bool {
559        match op {
560            CompareOp::Eq => l == r,
561            CompareOp::Ne => l != r,
562            CompareOp::Lt => l < r,
563            CompareOp::Le => l <= r,
564            CompareOp::Gt => l > r,
565            CompareOp::Ge => l >= r,
566        }
567    }
568
569    /// Get a constant value for a value ID if available (for future use).
570    #[allow(dead_code)]
571    fn get_constant<'a>(&'a self, id: ValueId, module: &'a IrModule) -> Option<&'a ConstantValue> {
572        // First check our map
573        if let Some(c) = self.constants.get(&id) {
574            return Some(c);
575        }
576
577        // Then check if it's defined as a constant in the module
578        if let Some(value) = module.get_value(id) {
579            if let IrNode::Constant(ref c) = value.node {
580                return Some(c);
581            }
582        }
583
584        None
585    }
586}
587
588impl Default for ConstantFolding {
589    fn default() -> Self {
590        Self::new()
591    }
592}
593
594impl OptimizationPass for ConstantFolding {
595    fn run(&self, module: &mut IrModule) -> OptimizationResult {
596        let mut result = OptimizationResult::unchanged();
597        let mut constants = HashMap::new();
598
599        // First pass: collect all constants
600        for value in module.values.values() {
601            if let IrNode::Constant(ref c) = value.node {
602                constants.insert(value.id, c.clone());
603            }
604        }
605
606        // Second pass: fold operations
607        for block in module.blocks.values_mut() {
608            for inst in &mut block.instructions {
609                let folded = match &inst.node {
610                    IrNode::BinaryOp(op, lhs, rhs) => {
611                        let lhs_const = constants.get(lhs);
612                        let rhs_const = constants.get(rhs);
613
614                        if let (Some(l), Some(r)) = (lhs_const, rhs_const) {
615                            Self::new().fold_binary_op(*op, l, r)
616                        } else {
617                            None
618                        }
619                    }
620                    IrNode::UnaryOp(op, operand) => {
621                        if let Some(c) = constants.get(operand) {
622                            Self::new().fold_unary_op(*op, c)
623                        } else {
624                            None
625                        }
626                    }
627                    IrNode::Compare(op, lhs, rhs) => {
628                        let lhs_const = constants.get(lhs);
629                        let rhs_const = constants.get(rhs);
630
631                        if let (Some(l), Some(r)) = (lhs_const, rhs_const) {
632                            Self::new().fold_compare(*op, l, r)
633                        } else {
634                            None
635                        }
636                    }
637                    IrNode::Select(cond, then_val, else_val) => {
638                        if let Some(ConstantValue::Bool(c)) = constants.get(cond) {
639                            // Fold to one branch
640                            let selected = if *c { then_val } else { else_val };
641                            constants.get(selected).cloned()
642                        } else {
643                            None
644                        }
645                    }
646                    _ => None,
647                };
648
649                if let Some(constant) = folded {
650                    // Replace instruction with constant
651                    let new_type = constant.ir_type();
652                    inst.node = IrNode::Constant(constant.clone());
653                    inst.result_type = new_type;
654                    constants.insert(inst.result, constant);
655                    result.changed = true;
656                    result.instructions_modified += 1;
657                }
658            }
659        }
660
661        result
662    }
663
664    fn name(&self) -> &'static str {
665        "constant-folding"
666    }
667}
668
669// ============================================================================
670// DEAD BLOCK ELIMINATION
671// ============================================================================
672
673/// Dead Block Elimination pass.
674///
675/// Removes unreachable blocks from the control flow graph.
676pub struct DeadBlockElimination;
677
678impl DeadBlockElimination {
679    /// Create a new dead block elimination pass.
680    pub fn new() -> Self {
681        Self
682    }
683
684    /// Find all reachable blocks starting from the entry.
685    fn find_reachable_blocks(&self, module: &IrModule) -> HashSet<BlockId> {
686        let mut reachable = HashSet::new();
687        let mut worklist = vec![module.entry_block];
688
689        while let Some(block_id) = worklist.pop() {
690            if !reachable.insert(block_id) {
691                continue;
692            }
693
694            if let Some(block) = module.get_block(block_id) {
695                // Add successors to worklist
696                match &block.terminator {
697                    Some(Terminator::Branch(target)) => {
698                        worklist.push(*target);
699                    }
700                    Some(Terminator::CondBranch(_, then_target, else_target)) => {
701                        worklist.push(*then_target);
702                        worklist.push(*else_target);
703                    }
704                    Some(Terminator::Switch(_, default, cases)) => {
705                        worklist.push(*default);
706                        for (_, target) in cases {
707                            worklist.push(*target);
708                        }
709                    }
710                    _ => {}
711                }
712            }
713        }
714
715        reachable
716    }
717}
718
719impl Default for DeadBlockElimination {
720    fn default() -> Self {
721        Self::new()
722    }
723}
724
725impl OptimizationPass for DeadBlockElimination {
726    fn run(&self, module: &mut IrModule) -> OptimizationResult {
727        let reachable = self.find_reachable_blocks(module);
728        let mut result = OptimizationResult::unchanged();
729
730        // Collect unreachable blocks
731        let unreachable: Vec<BlockId> = module
732            .blocks
733            .keys()
734            .filter(|id| !reachable.contains(id))
735            .copied()
736            .collect();
737
738        // Remove unreachable blocks
739        for block_id in unreachable {
740            module.blocks.remove(&block_id);
741            result.changed = true;
742            result.blocks_removed += 1;
743        }
744
745        result
746    }
747
748    fn name(&self) -> &'static str {
749        "dead-block-elimination"
750    }
751}
752
753// ============================================================================
754// ALGEBRAIC SIMPLIFICATION
755// ============================================================================
756
757/// Algebraic Simplification pass.
758///
759/// Simplifies expressions using algebraic identities:
760/// - x + 0 = x
761/// - x * 1 = x
762/// - x * 0 = 0
763/// - x - x = 0
764/// - x / 1 = x
765/// - x & 0 = 0
766/// - x | 0 = x
767/// - etc.
768pub struct AlgebraicSimplification;
769
770impl AlgebraicSimplification {
771    /// Create a new algebraic simplification pass.
772    pub fn new() -> Self {
773        Self
774    }
775
776    /// Check if a constant is zero.
777    fn is_zero(c: &ConstantValue) -> bool {
778        match c {
779            ConstantValue::I32(0) => true,
780            ConstantValue::U32(0) => true,
781            ConstantValue::I64(0) => true,
782            ConstantValue::U64(0) => true,
783            ConstantValue::F32(f) => *f == 0.0,
784            ConstantValue::F64(f) => *f == 0.0,
785            _ => false,
786        }
787    }
788
789    /// Check if a constant is one.
790    fn is_one(c: &ConstantValue) -> bool {
791        match c {
792            ConstantValue::I32(1) => true,
793            ConstantValue::U32(1) => true,
794            ConstantValue::I64(1) => true,
795            ConstantValue::U64(1) => true,
796            ConstantValue::F32(f) => *f == 1.0,
797            ConstantValue::F64(f) => *f == 1.0,
798            _ => false,
799        }
800    }
801
802    /// Create a zero constant of the given type.
803    fn zero_for_type(ty: &IrType) -> Option<ConstantValue> {
804        Some(match ty {
805            IrType::Scalar(ScalarType::I32) => ConstantValue::I32(0),
806            IrType::Scalar(ScalarType::U32) => ConstantValue::U32(0),
807            IrType::Scalar(ScalarType::I64) => ConstantValue::I64(0),
808            IrType::Scalar(ScalarType::U64) => ConstantValue::U64(0),
809            IrType::Scalar(ScalarType::F32) => ConstantValue::F32(0.0),
810            IrType::Scalar(ScalarType::F64) => ConstantValue::F64(0.0),
811            _ => return None,
812        })
813    }
814}
815
816impl Default for AlgebraicSimplification {
817    fn default() -> Self {
818        Self::new()
819    }
820}
821
822impl OptimizationPass for AlgebraicSimplification {
823    fn run(&self, module: &mut IrModule) -> OptimizationResult {
824        let mut result = OptimizationResult::unchanged();
825
826        // Collect constants
827        let mut constants = HashMap::new();
828        for value in module.values.values() {
829            if let IrNode::Constant(ref c) = value.node {
830                constants.insert(value.id, c.clone());
831            }
832        }
833
834        // Simplify operations
835        for block in module.blocks.values_mut() {
836            for inst in &mut block.instructions {
837                let simplified = match &inst.node {
838                    IrNode::BinaryOp(op, lhs, rhs) => {
839                        let lhs_const = constants.get(lhs);
840                        let rhs_const = constants.get(rhs);
841
842                        match op {
843                            // x + 0 = x
844                            BinaryOp::Add if rhs_const.is_some_and(Self::is_zero) => {
845                                Some(IrNode::Parameter(0)) // Placeholder, replaced below
846                            }
847                            // 0 + x = x
848                            BinaryOp::Add if lhs_const.is_some_and(Self::is_zero) => {
849                                Some(IrNode::Parameter(1))
850                            }
851                            // x * 1 = x
852                            BinaryOp::Mul if rhs_const.is_some_and(Self::is_one) => {
853                                Some(IrNode::Parameter(0))
854                            }
855                            // 1 * x = x
856                            BinaryOp::Mul if lhs_const.is_some_and(Self::is_one) => {
857                                Some(IrNode::Parameter(1))
858                            }
859                            // x * 0 = 0
860                            BinaryOp::Mul
861                                if rhs_const.is_some_and(Self::is_zero)
862                                    || lhs_const.is_some_and(Self::is_zero) =>
863                            {
864                                Self::zero_for_type(&inst.result_type).map(IrNode::Constant)
865                            }
866                            // x - 0 = x
867                            BinaryOp::Sub if rhs_const.is_some_and(Self::is_zero) => {
868                                Some(IrNode::Parameter(0))
869                            }
870                            // x / 1 = x
871                            BinaryOp::Div if rhs_const.is_some_and(Self::is_one) => {
872                                Some(IrNode::Parameter(0))
873                            }
874                            // x & 0 = 0
875                            BinaryOp::And if rhs_const.is_some_and(Self::is_zero) => {
876                                Self::zero_for_type(&inst.result_type).map(IrNode::Constant)
877                            }
878                            // x | 0 = x
879                            BinaryOp::Or if rhs_const.is_some_and(Self::is_zero) => {
880                                Some(IrNode::Parameter(0))
881                            }
882                            // x ^ 0 = x
883                            BinaryOp::Xor if rhs_const.is_some_and(Self::is_zero) => {
884                                Some(IrNode::Parameter(0))
885                            }
886                            _ => None,
887                        }
888                    }
889                    _ => None,
890                };
891
892                // Apply simplification
893                if let Some(simplified_node) = simplified {
894                    match simplified_node {
895                        IrNode::Parameter(0) => {
896                            // Replace with lhs
897                            // Note: Full value propagation requires SSA-form copy propagation
898                            // which is a more complex optimization. For now, we only handle
899                            // constant folding cases. The instruction remains unchanged.
900                        }
901                        IrNode::Parameter(1) => {
902                            // Replace with rhs
903                            // Same limitation as above - would need copy propagation pass
904                        }
905                        IrNode::Constant(c) => {
906                            inst.node = IrNode::Constant(c.clone());
907                            constants.insert(inst.result, c);
908                            result.changed = true;
909                            result.instructions_modified += 1;
910                        }
911                        _ => {}
912                    }
913                }
914            }
915        }
916
917        result
918    }
919
920    fn name(&self) -> &'static str {
921        "algebraic-simplification"
922    }
923}
924
925// ============================================================================
926// PASS MANAGER
927// ============================================================================
928
929/// Runs optimization passes on an IR module.
930pub struct PassManager {
931    passes: Vec<Box<dyn OptimizationPass>>,
932    max_iterations: usize,
933}
934
935impl PassManager {
936    /// Create a new pass manager with default passes.
937    pub fn new() -> Self {
938        Self {
939            passes: vec![
940                Box::new(ConstantFolding::new()),
941                Box::new(AlgebraicSimplification::new()),
942                Box::new(DeadCodeElimination::new()),
943                Box::new(DeadBlockElimination::new()),
944            ],
945            max_iterations: 10,
946        }
947    }
948
949    /// Create an empty pass manager.
950    pub fn empty() -> Self {
951        Self {
952            passes: Vec::new(),
953            max_iterations: 10,
954        }
955    }
956
957    /// Add a pass to the manager.
958    pub fn add_pass<P: OptimizationPass + 'static>(&mut self, pass: P) -> &mut Self {
959        self.passes.push(Box::new(pass));
960        self
961    }
962
963    /// Set the maximum number of iterations.
964    pub fn max_iterations(&mut self, n: usize) -> &mut Self {
965        self.max_iterations = n;
966        self
967    }
968
969    /// Run all passes on the module.
970    pub fn run(&self, module: &mut IrModule) -> OptimizationResult {
971        let mut total_result = OptimizationResult::unchanged();
972
973        for iteration in 0..self.max_iterations {
974            let mut changed = false;
975
976            for pass in &self.passes {
977                let pass_result = pass.run(module);
978                changed |= pass_result.changed;
979                total_result.merge(pass_result);
980            }
981
982            if !changed {
983                break;
984            }
985
986            // Safety check
987            if iteration == self.max_iterations - 1 {
988                eprintln!(
989                    "Warning: optimization reached max iterations ({})",
990                    self.max_iterations
991                );
992            }
993        }
994
995        total_result
996    }
997}
998
999impl Default for PassManager {
1000    fn default() -> Self {
1001        Self::new()
1002    }
1003}
1004
1005// ============================================================================
1006// CONVENIENCE FUNCTIONS
1007// ============================================================================
1008
1009/// Run all standard optimization passes on a module.
1010pub fn optimize(module: &mut IrModule) -> OptimizationResult {
1011    PassManager::new().run(module)
1012}
1013
1014/// Run only DCE on a module.
1015pub fn run_dce(module: &mut IrModule) -> OptimizationResult {
1016    DeadCodeElimination::new().run(module)
1017}
1018
1019/// Run only constant folding on a module.
1020pub fn run_constant_folding(module: &mut IrModule) -> OptimizationResult {
1021    ConstantFolding::new().run(module)
1022}
1023
1024// ============================================================================
1025// TESTS
1026// ============================================================================
1027
1028#[cfg(test)]
1029mod tests {
1030    use super::*;
1031    use crate::IrBuilder;
1032
1033    #[test]
1034    fn test_dce_removes_unused() {
1035        let mut builder = IrBuilder::new("test");
1036
1037        // Create some values (constants are stored in values map, not as instructions)
1038        let a = builder.const_i32(10);
1039        let b = builder.const_i32(20);
1040
1041        // Create an unused computation - this adds an instruction to the block
1042        let _unused_sum = builder.add(a, b);
1043
1044        // Create a used computation
1045        let c = builder.const_i32(5);
1046        let used = builder.mul(c, c);
1047
1048        // Return the used value
1049        builder.ret_value(used);
1050
1051        let mut module = builder.build();
1052
1053        let result = DeadCodeElimination::new().run(&mut module);
1054
1055        // The unused add instruction should be removed
1056        assert!(result.changed);
1057        assert!(result.instructions_removed > 0);
1058    }
1059
1060    #[test]
1061    fn test_constant_folding_binary() {
1062        let mut builder = IrBuilder::new("test");
1063
1064        // 2 + 3 should fold to 5
1065        let a = builder.const_i32(2);
1066        let b = builder.const_i32(3);
1067        let sum = builder.add(a, b);
1068
1069        builder.ret_value(sum);
1070
1071        let mut module = builder.build();
1072
1073        let result = ConstantFolding::new().run(&mut module);
1074
1075        assert!(result.changed);
1076        assert!(result.instructions_modified > 0);
1077    }
1078
1079    #[test]
1080    fn test_constant_folding_unary() {
1081        let mut builder = IrBuilder::new("test");
1082
1083        // -5 should fold
1084        let a = builder.const_i32(5);
1085        let neg = builder.neg(a);
1086
1087        builder.ret_value(neg);
1088
1089        let mut module = builder.build();
1090
1091        let result = ConstantFolding::new().run(&mut module);
1092
1093        assert!(result.changed);
1094    }
1095
1096    #[test]
1097    fn test_pass_manager() {
1098        let mut builder = IrBuilder::new("test");
1099
1100        // Create some optimizable code
1101        let a = builder.const_i32(2);
1102        let b = builder.const_i32(3);
1103        let sum = builder.add(a, b);
1104        let _unused = builder.const_i32(999);
1105
1106        builder.ret_value(sum);
1107
1108        let mut module = builder.build();
1109
1110        let result = PassManager::new().run(&mut module);
1111
1112        assert!(result.changed);
1113    }
1114
1115    #[test]
1116    fn test_optimization_result_merge() {
1117        let mut r1 = OptimizationResult {
1118            changed: true,
1119            instructions_removed: 5,
1120            instructions_modified: 3,
1121            blocks_removed: 1,
1122        };
1123
1124        let r2 = OptimizationResult {
1125            changed: false,
1126            instructions_removed: 2,
1127            instructions_modified: 1,
1128            blocks_removed: 0,
1129        };
1130
1131        r1.merge(r2);
1132
1133        assert!(r1.changed);
1134        assert_eq!(r1.instructions_removed, 7);
1135        assert_eq!(r1.instructions_modified, 4);
1136        assert_eq!(r1.blocks_removed, 1);
1137    }
1138}