use std::collections::HashMap;
use crate::ir::{Instruction, Operand, PtxType, Register, WmmaOp};
#[derive(Debug, Clone)]
pub struct RegisterPressureReport {
pub peak_by_type: HashMap<PtxType, usize>,
pub total_peak: usize,
pub live_at_instruction: Vec<usize>,
pub spill_risk: bool,
pub estimated_max_warps_per_sm: Option<u32>,
}
const SPILL_THRESHOLD: usize = 255;
const SM80_REGS_PER_SM: u32 = 65536;
const THREADS_PER_WARP: u32 = 32;
pub fn analyze_register_pressure(instructions: &[Instruction]) -> RegisterPressureReport {
if instructions.is_empty() {
return RegisterPressureReport {
peak_by_type: HashMap::new(),
total_peak: 0,
live_at_instruction: Vec::new(),
spill_risk: false,
estimated_max_warps_per_sm: None,
};
}
let mut first_def: HashMap<String, usize> = HashMap::new();
let mut last_use: HashMap<String, usize> = HashMap::new();
let mut reg_types: HashMap<String, PtxType> = HashMap::new();
for (idx, inst) in instructions.iter().enumerate() {
for reg in defs(inst) {
first_def.entry(reg.name.clone()).or_insert(idx);
reg_types.entry(reg.name.clone()).or_insert(reg.ty);
}
for reg in uses(inst) {
last_use.insert(reg.name.clone(), idx);
reg_types.entry(reg.name.clone()).or_insert(reg.ty);
}
}
for (name, def_idx) in &first_def {
last_use.entry(name.clone()).or_insert(*def_idx);
}
let num_instructions = instructions.len();
let mut live_at_instruction = Vec::with_capacity(num_instructions);
let mut peak_by_type: HashMap<PtxType, usize> = HashMap::new();
let mut total_peak: usize = 0;
for i in 0..num_instructions {
let mut live_count: usize = 0;
let mut type_counts: HashMap<PtxType, usize> = HashMap::new();
for (name, def_idx) in &first_def {
let use_idx = last_use.get(name).copied().unwrap_or(*def_idx);
if *def_idx <= i && i <= use_idx {
live_count += 1;
if let Some(ty) = reg_types.get(name) {
*type_counts.entry(*ty).or_insert(0) += 1;
}
}
}
live_at_instruction.push(live_count);
if live_count > total_peak {
total_peak = live_count;
}
for (ty, count) in &type_counts {
let current = peak_by_type.entry(*ty).or_insert(0);
if *count > *current {
*current = *count;
}
}
}
let spill_risk = total_peak > SPILL_THRESHOLD;
let estimated_max_warps_per_sm = if total_peak == 0 {
None
} else {
let peak_u32 = u32::try_from(total_peak).unwrap_or(u32::MAX);
let regs_per_warp = THREADS_PER_WARP.saturating_mul(peak_u32);
SM80_REGS_PER_SM.checked_div(regs_per_warp)
};
RegisterPressureReport {
peak_by_type,
total_peak,
live_at_instruction,
spill_risk,
estimated_max_warps_per_sm,
}
}
fn defs(inst: &Instruction) -> Vec<&Register> {
match inst {
Instruction::Add { dst, .. }
| Instruction::Sub { dst, .. }
| Instruction::Mul { dst, .. }
| Instruction::Mad { dst, .. }
| Instruction::MadLo { dst, .. }
| Instruction::MadHi { dst, .. }
| Instruction::MadWide { dst, .. }
| Instruction::Fma { dst, .. }
| Instruction::Neg { dst, .. }
| Instruction::Abs { dst, .. }
| Instruction::Min { dst, .. }
| Instruction::Max { dst, .. }
| Instruction::Brev { dst, .. }
| Instruction::Clz { dst, .. }
| Instruction::Popc { dst, .. }
| Instruction::Bfind { dst, .. }
| Instruction::Bfe { dst, .. }
| Instruction::Bfi { dst, .. }
| Instruction::Rcp { dst, .. }
| Instruction::Rsqrt { dst, .. }
| Instruction::Sqrt { dst, .. }
| Instruction::Ex2 { dst, .. }
| Instruction::Lg2 { dst, .. }
| Instruction::Sin { dst, .. }
| Instruction::Cos { dst, .. }
| Instruction::Shl { dst, .. }
| Instruction::Shr { dst, .. }
| Instruction::Div { dst, .. }
| Instruction::Rem { dst, .. }
| Instruction::And { dst, .. }
| Instruction::Or { dst, .. }
| Instruction::Xor { dst, .. }
| Instruction::SetP { dst, .. }
| Instruction::Load { dst, .. }
| Instruction::Cvt { dst, .. }
| Instruction::MovSpecial { dst, .. }
| Instruction::LoadParam { dst, .. }
| Instruction::Atom { dst, .. }
| Instruction::AtomCas { dst, .. }
| Instruction::Dp4a { dst, .. }
| Instruction::Dp2a { dst, .. }
| Instruction::Tex1d { dst, .. }
| Instruction::Tex2d { dst, .. }
| Instruction::Tex3d { dst, .. }
| Instruction::SurfLoad { dst, .. }
| Instruction::Redux { dst, .. }
| Instruction::ElectSync { dst, .. } => vec![dst],
Instruction::Ldmatrix { dst_regs, .. } => dst_regs.iter().collect(),
Instruction::Store { .. }
| Instruction::CpAsync { .. }
| Instruction::CpAsyncCommit
| Instruction::CpAsyncWait { .. }
| Instruction::Branch { .. }
| Instruction::Label(_)
| Instruction::Return
| Instruction::BarSync { .. }
| Instruction::BarArrive { .. }
| Instruction::FenceAcqRel { .. }
| Instruction::TmaLoad { .. }
| Instruction::Red { .. }
| Instruction::SurfStore { .. }
| Instruction::Comment(_)
| Instruction::Raw(_)
| Instruction::Pragma(_)
| Instruction::Stmatrix { .. }
| Instruction::Setmaxnreg { .. }
| Instruction::Griddepcontrol { .. }
| Instruction::FenceProxy { .. }
| Instruction::MbarrierInit { .. }
| Instruction::MbarrierArrive { .. }
| Instruction::MbarrierWait { .. }
| Instruction::Tcgen05Mma { .. }
| Instruction::BarrierCluster
| Instruction::FenceCluster
| Instruction::CpAsyncBulk { .. } => vec![],
Instruction::Wmma { op, fragments, .. } => match op {
WmmaOp::LoadA | WmmaOp::LoadB | WmmaOp::Mma => fragments.iter().collect(),
WmmaOp::StoreD => vec![],
},
Instruction::Mma { d_regs, .. } | Instruction::Wgmma { d_regs, .. } => {
d_regs.iter().collect()
}
}
}
#[allow(clippy::too_many_lines)]
fn uses(inst: &Instruction) -> Vec<&Register> {
match inst {
Instruction::Add { a, b, .. }
| Instruction::Sub { a, b, .. }
| Instruction::Mul { a, b, .. }
| Instruction::Min { a, b, .. }
| Instruction::Max { a, b, .. }
| Instruction::Div { a, b, .. }
| Instruction::Rem { a, b, .. }
| Instruction::And { a, b, .. }
| Instruction::Or { a, b, .. }
| Instruction::Xor { a, b, .. }
| Instruction::SetP { a, b, .. }
| Instruction::Shl {
src: a, amount: b, ..
}
| Instruction::Shr {
src: a, amount: b, ..
} => {
let mut regs = operand_regs(a);
regs.extend(operand_regs(b));
regs
}
Instruction::Mad { a, b, c, .. }
| Instruction::MadLo { a, b, c, .. }
| Instruction::MadHi { a, b, c, .. }
| Instruction::MadWide { a, b, c, .. }
| Instruction::Fma { a, b, c, .. }
| Instruction::Dp4a { a, b, c, .. }
| Instruction::Dp2a { a, b, c, .. } => {
let mut regs = operand_regs(a);
regs.extend(operand_regs(b));
regs.extend(operand_regs(c));
regs
}
Instruction::Neg { src, .. }
| Instruction::Abs { src, .. }
| Instruction::Brev { src, .. }
| Instruction::Clz { src, .. }
| Instruction::Popc { src, .. }
| Instruction::Bfind { src, .. }
| Instruction::Rcp { src, .. }
| Instruction::Rsqrt { src, .. }
| Instruction::Sqrt { src, .. }
| Instruction::Ex2 { src, .. }
| Instruction::Lg2 { src, .. }
| Instruction::Sin { src, .. }
| Instruction::Cos { src, .. }
| Instruction::Cvt { src, .. }
| Instruction::Redux { src, .. } => operand_regs(src),
Instruction::Bfe {
src, start, len, ..
} => {
let mut regs = operand_regs(src);
regs.extend(operand_regs(start));
regs.extend(operand_regs(len));
regs
}
Instruction::Bfi {
insert,
base,
start,
len,
..
} => {
let mut regs = operand_regs(insert);
regs.extend(operand_regs(base));
regs.extend(operand_regs(start));
regs.extend(operand_regs(len));
regs
}
Instruction::Load { addr, .. } | Instruction::MbarrierArrive { addr, .. } => {
operand_regs(addr)
}
Instruction::Store { addr, src, .. } => {
let mut regs = operand_regs(addr);
regs.push(src);
regs
}
Instruction::CpAsync {
dst_shared,
src_global,
..
} => {
let mut regs = operand_regs(dst_shared);
regs.extend(operand_regs(src_global));
regs
}
Instruction::CpAsyncCommit
| Instruction::CpAsyncWait { .. }
| Instruction::Label(_)
| Instruction::Return
| Instruction::BarSync { .. }
| Instruction::BarArrive { .. }
| Instruction::FenceAcqRel { .. }
| Instruction::MovSpecial { .. }
| Instruction::LoadParam { .. }
| Instruction::ElectSync { .. }
| Instruction::Setmaxnreg { .. }
| Instruction::Griddepcontrol { .. }
| Instruction::FenceProxy { .. }
| Instruction::BarrierCluster
| Instruction::FenceCluster
| Instruction::Comment(_)
| Instruction::Raw(_)
| Instruction::Pragma(_) => vec![],
Instruction::Branch { predicate, .. } => {
if let Some((reg, _negated)) = predicate {
vec![reg]
} else {
vec![]
}
}
Instruction::Wmma {
op,
fragments,
addr,
stride,
..
} => {
let mut regs: Vec<&Register> = Vec::new();
match op {
WmmaOp::LoadA | WmmaOp::LoadB => {
if let Some(a) = addr {
regs.extend(operand_regs(a));
}
if let Some(s) = stride {
regs.extend(operand_regs(s));
}
}
WmmaOp::StoreD => {
regs.extend(fragments.iter());
if let Some(a) = addr {
regs.extend(operand_regs(a));
}
if let Some(s) = stride {
regs.extend(operand_regs(s));
}
}
WmmaOp::Mma => {
regs.extend(fragments.iter());
}
}
regs
}
Instruction::Mma {
a_regs,
b_regs,
c_regs,
..
} => {
let mut regs: Vec<&Register> = Vec::new();
regs.extend(a_regs.iter());
regs.extend(b_regs.iter());
regs.extend(c_regs.iter());
regs
}
Instruction::Wgmma { desc_a, desc_b, .. } => vec![desc_a, desc_b],
Instruction::TmaLoad {
dst_shared,
desc,
coords,
barrier,
..
} => {
let mut regs = operand_regs(dst_shared);
regs.push(desc);
regs.extend(coords.iter());
regs.push(barrier);
regs
}
Instruction::Atom { addr, src, .. } | Instruction::Red { addr, src, .. } => {
let mut regs = operand_regs(addr);
regs.extend(operand_regs(src));
regs
}
Instruction::AtomCas {
addr,
compare,
value,
..
} => {
let mut regs = operand_regs(addr);
regs.extend(operand_regs(compare));
regs.extend(operand_regs(value));
regs
}
Instruction::Tex1d { coord, .. } | Instruction::SurfLoad { coord, .. } => {
operand_regs(coord)
}
Instruction::Tex2d {
coord_x, coord_y, ..
} => {
let mut regs = operand_regs(coord_x);
regs.extend(operand_regs(coord_y));
regs
}
Instruction::Tex3d {
coord_x,
coord_y,
coord_z,
..
} => {
let mut regs = operand_regs(coord_x);
regs.extend(operand_regs(coord_y));
regs.extend(operand_regs(coord_z));
regs
}
Instruction::SurfStore { coord, src, .. } => {
let mut regs = operand_regs(coord);
regs.push(src);
regs
}
Instruction::Stmatrix { dst_addr, src, .. } => {
let mut regs = operand_regs(dst_addr);
regs.push(src);
regs
}
Instruction::MbarrierInit { addr, count, .. } => {
let mut regs = operand_regs(addr);
regs.extend(operand_regs(count));
regs
}
Instruction::MbarrierWait { addr, phase, .. } => {
let mut regs = operand_regs(addr);
regs.extend(operand_regs(phase));
regs
}
Instruction::Tcgen05Mma { a_desc, b_desc } => vec![a_desc, b_desc],
Instruction::CpAsyncBulk {
dst_smem,
src_gmem,
desc,
} => vec![dst_smem, src_gmem, desc],
Instruction::Ldmatrix { src_addr, .. } => operand_regs(src_addr),
}
}
fn operand_regs(op: &Operand) -> Vec<&Register> {
match op {
Operand::Register(reg) => vec![reg],
Operand::Address { base, .. } => vec![base],
Operand::Immediate(_) | Operand::Symbol(_) => vec![],
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{
CacheQualifier, CmpOp, ImmValue, Instruction, MemorySpace, MulMode, Operand, PtxType,
Register, SpecialReg, VectorWidth,
};
fn reg(name: &str, ty: PtxType) -> Register {
Register {
name: name.to_string(),
ty,
}
}
fn reg_op(name: &str, ty: PtxType) -> Operand {
Operand::Register(reg(name, ty))
}
fn imm_u32(val: u32) -> Operand {
Operand::Immediate(ImmValue::U32(val))
}
#[test]
fn test_empty_instructions() {
let report = analyze_register_pressure(&[]);
assert_eq!(report.total_peak, 0);
assert!(report.live_at_instruction.is_empty());
assert!(!report.spill_risk);
assert!(report.estimated_max_warps_per_sm.is_none());
}
#[test]
fn test_single_add() {
let instructions = vec![Instruction::Add {
ty: PtxType::F32,
dst: reg("%f0", PtxType::F32),
a: reg_op("%f1", PtxType::F32),
b: reg_op("%f2", PtxType::F32),
}];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.live_at_instruction.len(), 1);
assert_eq!(report.total_peak, 1);
}
#[test]
fn test_sequence_peak_pressure() {
let instructions = vec![
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Cvt {
rnd: None,
dst_ty: PtxType::F32,
src_ty: PtxType::U32,
dst: reg("%f0", PtxType::F32),
src: reg_op("%r0", PtxType::U32),
},
Instruction::Add {
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
a: reg_op("%f0", PtxType::F32),
b: imm_u32(1),
},
Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: None,
},
src: reg("%f1", PtxType::F32),
},
];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.live_at_instruction, vec![1, 2, 2, 1]);
assert_eq!(report.total_peak, 2);
}
#[test]
fn test_register_reuse_reduces_peak() {
let instructions = vec![
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Add {
ty: PtxType::U32,
dst: reg("%r1", PtxType::U32),
a: reg_op("%r0", PtxType::U32),
b: imm_u32(1),
},
Instruction::Add {
ty: PtxType::U32,
dst: reg("%r2", PtxType::U32),
a: reg_op("%r1", PtxType::U32),
b: imm_u32(2),
},
];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.total_peak, 2);
assert_eq!(report.live_at_instruction, vec![1, 2, 2]);
}
#[test]
fn test_non_overlapping_lifetimes() {
let instructions = vec![
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: None,
},
src: reg("%r0", PtxType::U32),
},
Instruction::MovSpecial {
dst: reg("%r1", PtxType::U32),
special: SpecialReg::TidY,
},
Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd1", PtxType::U64),
offset: None,
},
src: reg("%r1", PtxType::U32),
},
];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.total_peak, 1);
}
#[test]
fn test_spill_risk_detection() {
let mut instructions = Vec::new();
for i in 0..256 {
instructions.push(Instruction::MovSpecial {
dst: reg(&format!("%r{i}"), PtxType::U32),
special: SpecialReg::TidX,
});
}
for i in 0..256_i32 {
instructions.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: Some(i64::from(i) * 4),
},
src: reg(&format!("%r{i}"), PtxType::U32),
});
}
let report = analyze_register_pressure(&instructions);
assert!(report.spill_risk);
assert!(report.total_peak > SPILL_THRESHOLD);
}
#[test]
fn test_no_spill_risk() {
let instructions = vec![
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: None,
},
src: reg("%r0", PtxType::U32),
},
];
let report = analyze_register_pressure(&instructions);
assert!(!report.spill_risk);
}
#[test]
fn test_peak_by_type() {
let instructions = vec![
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Cvt {
rnd: None,
dst_ty: PtxType::F32,
src_ty: PtxType::U32,
dst: reg("%f0", PtxType::F32),
src: reg_op("%r0", PtxType::U32),
},
Instruction::Add {
ty: PtxType::F32,
dst: reg("%f1", PtxType::F32),
a: reg_op("%f0", PtxType::F32),
b: imm_u32(0),
},
];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.peak_by_type.get(&PtxType::U32), Some(&1));
assert_eq!(report.peak_by_type.get(&PtxType::F32), Some(&2));
}
#[test]
fn test_occupancy_estimation() {
let mut instructions = Vec::new();
for i in 0..32 {
instructions.push(Instruction::MovSpecial {
dst: reg(&format!("%r{i}"), PtxType::U32),
special: SpecialReg::TidX,
});
}
for i in 0..32_i32 {
instructions.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: Some(i64::from(i) * 4),
},
src: reg(&format!("%r{i}"), PtxType::U32),
});
}
let report = analyze_register_pressure(&instructions);
assert_eq!(report.total_peak, 32);
assert_eq!(report.estimated_max_warps_per_sm, Some(64));
}
#[test]
fn test_occupancy_high_register_usage() {
let mut instructions = Vec::new();
for i in 0..128 {
instructions.push(Instruction::MovSpecial {
dst: reg(&format!("%r{i}"), PtxType::U32),
special: SpecialReg::TidX,
});
}
for i in 0..128_i32 {
instructions.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: Some(i64::from(i) * 4),
},
src: reg(&format!("%r{i}"), PtxType::U32),
});
}
let report = analyze_register_pressure(&instructions);
assert_eq!(report.total_peak, 128);
assert_eq!(report.estimated_max_warps_per_sm, Some(16));
}
#[test]
fn test_mad_three_operands() {
let instructions = vec![
Instruction::Mad {
ty: PtxType::S32,
mode: MulMode::Lo,
dst: reg("%r0", PtxType::S32),
a: reg_op("%r1", PtxType::S32),
b: reg_op("%r2", PtxType::S32),
c: reg_op("%r3", PtxType::S32),
},
Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::S32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: None,
},
src: reg("%r0", PtxType::S32),
},
];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.total_peak, 1);
}
#[test]
fn test_branch_predicate_use() {
let instructions = vec![
Instruction::SetP {
cmp: CmpOp::Lt,
ty: PtxType::U32,
dst: reg("%p0", PtxType::Pred),
a: reg_op("%r0", PtxType::U32),
b: imm_u32(10),
},
Instruction::Branch {
target: "label1".to_string(),
predicate: Some((reg("%p0", PtxType::Pred), false)),
},
];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.total_peak, 1);
}
#[test]
fn test_mma_register_pressure() {
use crate::ir::MmaShape;
let instructions = vec![Instruction::Mma {
shape: MmaShape::M16N8K16,
a_ty: PtxType::F16,
b_ty: PtxType::F16,
c_ty: PtxType::F32,
d_ty: PtxType::F32,
d_regs: vec![
reg("%f0", PtxType::F32),
reg("%f1", PtxType::F32),
reg("%f2", PtxType::F32),
reg("%f3", PtxType::F32),
],
a_regs: vec![reg("%f10", PtxType::F16), reg("%f11", PtxType::F16)],
b_regs: vec![reg("%f20", PtxType::F16)],
c_regs: vec![
reg("%f30", PtxType::F32),
reg("%f31", PtxType::F32),
reg("%f32", PtxType::F32),
reg("%f33", PtxType::F32),
],
}];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.total_peak, 4);
}
#[test]
fn test_live_at_instruction_length() {
let instructions = vec![
Instruction::MovSpecial {
dst: reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
},
Instruction::Return,
Instruction::Label("exit".to_string()),
];
let report = analyze_register_pressure(&instructions);
assert_eq!(report.live_at_instruction.len(), 3);
}
#[test]
fn test_register_pressure_under_limit_no_warning() {
let count: usize = 10;
let mut instructions = Vec::new();
for i in 0..count {
instructions.push(Instruction::MovSpecial {
dst: reg(&format!("%r{i}"), PtxType::U32),
special: SpecialReg::TidX,
});
}
for i in 0..count {
let offset = i64::try_from(i).unwrap_or(i64::MAX) * 4;
instructions.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: Some(offset),
},
src: reg(&format!("%r{i}"), PtxType::U32),
});
}
let report = analyze_register_pressure(&instructions);
assert_eq!(
report.total_peak, count,
"peak should equal number of simultaneously-live registers"
);
assert!(
!report.spill_risk,
"10 registers must not trigger spill risk"
);
}
#[test]
fn test_register_pressure_at_limit_no_warning() {
let count: usize = SPILL_THRESHOLD; let mut instructions = Vec::new();
for i in 0..count {
instructions.push(Instruction::MovSpecial {
dst: reg(&format!("%r{i}"), PtxType::U32),
special: SpecialReg::TidX,
});
}
for i in 0..count {
let offset = i64::try_from(i).unwrap_or(i64::MAX) * 4;
instructions.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: Some(offset),
},
src: reg(&format!("%r{i}"), PtxType::U32),
});
}
let report = analyze_register_pressure(&instructions);
assert_eq!(
report.total_peak, 255,
"peak must be exactly 255 at boundary"
);
assert!(
!report.spill_risk,
"exactly 255 registers must NOT trigger spill risk (threshold is > 255)"
);
}
#[test]
fn test_register_pressure_over_limit_warns() {
let count: usize = SPILL_THRESHOLD + 1; let mut instructions = Vec::new();
for i in 0..count {
instructions.push(Instruction::MovSpecial {
dst: reg(&format!("%r{i}"), PtxType::U32),
special: SpecialReg::TidX,
});
}
for i in 0..count {
let offset = i64::try_from(i).unwrap_or(i64::MAX) * 4;
instructions.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: Some(offset),
},
src: reg(&format!("%r{i}"), PtxType::U32),
});
}
let report = analyze_register_pressure(&instructions);
assert!(
report.total_peak > SPILL_THRESHOLD,
"256 simultaneously-live registers must exceed spill threshold"
);
assert!(
report.spill_risk,
"256 simultaneously-live registers must trigger spill risk"
);
}
#[test]
fn test_register_count_matches_allocations() {
let count: usize = 5;
let mut instructions = Vec::new();
for i in 0..count {
instructions.push(Instruction::MovSpecial {
dst: reg(&format!("%r{i}"), PtxType::U32),
special: SpecialReg::TidX,
});
}
for i in 0..count {
let offset = i64::try_from(i).unwrap_or(i64::MAX) * 4;
instructions.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::U32,
addr: Operand::Address {
base: reg("%rd0", PtxType::U64),
offset: Some(offset),
},
src: reg(&format!("%r{i}"), PtxType::U32),
});
}
let report = analyze_register_pressure(&instructions);
assert_eq!(
report.total_peak, count,
"peak must match the number of allocated registers"
);
assert_eq!(
report.peak_by_type.get(&PtxType::U32).copied().unwrap_or(0),
count,
"peak U32 count must match the number of U32 allocations"
);
}
}