#![allow(clippy::similar_names)]
use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};
#[derive(Debug, Clone)]
pub struct TransposeKernel {
pub rows: u32,
pub cols: u32,
}
impl TransposeKernel {
#[must_use]
pub const fn new(rows: u32, cols: u32) -> Self {
Self { rows, cols }
}
}
impl Kernel for TransposeKernel {
fn name(&self) -> &str {
"transpose"
}
fn build_ptx(&self) -> PtxKernel {
let rows = self.rows;
let cols = self.cols;
let total_elems = rows * cols;
PtxKernel::new("transpose")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "rows")
.param(PtxType::U32, "cols")
.build(move |ctx| {
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let total = ctx.mov_u32_imm(total_elems);
let in_bounds = ctx.setp_lt_u32(gid, total);
ctx.branch_if_not(in_bounds, "exit");
let row_idx = ctx.div_u32(gid, cols);
let col_idx = ctx.rem_u32(gid, cols);
let four = ctx.mov_u32_imm(4);
let input_offset = ctx.mul_wide_u32_reg(gid, four);
let input_addr = ctx.add_u64(input_ptr, input_offset);
let rows_reg = ctx.mov_u32_imm(rows);
let out_linear = ctx.mad_lo_u32(col_idx, rows_reg, row_idx);
let output_offset = ctx.mul_wide_u32_reg(out_linear, four);
let output_addr = ctx.add_u64(output_ptr, output_offset);
let val = ctx.ld_global_f32(input_addr);
ctx.st_global_f32(output_addr, val);
ctx.label("exit");
ctx.ret();
})
}
}
#[derive(Debug, Clone)]
pub struct BatchedTransposeKernel {
pub batch: u32,
pub rows: u32,
pub cols: u32,
}
impl BatchedTransposeKernel {
#[must_use]
pub const fn new(batch: u32, rows: u32, cols: u32) -> Self {
Self { batch, rows, cols }
}
}
impl Kernel for BatchedTransposeKernel {
fn name(&self) -> &str {
"batched_transpose"
}
fn build_ptx(&self) -> PtxKernel {
let rows = self.rows;
let cols = self.cols;
let total_per_batch = rows * cols;
PtxKernel::new("batched_transpose")
.param(PtxType::U64, "input_ptr")
.param(PtxType::U64, "output_ptr")
.param(PtxType::U32, "batch")
.param(PtxType::U32, "rows")
.param(PtxType::U32, "cols")
.build(move |ctx| {
let batch_idx = ctx.special_reg(PtxReg::CtaIdZ);
let tid = ctx.special_reg(PtxReg::TidX);
let ctaid = ctx.special_reg(PtxReg::CtaIdX);
let ntid = ctx.special_reg(PtxReg::NtidX);
let gid = ctx.mad_lo_u32(ctaid, ntid, tid);
let total = ctx.mov_u32_imm(total_per_batch);
let in_bounds = ctx.setp_lt_u32(gid, total);
let batch_param = ctx.load_param_u32("batch");
let batch_valid = ctx.setp_lt_u32(batch_idx, batch_param);
let valid = ctx.and_pred(in_bounds, batch_valid);
ctx.branch_if_not(valid, "exit");
let input_ptr = ctx.load_param_u64("input_ptr");
let output_ptr = ctx.load_param_u64("output_ptr");
let row = ctx.div_u32(gid, cols);
let col = ctx.rem_u32(gid, cols);
let batch_offset = ctx.mul_wide_u32(batch_idx, total_per_batch * 4);
let in_batch_ptr = ctx.add_u64(input_ptr, batch_offset);
let out_batch_ptr = ctx.add_u64(output_ptr, batch_offset);
let cols_reg = ctx.mov_u32_imm(cols);
let in_idx = ctx.mad_lo_u32(row, cols_reg, col);
let rows_reg = ctx.mov_u32_imm(rows);
let out_idx = ctx.mad_lo_u32(col, rows_reg, row);
let four = ctx.mov_u32_imm(4);
let in_offset = ctx.mul_wide_u32_reg(in_idx, four);
let out_offset = ctx.mul_wide_u32_reg(out_idx, four);
let in_addr = ctx.add_u64(in_batch_ptr, in_offset);
let out_addr = ctx.add_u64(out_batch_ptr, out_offset);
let val = ctx.ld_global_f32(in_addr);
ctx.st_global_f32(out_addr, val);
ctx.label("exit");
ctx.ret();
})
}
}