Skip to main content

wave_compiler/regalloc/
spill.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Spill code generation for register allocation.
5//!
6//! When a virtual register cannot be assigned a physical register,
7//! spill code is inserted: stores before definitions and loads before uses.
8//! The spilled value lives in local (scratchpad) memory.
9
10use crate::lir::instruction::LirInst;
11use crate::lir::operand::{MemWidth, VReg};
12
13/// Generate spill code for a set of spilled virtual registers.
14///
15/// Inserts local stores after definitions and local loads before uses
16/// of spilled registers, using a fresh `VReg` for each reload.
17pub fn insert_spill_code(
18    instructions: &mut Vec<LirInst>,
19    spilled: &[VReg],
20    next_vreg: &mut u32,
21    spill_slot_base: u32,
22) -> u32 {
23    if spilled.is_empty() {
24        return 0;
25    }
26
27    let mut spill_offsets: std::collections::HashMap<VReg, u32> = std::collections::HashMap::new();
28    for (i, &vreg) in spilled.iter().enumerate() {
29        #[allow(clippy::cast_possible_truncation)]
30        let slot_index = i as u32;
31        spill_offsets.insert(vreg, spill_slot_base + slot_index * 4);
32    }
33
34    let mut new_insts: Vec<LirInst> = Vec::new();
35    let mut spill_count = 0u32;
36
37    for inst in instructions.iter() {
38        let mut reload_map: std::collections::HashMap<VReg, VReg> =
39            std::collections::HashMap::new();
40
41        for src in inst.src_vregs() {
42            if let Some(&offset) = spill_offsets.get(&src) {
43                let reload_vreg = VReg(*next_vreg);
44                *next_vreg += 1;
45
46                let addr_vreg = VReg(*next_vreg);
47                *next_vreg += 1;
48
49                new_insts.push(LirInst::MovImm {
50                    dest: addr_vreg,
51                    value: offset,
52                });
53                new_insts.push(LirInst::LocalLoad {
54                    dest: reload_vreg,
55                    addr: addr_vreg,
56                    width: MemWidth::W32,
57                });
58                reload_map.insert(src, reload_vreg);
59                spill_count += 1;
60            }
61        }
62
63        new_insts.push(inst.clone());
64
65        if let Some(dest) = inst.dest_vreg() {
66            if let Some(&offset) = spill_offsets.get(&dest) {
67                let addr_vreg = VReg(*next_vreg);
68                *next_vreg += 1;
69
70                new_insts.push(LirInst::MovImm {
71                    dest: addr_vreg,
72                    value: offset,
73                });
74                new_insts.push(LirInst::LocalStore {
75                    addr: addr_vreg,
76                    value: dest,
77                    width: MemWidth::W32,
78                });
79                spill_count += 1;
80            }
81        }
82    }
83
84    *instructions = new_insts;
85    spill_count
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_no_spills_no_change() {
94        let mut insts = vec![
95            LirInst::MovImm {
96                dest: VReg(0),
97                value: 42,
98            },
99            LirInst::Halt,
100        ];
101        let original_len = insts.len();
102        let mut next_vreg = 1;
103        let count = insert_spill_code(&mut insts, &[], &mut next_vreg, 0);
104        assert_eq!(count, 0);
105        assert_eq!(insts.len(), original_len);
106    }
107
108    #[test]
109    fn test_spill_inserts_stores() {
110        let mut insts = vec![
111            LirInst::MovImm {
112                dest: VReg(0),
113                value: 42,
114            },
115            LirInst::Halt,
116        ];
117        let mut next_vreg = 1;
118        let count = insert_spill_code(&mut insts, &[VReg(0)], &mut next_vreg, 0);
119        assert!(count > 0);
120        assert!(insts.len() > 2);
121        let has_store = insts
122            .iter()
123            .any(|i| matches!(i, LirInst::LocalStore { .. }));
124        assert!(has_store);
125    }
126}