Skip to main content

kaio_core/ir/
register.rs

1//! PTX register types and virtual register allocator.
2
3use std::fmt;
4
5use crate::types::{PtxType, RegKind};
6
7/// A virtual PTX register with a kind prefix and unique index.
8///
9/// Register names follow PTX conventions: `%r0` (32-bit int), `%rd0`
10/// (64-bit int), `%f0` (f32), `%fd0` (f64), `%p0` (predicate).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Register {
13    /// Which register name prefix to use.
14    pub kind: RegKind,
15    /// Unique index within this kind.
16    pub index: u32,
17    /// The PTX type this register was declared with.
18    pub ptx_type: PtxType,
19}
20
21impl Register {
22    /// The PTX register name (e.g. `%r0`, `%fd3`, `%p1`).
23    pub fn name(&self) -> String {
24        format!("{}{}", self.kind.prefix(), self.index)
25    }
26}
27
28impl fmt::Display for Register {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        write!(f, "{}{}", self.kind.prefix(), self.index)
31    }
32}
33
34/// Allocates virtual registers with unique names per kind prefix.
35///
36/// Each call to [`alloc`](Self::alloc) returns a [`Register`] with a
37/// monotonically increasing index within its [`RegKind`]. The allocator
38/// tracks all allocations so they can be emitted as `.reg` declarations
39/// in the PTX kernel prelude.
40#[derive(Debug)]
41pub struct RegisterAllocator {
42    counters: [u32; 7],
43    allocated: Vec<Register>,
44}
45
46impl Default for RegisterAllocator {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl RegisterAllocator {
53    /// Create a new allocator with all counters at zero.
54    pub fn new() -> Self {
55        Self {
56            counters: [0; 7],
57            allocated: Vec::new(),
58        }
59    }
60
61    /// Allocate a fresh register for the given PTX type.
62    pub fn alloc(&mut self, ptx_type: PtxType) -> Register {
63        let kind = ptx_type.reg_kind();
64        let idx = kind.counter_index();
65        let index = self.counters[idx];
66        self.counters[idx] += 1;
67        let reg = Register {
68            kind,
69            index,
70            ptx_type,
71        };
72        self.allocated.push(reg);
73        reg
74    }
75
76    /// Allocate a `.b32` register intended to hold **two fp16 values
77    /// packed into 32 bits** — the storage format `mma.sync.m16n8k16.f16`
78    /// expects for its A and B fragment operands.
79    ///
80    /// This is a thin semantic alias for `alloc(PtxType::U32)` — the
81    /// register lives in the `%r` class at the PTX level. The separate
82    /// method exists so call sites in `fragment.rs` make their intent
83    /// explicit and so grep'ing for `alloc_packed_half2` finds every
84    /// fragment-storage allocation in the codebase.
85    ///
86    /// No new [`RegKind`] variant is introduced — the register really is
87    /// `.b32` at the hardware level, and inventing a fake kind would lie
88    /// about the PTX reality.
89    pub fn alloc_packed_half2(&mut self) -> Register {
90        self.alloc(PtxType::U32)
91    }
92
93    /// All registers allocated so far, in allocation order.
94    pub fn allocated(&self) -> &[Register] {
95        &self.allocated
96    }
97
98    /// Consume the allocator and return all allocated registers.
99    pub fn into_allocated(self) -> Vec<Register> {
100        self.allocated
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[test]
109    fn sequential_indices_within_kind() {
110        let mut alloc = RegisterAllocator::new();
111        let r0 = alloc.alloc(PtxType::S32);
112        let r1 = alloc.alloc(PtxType::U32); // same RegKind::R
113        let r2 = alloc.alloc(PtxType::S32);
114        assert_eq!(r0.index, 0);
115        assert_eq!(r1.index, 1);
116        assert_eq!(r2.index, 2);
117        assert_eq!(r0.name(), "%r0");
118        assert_eq!(r1.name(), "%r1");
119        assert_eq!(r2.name(), "%r2");
120    }
121
122    #[test]
123    fn independent_counters_per_kind() {
124        let mut alloc = RegisterAllocator::new();
125        let r = alloc.alloc(PtxType::S32);
126        let f = alloc.alloc(PtxType::F32);
127        let rd = alloc.alloc(PtxType::U64);
128        let p = alloc.alloc(PtxType::Pred);
129        let fd = alloc.alloc(PtxType::F64);
130        let h = alloc.alloc(PtxType::F16);
131        let hb = alloc.alloc(PtxType::BF16);
132        // All should be index 0 — independent counters
133        assert_eq!(r.index, 0);
134        assert_eq!(f.index, 0);
135        assert_eq!(rd.index, 0);
136        assert_eq!(p.index, 0);
137        assert_eq!(fd.index, 0);
138        assert_eq!(h.index, 0);
139        assert_eq!(hb.index, 0);
140        assert_eq!(r.name(), "%r0");
141        assert_eq!(f.name(), "%f0");
142        assert_eq!(rd.name(), "%rd0");
143        assert_eq!(p.name(), "%p0");
144        assert_eq!(fd.name(), "%fd0");
145        assert_eq!(h.name(), "%h0");
146        assert_eq!(hb.name(), "%hb0");
147    }
148
149    #[test]
150    fn into_allocated_preserves_order() {
151        let mut alloc = RegisterAllocator::new();
152        let r0 = alloc.alloc(PtxType::F32);
153        let r1 = alloc.alloc(PtxType::S32);
154        let r2 = alloc.alloc(PtxType::F32);
155        let regs = alloc.into_allocated();
156        assert_eq!(regs.len(), 3);
157        assert_eq!(regs[0], r0);
158        assert_eq!(regs[1], r1);
159        assert_eq!(regs[2], r2);
160    }
161
162    #[test]
163    fn f16_bf16_register_allocation() {
164        let mut alloc = RegisterAllocator::new();
165        let h0 = alloc.alloc(PtxType::F16);
166        let h1 = alloc.alloc(PtxType::F16);
167        let hb0 = alloc.alloc(PtxType::BF16);
168        let hb1 = alloc.alloc(PtxType::BF16);
169        // f16 counters are independent from bf16
170        assert_eq!(h0.index, 0);
171        assert_eq!(h1.index, 1);
172        assert_eq!(hb0.index, 0);
173        assert_eq!(hb1.index, 1);
174        assert_eq!(h0.name(), "%h0");
175        assert_eq!(h1.name(), "%h1");
176        assert_eq!(hb0.name(), "%hb0");
177        assert_eq!(hb1.name(), "%hb1");
178        // Verify kinds
179        assert_eq!(h0.kind, RegKind::H);
180        assert_eq!(hb0.kind, RegKind::Hb);
181    }
182
183    #[test]
184    fn register_display() {
185        let reg = Register {
186            kind: RegKind::Fd,
187            index: 7,
188            ptx_type: PtxType::F64,
189        };
190        assert_eq!(format!("{reg}"), "%fd7");
191    }
192}