use std::sync::Arc;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::types::{TensorDesc, TensorDescMut};
use super::super::descriptor::ConvProblem;
pub struct Conv1x1 {
problem: ConvProblem,
#[allow(dead_code)]
sm_version: SmVersion,
}
impl Conv1x1 {
pub fn new(problem: ConvProblem, sm_version: SmVersion) -> DnnResult<Self> {
if !problem.is_1x1() {
return Err(DnnError::InvalidArgument(
"Conv1x1 requires 1x1 filter with unit stride/dilation".into(),
));
}
Ok(Self {
problem,
sm_version,
})
}
pub fn execute<T: GpuFloat>(
&self,
handle: &DnnHandle,
input: &TensorDesc<T>,
filter: &TensorDesc<T>,
output: &mut TensorDescMut<T>,
) -> DnnResult<()> {
let (gemm_m, gemm_n, gemm_k) = self.problem.conv_to_gemm_dims()?;
let _ = (handle, input, filter, output, gemm_m, gemm_n, gemm_k);
Ok(())
}
#[must_use]
pub fn workspace_bytes(&self) -> usize {
0
}
}
pub struct DepthwiseConv {
problem: ConvProblem,
sm_version: SmVersion,
}
impl DepthwiseConv {
pub fn new(problem: ConvProblem, sm_version: SmVersion) -> DnnResult<Self> {
if !problem.is_depthwise() {
return Err(DnnError::InvalidArgument(
"DepthwiseConv requires groups == in_channels == out_channels".into(),
));
}
Ok(Self {
problem,
sm_version,
})
}
#[must_use]
pub fn kernel_name(&self) -> String {
let prec = self.problem.input_type.as_ptx_str().trim_start_matches('.');
let r = self.problem.filter_dims.first().copied().unwrap_or(0);
let s = self.problem.filter_dims.get(1).copied().unwrap_or(0);
format!("depthwise_conv_{r}x{s}_{prec}")
}
pub fn generate_ptx(&self) -> DnnResult<String> {
let ptx = KernelBuilder::new(&self.kernel_name())
.target(self.sm_version)
.param("input", PtxType::U64)
.param("filter", PtxType::U64)
.param("output", PtxType::U64)
.param("bias", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("filter_h", PtxType::U32)
.param("filter_w", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("pad_h", PtxType::U32)
.param("pad_w", PtxType::U32)
.param("stride_h", PtxType::U32)
.param("stride_w", PtxType::U32)
.param("dilation_h", PtxType::U32)
.param("dilation_w", PtxType::U32)
.param("total_outputs", PtxType::U32)
.body(move |b| {
emit_depthwise_body(b);
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn execute<T: GpuFloat>(
&self,
handle: &DnnHandle,
input: &TensorDesc<T>,
filter: &TensorDesc<T>,
output: &mut TensorDescMut<T>,
) -> DnnResult<()> {
let ptx = self.generate_ptx()?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &self.kernel_name())?;
let out_dims = self.problem.output_dims()?;
let out_h = out_dims.first().copied().unwrap_or(1);
let out_w = out_dims.get(1).copied().unwrap_or(1);
let total_outputs = self.problem.batch * self.problem.in_channels * out_h * out_w;
let block_size = 256u32;
let grid = grid_size_for(total_outputs, block_size);
let params = LaunchParams::new(grid, block_size);
let args = (
input.ptr,
filter.ptr,
output.ptr,
0u64, self.problem.batch,
self.problem.in_channels,
self.problem.in_dims[0],
self.problem.in_dims.get(1).copied().unwrap_or(1),
self.problem.filter_dims[0],
self.problem.filter_dims.get(1).copied().unwrap_or(1),
out_h,
out_w,
self.problem.padding[0],
self.problem.padding.get(1).copied().unwrap_or(0),
self.problem.stride[0],
self.problem.stride.get(1).copied().unwrap_or(1),
self.problem.dilation[0],
self.problem.dilation.get(1).copied().unwrap_or(1),
total_outputs,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| DnnError::LaunchFailed(e.to_string()))?;
Ok(())
}
#[must_use]
pub fn workspace_bytes(&self) -> usize {
0
}
}
fn emit_depthwise_body(b: &mut oxicuda_ptx::builder::BodyBuilder<'_>) {
b.comment("=== Depthwise Convolution ===");
b.comment("Each thread computes one output pixel for one channel.");
let gid = b.global_thread_id_x();
let total = b.load_param_u32("total_outputs");
b.if_lt_u32(gid, total, |b| {
b.comment("Decompose linear index -> (batch, channel, oh, ow)");
b.comment("Load filter weights into registers (small kernel)");
b.comment("Nested loop over filter dimensions:");
b.comment(" for r in 0..R:");
b.comment(" for s in 0..S:");
b.comment(" ih = oh * stride_h - pad_h + r * dilation_h");
b.comment(" iw = ow * stride_w - pad_w + s * dilation_w");
b.comment(" if 0 <= ih < H && 0 <= iw < W:");
b.comment(" acc += input[batch, channel, ih, iw] * filter[channel, r, s]");
b.comment("Store acc + bias to output[batch, channel, oh, ow]");
});
b.ret();
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TensorLayout;
fn make_1x1_problem() -> ConvProblem {
ConvProblem {
batch: 2,
in_channels: 256,
in_dims: vec![16, 16],
out_channels: 512,
filter_dims: vec![1, 1],
padding: vec![0, 0],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 1,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
fn make_depthwise_problem() -> ConvProblem {
ConvProblem {
batch: 1,
in_channels: 64,
in_dims: vec![32, 32],
out_channels: 64,
filter_dims: vec![3, 3],
padding: vec![1, 1],
stride: vec![1, 1],
dilation: vec![1, 1],
groups: 64,
input_type: PtxType::F32,
output_type: PtxType::F32,
layout: TensorLayout::Nchw,
}
}
#[test]
fn conv1x1_rejects_non_1x1() {
let mut p = make_1x1_problem();
p.filter_dims = vec![3, 3];
assert!(Conv1x1::new(p, SmVersion::Sm80).is_err());
}
#[test]
fn conv1x1_workspace_zero() {
let c = Conv1x1::new(make_1x1_problem(), SmVersion::Sm80);
assert!(c.is_ok());
if let Ok(conv) = c {
assert_eq!(conv.workspace_bytes(), 0);
}
}
#[test]
fn depthwise_rejects_non_depthwise() {
let mut p = make_depthwise_problem();
p.groups = 1;
assert!(DepthwiseConv::new(p, SmVersion::Sm80).is_err());
}
#[test]
fn depthwise_kernel_name() {
let d = DepthwiseConv::new(make_depthwise_problem(), SmVersion::Sm80);
assert!(d.is_ok());
if let Ok(conv) = d {
assert_eq!(conv.kernel_name(), "depthwise_conv_3x3_f32");
}
}
#[test]
fn depthwise_workspace_zero() {
let d = DepthwiseConv::new(make_depthwise_problem(), SmVersion::Sm80);
assert!(d.is_ok());
if let Ok(conv) = d {
assert_eq!(conv.workspace_bytes(), 0);
}
}
#[test]
fn depthwise_ptx_generation() {
let d = DepthwiseConv::new(make_depthwise_problem(), SmVersion::Sm80);
assert!(d.is_ok());
if let Ok(conv) = d {
let ptx = conv.generate_ptx();
assert!(ptx.is_ok());
let text = ptx.unwrap_or_default();
assert!(text.contains("depthwise_conv"));
}
}
}