use crate::ir::{Register, RegisterAllocator};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FragmentA {
pub regs: [Register; 4],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FragmentB {
pub regs: [Register; 2],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FragmentC {
pub regs: [Register; 4],
}
pub fn alloc_a(alloc: &mut RegisterAllocator) -> FragmentA {
FragmentA {
regs: [
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
alloc.alloc_packed_half2(),
],
}
}
pub fn alloc_b(alloc: &mut RegisterAllocator) -> FragmentB {
FragmentB {
regs: [alloc.alloc_packed_half2(), alloc.alloc_packed_half2()],
}
}
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FragmentA_M16N8K32 {
pub regs: [Register; 4],
}
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FragmentB_M16N8K32 {
pub regs: [Register; 2],
}
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FragmentC_M16N8K32 {
pub regs: [Register; 4],
}
pub fn alloc_c(alloc: &mut RegisterAllocator) -> FragmentC {
use crate::types::PtxType;
FragmentC {
regs: [
alloc.alloc(PtxType::F32),
alloc.alloc(PtxType::F32),
alloc.alloc(PtxType::F32),
alloc.alloc(PtxType::F32),
],
}
}
#[allow(non_snake_case)]
pub fn alloc_a_M16N8K32(alloc: &mut RegisterAllocator) -> FragmentA_M16N8K32 {
FragmentA_M16N8K32 {
regs: [
alloc.alloc_packed_int8x4(),
alloc.alloc_packed_int8x4(),
alloc.alloc_packed_int8x4(),
alloc.alloc_packed_int8x4(),
],
}
}
#[allow(non_snake_case)]
pub fn alloc_b_M16N8K32(alloc: &mut RegisterAllocator) -> FragmentB_M16N8K32 {
FragmentB_M16N8K32 {
regs: [alloc.alloc_packed_int8x4(), alloc.alloc_packed_int8x4()],
}
}
#[allow(non_snake_case)]
pub fn alloc_c_M16N8K32(alloc: &mut RegisterAllocator) -> FragmentC_M16N8K32 {
use crate::types::PtxType;
FragmentC_M16N8K32 {
regs: [
alloc.alloc(PtxType::S32),
alloc.alloc(PtxType::S32),
alloc.alloc(PtxType::S32),
alloc.alloc(PtxType::S32),
],
}
}
use crate::instr::{ArithOp, MemoryOp};
use crate::ir::{Operand, PtxInstruction, PtxKernel};
use crate::types::PtxType;
const FRAGMENT_HALF_ROW_STRIDE_BYTES: i32 = 32;
fn compute_group_thread_ids(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
tid_x: crate::ir::Register,
) -> (crate::ir::Register, crate::ir::Register) {
let group_id = alloc.alloc(PtxType::U32);
let thread_id_in_group = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Div {
dst: group_id,
lhs: Operand::Reg(tid_x),
rhs: Operand::ImmU32(4),
ty: PtxType::U32,
}));
kernel.push(PtxInstruction::Arith(ArithOp::Rem {
dst: thread_id_in_group,
lhs: Operand::Reg(tid_x),
rhs: Operand::ImmU32(4),
ty: PtxType::U32,
}));
(group_id, thread_id_in_group)
}
fn u64_addr_from_u32_offset(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
base: crate::ir::Register,
offset_u32: crate::ir::Register,
extra_bytes: u32,
) -> crate::ir::Register {
let off64 = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Cvt {
dst: off64,
src: offset_u32,
dst_ty: PtxType::U64,
src_ty: PtxType::U32,
});
let addr = alloc.alloc(PtxType::U64);
if extra_bytes == 0 {
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: addr,
lhs: Operand::Reg(base),
rhs: Operand::Reg(off64),
ty: PtxType::U64,
}));
} else {
let tmp = alloc.alloc(PtxType::U64);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: tmp,
lhs: Operand::Reg(base),
rhs: Operand::Reg(off64),
ty: PtxType::U64,
}));
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: addr,
lhs: Operand::Reg(tmp),
rhs: Operand::ImmU32(extra_bytes),
ty: PtxType::U64,
}));
}
addr
}
pub fn load_fragment_a_m16n8k16_global_row(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
matrix_base_global: crate::ir::Register,
tid_x: crate::ir::Register,
) -> FragmentA {
let (group_id, tig) = compute_group_thread_ids(alloc, kernel, tid_x);
let row_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: row_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(FRAGMENT_HALF_ROW_STRIDE_BYTES as u32),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(row_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let base_off_plus_8rows = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: base_off_plus_8rows,
lhs: Operand::Reg(base_off),
rhs: Operand::ImmU32((8 * FRAGMENT_HALF_ROW_STRIDE_BYTES) as u32),
ty: PtxType::U32,
}));
let addr0 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 0);
let addr1 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 0);
let addr2 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 16);
let addr3 =
u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 16);
let frag = alloc_a(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1, addr2, addr3]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
pub fn load_fragment_b_m16n8k16_global_col(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
matrix_base_global: crate::ir::Register,
tid_x: crate::ir::Register,
) -> FragmentB {
let (group_id, tig) = compute_group_thread_ids(alloc, kernel, tid_x);
let col_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: col_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(FRAGMENT_HALF_ROW_STRIDE_BYTES as u32),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(col_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let addr0 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 0);
let addr1 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 16);
let frag = alloc_b(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
pub fn store_fragment_c_m16n8k16_global_row(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
matrix_base_global: crate::ir::Register,
tid_x: crate::ir::Register,
fragment: FragmentC,
row_stride_bytes: u32,
) {
let (group_id, tig) = compute_group_thread_ids(alloc, kernel, tid_x);
let row_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: row_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(row_stride_bytes),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(8),
c: Operand::Reg(row_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let base_off_plus_8rows = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: base_off_plus_8rows,
lhs: Operand::Reg(base_off),
rhs: Operand::ImmU32(8 * row_stride_bytes),
ty: PtxType::U32,
}));
let addr0 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 0);
let addr1 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 4);
let addr2 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 0);
let addr3 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 4);
for (reg, addr) in fragment.regs.iter().zip([addr0, addr1, addr2, addr3]) {
kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
addr,
src: *reg,
ty: PtxType::F32,
}));
}
}
#[allow(non_snake_case)]
pub fn load_fragment_a_m16n8k32_global_row(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
matrix_base_global: crate::ir::Register,
tid_x: crate::ir::Register,
) -> FragmentA_M16N8K32 {
let (group_id, tig) = compute_group_thread_ids(alloc, kernel, tid_x);
let row_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: row_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(FRAGMENT_HALF_ROW_STRIDE_BYTES as u32),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(row_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let base_off_plus_8rows = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: base_off_plus_8rows,
lhs: Operand::Reg(base_off),
rhs: Operand::ImmU32((8 * FRAGMENT_HALF_ROW_STRIDE_BYTES) as u32),
ty: PtxType::U32,
}));
let addr0 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 0);
let addr1 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 0);
let addr2 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 16);
let addr3 =
u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 16);
let frag = alloc_a_M16N8K32(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1, addr2, addr3]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
#[allow(non_snake_case)]
pub fn load_fragment_b_m16n8k32_global_col(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
matrix_base_global: crate::ir::Register,
tid_x: crate::ir::Register,
) -> FragmentB_M16N8K32 {
let (group_id, tig) = compute_group_thread_ids(alloc, kernel, tid_x);
let col_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: col_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(FRAGMENT_HALF_ROW_STRIDE_BYTES as u32),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(col_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let addr0 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 0);
let addr1 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 16);
let frag = alloc_b_M16N8K32(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdGlobal {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
#[allow(non_snake_case)]
pub fn store_fragment_c_m16n8k32_global_row(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
matrix_base_global: crate::ir::Register,
tid_x: crate::ir::Register,
fragment: FragmentC_M16N8K32,
row_stride_bytes: u32,
) {
let (group_id, tig) = compute_group_thread_ids(alloc, kernel, tid_x);
let row_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: row_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(row_stride_bytes),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(8),
c: Operand::Reg(row_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let base_off_plus_8rows = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: base_off_plus_8rows,
lhs: Operand::Reg(base_off),
rhs: Operand::ImmU32(8 * row_stride_bytes),
ty: PtxType::U32,
}));
let addr0 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 0);
let addr1 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off, 4);
let addr2 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 0);
let addr3 = u64_addr_from_u32_offset(alloc, kernel, matrix_base_global, base_off_plus_8rows, 4);
for (reg, addr) in fragment.regs.iter().zip([addr0, addr1, addr2, addr3]) {
kernel.push(PtxInstruction::Memory(MemoryOp::StGlobal {
addr,
src: *reg,
ty: PtxType::S32,
}));
}
}
fn u32_shared_addr_from_offset(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
tile_base_shared: crate::ir::Register,
offset_u32: crate::ir::Register,
extra_bytes: u32,
) -> crate::ir::Register {
let addr = alloc.alloc(PtxType::U32);
if extra_bytes == 0 {
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: addr,
lhs: Operand::Reg(tile_base_shared),
rhs: Operand::Reg(offset_u32),
ty: PtxType::U32,
}));
} else {
let tmp = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: tmp,
lhs: Operand::Reg(tile_base_shared),
rhs: Operand::Reg(offset_u32),
ty: PtxType::U32,
}));
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: addr,
lhs: Operand::Reg(tmp),
rhs: Operand::ImmU32(extra_bytes),
ty: PtxType::U32,
}));
}
addr
}
pub fn load_fragment_a_m16n8k16_shared_row(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
tile_base_shared: crate::ir::Register,
tid_x: crate::ir::Register,
row_stride_bytes: u32,
group_tig_override: Option<(crate::ir::Register, crate::ir::Register)>,
) -> FragmentA {
let (group_id, tig) = match group_tig_override {
Some(pair) => pair,
None => compute_group_thread_ids(alloc, kernel, tid_x),
};
let row_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: row_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(row_stride_bytes),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(row_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let base_off_plus_8rows = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: base_off_plus_8rows,
lhs: Operand::Reg(base_off),
rhs: Operand::ImmU32(8 * row_stride_bytes),
ty: PtxType::U32,
}));
let addr0 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 0);
let addr1 =
u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off_plus_8rows, 0);
let addr2 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 16);
let addr3 =
u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off_plus_8rows, 16);
let frag = alloc_a(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1, addr2, addr3]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdShared {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
pub fn load_fragment_b_m16n8k16_shared_col(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
tile_base_shared: crate::ir::Register,
tid_x: crate::ir::Register,
col_stride_bytes: u32,
group_tig_override: Option<(crate::ir::Register, crate::ir::Register)>,
) -> FragmentB {
let (group_id, tig) = match group_tig_override {
Some(pair) => pair,
None => compute_group_thread_ids(alloc, kernel, tid_x),
};
let col_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: col_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(col_stride_bytes),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(col_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let addr0 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 0);
let addr1 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 16);
let frag = alloc_b(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdShared {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
#[allow(non_snake_case)]
pub fn load_fragment_a_m16n8k32_shared_row(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
tile_base_shared: crate::ir::Register,
tid_x: crate::ir::Register,
row_stride_bytes: u32,
group_tig_override: Option<(crate::ir::Register, crate::ir::Register)>,
) -> FragmentA_M16N8K32 {
let (group_id, tig) = match group_tig_override {
Some(pair) => pair,
None => compute_group_thread_ids(alloc, kernel, tid_x),
};
let row_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: row_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(row_stride_bytes),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(row_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let base_off_plus_8rows = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Add {
dst: base_off_plus_8rows,
lhs: Operand::Reg(base_off),
rhs: Operand::ImmU32(8 * row_stride_bytes),
ty: PtxType::U32,
}));
let addr0 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 0);
let addr1 =
u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off_plus_8rows, 0);
let addr2 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 16);
let addr3 =
u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off_plus_8rows, 16);
let frag = alloc_a_M16N8K32(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1, addr2, addr3]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdShared {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
#[allow(non_snake_case)]
pub fn load_fragment_b_m16n8k32_shared_col(
alloc: &mut crate::ir::RegisterAllocator,
kernel: &mut PtxKernel,
tile_base_shared: crate::ir::Register,
tid_x: crate::ir::Register,
col_stride_bytes: u32,
group_tig_override: Option<(crate::ir::Register, crate::ir::Register)>,
) -> FragmentB_M16N8K32 {
let (group_id, tig) = match group_tig_override {
Some(pair) => pair,
None => compute_group_thread_ids(alloc, kernel, tid_x),
};
let col_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mul {
dst: col_off,
lhs: Operand::Reg(group_id),
rhs: Operand::ImmU32(col_stride_bytes),
ty: PtxType::U32,
}));
let base_off = alloc.alloc(PtxType::U32);
kernel.push(PtxInstruction::Arith(ArithOp::Mad {
dst: base_off,
a: Operand::Reg(tig),
b: Operand::ImmU32(4),
c: Operand::Reg(col_off),
ty: PtxType::U32,
mode: crate::instr::MadMode::Lo,
}));
let addr0 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 0);
let addr1 = u32_shared_addr_from_offset(alloc, kernel, tile_base_shared, base_off, 16);
let frag = alloc_b_M16N8K32(alloc);
for (reg, addr) in frag.regs.iter().zip([addr0, addr1]) {
kernel.push(PtxInstruction::Memory(MemoryOp::LdShared {
dst: *reg,
addr,
ty: PtxType::U32,
}));
}
frag
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{PtxType, RegKind};
#[test]
fn alloc_a_gives_four_b32_regs() {
let mut a = RegisterAllocator::new();
let frag = alloc_a(&mut a);
for r in &frag.regs {
assert_eq!(r.kind, RegKind::R);
assert_eq!(r.ptx_type, PtxType::U32);
}
assert_eq!(frag.regs[0].index, 0);
assert_eq!(frag.regs[1].index, 1);
assert_eq!(frag.regs[2].index, 2);
assert_eq!(frag.regs[3].index, 3);
}
#[test]
fn alloc_b_gives_two_b32_regs() {
let mut a = RegisterAllocator::new();
let frag = alloc_b(&mut a);
for r in &frag.regs {
assert_eq!(r.kind, RegKind::R);
assert_eq!(r.ptx_type, PtxType::U32);
}
}
#[test]
fn alloc_c_gives_four_f32_regs() {
let mut a = RegisterAllocator::new();
let frag = alloc_c(&mut a);
for r in &frag.regs {
assert_eq!(r.kind, RegKind::F);
assert_eq!(r.ptx_type, PtxType::F32);
}
}
#[test]
fn load_fragment_a_emits_four_b32_loads() {
use crate::ir::{PtxKernel, RegisterAllocator};
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("test");
let base = alloc.alloc(PtxType::U64);
let tid = alloc.alloc(PtxType::U32);
let frag = load_fragment_a_m16n8k16_global_row(&mut alloc, &mut kernel, base, tid);
assert_eq!(frag.regs.len(), 4);
let n_loads = kernel
.body
.iter()
.filter(|instr| {
matches!(
instr,
PtxInstruction::Memory(MemoryOp::LdGlobal {
ty: PtxType::U32,
..
})
)
})
.count();
assert_eq!(n_loads, 4, "expected 4 ld.global.b32 for FragmentA");
}
#[test]
fn load_fragment_b_emits_two_b32_loads() {
use crate::ir::{PtxKernel, RegisterAllocator};
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("test");
let base = alloc.alloc(PtxType::U64);
let tid = alloc.alloc(PtxType::U32);
let frag = load_fragment_b_m16n8k16_global_col(&mut alloc, &mut kernel, base, tid);
assert_eq!(frag.regs.len(), 2);
let n_loads = kernel
.body
.iter()
.filter(|instr| {
matches!(
instr,
PtxInstruction::Memory(MemoryOp::LdGlobal {
ty: PtxType::U32,
..
})
)
})
.count();
assert_eq!(n_loads, 2, "expected 2 ld.global.b32 for FragmentB");
}
#[test]
fn store_fragment_c_emits_four_f32_stores() {
use crate::ir::{PtxKernel, RegisterAllocator};
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("test");
let base = alloc.alloc(PtxType::U64);
let tid = alloc.alloc(PtxType::U32);
let frag = alloc_c(&mut alloc);
store_fragment_c_m16n8k16_global_row(&mut alloc, &mut kernel, base, tid, frag, 32);
let n_stores = kernel
.body
.iter()
.filter(|instr| {
matches!(
instr,
PtxInstruction::Memory(MemoryOp::StGlobal {
ty: PtxType::F32,
..
})
)
})
.count();
assert_eq!(n_stores, 4, "expected 4 st.global.f32 for FragmentC");
}
#[test]
fn load_fragment_a_shared_emits_four_b32_shared_loads() {
use crate::ir::{PtxKernel, RegisterAllocator};
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("test");
let base = alloc.alloc(PtxType::U32); let tid = alloc.alloc(PtxType::U32);
let frag =
load_fragment_a_m16n8k16_shared_row(&mut alloc, &mut kernel, base, tid, 32, None);
assert_eq!(frag.regs.len(), 4);
let n_loads = kernel
.body
.iter()
.filter(|instr| {
matches!(
instr,
PtxInstruction::Memory(MemoryOp::LdShared {
ty: PtxType::U32,
..
})
)
})
.count();
assert_eq!(n_loads, 4, "expected 4 ld.shared.b32 for FragmentA");
let n_global = kernel
.body
.iter()
.filter(|instr| matches!(instr, PtxInstruction::Memory(MemoryOp::LdGlobal { .. })))
.count();
assert_eq!(n_global, 0, "shared helper should not emit any ld.global");
}
#[test]
fn load_fragment_b_shared_emits_two_b32_shared_loads() {
use crate::ir::{PtxKernel, RegisterAllocator};
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("test");
let base = alloc.alloc(PtxType::U32);
let tid = alloc.alloc(PtxType::U32);
let frag =
load_fragment_b_m16n8k16_shared_col(&mut alloc, &mut kernel, base, tid, 32, None);
assert_eq!(frag.regs.len(), 2);
let n_loads = kernel
.body
.iter()
.filter(|instr| {
matches!(
instr,
PtxInstruction::Memory(MemoryOp::LdShared {
ty: PtxType::U32,
..
})
)
})
.count();
assert_eq!(n_loads, 2, "expected 2 ld.shared.b32 for FragmentB");
}
#[test]
fn load_fragment_a_shared_respects_stride_parameter() {
use crate::ir::{PtxKernel, RegisterAllocator};
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("test");
let base = alloc.alloc(PtxType::U32);
let tid = alloc.alloc(PtxType::U32);
let _ = load_fragment_a_m16n8k16_shared_row(&mut alloc, &mut kernel, base, tid, 128, None);
let mut saw_stride_mul = false;
let mut saw_eight_row_add = false;
for instr in &kernel.body {
if let PtxInstruction::Arith(ArithOp::Mul {
rhs: Operand::ImmU32(128),
..
}) = instr
{
saw_stride_mul = true;
}
if let PtxInstruction::Arith(ArithOp::Add {
rhs: Operand::ImmU32(1024),
..
}) = instr
{
saw_eight_row_add = true;
}
}
assert!(
saw_stride_mul,
"shared A loader should multiply group_id by the caller-supplied row_stride_bytes"
);
assert!(
saw_eight_row_add,
"shared A loader should add 8*row_stride_bytes for the +8-rows address"
);
}
#[test]
fn load_fragment_b_shared_respects_stride_parameter() {
use crate::ir::{PtxKernel, RegisterAllocator};
let mut alloc = RegisterAllocator::new();
let mut kernel = PtxKernel::new("test");
let base = alloc.alloc(PtxType::U32);
let tid = alloc.alloc(PtxType::U32);
let _ = load_fragment_b_m16n8k16_shared_col(&mut alloc, &mut kernel, base, tid, 96, None);
let mut saw_stride_mul = false;
for instr in &kernel.body {
if let PtxInstruction::Arith(ArithOp::Mul {
rhs: Operand::ImmU32(96),
..
}) = instr
{
saw_stride_mul = true;
}
}
assert!(
saw_stride_mul,
"shared B loader should multiply group_id by the caller-supplied col_stride_bytes"
);
}
#[test]
fn fragment_counters_independent() {
let mut a = RegisterAllocator::new();
let fa = alloc_a(&mut a);
let fb = alloc_b(&mut a);
let fc = alloc_c(&mut a);
assert_eq!(fa.regs[0].index, 0);
assert_eq!(fa.regs[3].index, 3);
assert_eq!(fb.regs[0].index, 4);
assert_eq!(fb.regs[1].index, 5);
assert_eq!(fc.regs[0].index, 0);
assert_eq!(fc.regs[3].index, 3);
}
}