use std::sync::Arc;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams, grid_size_for};
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::types::{TensorDesc, TensorDescMut, TileConfig};
use super::super::descriptor::ConvProblem;
pub struct ImplicitGemmConv {
problem: ConvProblem,
tile_config: TileConfig,
sm_version: SmVersion,
}
impl ImplicitGemmConv {
#[must_use]
pub fn new(problem: ConvProblem, sm_version: SmVersion) -> Self {
let tile_config = TileConfig::default_conv(sm_version);
Self {
problem,
tile_config,
sm_version,
}
}
#[must_use]
pub fn with_tile_config(
problem: ConvProblem,
tile_config: TileConfig,
sm_version: SmVersion,
) -> Self {
Self {
problem,
tile_config,
sm_version,
}
}
#[must_use]
pub fn kernel_name(&self) -> String {
let prec = self.problem.input_type.as_ptx_str().trim_start_matches('.');
format!(
"implicit_gemm_conv_{}x{}x{}_{}",
self.tile_config.tile_m, self.tile_config.tile_n, self.tile_config.tile_k, prec,
)
}
pub fn generate_ptx(&self) -> DnnResult<String> {
let _gemm_dims = self.problem.conv_to_gemm_dims()?;
let sm = self.sm_version;
let stages = self.tile_config.stages;
let ptx = KernelBuilder::new(&self.kernel_name())
.target(self.sm_version)
.param("input", PtxType::U64)
.param("filter", PtxType::U64)
.param("output", PtxType::U64)
.param("bias", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("in_channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("out_channels", PtxType::U32)
.param("filter_h", PtxType::U32)
.param("filter_w", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("pad_h", PtxType::U32)
.param("pad_w", PtxType::U32)
.param("stride_h", PtxType::U32)
.param("stride_w", PtxType::U32)
.param("dilation_h", PtxType::U32)
.param("dilation_w", PtxType::U32)
.param("gemm_m", PtxType::U32)
.param("gemm_n", PtxType::U32)
.param("gemm_k", PtxType::U32)
.shared_mem(
"smem_input",
self.problem.input_type,
self.smem_input_elements(),
)
.shared_mem(
"smem_filter",
self.problem.input_type,
self.smem_filter_elements(),
)
.body(move |b| {
emit_implicit_gemm_body(b, sm, stages);
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn execute<T: GpuFloat>(
&self,
handle: &DnnHandle,
input: &TensorDesc<T>,
filter: &TensorDesc<T>,
bias: Option<&TensorDesc<T>>,
output: &mut TensorDescMut<T>,
) -> DnnResult<()> {
let ptx = self.generate_ptx()?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &self.kernel_name())?;
let (gemm_m, gemm_n, _gemm_k) = self.problem.conv_to_gemm_dims()?;
let out_dims = self.problem.output_dims()?;
let out_h = out_dims.first().copied().unwrap_or(1);
let out_w = out_dims.get(1).copied().unwrap_or(1);
let grid_x = grid_size_for(gemm_m, self.tile_config.tile_m);
let grid_y = grid_size_for(gemm_n, self.tile_config.tile_n);
let grid = Dim3::xy(grid_x, grid_y);
let warps_m = self.tile_config.tile_m / self.tile_config.warp_m;
let warps_n = self.tile_config.tile_n / self.tile_config.warp_n;
let threads = warps_m * warps_n * 32;
let block = Dim3::x(threads.min(1024));
let shared_bytes = (self.smem_input_elements() + self.smem_filter_elements())
* self.problem.input_type.size_bytes();
let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes as u32);
let bias_ptr = bias.map_or(0u64, |b| b.ptr);
let args = (
input.ptr,
filter.ptr,
output.ptr,
bias_ptr,
self.problem.batch,
self.problem.in_channels,
self.problem.in_dims[0],
self.problem.in_dims.get(1).copied().unwrap_or(1),
self.problem.out_channels,
self.problem.filter_dims[0],
self.problem.filter_dims.get(1).copied().unwrap_or(1),
out_h,
out_w,
self.problem.padding[0],
self.problem.padding.get(1).copied().unwrap_or(0),
self.problem.stride[0],
self.problem.stride.get(1).copied().unwrap_or(1),
self.problem.dilation[0],
self.problem.dilation.get(1).copied().unwrap_or(1),
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| DnnError::LaunchFailed(e.to_string()))?;
Ok(())
}
fn smem_input_elements(&self) -> usize {
let tile_m = self.tile_config.tile_m as usize;
let tile_k = self.tile_config.tile_k as usize;
let stages = self.tile_config.stages as usize;
tile_m * tile_k * stages
}
fn smem_filter_elements(&self) -> usize {
let tile_n = self.tile_config.tile_n as usize;
let tile_k = self.tile_config.tile_k as usize;
let stages = self.tile_config.stages as usize;
tile_n * tile_k * stages
}
#[must_use]
pub fn workspace_bytes(&self) -> usize {
0
}
}
fn emit_implicit_gemm_body(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
sm: SmVersion,
stages: u32,
) {
b.comment("=== Implicit GEMM Convolution (forward) ===");
b.comment("Step 1: Map CTA to GEMM tile coordinates");
b.comment(" blockIdx.x -> M-tile (batch * out_h * out_w)");
b.comment(" blockIdx.y -> N-tile (out_channels)");
let _gid_x = b.global_thread_id_x();
let _gid_y = b.global_thread_id_y();
b.comment("Step 2: Mainloop over filter volume (C x R x S)");
b.comment(" For each k-iteration:");
b.comment(" channel_idx = k / (R * S)");
b.comment(" r_idx = (k / S) % R");
b.comment(" s_idx = k % S");
b.comment(" input_h = out_h * stride_h - pad_h + r_idx * dilation_h");
b.comment(" input_w = out_w * stride_w - pad_w + s_idx * dilation_w");
b.comment(" Boundary check: load 0 if out of bounds (zero-padding)");
if sm >= SmVersion::Sm80 {
b.comment("--- Async pipeline (cp.async) for Ampere+ ---");
b.comment(&format!("Pipeline depth: {stages} stages"));
for stage in 0..stages.saturating_sub(1) {
b.comment(&format!(" Prologue: async load stage {stage}"));
}
b.comment(" Mainloop: for each K-tile");
b.comment(" 1. Wait for oldest async load to complete");
b.comment(" 2. Compute GEMM tile (MMA or FMA)");
b.comment(" 3. Issue next async load");
for stage in 0..stages.saturating_sub(1) {
b.comment(&format!(" Drain: compute stage {stage}"));
}
} else {
b.comment("--- Standard mainloop (Turing / pre-Ampere) ---");
b.comment("For each K-tile:");
b.comment(" 1. Load input tile to smem with boundary predicates");
b.comment(" 2. Load filter tile to smem");
b.comment(" 3. __syncthreads()");
b.comment(" 4. Compute tile GEMM (FMA loop or WMMA)");
b.comment(" 5. __syncthreads()");
}
b.comment("Step 3: Epilogue -- write accumulator to global output");
emit_bias_epilogue(b);
b.ret();
}
fn emit_bias_epilogue(b: &mut oxicuda_ptx::builder::BodyBuilder<'_>) {
b.comment("--- Bias epilogue (guarded per-output-channel add) ---");
let bias_ptr = b.load_param_u64("bias");
let output_ptr = b.load_param_u64("output");
let gemm_m = b.load_param_u32("gemm_m");
let gemm_n = b.load_param_u32("gemm_n");
let m_coord = b.global_thread_id_x();
let n_coord = b.global_thread_id_y();
let skip_epilogue = b.fresh_label("ig_epilogue_skip");
let p_m = b.alloc_reg(PtxType::Pred);
let p_n = b.alloc_reg(PtxType::Pred);
let p_in = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {p_m}, {m_coord}, {gemm_m};"));
b.raw_ptx(&format!("setp.lo.u32 {p_n}, {n_coord}, {gemm_n};"));
b.raw_ptx(&format!("and.pred {p_in}, {p_m}, {p_n};"));
b.raw_ptx(&format!("@!{p_in} bra {skip_epilogue};"));
let out_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {out_idx}, {n_coord}, {gemm_m};"));
b.raw_ptx(&format!("add.u32 {out_idx}, {out_idx}, {m_coord};"));
let out_idx64 = b.alloc_reg(PtxType::U64);
let out_off = b.alloc_reg(PtxType::U64);
let out_addr = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {out_idx64}, {out_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {out_off}, {out_idx64}, 4;"));
b.raw_ptx(&format!("add.u64 {out_addr}, {output_ptr}, {out_off};"));
let acc = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("ld.global.f32 {acc}, [{out_addr}];"));
let no_bias = b.fresh_label("ig_no_bias");
let p_has_bias = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ne.u64 {p_has_bias}, {bias_ptr}, 0;"));
b.raw_ptx(&format!("@!{p_has_bias} bra {no_bias};"));
let bias_idx64 = b.alloc_reg(PtxType::U64);
let bias_off = b.alloc_reg(PtxType::U64);
let bias_addr = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {bias_idx64}, {n_coord};"));
b.raw_ptx(&format!("mul.lo.u64 {bias_off}, {bias_idx64}, 4;"));
b.raw_ptx(&format!("add.u64 {bias_addr}, {bias_ptr}, {bias_off};"));
let bias_val = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("ld.global.f32 {bias_val}, [{bias_addr}];"));
b.raw_ptx(&format!("add.rn.f32 {acc}, {acc}, {bias_val};"));
b.raw_ptx(&format!("{no_bias}:"));
b.raw_ptx(&format!("st.global.f32 [{out_addr}], {acc};"));
b.raw_ptx(&format!("{skip_epilogue}:"));
}
#[inline]
pub fn gemm_m_to_conv_coords(m: u32, out_h: u32, out_w: u32) -> (u32, u32, u32) {
let spatial = out_h * out_w;
let batch_idx = m / spatial;
let remainder = m % spatial;
let oh = remainder / out_w;
let ow = remainder % out_w;
(batch_idx, oh, ow)
}
#[inline]
pub fn gemm_k_to_filter_coords(k: u32, filter_h: u32, filter_w: u32) -> (u32, u32, u32) {
let rs = filter_h * filter_w;
let c = k / rs;
let remainder = k % rs;
let r = remainder / filter_w;
let s = remainder % filter_w;
(c, r, s)
}
#[inline]
pub fn input_coord(
out_pos: u32,
filter_pos: u32,
pad: u32,
stride: u32,
dilation: u32,
input_size: u32,
) -> Option<u32> {
let pos = (out_pos * stride) as i64 - pad as i64 + (filter_pos * dilation) as i64;
if pos >= 0 && (pos as u32) < input_size {
Some(pos as u32)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TensorLayout;
fn make_problem() -> ConvProblem {
ConvProblem {
batch: 2,
in_channels: 64,
in_dims: vec![32, 32],
out_channels: 128,
filter_dims: vec![3, 3],
padding: vec![1, 1],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 1,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
#[test]
fn kernel_name_format() {
let conv = ImplicitGemmConv::new(make_problem(), SmVersion::Sm80);
let name = conv.kernel_name();
assert!(name.contains("implicit_gemm_conv"));
assert!(name.contains("f32"));
}
#[test]
fn workspace_is_zero() {
let conv = ImplicitGemmConv::new(make_problem(), SmVersion::Sm80);
assert_eq!(conv.workspace_bytes(), 0);
}
#[test]
fn smem_sizes_positive() {
let conv = ImplicitGemmConv::new(make_problem(), SmVersion::Sm80);
assert!(conv.smem_input_elements() > 0);
assert!(conv.smem_filter_elements() > 0);
}
#[test]
fn gemm_m_to_conv_coords_basic() {
assert_eq!(gemm_m_to_conv_coords(0, 4, 4), (0, 0, 0));
assert_eq!(gemm_m_to_conv_coords(5, 4, 4), (0, 1, 1));
assert_eq!(gemm_m_to_conv_coords(16, 4, 4), (1, 0, 0));
}
#[test]
fn gemm_k_to_filter_coords_basic() {
assert_eq!(gemm_k_to_filter_coords(0, 3, 3), (0, 0, 0));
assert_eq!(gemm_k_to_filter_coords(4, 3, 3), (0, 1, 1));
assert_eq!(gemm_k_to_filter_coords(9, 3, 3), (1, 0, 0));
}
#[test]
fn input_coord_valid() {
assert_eq!(input_coord(1, 0, 1, 1, 1, 32), Some(0));
}
#[test]
fn input_coord_padded() {
assert_eq!(input_coord(0, 0, 1, 1, 1, 32), None);
}
#[test]
fn input_coord_beyond_input() {
assert_eq!(input_coord(31, 2, 1, 1, 1, 32), None);
}
#[test]
fn ptx_generation_produces_output() {
let conv = ImplicitGemmConv::new(make_problem(), SmVersion::Sm80);
let ptx = conv.generate_ptx();
assert!(ptx.is_ok());
let ptx_text = ptx.unwrap_or_default();
assert!(ptx_text.contains("implicit_gemm_conv"));
assert!(ptx_text.contains(".entry"));
}
#[test]
fn ptx_epilogue_has_guarded_bias_add() {
let conv = ImplicitGemmConv::new(make_problem(), SmVersion::Sm80);
let ptx = conv.generate_ptx().expect("ptx generation");
assert!(
ptx.contains("setp.ne.u64"),
"epilogue must test the bias pointer for null"
);
assert!(
ptx.contains("ld.global.f32"),
"epilogue must load the bias value"
);
assert!(
ptx.contains("add.rn.f32"),
"epilogue must add the bias to the accumulator"
);
assert!(
ptx.contains("st.global.f32"),
"epilogue must store the result"
);
}
#[test]
fn ptx_declares_bias_param() {
let conv = ImplicitGemmConv::new(make_problem(), SmVersion::Sm80);
let ptx = conv.generate_ptx().expect("ptx generation");
assert!(ptx.contains("bias"), "kernel must declare a bias parameter");
}
#[test]
fn bias_epilogue_cpu_reference() {
let out_channels = 4usize;
let m = 6usize; let mut acc: Vec<f32> = (0..out_channels * m)
.map(|i| (i as f32) * 0.25 - 1.0)
.collect();
let pre = acc.clone();
let bias: Vec<f32> = (0..out_channels).map(|c| (c as f32) * 0.5 + 0.1).collect();
for (n, &bias_n) in bias.iter().enumerate() {
for mi in 0..m {
acc[n * m + mi] += bias_n;
}
}
for (n, &bias_n) in bias.iter().enumerate() {
for mi in 0..m {
let idx = n * m + mi;
let expected = pre[idx] + bias_n;
assert!(
(acc[idx] - expected).abs() < 1e-6,
"bias add mismatch at (n={n}, m={mi})"
);
}
}
}
#[test]
fn no_bias_leaves_accumulator_unchanged() {
let acc: Vec<f32> = vec![1.5, -2.0, 0.0, 3.25];
let null_bias: Option<&[f32]> = None;
let result: Vec<f32> = acc
.iter()
.enumerate()
.map(|(i, &v)| v + null_bias.map_or(0.0, |b| b[i]))
.collect();
assert_eq!(result, acc);
}
}