1use std::fmt;
4
5use crate::types::{PtxType, RegKind};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Register {
13 pub kind: RegKind,
15 pub index: u32,
17 pub ptx_type: PtxType,
19}
20
21impl Register {
22 pub fn name(&self) -> String {
24 format!("{}{}", self.kind.prefix(), self.index)
25 }
26}
27
28impl fmt::Display for Register {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 write!(f, "{}{}", self.kind.prefix(), self.index)
31 }
32}
33
34#[derive(Debug)]
41pub struct RegisterAllocator {
42 counters: [u32; 7],
43 allocated: Vec<Register>,
44}
45
46impl Default for RegisterAllocator {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl RegisterAllocator {
53 pub fn new() -> Self {
55 Self {
56 counters: [0; 7],
57 allocated: Vec::new(),
58 }
59 }
60
61 pub fn alloc(&mut self, ptx_type: PtxType) -> Register {
63 let kind = ptx_type.reg_kind();
64 let idx = kind.counter_index();
65 let index = self.counters[idx];
66 self.counters[idx] += 1;
67 let reg = Register {
68 kind,
69 index,
70 ptx_type,
71 };
72 self.allocated.push(reg);
73 reg
74 }
75
76 pub fn alloc_packed_half2(&mut self) -> Register {
90 self.alloc(PtxType::U32)
91 }
92
93 pub fn alloc_packed_int8x4(&mut self) -> Register {
106 self.alloc(PtxType::U32)
107 }
108
109 pub fn allocated(&self) -> &[Register] {
111 &self.allocated
112 }
113
114 pub fn into_allocated(self) -> Vec<Register> {
116 self.allocated
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn sequential_indices_within_kind() {
126 let mut alloc = RegisterAllocator::new();
127 let r0 = alloc.alloc(PtxType::S32);
128 let r1 = alloc.alloc(PtxType::U32); let r2 = alloc.alloc(PtxType::S32);
130 assert_eq!(r0.index, 0);
131 assert_eq!(r1.index, 1);
132 assert_eq!(r2.index, 2);
133 assert_eq!(r0.name(), "%r0");
134 assert_eq!(r1.name(), "%r1");
135 assert_eq!(r2.name(), "%r2");
136 }
137
138 #[test]
139 fn independent_counters_per_kind() {
140 let mut alloc = RegisterAllocator::new();
141 let r = alloc.alloc(PtxType::S32);
142 let f = alloc.alloc(PtxType::F32);
143 let rd = alloc.alloc(PtxType::U64);
144 let p = alloc.alloc(PtxType::Pred);
145 let fd = alloc.alloc(PtxType::F64);
146 let h = alloc.alloc(PtxType::F16);
147 let hb = alloc.alloc(PtxType::BF16);
148 assert_eq!(r.index, 0);
150 assert_eq!(f.index, 0);
151 assert_eq!(rd.index, 0);
152 assert_eq!(p.index, 0);
153 assert_eq!(fd.index, 0);
154 assert_eq!(h.index, 0);
155 assert_eq!(hb.index, 0);
156 assert_eq!(r.name(), "%r0");
157 assert_eq!(f.name(), "%f0");
158 assert_eq!(rd.name(), "%rd0");
159 assert_eq!(p.name(), "%p0");
160 assert_eq!(fd.name(), "%fd0");
161 assert_eq!(h.name(), "%h0");
162 assert_eq!(hb.name(), "%hb0");
163 }
164
165 #[test]
166 fn into_allocated_preserves_order() {
167 let mut alloc = RegisterAllocator::new();
168 let r0 = alloc.alloc(PtxType::F32);
169 let r1 = alloc.alloc(PtxType::S32);
170 let r2 = alloc.alloc(PtxType::F32);
171 let regs = alloc.into_allocated();
172 assert_eq!(regs.len(), 3);
173 assert_eq!(regs[0], r0);
174 assert_eq!(regs[1], r1);
175 assert_eq!(regs[2], r2);
176 }
177
178 #[test]
179 fn f16_bf16_register_allocation() {
180 let mut alloc = RegisterAllocator::new();
181 let h0 = alloc.alloc(PtxType::F16);
182 let h1 = alloc.alloc(PtxType::F16);
183 let hb0 = alloc.alloc(PtxType::BF16);
184 let hb1 = alloc.alloc(PtxType::BF16);
185 assert_eq!(h0.index, 0);
187 assert_eq!(h1.index, 1);
188 assert_eq!(hb0.index, 0);
189 assert_eq!(hb1.index, 1);
190 assert_eq!(h0.name(), "%h0");
191 assert_eq!(h1.name(), "%h1");
192 assert_eq!(hb0.name(), "%hb0");
193 assert_eq!(hb1.name(), "%hb1");
194 assert_eq!(h0.kind, RegKind::H);
196 assert_eq!(hb0.kind, RegKind::Hb);
197 }
198
199 #[test]
200 fn register_display() {
201 let reg = Register {
202 kind: RegKind::Fd,
203 index: 7,
204 ptx_type: PtxType::F64,
205 };
206 assert_eq!(format!("{reg}"), "%fd7");
207 }
208}