1use std::fmt;
4
5use crate::types::{PtxType, RegKind};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct Register {
13 pub kind: RegKind,
15 pub index: u32,
17 pub ptx_type: PtxType,
19}
20
21impl Register {
22 pub fn name(&self) -> String {
24 format!("{}{}", self.kind.prefix(), self.index)
25 }
26}
27
28impl fmt::Display for Register {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 write!(f, "{}{}", self.kind.prefix(), self.index)
31 }
32}
33
34#[derive(Debug)]
41pub struct RegisterAllocator {
42 counters: [u32; 7],
43 allocated: Vec<Register>,
44}
45
46impl Default for RegisterAllocator {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl RegisterAllocator {
53 pub fn new() -> Self {
55 Self {
56 counters: [0; 7],
57 allocated: Vec::new(),
58 }
59 }
60
61 pub fn alloc(&mut self, ptx_type: PtxType) -> Register {
63 let kind = ptx_type.reg_kind();
64 let idx = kind.counter_index();
65 let index = self.counters[idx];
66 self.counters[idx] += 1;
67 let reg = Register {
68 kind,
69 index,
70 ptx_type,
71 };
72 self.allocated.push(reg);
73 reg
74 }
75
76 pub fn alloc_packed_half2(&mut self) -> Register {
90 self.alloc(PtxType::U32)
91 }
92
93 pub fn allocated(&self) -> &[Register] {
95 &self.allocated
96 }
97
98 pub fn into_allocated(self) -> Vec<Register> {
100 self.allocated
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn sequential_indices_within_kind() {
110 let mut alloc = RegisterAllocator::new();
111 let r0 = alloc.alloc(PtxType::S32);
112 let r1 = alloc.alloc(PtxType::U32); let r2 = alloc.alloc(PtxType::S32);
114 assert_eq!(r0.index, 0);
115 assert_eq!(r1.index, 1);
116 assert_eq!(r2.index, 2);
117 assert_eq!(r0.name(), "%r0");
118 assert_eq!(r1.name(), "%r1");
119 assert_eq!(r2.name(), "%r2");
120 }
121
122 #[test]
123 fn independent_counters_per_kind() {
124 let mut alloc = RegisterAllocator::new();
125 let r = alloc.alloc(PtxType::S32);
126 let f = alloc.alloc(PtxType::F32);
127 let rd = alloc.alloc(PtxType::U64);
128 let p = alloc.alloc(PtxType::Pred);
129 let fd = alloc.alloc(PtxType::F64);
130 let h = alloc.alloc(PtxType::F16);
131 let hb = alloc.alloc(PtxType::BF16);
132 assert_eq!(r.index, 0);
134 assert_eq!(f.index, 0);
135 assert_eq!(rd.index, 0);
136 assert_eq!(p.index, 0);
137 assert_eq!(fd.index, 0);
138 assert_eq!(h.index, 0);
139 assert_eq!(hb.index, 0);
140 assert_eq!(r.name(), "%r0");
141 assert_eq!(f.name(), "%f0");
142 assert_eq!(rd.name(), "%rd0");
143 assert_eq!(p.name(), "%p0");
144 assert_eq!(fd.name(), "%fd0");
145 assert_eq!(h.name(), "%h0");
146 assert_eq!(hb.name(), "%hb0");
147 }
148
149 #[test]
150 fn into_allocated_preserves_order() {
151 let mut alloc = RegisterAllocator::new();
152 let r0 = alloc.alloc(PtxType::F32);
153 let r1 = alloc.alloc(PtxType::S32);
154 let r2 = alloc.alloc(PtxType::F32);
155 let regs = alloc.into_allocated();
156 assert_eq!(regs.len(), 3);
157 assert_eq!(regs[0], r0);
158 assert_eq!(regs[1], r1);
159 assert_eq!(regs[2], r2);
160 }
161
162 #[test]
163 fn f16_bf16_register_allocation() {
164 let mut alloc = RegisterAllocator::new();
165 let h0 = alloc.alloc(PtxType::F16);
166 let h1 = alloc.alloc(PtxType::F16);
167 let hb0 = alloc.alloc(PtxType::BF16);
168 let hb1 = alloc.alloc(PtxType::BF16);
169 assert_eq!(h0.index, 0);
171 assert_eq!(h1.index, 1);
172 assert_eq!(hb0.index, 0);
173 assert_eq!(hb1.index, 1);
174 assert_eq!(h0.name(), "%h0");
175 assert_eq!(h1.name(), "%h1");
176 assert_eq!(hb0.name(), "%hb0");
177 assert_eq!(hb1.name(), "%hb1");
178 assert_eq!(h0.kind, RegKind::H);
180 assert_eq!(hb0.kind, RegKind::Hb);
181 }
182
183 #[test]
184 fn register_display() {
185 let reg = Register {
186 kind: RegKind::Fd,
187 index: 7,
188 ptx_type: PtxType::F64,
189 };
190 assert_eq!(format!("{reg}"), "%fd7");
191 }
192}