use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::{DnnError, DnnResult};
use super::helpers::{emit_activation, emit_bn_epilogue, parse_precision};
use super::types::{ActivationType, DepthwiseSeparableConfig};
#[derive(Debug, Clone, Copy)]
struct PointwiseBodyParams {
float_type: PtxType,
elem_bytes: u32,
activation: ActivationType,
has_bn: bool,
}
pub fn generate_pointwise_conv_ptx(
config: &DepthwiseSeparableConfig,
precision: &str,
sm_version: SmVersion,
) -> DnnResult<String> {
let (float_type, elem_bytes) = parse_precision(precision)?;
let act_suffix = config.pointwise_activation.kernel_suffix();
let kernel_name = format!("depthwise_separable_pw_{act_suffix}_{precision}");
let in_channels = config.depthwise_out_channels() as u32;
let params = PointwiseBodyParams {
float_type,
elem_bytes,
activation: config.pointwise_activation,
has_bn: config.pointwise_bn,
};
let mut builder = KernelBuilder::new(&kernel_name);
builder = builder
.target(sm_version)
.param("input", PtxType::U64) .param("weight", PtxType::U64) .param("bias", PtxType::U64)
.param("output", PtxType::U64)
.param("in_channels", PtxType::U32)
.param("out_channels", PtxType::U32)
.param("spatial_size", PtxType::U32) .param("total_outputs", PtxType::U32);
if config.pointwise_bn {
builder = builder
.param("bn_gamma", PtxType::U64)
.param("bn_beta", PtxType::U64)
.param("bn_mean", PtxType::U64)
.param("bn_var", PtxType::U64);
}
let ptx = builder
.body(move |b| {
emit_pointwise_conv_body(b, in_channels, ¶ms);
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
fn emit_pointwise_conv_body(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
_compile_time_in_ch: u32,
p: &PointwiseBodyParams,
) {
let float_type = p.float_type;
let elem_bytes = p.elem_bytes;
b.comment("=== Pointwise (1x1) Convolution Stage ===");
b.comment("Each thread: one output element (batch*h*w, oc).");
let gid = b.global_thread_id_x();
let total = b.load_param_u32("total_outputs");
let pred = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred}, {gid}, {total};"));
let exit_label = b.fresh_label("pw_exit");
b.raw_ptx(&format!("@!{pred} bra {exit_label};"));
let input_ptr = b.load_param_u64("input");
let weight_ptr = b.load_param_u64("weight");
let output_ptr = b.load_param_u64("output");
let in_channels = b.load_param_u32("in_channels");
let out_channels = b.load_param_u32("out_channels");
let spatial_size = b.load_param_u32("spatial_size");
b.comment("Decompose gid -> (spatial_idx, oc) where linear = spatial_idx * out_channels + oc");
let oc = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {oc}, {gid}, {out_channels};"));
let spatial_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {spatial_idx}, {gid}, {out_channels};"));
let _ = spatial_size;
let acc = b.alloc_reg(float_type);
if float_type == PtxType::F32 {
let zero = 0f32.to_bits();
b.raw_ptx(&format!("mov.b32 {acc}, 0F{zero:08X};"));
} else {
let zero = 0f64.to_bits();
b.raw_ptx(&format!("mov.b64 {acc}, 0D{zero:016X};"));
}
let ic = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {ic}, 0;"));
let loop_label = b.fresh_label("pw_loop");
let loop_end = b.fresh_label("pw_loop_end");
b.label(&loop_label);
let pred_loop = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_loop}, {ic}, {in_channels};"));
b.raw_ptx(&format!("@!{pred_loop} bra {loop_end};"));
let in_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"mul.lo.u32 {in_idx}, {spatial_idx}, {in_channels};"
));
b.raw_ptx(&format!("add.u32 {in_idx}, {in_idx}, {ic};"));
let idx64 = b.alloc_reg(PtxType::U64);
let off64 = b.alloc_reg(PtxType::U64);
let addr = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {in_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {elem_bytes};"));
b.raw_ptx(&format!("add.u64 {addr}, {input_ptr}, {off64};"));
let in_val = b.alloc_reg(float_type);
if float_type == PtxType::F32 {
b.raw_ptx(&format!("ld.global.f32 {in_val}, [{addr}];"));
} else {
b.raw_ptx(&format!("ld.global.f64 {in_val}, [{addr}];"));
}
let w_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {w_idx}, {oc}, {in_channels};"));
b.raw_ptx(&format!("add.u32 {w_idx}, {w_idx}, {ic};"));
let w_idx64 = b.alloc_reg(PtxType::U64);
let w_off64 = b.alloc_reg(PtxType::U64);
let w_addr = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {w_idx64}, {w_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {w_off64}, {w_idx64}, {elem_bytes};"));
b.raw_ptx(&format!("add.u64 {w_addr}, {weight_ptr}, {w_off64};"));
let w_val = b.alloc_reg(float_type);
if float_type == PtxType::F32 {
b.raw_ptx(&format!("ld.global.f32 {w_val}, [{w_addr}];"));
b.raw_ptx(&format!("fma.rn.f32 {acc}, {in_val}, {w_val}, {acc};"));
} else {
b.raw_ptx(&format!("ld.global.f64 {w_val}, [{w_addr}];"));
b.raw_ptx(&format!("fma.rn.f64 {acc}, {in_val}, {w_val}, {acc};"));
}
b.raw_ptx(&format!("add.u32 {ic}, {ic}, 1;"));
b.raw_ptx(&format!("bra {loop_label};"));
b.label(&loop_end);
if p.has_bn {
emit_bn_epilogue(b, float_type, elem_bytes, &acc, oc);
}
let activated = emit_activation(b, float_type, acc, p.activation);
b.comment("Store pointwise output");
let out_idx64 = b.alloc_reg(PtxType::U64);
let out_off64 = b.alloc_reg(PtxType::U64);
let out_addr = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {out_idx64}, {gid};"));
b.raw_ptx(&format!(
"mul.lo.u64 {out_off64}, {out_idx64}, {elem_bytes};"
));
b.raw_ptx(&format!("add.u64 {out_addr}, {output_ptr}, {out_off64};"));
if float_type == PtxType::F32 {
b.raw_ptx(&format!("st.global.f32 [{out_addr}], {activated};"));
} else {
b.raw_ptx(&format!("st.global.f64 [{out_addr}], {activated};"));
}
b.label(&exit_label);
b.ret();
}