Skip to main content

kaio_core/ir/
register.rs

1//! PTX register types and virtual register allocator.
2
3use std::fmt;
4
5use crate::types::{PtxType, RegKind};
6
7/// A virtual PTX register with a kind prefix and unique index.
8///
9/// Register names follow PTX conventions: `%r0` (32-bit int), `%rd0`
10/// (64-bit int), `%f0` (f32), `%fd0` (f64), `%p0` (predicate).
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Register {
13    /// Which register name prefix to use.
14    pub kind: RegKind,
15    /// Unique index within this kind.
16    pub index: u32,
17    /// The PTX type this register was declared with.
18    pub ptx_type: PtxType,
19}
20
21impl Register {
22    /// The PTX register name (e.g. `%r0`, `%fd3`, `%p1`).
23    pub fn name(&self) -> String {
24        format!("{}{}", self.kind.prefix(), self.index)
25    }
26}
27
28impl fmt::Display for Register {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        write!(f, "{}{}", self.kind.prefix(), self.index)
31    }
32}
33
34/// Allocates virtual registers with unique names per kind prefix.
35///
36/// Each call to [`alloc`](Self::alloc) returns a [`Register`] with a
37/// monotonically increasing index within its [`RegKind`]. The allocator
38/// tracks all allocations so they can be emitted as `.reg` declarations
39/// in the PTX kernel prelude.
40#[derive(Debug)]
41pub struct RegisterAllocator {
42    counters: [u32; 7],
43    allocated: Vec<Register>,
44}
45
46impl Default for RegisterAllocator {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl RegisterAllocator {
53    /// Create a new allocator with all counters at zero.
54    pub fn new() -> Self {
55        Self {
56            counters: [0; 7],
57            allocated: Vec::new(),
58        }
59    }
60
61    /// Allocate a fresh register for the given PTX type.
62    pub fn alloc(&mut self, ptx_type: PtxType) -> Register {
63        let kind = ptx_type.reg_kind();
64        let idx = kind.counter_index();
65        let index = self.counters[idx];
66        self.counters[idx] += 1;
67        let reg = Register {
68            kind,
69            index,
70            ptx_type,
71        };
72        self.allocated.push(reg);
73        reg
74    }
75
76    /// Allocate a `.b32` register intended to hold **two fp16 values
77    /// packed into 32 bits** — the storage format `mma.sync.m16n8k16.f16`
78    /// expects for its A and B fragment operands.
79    ///
80    /// This is a thin semantic alias for `alloc(PtxType::U32)` — the
81    /// register lives in the `%r` class at the PTX level. The separate
82    /// method exists so call sites in `fragment.rs` make their intent
83    /// explicit and so grep'ing for `alloc_packed_half2` finds every
84    /// fragment-storage allocation in the codebase.
85    ///
86    /// No new [`RegKind`] variant is introduced — the register really is
87    /// `.b32` at the hardware level, and inventing a fake kind would lie
88    /// about the PTX reality.
89    pub fn alloc_packed_half2(&mut self) -> Register {
90        self.alloc(PtxType::U32)
91    }
92
93    /// Allocate a `.b32` register intended to hold **four signed 8-bit
94    /// values packed into 32 bits** — the storage format
95    /// `mma.sync.m16n8k32.s8` expects for its A and B fragment operands.
96    ///
97    /// Same hardware register class as [`alloc_packed_half2`](Self::alloc_packed_half2)
98    /// (both live in `%r` / `.b32`); the separate method is a naming hook
99    /// so INT8-path call sites document their intent and so `grep` for
100    /// `alloc_packed_int8x4` finds every INT8 fragment-storage
101    /// allocation in the codebase.
102    ///
103    /// Introduced in Sprint 7.1 for the
104    /// `mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32` path.
105    pub fn alloc_packed_int8x4(&mut self) -> Register {
106        self.alloc(PtxType::U32)
107    }
108
109    /// All registers allocated so far, in allocation order.
110    pub fn allocated(&self) -> &[Register] {
111        &self.allocated
112    }
113
114    /// Consume the allocator and return all allocated registers.
115    pub fn into_allocated(self) -> Vec<Register> {
116        self.allocated
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn sequential_indices_within_kind() {
126        let mut alloc = RegisterAllocator::new();
127        let r0 = alloc.alloc(PtxType::S32);
128        let r1 = alloc.alloc(PtxType::U32); // same RegKind::R
129        let r2 = alloc.alloc(PtxType::S32);
130        assert_eq!(r0.index, 0);
131        assert_eq!(r1.index, 1);
132        assert_eq!(r2.index, 2);
133        assert_eq!(r0.name(), "%r0");
134        assert_eq!(r1.name(), "%r1");
135        assert_eq!(r2.name(), "%r2");
136    }
137
138    #[test]
139    fn independent_counters_per_kind() {
140        let mut alloc = RegisterAllocator::new();
141        let r = alloc.alloc(PtxType::S32);
142        let f = alloc.alloc(PtxType::F32);
143        let rd = alloc.alloc(PtxType::U64);
144        let p = alloc.alloc(PtxType::Pred);
145        let fd = alloc.alloc(PtxType::F64);
146        let h = alloc.alloc(PtxType::F16);
147        let hb = alloc.alloc(PtxType::BF16);
148        // All should be index 0 — independent counters
149        assert_eq!(r.index, 0);
150        assert_eq!(f.index, 0);
151        assert_eq!(rd.index, 0);
152        assert_eq!(p.index, 0);
153        assert_eq!(fd.index, 0);
154        assert_eq!(h.index, 0);
155        assert_eq!(hb.index, 0);
156        assert_eq!(r.name(), "%r0");
157        assert_eq!(f.name(), "%f0");
158        assert_eq!(rd.name(), "%rd0");
159        assert_eq!(p.name(), "%p0");
160        assert_eq!(fd.name(), "%fd0");
161        assert_eq!(h.name(), "%h0");
162        assert_eq!(hb.name(), "%hb0");
163    }
164
165    #[test]
166    fn into_allocated_preserves_order() {
167        let mut alloc = RegisterAllocator::new();
168        let r0 = alloc.alloc(PtxType::F32);
169        let r1 = alloc.alloc(PtxType::S32);
170        let r2 = alloc.alloc(PtxType::F32);
171        let regs = alloc.into_allocated();
172        assert_eq!(regs.len(), 3);
173        assert_eq!(regs[0], r0);
174        assert_eq!(regs[1], r1);
175        assert_eq!(regs[2], r2);
176    }
177
178    #[test]
179    fn f16_bf16_register_allocation() {
180        let mut alloc = RegisterAllocator::new();
181        let h0 = alloc.alloc(PtxType::F16);
182        let h1 = alloc.alloc(PtxType::F16);
183        let hb0 = alloc.alloc(PtxType::BF16);
184        let hb1 = alloc.alloc(PtxType::BF16);
185        // f16 counters are independent from bf16
186        assert_eq!(h0.index, 0);
187        assert_eq!(h1.index, 1);
188        assert_eq!(hb0.index, 0);
189        assert_eq!(hb1.index, 1);
190        assert_eq!(h0.name(), "%h0");
191        assert_eq!(h1.name(), "%h1");
192        assert_eq!(hb0.name(), "%hb0");
193        assert_eq!(hb1.name(), "%hb1");
194        // Verify kinds
195        assert_eq!(h0.kind, RegKind::H);
196        assert_eq!(hb0.kind, RegKind::Hb);
197    }
198
199    #[test]
200    fn register_display() {
201        let reg = Register {
202            kind: RegKind::Fd,
203            index: 7,
204            ptx_type: PtxType::F64,
205        };
206        assert_eq!(format!("{reg}"), "%fd7");
207    }
208}