use super::types::PtxType;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PtxReg {
TidX,
TidY,
TidZ,
CtaIdX,
CtaIdY,
CtaIdZ,
NtidX,
NtidY,
NtidZ,
NctaIdX,
NctaIdY,
NctaIdZ,
WarpId,
LaneId,
SmId,
Clock,
Clock64,
}
impl PtxReg {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::TidX => "%tid.x",
Self::TidY => "%tid.y",
Self::TidZ => "%tid.z",
Self::CtaIdX => "%ctaid.x",
Self::CtaIdY => "%ctaid.y",
Self::CtaIdZ => "%ctaid.z",
Self::NtidX => "%ntid.x",
Self::NtidY => "%ntid.y",
Self::NtidZ => "%ntid.z",
Self::NctaIdX => "%nctaid.x",
Self::NctaIdY => "%nctaid.y",
Self::NctaIdZ => "%nctaid.z",
Self::WarpId => "%warpid",
Self::LaneId => "%laneid",
Self::SmId => "%smid",
Self::Clock => "%clock",
Self::Clock64 => "%clock64",
}
}
#[must_use]
pub const fn data_type(self) -> PtxType {
match self {
Self::Clock64 => PtxType::U64,
_ => PtxType::U32,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VirtualReg {
id: u32,
ty: PtxType,
}
impl VirtualReg {
#[must_use]
pub const fn new(id: u32, ty: PtxType) -> Self {
Self { id, ty }
}
#[must_use]
pub const fn id(self) -> u32 {
self.id
}
#[must_use]
pub const fn ty(self) -> PtxType {
self.ty
}
#[must_use]
pub fn to_ptx_string(self) -> String {
format!("{}{}", self.ty.register_prefix(), self.id)
}
#[inline]
pub fn write_to<W: fmt::Write>(self, w: &mut W) -> fmt::Result {
write!(w, "{}{}", self.ty.register_prefix(), self.id)
}
}
impl fmt::Display for VirtualReg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{}", self.ty.register_prefix(), self.id)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PhysicalReg(pub u32);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LiveRange {
pub start: usize,
pub end: usize,
}
impl LiveRange {
#[must_use]
pub const fn new(start: usize, end: usize) -> Self {
Self { start, end }
}
#[must_use]
pub const fn overlaps(&self, other: &Self) -> bool {
self.start < other.end && other.start < self.end
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RegisterPressure {
pub max_live: usize,
pub spill_count: usize,
pub utilization: f64,
}
#[derive(Debug, Clone)]
pub struct RegisterAllocator {
type_counters: HashMap<PtxType, u32>,
live_ranges: HashMap<(PtxType, u32), LiveRange>,
allocated: Vec<VirtualReg>,
current_instruction: usize,
spill_count: usize,
}
impl RegisterAllocator {
#[must_use]
pub fn new() -> Self {
Self {
type_counters: HashMap::new(),
live_ranges: HashMap::new(),
allocated: Vec::new(),
current_instruction: 0,
spill_count: 0,
}
}
pub fn allocate_virtual(&mut self, ty: PtxType) -> VirtualReg {
let id = *self.type_counters.get(&ty).unwrap_or(&0);
self.type_counters.insert(ty, id + 1);
let vreg = VirtualReg::new(id, ty);
self.allocated.push(vreg);
self.live_ranges.insert(
(ty, id),
LiveRange::new(self.current_instruction, self.current_instruction + 1),
);
vreg
}
pub fn extend_live_range(&mut self, vreg: VirtualReg) {
if let Some(range) = self.live_ranges.get_mut(&(vreg.ty(), vreg.id())) {
range.end = self.current_instruction + 1;
}
}
pub fn next_instruction(&mut self) {
self.current_instruction += 1;
}
#[must_use]
pub fn pressure_report(&self) -> RegisterPressure {
let max_live = self.allocated.len();
RegisterPressure {
max_live,
spill_count: self.spill_count,
utilization: max_live as f64 / 256.0, }
}
#[must_use]
pub fn emit_declarations(&self) -> String {
let mut decls = String::new();
let mut by_type: HashMap<PtxType, Vec<&VirtualReg>> = HashMap::new();
for vreg in &self.allocated {
by_type.entry(vreg.ty()).or_default().push(vreg);
}
for (ty, regs) in by_type {
if !regs.is_empty() {
let count = regs.len();
let decl_type = ty.register_declaration_type();
decls.push_str(&format!(
" .reg {} {}<{}>;\n",
decl_type,
ty.register_prefix(),
count
));
}
}
decls
}
}
impl Default for RegisterAllocator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_special_register_strings() {
assert_eq!(PtxReg::TidX.to_ptx_string(), "%tid.x");
assert_eq!(PtxReg::CtaIdX.to_ptx_string(), "%ctaid.x");
assert_eq!(PtxReg::NtidX.to_ptx_string(), "%ntid.x");
assert_eq!(PtxReg::LaneId.to_ptx_string(), "%laneid");
assert_eq!(PtxReg::WarpId.to_ptx_string(), "%warpid");
}
#[test]
fn test_special_register_types() {
assert_eq!(PtxReg::TidX.data_type(), PtxType::U32);
assert_eq!(PtxReg::Clock64.data_type(), PtxType::U64);
}
#[test]
fn test_virtual_register_creation() {
let vreg = VirtualReg::new(0, PtxType::F32);
assert_eq!(vreg.id(), 0);
assert_eq!(vreg.ty(), PtxType::F32);
}
#[test]
fn test_virtual_register_string() {
let vreg = VirtualReg::new(5, PtxType::F32);
assert_eq!(vreg.to_ptx_string(), "%f5");
let vreg_u32 = VirtualReg::new(3, PtxType::U32);
assert_eq!(vreg_u32.to_ptx_string(), "%r3");
let vreg_pred = VirtualReg::new(1, PtxType::Pred);
assert_eq!(vreg_pred.to_ptx_string(), "%p1");
}
#[test]
fn test_live_range_overlap() {
let r1 = LiveRange::new(0, 5);
let r2 = LiveRange::new(3, 8);
let r3 = LiveRange::new(5, 10);
let r4 = LiveRange::new(10, 15);
assert!(r1.overlaps(&r2)); assert!(!r1.overlaps(&r3)); assert!(!r1.overlaps(&r4));
}
#[test]
fn test_register_allocator() {
let mut alloc = RegisterAllocator::new();
let r1 = alloc.allocate_virtual(PtxType::F32);
let r2 = alloc.allocate_virtual(PtxType::F32);
let r3 = alloc.allocate_virtual(PtxType::U32);
assert_eq!(r1.id(), 0); assert_eq!(r2.id(), 1); assert_eq!(r3.id(), 0);
assert_eq!(r1.ty(), PtxType::F32);
assert_eq!(r3.ty(), PtxType::U32);
}
#[test]
fn test_pressure_report() {
let mut alloc = RegisterAllocator::new();
let _ = alloc.allocate_virtual(PtxType::F32);
let _ = alloc.allocate_virtual(PtxType::F32);
let _ = alloc.allocate_virtual(PtxType::F32);
let report = alloc.pressure_report();
assert_eq!(report.max_live, 3);
assert_eq!(report.spill_count, 0);
}
#[test]
fn test_emit_declarations() {
let mut alloc = RegisterAllocator::new();
let _ = alloc.allocate_virtual(PtxType::F32);
let _ = alloc.allocate_virtual(PtxType::F32);
let _ = alloc.allocate_virtual(PtxType::U32);
let decls = alloc.emit_declarations();
assert!(decls.contains(".reg .f32"));
assert!(decls.contains(".reg .u32") || decls.contains(".reg .s32"));
}
#[test]
fn test_u32_s32_separate_prefixes() {
let mut alloc = RegisterAllocator::new();
let u0 = alloc.allocate_virtual(PtxType::U32);
let u1 = alloc.allocate_virtual(PtxType::U32);
let u2 = alloc.allocate_virtual(PtxType::U32);
let s0 = alloc.allocate_virtual(PtxType::S32);
let s1 = alloc.allocate_virtual(PtxType::S32);
assert_eq!(u0.id(), 0);
assert_eq!(u1.id(), 1);
assert_eq!(u2.id(), 2);
assert_eq!(s0.id(), 0);
assert_eq!(s1.id(), 1);
assert_eq!(u0.to_ptx_string(), "%r0");
assert_eq!(s0.to_ptx_string(), "%ri0");
assert_ne!(u0.to_ptx_string(), s0.to_ptx_string());
let decls = alloc.emit_declarations();
assert!(decls.contains(".reg .u32 %r<3>"), "Missing u32 decl in:\n{decls}");
assert!(decls.contains(".reg .s32 %ri<2>"), "Missing s32 decl in:\n{decls}");
}
#[test]
fn test_all_special_registers() {
assert_eq!(PtxReg::TidY.to_ptx_string(), "%tid.y");
assert_eq!(PtxReg::TidZ.to_ptx_string(), "%tid.z");
assert_eq!(PtxReg::CtaIdY.to_ptx_string(), "%ctaid.y");
assert_eq!(PtxReg::CtaIdZ.to_ptx_string(), "%ctaid.z");
assert_eq!(PtxReg::NtidY.to_ptx_string(), "%ntid.y");
assert_eq!(PtxReg::NtidZ.to_ptx_string(), "%ntid.z");
assert_eq!(PtxReg::NctaIdX.to_ptx_string(), "%nctaid.x");
assert_eq!(PtxReg::NctaIdY.to_ptx_string(), "%nctaid.y");
assert_eq!(PtxReg::NctaIdZ.to_ptx_string(), "%nctaid.z");
assert_eq!(PtxReg::SmId.to_ptx_string(), "%smid");
assert_eq!(PtxReg::Clock.to_ptx_string(), "%clock");
}
#[test]
fn test_extend_live_range() {
let mut alloc = RegisterAllocator::new();
let vreg = alloc.allocate_virtual(PtxType::F32);
alloc.next_instruction();
alloc.next_instruction();
alloc.extend_live_range(vreg);
let report = alloc.pressure_report();
assert_eq!(report.max_live, 1);
}
#[test]
fn test_next_instruction() {
let mut alloc = RegisterAllocator::new();
let _ = alloc.allocate_virtual(PtxType::F32);
alloc.next_instruction();
let _ = alloc.allocate_virtual(PtxType::F32);
alloc.next_instruction();
let _ = alloc.allocate_virtual(PtxType::F32);
assert_eq!(alloc.pressure_report().max_live, 3);
}
#[test]
fn test_virtual_register_display() {
let vreg = VirtualReg::new(7, PtxType::F64);
let display_str = format!("{}", vreg);
assert_eq!(display_str, "%fd7");
let vreg_u64 = VirtualReg::new(2, PtxType::U64);
let display_str_u64 = format!("{}", vreg_u64);
assert_eq!(display_str_u64, "%rd2");
}
#[test]
fn test_virtual_register_write_to() {
let vreg = VirtualReg::new(3, PtxType::F32);
let mut buffer = String::new();
vreg.write_to(&mut buffer).unwrap();
assert_eq!(buffer, "%f3");
}
#[test]
fn test_register_allocator_default() {
let alloc = RegisterAllocator::default();
assert_eq!(alloc.pressure_report().max_live, 0);
}
#[test]
fn test_physical_reg() {
let preg = PhysicalReg(5);
assert_eq!(preg.0, 5);
}
#[test]
fn test_register_pressure_fields() {
let pressure = RegisterPressure {
max_live: 10,
spill_count: 0,
utilization: 0.039,
};
assert_eq!(pressure.max_live, 10);
assert_eq!(pressure.spill_count, 0);
assert!((pressure.utilization - 0.039).abs() < f64::EPSILON);
}
#[test]
fn test_live_range_fields() {
let range = LiveRange::new(5, 15);
assert_eq!(range.start, 5);
assert_eq!(range.end, 15);
}
}