use std::collections::HashMap;
use std::fmt;
use super::types::PtxType;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Register {
pub name: String,
pub ty: PtxType,
}
impl fmt::Display for Register {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}
pub struct RegisterAllocator {
counters: HashMap<&'static str, u32>,
used_types: Vec<(&'static str, PtxType)>,
}
impl RegisterAllocator {
#[must_use]
pub fn new() -> Self {
Self {
counters: HashMap::new(),
used_types: Vec::new(),
}
}
pub fn alloc(&mut self, ty: PtxType) -> Register {
let prefix = Self::prefix_for(ty);
let counter = self.counters.entry(prefix).or_insert(0);
let idx = *counter;
*counter += 1;
if !self
.used_types
.iter()
.any(|(p, t)| *p == prefix && *t == ty)
{
self.used_types.push((prefix, ty));
}
Register {
name: format!("%{prefix}{idx}"),
ty,
}
}
pub fn alloc_group(&mut self, ty: PtxType, count: u32) -> Vec<Register> {
(0..count).map(|_| self.alloc(ty)).collect()
}
#[must_use]
pub fn emit_declarations(&self) -> Vec<String> {
let mut declarations = Vec::new();
for (prefix, count) in &self.counters {
if *count == 0 {
continue;
}
let ptx_type_str = self
.used_types
.iter()
.find(|(p, _)| p == prefix)
.map_or(".b32", |(_, ty)| ty.reg_type().as_ptx_str());
declarations.push(format!(".reg {ptx_type_str} %{prefix}<{count}>;"));
}
declarations.sort();
declarations
}
const fn prefix_for(ty: PtxType) -> &'static str {
match ty {
PtxType::Pred => "p",
PtxType::F16
| PtxType::F16x2
| PtxType::BF16
| PtxType::BF16x2
| PtxType::F32
| PtxType::F64 => "f",
PtxType::B64 | PtxType::U64 | PtxType::S64 => "rd",
_ => "r",
}
}
}
impl Default for RegisterAllocator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn alloc_float_registers() {
let mut alloc = RegisterAllocator::new();
let r0 = alloc.alloc(PtxType::F32);
let r1 = alloc.alloc(PtxType::F32);
let r2 = alloc.alloc(PtxType::F64);
assert_eq!(r0.name, "%f0");
assert_eq!(r1.name, "%f1");
assert_eq!(r2.name, "%f2");
assert_eq!(r0.ty, PtxType::F32);
assert_eq!(r2.ty, PtxType::F64);
}
#[test]
fn alloc_integer_registers() {
let mut alloc = RegisterAllocator::new();
let r0 = alloc.alloc(PtxType::U32);
let r1 = alloc.alloc(PtxType::S32);
assert_eq!(r0.name, "%r0");
assert_eq!(r1.name, "%r1");
}
#[test]
fn alloc_64bit_registers() {
let mut alloc = RegisterAllocator::new();
let r0 = alloc.alloc(PtxType::U64);
let r1 = alloc.alloc(PtxType::S64);
let r2 = alloc.alloc(PtxType::B64);
assert_eq!(r0.name, "%rd0");
assert_eq!(r1.name, "%rd1");
assert_eq!(r2.name, "%rd2");
}
#[test]
fn alloc_predicate_registers() {
let mut alloc = RegisterAllocator::new();
let p0 = alloc.alloc(PtxType::Pred);
let p1 = alloc.alloc(PtxType::Pred);
assert_eq!(p0.name, "%p0");
assert_eq!(p1.name, "%p1");
}
#[test]
fn alloc_group() {
let mut alloc = RegisterAllocator::new();
let regs = alloc.alloc_group(PtxType::F32, 4);
assert_eq!(regs.len(), 4);
assert_eq!(regs[0].name, "%f0");
assert_eq!(regs[3].name, "%f3");
}
#[test]
fn emit_declarations_sorted() {
let mut alloc = RegisterAllocator::new();
alloc.alloc(PtxType::F32);
alloc.alloc(PtxType::F32);
alloc.alloc(PtxType::U32);
alloc.alloc(PtxType::Pred);
alloc.alloc(PtxType::U64);
let decls = alloc.emit_declarations();
assert_eq!(decls.len(), 4);
let joined = decls.join("\n");
assert!(joined.contains("%f<2>"), "missing f decl: {joined}");
assert!(joined.contains("%p<1>"), "missing p decl: {joined}");
assert!(joined.contains("%r<1>"), "missing r decl: {joined}");
assert!(joined.contains("%rd<1>"), "missing rd decl: {joined}");
for pair in decls.windows(2) {
assert!(pair[0] <= pair[1], "declarations not sorted: {decls:?}");
}
}
#[test]
fn register_display() {
let r = Register {
name: "%f0".to_string(),
ty: PtxType::F32,
};
assert_eq!(format!("{r}"), "%f0");
}
}