1use lamina_mir::{Function, Instruction, Operand, Register, VirtualReg};
2use std::collections::{HashMap, HashSet};
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
7pub enum PhysRegHandle {
8 Named(&'static str),
10}
11
12impl PhysRegHandle {
13 pub fn as_named(self) -> Option<&'static str> {
15 match self {
16 PhysRegHandle::Named(name) => Some(name),
17 }
18 }
19}
20
21pub trait PhysRegConvertible: Copy + Eq {
24 fn into_handle(self) -> PhysRegHandle;
26
27 fn from_handle(handle: PhysRegHandle) -> Option<Self>
29 where
30 Self: Sized;
31}
32
33impl PhysRegConvertible for &'static str {
34 fn into_handle(self) -> PhysRegHandle {
35 PhysRegHandle::Named(self)
36 }
37
38 fn from_handle(handle: PhysRegHandle) -> Option<Self> {
39 match handle {
40 PhysRegHandle::Named(name) => Some(name),
41 }
42 }
43}
44
45pub trait RegisterAllocator {
54 type PhysReg: PhysRegConvertible;
56
57 fn alloc_scratch(&mut self) -> Option<Self::PhysReg>;
60
61 fn free_scratch(&mut self, phys: Self::PhysReg);
63
64 fn get_mapping(&self, vreg: &VirtualReg) -> Option<Self::PhysReg>;
67
68 fn ensure_mapping(&mut self, vreg: VirtualReg) -> Option<Self::PhysReg>;
72
73 fn mapped_for_register(&self, reg: &Register) -> Option<Self::PhysReg>;
76
77 fn occupy(&mut self, phys: Self::PhysReg);
80
81 fn release(&mut self, phys: Self::PhysReg);
83
84 fn is_occupied(&self, phys: Self::PhysReg) -> bool;
87}
88
89pub trait RegisterAllocatorDyn {
92 fn alloc_scratch_dyn(&mut self) -> Option<PhysRegHandle>;
93 fn free_scratch_dyn(&mut self, phys: PhysRegHandle);
94 fn get_mapping_dyn(&self, vreg: &VirtualReg) -> Option<PhysRegHandle>;
95 fn ensure_mapping_dyn(&mut self, vreg: VirtualReg) -> Option<PhysRegHandle>;
96 fn mapped_for_register_dyn(&self, reg: &Register) -> Option<PhysRegHandle>;
97 fn occupy_dyn(&mut self, phys: PhysRegHandle);
98 fn release_dyn(&mut self, phys: PhysRegHandle);
99 fn is_occupied_dyn(&self, phys: PhysRegHandle) -> bool;
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum Allocation<R: Copy> {
106 Register(R),
107 Spill(i32),
109}
110
111#[derive(Debug, Clone, PartialEq, Eq)]
115pub struct LiveInterval {
116 pub vreg: VirtualReg,
117 pub start: usize,
119 pub end: usize,
121}
122
123pub struct LinearScanAllocator;
132
133impl LinearScanAllocator {
134 pub fn compute_intervals(function: &Function) -> Vec<LiveInterval> {
143 let mut intervals: HashMap<VirtualReg, LiveInterval> = HashMap::new();
144 let mut pos: usize = 0;
145
146 for param in &function.sig.params {
148 if let Register::Virtual(v) = ¶m.reg {
149 intervals.entry(*v).or_insert(LiveInterval {
150 vreg: *v,
151 start: 0,
152 end: 0,
153 });
154 }
155 }
156
157 for block in &function.blocks {
158 for instr in &block.instructions {
159 Self::scan_instruction(instr, pos, &mut intervals);
160 pos += 1;
161 }
162 }
163
164 let mut result: Vec<LiveInterval> = intervals.into_values().collect();
165 result.sort_by_key(|i| i.start);
166 result
167 }
168
169 fn scan_instruction(
170 instr: &Instruction,
171 pos: usize,
172 intervals: &mut HashMap<VirtualReg, LiveInterval>,
173 ) {
174 if let Some(def) = Self::def_reg(instr)
176 && let Register::Virtual(v) = def
177 {
178 let entry = intervals.entry(v).or_insert(LiveInterval {
179 vreg: v,
180 start: pos,
181 end: pos,
182 });
183 if pos > entry.end {
185 entry.end = pos;
186 }
187 }
188
189 for used in Self::use_regs(instr) {
191 if let Register::Virtual(v) = used {
192 let entry = intervals.entry(v).or_insert(LiveInterval {
193 vreg: v,
194 start: pos,
195 end: pos,
196 });
197 if pos > entry.end {
198 entry.end = pos;
199 }
200 }
201 }
202 }
203
204 fn def_reg(instr: &Instruction) -> Option<Register> {
206 match instr {
207 Instruction::IntBinary { dst, .. }
208 | Instruction::FloatBinary { dst, .. }
209 | Instruction::FloatUnary { dst, .. }
210 | Instruction::IntCmp { dst, .. }
211 | Instruction::FloatCmp { dst, .. }
212 | Instruction::Select { dst, .. }
213 | Instruction::Load { dst, .. }
214 | Instruction::Lea { dst, .. }
215 | Instruction::VectorOp { dst, .. } => Some(dst.clone()),
216
217 Instruction::Call { ret: Some(ret), .. } => Some(ret.clone()),
218
219 #[cfg(feature = "nightly")]
220 Instruction::SimdBinary { dst, .. }
221 | Instruction::SimdUnary { dst, .. }
222 | Instruction::SimdTernary { dst, .. }
223 | Instruction::SimdShuffle { dst, .. }
224 | Instruction::SimdExtract { dst, .. }
225 | Instruction::SimdInsert { dst, .. }
226 | Instruction::SimdLoad { dst, .. } => Some(dst.clone()),
227
228 #[cfg(feature = "nightly")]
229 Instruction::AtomicLoad { dst, .. }
230 | Instruction::AtomicBinary { dst, .. }
231 | Instruction::AtomicCompareExchange { dst, .. } => Some(dst.clone()),
232
233 _ => None,
234 }
235 }
236
237 fn use_regs(instr: &Instruction) -> Vec<Register> {
239 let mut uses = Vec::new();
240
241 let push_op = |uses: &mut Vec<Register>, op: &Operand| {
242 if let Operand::Register(r) = op {
243 uses.push(r.clone());
244 }
245 };
246
247 let push_addr = |uses: &mut Vec<Register>, addr: &lamina_mir::AddressMode| match addr {
248 lamina_mir::AddressMode::BaseOffset { base, .. } => uses.push(base.clone()),
249 lamina_mir::AddressMode::BaseIndexScale { base, index, .. } => {
250 uses.push(base.clone());
251 uses.push(index.clone());
252 }
253 };
254
255 match instr {
256 Instruction::IntBinary { lhs, rhs, .. }
257 | Instruction::FloatBinary { lhs, rhs, .. }
258 | Instruction::IntCmp { lhs, rhs, .. }
259 | Instruction::FloatCmp { lhs, rhs, .. } => {
260 push_op(&mut uses, lhs);
261 push_op(&mut uses, rhs);
262 }
263
264 Instruction::FloatUnary { src, .. } => push_op(&mut uses, src),
265
266 Instruction::Select {
267 cond,
268 true_val,
269 false_val,
270 ..
271 } => {
272 uses.push(cond.clone());
273 push_op(&mut uses, true_val);
274 push_op(&mut uses, false_val);
275 }
276
277 Instruction::Load { addr, .. } => push_addr(&mut uses, addr),
278
279 Instruction::Store { src, addr, .. } => {
280 push_op(&mut uses, src);
281 push_addr(&mut uses, addr);
282 }
283
284 Instruction::Lea { base, .. } => uses.push(base.clone()),
285
286 Instruction::VectorOp { operands, .. } => {
287 for op in operands {
288 push_op(&mut uses, op);
289 }
290 }
291
292 Instruction::Call { args, .. } | Instruction::TailCall { args, .. } => {
293 for op in args {
294 push_op(&mut uses, op);
295 }
296 }
297
298 Instruction::Ret { value: Some(v) } => push_op(&mut uses, v),
299
300 Instruction::Br { cond, .. } | Instruction::Switch { value: cond, .. } => {
301 uses.push(cond.clone());
302 }
303
304 _ => {}
305 }
306
307 uses
308 }
309
310 pub fn allocate<R: Copy + Eq>(
319 intervals: &[LiveInterval],
320 available_regs: &[R],
321 ) -> HashMap<VirtualReg, Allocation<R>> {
322 let mut result: HashMap<VirtualReg, Allocation<R>> = HashMap::new();
323 let mut active: Vec<(&LiveInterval, R)> = Vec::new();
325 let mut free: Vec<R> = available_regs.to_vec();
326 let mut next_spill: i32 = -8; for interval in intervals {
329 let current_start = interval.start;
331 let mut freed: Vec<R> = Vec::new();
332 active.retain(|(ai, reg)| {
333 if ai.end < current_start {
334 freed.push(*reg);
335 false
336 } else {
337 true
338 }
339 });
340 free.extend(freed);
341
342 if let Some(reg) = free.pop() {
343 result.insert(interval.vreg, Allocation::Register(reg));
345 let pos = active
347 .binary_search_by_key(&interval.end, |(ai, _)| ai.end)
348 .unwrap_or_else(|i| i);
349 active.insert(pos, (interval, reg));
350 } else {
351 match active.last().cloned() {
354 Some((spill_interval, spill_reg)) if spill_interval.end > interval.end => {
355 result.insert(spill_interval.vreg, Allocation::Spill(next_spill));
357 next_spill -= 8;
358 active.pop();
359 result.insert(interval.vreg, Allocation::Register(spill_reg));
360 let pos = active
361 .binary_search_by_key(&interval.end, |(ai, _)| ai.end)
362 .unwrap_or_else(|i| i);
363 active.insert(pos, (interval, spill_reg));
364 }
365 _ => {
366 result.insert(interval.vreg, Allocation::Spill(next_spill));
368 next_spill -= 8;
369 }
370 }
371 }
372 }
373
374 result
375 }
376}
377
378#[inline]
380pub fn intervals_interfere(a: &LiveInterval, b: &LiveInterval) -> bool {
381 a.start <= b.end && b.start <= a.end
382}
383
384pub struct GraphColorAllocator;
391
392impl GraphColorAllocator {
393 pub fn allocate<R: Copy + Eq + std::hash::Hash>(
397 intervals: &[LiveInterval],
398 available_regs: &[R],
399 ) -> HashMap<VirtualReg, Allocation<R>> {
400 if intervals.is_empty() {
401 return HashMap::new();
402 }
403
404 let mut order: Vec<usize> = (0..intervals.len()).collect();
405 order.sort_by(|&i, &j| {
406 let deg_i = intervals
407 .iter()
408 .enumerate()
409 .filter(|(k, other)| *k != i && intervals_interfere(&intervals[i], other))
410 .count();
411 let deg_j = intervals
412 .iter()
413 .enumerate()
414 .filter(|(k, other)| *k != j && intervals_interfere(&intervals[j], other))
415 .count();
416 deg_j.cmp(°_i).then_with(|| i.cmp(&j))
417 });
418
419 let mut result: HashMap<VirtualReg, Allocation<R>> = HashMap::new();
420 let mut next_spill: i32 = -8;
421
422 for idx in order {
423 let interval = &intervals[idx];
424 let mut blocked: HashSet<R> = HashSet::new();
425 for (j, other) in intervals.iter().enumerate() {
426 if j == idx || !intervals_interfere(interval, other) {
427 continue;
428 }
429 if let Some(Allocation::Register(r)) = result.get(&other.vreg) {
430 blocked.insert(*r);
431 }
432 }
433
434 let mut picked: Option<R> = None;
435 for reg in available_regs {
436 if !blocked.contains(reg) {
437 picked = Some(*reg);
438 break;
439 }
440 }
441
442 match picked {
443 Some(r) => {
444 result.insert(interval.vreg, Allocation::Register(r));
445 }
446 None => {
447 result.insert(interval.vreg, Allocation::Spill(next_spill));
448 next_spill -= 8;
449 }
450 }
451 }
452
453 result
454 }
455}
456
457impl<T> RegisterAllocatorDyn for T
458where
459 T: RegisterAllocator,
460{
461 fn alloc_scratch_dyn(&mut self) -> Option<PhysRegHandle> {
462 self.alloc_scratch().map(|reg| reg.into_handle())
463 }
464
465 fn free_scratch_dyn(&mut self, phys: PhysRegHandle) {
466 if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
467 self.free_scratch(reg);
468 } else {
469 debug_assert!(false, "failed to decode physical register handle");
470 }
471 }
472
473 fn get_mapping_dyn(&self, vreg: &VirtualReg) -> Option<PhysRegHandle> {
474 self.get_mapping(vreg).map(|reg| reg.into_handle())
475 }
476
477 fn ensure_mapping_dyn(&mut self, vreg: VirtualReg) -> Option<PhysRegHandle> {
478 self.ensure_mapping(vreg).map(|reg| reg.into_handle())
479 }
480
481 fn mapped_for_register_dyn(&self, reg: &Register) -> Option<PhysRegHandle> {
482 self.mapped_for_register(reg).map(|r| r.into_handle())
483 }
484
485 fn occupy_dyn(&mut self, phys: PhysRegHandle) {
486 if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
487 self.occupy(reg);
488 } else {
489 debug_assert!(false, "failed to decode physical register handle");
490 }
491 }
492
493 fn release_dyn(&mut self, phys: PhysRegHandle) {
494 if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
495 self.release(reg);
496 } else {
497 debug_assert!(false, "failed to decode physical register handle");
498 }
499 }
500
501 fn is_occupied_dyn(&self, phys: PhysRegHandle) -> bool {
502 if let Some(reg) = <T::PhysReg as PhysRegConvertible>::from_handle(phys) {
503 self.is_occupied(reg)
504 } else {
505 debug_assert!(false, "failed to decode physical register handle");
506 false
507 }
508 }
509}
510
511#[cfg(test)]
516mod tests {
517 use super::*;
518 use lamina_mir::{
519 Block, FunctionBuilder, Instruction, IntBinOp, MirType, Operand, Register, ScalarType,
520 VirtualReg,
521 };
522
523 fn make_add_function() -> Function {
524 let v0 = Register::Virtual(VirtualReg::gpr(0));
530 let v1 = Register::Virtual(VirtualReg::gpr(1));
531 let v2 = Register::Virtual(VirtualReg::gpr(2));
532 let i64_ty = MirType::Scalar(ScalarType::I64);
533
534 FunctionBuilder::new("add")
535 .param(v0.clone(), i64_ty)
536 .param(v1.clone(), i64_ty)
537 .returns(i64_ty)
538 .block("entry")
539 .instr(Instruction::IntBinary {
540 op: IntBinOp::Add,
541 ty: i64_ty,
542 dst: v2.clone(),
543 lhs: Operand::Register(v0),
544 rhs: Operand::Register(v1),
545 })
546 .instr(Instruction::Ret {
547 value: Some(Operand::Register(v2)),
548 })
549 .build()
550 }
551
552 #[test]
553 fn test_compute_intervals_basic() {
554 let func = make_add_function();
555 let intervals = LinearScanAllocator::compute_intervals(&func);
556
557 assert!(!intervals.is_empty());
559 let vreg_ids: Vec<u32> = intervals.iter().map(|i| i.vreg.id).collect();
561 assert!(vreg_ids.contains(&0));
562 assert!(vreg_ids.contains(&1));
563 assert!(vreg_ids.contains(&2));
564 }
565
566 #[test]
567 fn test_compute_intervals_sorted_by_start() {
568 let func = make_add_function();
569 let intervals = LinearScanAllocator::compute_intervals(&func);
570 let starts: Vec<usize> = intervals.iter().map(|i| i.start).collect();
571 let mut sorted = starts.clone();
572 sorted.sort_unstable();
573 assert_eq!(starts, sorted, "intervals should be sorted by start");
574 }
575
576 #[test]
577 fn test_allocate_fits_in_registers() {
578 let func = make_add_function();
579 let intervals = LinearScanAllocator::compute_intervals(&func);
580 let regs = ["r12", "r13", "r14", "r15"];
581 let alloc = LinearScanAllocator::allocate(&intervals, ®s);
582
583 for interval in &intervals {
585 let a = alloc
586 .get(&interval.vreg)
587 .expect("every vreg should be allocated");
588 assert!(
589 matches!(a, Allocation::Register(_)),
590 "vreg {:?} should be in a register, got {:?}",
591 interval.vreg,
592 a
593 );
594 }
595 }
596
597 #[test]
598 fn test_allocate_spills_when_registers_exhausted() {
599 let i64_ty = MirType::Scalar(ScalarType::I64);
601 let mut func = FunctionBuilder::new("spill_test").returns(i64_ty).build();
602 let mut block = Block::new("entry");
603
604 for i in 0u32..8 {
606 let vi = Register::Virtual(VirtualReg::gpr(i));
607 let vj = Register::Virtual(VirtualReg::gpr(i + 1));
608 let vd = Register::Virtual(VirtualReg::gpr(i + 2));
609 block.push(Instruction::IntBinary {
610 op: IntBinOp::Add,
611 ty: i64_ty,
612 dst: vd,
613 lhs: Operand::Register(vi),
614 rhs: Operand::Register(vj),
615 });
616 }
617 block.push(Instruction::Ret { value: None });
618 func.add_block(block);
619
620 let intervals = LinearScanAllocator::compute_intervals(&func);
621 let regs = ["r12", "r13"]; let alloc = LinearScanAllocator::allocate(&intervals, ®s);
623
624 let has_spill = alloc.values().any(|a| matches!(a, Allocation::Spill(_)));
626 assert!(has_spill, "expected spills when registers are exhausted");
627
628 for a in alloc.values() {
630 if let Allocation::Spill(offset) = a {
631 assert!(*offset < 0, "spill offset should be negative");
632 assert_eq!(offset % 8, 0, "spill offset should be 8-byte aligned");
633 }
634 }
635 }
636
637 #[test]
638 fn graph_color_fits_three_vregs_without_spill() {
639 let func = make_add_function();
640 let intervals = LinearScanAllocator::compute_intervals(&func);
641 let regs = ["r12", "r13", "r14", "r15"];
642 let gc = GraphColorAllocator::allocate(&intervals, ®s);
643 for interval in &intervals {
644 let a = gc.get(&interval.vreg).expect("allocated");
645 assert!(
646 matches!(a, Allocation::Register(_)),
647 "graph color should keep simple add in registers, got {:?}",
648 a
649 );
650 }
651 }
652
653 #[test]
654 fn graph_color_spills_when_register_count_exhausted() {
655 let i64_ty = MirType::Scalar(ScalarType::I64);
656 let mut func = FunctionBuilder::new("spill_gc").returns(i64_ty).build();
657 let mut block = Block::new("entry");
658 for i in 0u32..8 {
659 let vi = Register::Virtual(VirtualReg::gpr(i));
660 let vj = Register::Virtual(VirtualReg::gpr(i + 1));
661 let vd = Register::Virtual(VirtualReg::gpr(i + 2));
662 block.push(Instruction::IntBinary {
663 op: IntBinOp::Add,
664 ty: i64_ty,
665 dst: vd,
666 lhs: Operand::Register(vi),
667 rhs: Operand::Register(vj),
668 });
669 }
670 block.push(Instruction::Ret { value: None });
671 func.add_block(block);
672
673 let intervals = LinearScanAllocator::compute_intervals(&func);
674 let regs = ["r12", "r13"];
675 let gc = GraphColorAllocator::allocate(&intervals, ®s);
676 let has_spill = gc.values().any(|a| matches!(a, Allocation::Spill(_)));
677 assert!(
678 has_spill,
679 "graph color should spill when k=2 and pressure is high"
680 );
681 }
682}