use super::types::PtxType;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PtxReg {
TidX,
TidY,
TidZ,
CtaIdX,
CtaIdY,
CtaIdZ,
NtidX,
NtidY,
NtidZ,
NctaIdX,
NctaIdY,
NctaIdZ,
WarpId,
LaneId,
SmId,
Clock,
Clock64,
}
impl PtxReg {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::TidX => "%tid.x",
Self::TidY => "%tid.y",
Self::TidZ => "%tid.z",
Self::CtaIdX => "%ctaid.x",
Self::CtaIdY => "%ctaid.y",
Self::CtaIdZ => "%ctaid.z",
Self::NtidX => "%ntid.x",
Self::NtidY => "%ntid.y",
Self::NtidZ => "%ntid.z",
Self::NctaIdX => "%nctaid.x",
Self::NctaIdY => "%nctaid.y",
Self::NctaIdZ => "%nctaid.z",
Self::WarpId => "%warpid",
Self::LaneId => "%laneid",
Self::SmId => "%smid",
Self::Clock => "%clock",
Self::Clock64 => "%clock64",
}
}
#[must_use]
pub const fn data_type(self) -> PtxType {
match self {
Self::Clock64 => PtxType::U64,
_ => PtxType::U32,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VirtualReg {
id: u32,
ty: PtxType,
}
impl VirtualReg {
#[must_use]
pub const fn new(id: u32, ty: PtxType) -> Self {
Self { id, ty }
}
#[must_use]
pub const fn id(self) -> u32 {
self.id
}
#[must_use]
pub const fn ty(self) -> PtxType {
self.ty
}
#[must_use]
pub fn to_ptx_string(self) -> String {
format!("{}{}", self.ty.register_prefix(), self.id)
}
#[inline]
pub fn write_to<W: fmt::Write>(self, w: &mut W) -> fmt::Result {
write!(w, "{}{}", self.ty.register_prefix(), self.id)
}
}
impl fmt::Display for VirtualReg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{}", self.ty.register_prefix(), self.id)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PhysicalReg(pub u32);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LiveRange {
pub start: usize,
pub end: usize,
}
impl LiveRange {
#[must_use]
pub const fn new(start: usize, end: usize) -> Self {
Self { start, end }
}
#[must_use]
pub const fn overlaps(&self, other: &Self) -> bool {
self.start < other.end && other.start < self.end
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RegisterPressure {
pub max_live: usize,
pub spill_count: usize,
pub utilization: f64,
}
#[derive(Debug, Clone)]
pub struct RegisterAllocator {
type_counters: HashMap<PtxType, u32>,
live_ranges: HashMap<(PtxType, u32), LiveRange>,
allocated: Vec<VirtualReg>,
current_instruction: usize,
spill_count: usize,
}
impl RegisterAllocator {
#[must_use]
pub fn new() -> Self {
Self {
type_counters: HashMap::new(),
live_ranges: HashMap::new(),
allocated: Vec::new(),
current_instruction: 0,
spill_count: 0,
}
}
pub fn allocate_virtual(&mut self, ty: PtxType) -> VirtualReg {
let id = *self.type_counters.get(&ty).unwrap_or(&0);
self.type_counters.insert(ty, id + 1);
let vreg = VirtualReg::new(id, ty);
self.allocated.push(vreg);
self.live_ranges.insert(
(ty, id),
LiveRange::new(self.current_instruction, self.current_instruction + 1),
);
vreg
}
pub fn extend_live_range(&mut self, vreg: VirtualReg) {
if let Some(range) = self.live_ranges.get_mut(&(vreg.ty(), vreg.id())) {
range.end = self.current_instruction + 1;
}
}
pub fn next_instruction(&mut self) {
self.current_instruction += 1;
}
#[must_use]
pub fn pressure_report(&self) -> RegisterPressure {
let max_live = self.allocated.len();
RegisterPressure {
max_live,
spill_count: self.spill_count,
utilization: max_live as f64 / 256.0, }
}
#[must_use]
pub fn emit_declarations(&self) -> String {
let mut decls = String::new();
let mut by_type: HashMap<PtxType, Vec<&VirtualReg>> = HashMap::new();
for vreg in &self.allocated {
by_type.entry(vreg.ty()).or_default().push(vreg);
}
for (ty, regs) in by_type {
if !regs.is_empty() {
let count = regs.len();
let decl_type = ty.register_declaration_type();
decls.push_str(&format!(
" .reg {} {}<{}>;\n",
decl_type,
ty.register_prefix(),
count
));
}
}
decls
}
}
impl Default for RegisterAllocator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests;