use std::fmt;
use crate::ir::{Instruction, MemorySpace, Operand, PtxType};
const NUM_BANKS: u32 = 32;
const BANK_WIDTH_BYTES: u32 = 4;
#[derive(Debug, Clone)]
pub struct SharedMemAccess {
pub instruction_index: usize,
pub base_reg: String,
pub offset: AccessOffset,
pub access_width: u32,
pub is_store: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AccessOffset {
Constant(i64),
Strided {
base: i64,
stride: i64,
},
Unknown,
}
#[derive(Debug, Clone)]
pub struct BankConflictReport {
pub accesses: Vec<SharedMemAccess>,
pub conflicts: Vec<BankConflict>,
pub conflict_free_count: usize,
pub total_shared_accesses: usize,
}
#[derive(Debug, Clone)]
pub struct BankConflict {
pub instruction_index: usize,
pub conflict_type: ConflictType,
pub affected_banks: Vec<u32>,
pub degree: u32,
pub suggestion: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConflictType {
NWay(u32),
Broadcast,
StridedConflict {
stride: i64,
},
}
impl fmt::Display for BankConflictReport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Bank Conflict Report")?;
writeln!(
f,
" Total shared memory accesses: {}",
self.total_shared_accesses
)?;
writeln!(f, " Conflict-free accesses: {}", self.conflict_free_count)?;
writeln!(f, " Conflicts detected: {}", self.conflicts.len())?;
for (i, conflict) in self.conflicts.iter().enumerate() {
writeln!(
f,
" [{i}] instruction {}: {}-way, suggestion: {}",
conflict.instruction_index, conflict.degree, conflict.suggestion
)?;
}
Ok(())
}
}
impl fmt::Display for ConflictType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NWay(n) => write!(f, "{n}-way bank conflict"),
Self::Broadcast => write!(f, "broadcast (no conflict)"),
Self::StridedConflict { stride } => {
write!(f, "strided conflict (stride={stride} bytes)")
}
}
}
}
#[must_use]
pub fn byte_offset_to_bank(byte_offset: i64) -> u32 {
let word_index = byte_offset.div_euclid(i64::from(BANK_WIDTH_BYTES));
#[allow(clippy::cast_possible_truncation)]
{
word_index.rem_euclid(i64::from(NUM_BANKS)) as u32
}
}
#[must_use]
pub fn stride_conflict_degree(stride_bytes: i64, warp_size: u32) -> u32 {
if stride_bytes == 0 {
return 1;
}
let stride_words = stride_bytes.div_euclid(i64::from(BANK_WIDTH_BYTES));
let stride_abs = stride_words.unsigned_abs();
let effective = stride_abs % u64::from(NUM_BANKS);
if effective == 0 {
return warp_size;
}
let g = gcd_u64(effective, u64::from(NUM_BANKS));
let distinct_banks = u64::from(NUM_BANKS) / g;
let degree = u64::from(warp_size) / distinct_banks;
u32::try_from(degree.max(1)).unwrap_or(warp_size)
}
#[must_use]
pub const fn suggest_padding(current_row_bytes: u32) -> Option<u32> {
if current_row_bytes == 0 {
return None;
}
let bank_line = NUM_BANKS * BANK_WIDTH_BYTES; if current_row_bytes % bank_line == 0 {
Some(BANK_WIDTH_BYTES)
} else {
None
}
}
pub fn analyze_bank_conflicts(instructions: &[Instruction], warp_size: u32) -> BankConflictReport {
let mut accesses = Vec::new();
let mut conflicts = Vec::new();
let mut conflict_free_count: usize = 0;
for (idx, inst) in instructions.iter().enumerate() {
if let Some(access) = extract_shared_access(inst, idx) {
accesses.push(access);
}
}
let total_shared_accesses = accesses.len();
for access in &accesses {
match &access.offset {
AccessOffset::Strided { stride, .. } => {
let degree = stride_conflict_degree(*stride, warp_size);
if degree <= 1 {
conflict_free_count += 1;
} else {
let suggestion = strided_suggestion(*stride, degree);
let affected = compute_affected_banks_strided(*stride, warp_size);
conflicts.push(BankConflict {
instruction_index: access.instruction_index,
conflict_type: ConflictType::StridedConflict { stride: *stride },
affected_banks: affected,
degree,
suggestion,
});
}
}
AccessOffset::Constant(offset) => {
let bank = byte_offset_to_bank(*offset);
conflicts.push(BankConflict {
instruction_index: access.instruction_index,
conflict_type: ConflictType::Broadcast,
affected_banks: vec![bank],
degree: 1,
suggestion: "Broadcast access -- no conflict.".to_string(),
});
conflict_free_count += 1;
}
AccessOffset::Unknown => {
conflict_free_count += 1;
}
}
}
BankConflictReport {
accesses,
conflicts,
conflict_free_count,
total_shared_accesses,
}
}
fn extract_shared_access(inst: &Instruction, idx: usize) -> Option<SharedMemAccess> {
match inst {
Instruction::Load {
space: MemorySpace::Shared,
ty,
addr,
..
} => Some(build_access(idx, *ty, addr, false)),
Instruction::Store {
space: MemorySpace::Shared,
ty,
addr,
..
} => Some(build_access(idx, *ty, addr, true)),
_ => None,
}
}
fn build_access(
instruction_index: usize,
ty: PtxType,
addr: &Operand,
is_store: bool,
) -> SharedMemAccess {
let (base_reg, offset) = classify_address(addr);
#[allow(clippy::cast_possible_truncation)]
let width = ty.size_bytes() as u32; SharedMemAccess {
instruction_index,
base_reg,
offset,
access_width: width,
is_store,
}
}
fn classify_address(addr: &Operand) -> (String, AccessOffset) {
match addr {
Operand::Address { base, offset } => {
let base_name = base.name.clone();
let off = offset.unwrap_or(0);
(base_name, AccessOffset::Constant(off))
}
Operand::Register(reg) => (reg.name.clone(), AccessOffset::Constant(0)),
Operand::Symbol(sym) => (sym.clone(), AccessOffset::Unknown),
Operand::Immediate(_) => (String::new(), AccessOffset::Unknown),
}
}
fn compute_affected_banks_strided(stride_bytes: i64, warp_size: u32) -> Vec<u32> {
let mut banks = Vec::new();
for tid in 0..warp_size {
let byte_off = i64::from(tid) * stride_bytes;
let bank = byte_offset_to_bank(byte_off);
if !banks.contains(&bank) {
banks.push(bank);
}
}
banks.sort_unstable();
banks
}
fn strided_suggestion(stride_bytes: i64, degree: u32) -> String {
let stride_abs = stride_bytes.unsigned_abs();
if stride_abs.is_power_of_two() {
format!(
"Add {BANK_WIDTH_BYTES}-byte padding per row to break bank alignment \
({degree}-way conflict, stride={stride_bytes}B)"
)
} else {
format!(
"Consider using different offsets per thread to reduce \
{degree}-way conflict (stride={stride_bytes}B)"
)
}
}
const fn gcd_u64(mut a: u64, mut b: u64) -> u64 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{CacheQualifier, MemorySpace, Operand, PtxType, Register, VectorWidth};
fn shared_load(offset: i64, ty: PtxType) -> Instruction {
Instruction::Load {
space: MemorySpace::Shared,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty,
dst: Register {
name: "%r0".into(),
ty: PtxType::U32,
},
addr: Operand::Address {
base: Register {
name: "%rd_smem".into(),
ty: PtxType::U64,
},
offset: Some(offset),
},
}
}
fn shared_store(offset: i64, ty: PtxType) -> Instruction {
Instruction::Store {
space: MemorySpace::Shared,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty,
addr: Operand::Address {
base: Register {
name: "%rd_smem".into(),
ty: PtxType::U64,
},
offset: Some(offset),
},
src: Register {
name: "%r0".into(),
ty: PtxType::U32,
},
}
}
fn global_load() -> Instruction {
Instruction::Load {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
dst: Register {
name: "%f0".into(),
ty: PtxType::F32,
},
addr: Operand::Address {
base: Register {
name: "%rd0".into(),
ty: PtxType::U64,
},
offset: None,
},
}
}
#[test]
fn byte_offset_to_bank_basic() {
assert_eq!(byte_offset_to_bank(0), 0);
assert_eq!(byte_offset_to_bank(4), 1);
assert_eq!(byte_offset_to_bank(8), 2);
assert_eq!(byte_offset_to_bank(124), 31);
assert_eq!(byte_offset_to_bank(128), 0); assert_eq!(byte_offset_to_bank(132), 1);
}
#[test]
fn byte_offset_to_bank_negative() {
assert_eq!(byte_offset_to_bank(-4), 31);
assert_eq!(byte_offset_to_bank(-128), 0);
}
#[test]
fn stride_4_bytes_no_conflict() {
assert_eq!(stride_conflict_degree(4, 32), 1);
}
#[test]
fn stride_128_bytes_full_conflict() {
assert_eq!(stride_conflict_degree(128, 32), 32);
}
#[test]
fn stride_8_bytes_2way_conflict() {
assert_eq!(stride_conflict_degree(8, 32), 2);
}
#[test]
fn stride_32_bytes_8way_conflict() {
assert_eq!(stride_conflict_degree(32, 32), 8);
}
#[test]
fn stride_256_bytes_full_conflict() {
assert_eq!(stride_conflict_degree(256, 32), 32);
}
#[test]
fn stride_zero_broadcast() {
assert_eq!(stride_conflict_degree(0, 32), 1);
}
#[test]
fn suggest_padding_multiple_of_128() {
assert_eq!(suggest_padding(128), Some(4));
assert_eq!(suggest_padding(256), Some(4));
assert_eq!(suggest_padding(512), Some(4));
}
#[test]
fn suggest_padding_not_needed() {
assert_eq!(suggest_padding(64), None);
assert_eq!(suggest_padding(132), None);
assert_eq!(suggest_padding(0), None);
}
#[test]
fn empty_instructions() {
let report = analyze_bank_conflicts(&[], 32);
assert_eq!(report.total_shared_accesses, 0);
assert_eq!(report.conflict_free_count, 0);
assert!(report.conflicts.is_empty());
assert!(report.accesses.is_empty());
}
#[test]
fn no_shared_memory_accesses() {
let instructions = vec![global_load()];
let report = analyze_bank_conflicts(&instructions, 32);
assert_eq!(report.total_shared_accesses, 0);
assert!(report.conflicts.is_empty());
}
#[test]
fn constant_offset_broadcast_detection() {
let instructions = vec![shared_load(64, PtxType::F32)];
let report = analyze_bank_conflicts(&instructions, 32);
assert_eq!(report.total_shared_accesses, 1);
assert_eq!(report.conflict_free_count, 1);
assert_eq!(report.conflicts.len(), 1);
assert_eq!(report.conflicts[0].conflict_type, ConflictType::Broadcast);
assert_eq!(report.conflicts[0].degree, 1);
}
#[test]
fn store_vs_load_distinction() {
let instructions = vec![shared_load(0, PtxType::F32), shared_store(4, PtxType::F32)];
let report = analyze_bank_conflicts(&instructions, 32);
assert_eq!(report.total_shared_accesses, 2);
assert!(!report.accesses[0].is_store);
assert!(report.accesses[1].is_store);
}
#[test]
fn unknown_offset_symbol() {
let inst = Instruction::Load {
space: MemorySpace::Shared,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
dst: Register {
name: "%r0".into(),
ty: PtxType::U32,
},
addr: Operand::Symbol("unknown_addr".into()),
};
let report = analyze_bank_conflicts(&[inst], 32);
assert_eq!(report.total_shared_accesses, 1);
assert_eq!(report.accesses[0].offset, AccessOffset::Unknown);
assert_eq!(report.conflict_free_count, 1);
}
#[test]
fn mixed_accesses() {
let instructions = vec![
shared_load(0, PtxType::F32), global_load(), shared_store(64, PtxType::F32), ];
let report = analyze_bank_conflicts(&instructions, 32);
assert_eq!(report.total_shared_accesses, 2);
assert_eq!(report.conflict_free_count, 2);
}
#[test]
fn report_display() {
let report = BankConflictReport {
accesses: Vec::new(),
conflicts: vec![BankConflict {
instruction_index: 3,
conflict_type: ConflictType::NWay(8),
affected_banks: vec![0, 8, 16, 24],
degree: 8,
suggestion: "Add padding".to_string(),
}],
conflict_free_count: 5,
total_shared_accesses: 6,
};
let display = format!("{report}");
assert!(display.contains("Total shared memory accesses: 6"));
assert!(display.contains("Conflict-free accesses: 5"));
assert!(display.contains("Conflicts detected: 1"));
assert!(display.contains("8-way"));
}
#[test]
fn conflict_type_display() {
assert_eq!(format!("{}", ConflictType::NWay(4)), "4-way bank conflict");
assert_eq!(
format!("{}", ConflictType::Broadcast),
"broadcast (no conflict)"
);
assert_eq!(
format!("{}", ConflictType::StridedConflict { stride: 8 }),
"strided conflict (stride=8 bytes)"
);
}
#[test]
fn access_width_from_type() {
let inst_u8 = shared_load(0, PtxType::U8);
let inst_f64 = shared_load(0, PtxType::F64);
let inst_b128 = shared_load(0, PtxType::B128);
let report = analyze_bank_conflicts(&[inst_u8, inst_f64, inst_b128], 32);
assert_eq!(report.accesses[0].access_width, 1);
assert_eq!(report.accesses[1].access_width, 8);
assert_eq!(report.accesses[2].access_width, 16);
}
}