use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::{DnnError, DnnResult};
#[derive(Debug, Clone)]
pub struct DeformableConvConfig {
pub in_channels: u32,
pub out_channels: u32,
pub kernel_h: u32,
pub kernel_w: u32,
pub stride_h: u32,
pub stride_w: u32,
pub pad_h: u32,
pub pad_w: u32,
pub dilation_h: u32,
pub dilation_w: u32,
pub offset_groups: u32,
pub use_modulation: bool,
pub sm_version: SmVersion,
pub float_type: PtxType,
}
impl DeformableConvConfig {
#[must_use]
pub fn output_size(&self, in_h: u32, in_w: u32) -> (u32, u32) {
let effective_kh = self
.dilation_h
.saturating_mul(self.kernel_h.saturating_sub(1))
+ 1;
let effective_kw = self
.dilation_w
.saturating_mul(self.kernel_w.saturating_sub(1))
+ 1;
let padded_h = in_h + 2 * self.pad_h;
let padded_w = in_w + 2 * self.pad_w;
let out_h = if padded_h >= effective_kh {
(padded_h - effective_kh) / self.stride_h.max(1) + 1
} else {
0
};
let out_w = if padded_w >= effective_kw {
(padded_w - effective_kw) / self.stride_w.max(1) + 1
} else {
0
};
(out_h, out_w)
}
pub fn validate(&self) -> DnnResult<()> {
if self.kernel_h == 0 || self.kernel_w == 0 {
return Err(DnnError::InvalidArgument(
"deformable conv: kernel dimensions must be > 0".into(),
));
}
if self.stride_h == 0 || self.stride_w == 0 {
return Err(DnnError::InvalidArgument(
"deformable conv: stride must be > 0".into(),
));
}
if self.dilation_h == 0 || self.dilation_w == 0 {
return Err(DnnError::InvalidArgument(
"deformable conv: dilation must be > 0".into(),
));
}
if self.in_channels == 0 || self.out_channels == 0 {
return Err(DnnError::InvalidArgument(
"deformable conv: channel counts must be > 0".into(),
));
}
if self.offset_groups == 0 {
return Err(DnnError::InvalidArgument(
"deformable conv: offset_groups must be > 0".into(),
));
}
if self.in_channels % self.offset_groups != 0 {
return Err(DnnError::InvalidArgument(format!(
"deformable conv: in_channels ({}) not divisible by offset_groups ({})",
self.in_channels, self.offset_groups
)));
}
if !matches!(self.float_type, PtxType::F32 | PtxType::F16) {
return Err(DnnError::InvalidArgument(format!(
"deformable conv: unsupported float_type {:?}, expected F32 or F16",
self.float_type
)));
}
Ok(())
}
#[must_use]
pub fn channels_per_offset_group(&self) -> u32 {
if self.offset_groups == 0 {
return 0;
}
self.in_channels / self.offset_groups
}
#[must_use]
pub fn offset_channels(&self) -> u32 {
2 * self.kernel_h * self.kernel_w * self.offset_groups
}
#[must_use]
pub fn mask_channels(&self) -> u32 {
self.kernel_h * self.kernel_w * self.offset_groups
}
#[must_use]
pub fn effective_kernel_h(&self) -> u32 {
self.dilation_h * (self.kernel_h.saturating_sub(1)) + 1
}
#[must_use]
pub fn effective_kernel_w(&self) -> u32 {
self.dilation_w * (self.kernel_w.saturating_sub(1)) + 1
}
}
#[derive(Debug)]
pub struct DeformableConvPlan {
config: DeformableConvConfig,
}
impl DeformableConvPlan {
pub fn new(config: DeformableConvConfig) -> DnnResult<Self> {
config.validate()?;
Ok(Self { config })
}
#[must_use]
pub fn config(&self) -> &DeformableConvConfig {
&self.config
}
#[must_use]
pub fn output_size(&self, in_h: u32, in_w: u32) -> (u32, u32) {
self.config.output_size(in_h, in_w)
}
pub fn generate_forward(&self) -> DnnResult<String> {
let c = &self.config;
let (float_type, elem_bytes, precision) = float_type_info(c.float_type)?;
let kernel_name = format!(
"deformable_conv_forward_{precision}_{}x{}",
c.kernel_h, c.kernel_w
);
let sm = c.sm_version;
let params = DeformableBodyParams {
float_type,
elem_bytes,
kernel_h: c.kernel_h,
kernel_w: c.kernel_w,
stride_h: c.stride_h,
stride_w: c.stride_w,
pad_h: c.pad_h,
pad_w: c.pad_w,
dilation_h: c.dilation_h,
dilation_w: c.dilation_w,
offset_groups: c.offset_groups,
use_modulation: c.use_modulation,
channels_per_offset_group: c.channels_per_offset_group(),
};
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("input", PtxType::U64)
.param("offset", PtxType::U64)
.param("mask", PtxType::U64)
.param("weight", PtxType::U64)
.param("bias", PtxType::U64)
.param("output", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("in_channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("out_channels", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("total_outputs", PtxType::U32)
.body(move |b| {
emit_forward_body(b, ¶ms);
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn generate_backward_input(&self) -> DnnResult<String> {
let c = &self.config;
let (float_type, elem_bytes, precision) = float_type_info(c.float_type)?;
let kernel_name = format!(
"deformable_conv_backward_input_{precision}_{}x{}",
c.kernel_h, c.kernel_w
);
let sm = c.sm_version;
let params = DeformableBodyParams {
float_type,
elem_bytes,
kernel_h: c.kernel_h,
kernel_w: c.kernel_w,
stride_h: c.stride_h,
stride_w: c.stride_w,
pad_h: c.pad_h,
pad_w: c.pad_w,
dilation_h: c.dilation_h,
dilation_w: c.dilation_w,
offset_groups: c.offset_groups,
use_modulation: c.use_modulation,
channels_per_offset_group: c.channels_per_offset_group(),
};
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("grad_output", PtxType::U64)
.param("offset", PtxType::U64)
.param("mask", PtxType::U64)
.param("weight", PtxType::U64)
.param("grad_input", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("in_channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("out_channels", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("total_outputs", PtxType::U32)
.body(move |b| {
emit_backward_input_body(b, ¶ms);
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn generate_backward_offset(&self) -> DnnResult<String> {
let c = &self.config;
let (float_type, elem_bytes, precision) = float_type_info(c.float_type)?;
let kernel_name = format!(
"deformable_conv_backward_offset_{precision}_{}x{}",
c.kernel_h, c.kernel_w
);
let sm = c.sm_version;
let params = DeformableBodyParams {
float_type,
elem_bytes,
kernel_h: c.kernel_h,
kernel_w: c.kernel_w,
stride_h: c.stride_h,
stride_w: c.stride_w,
pad_h: c.pad_h,
pad_w: c.pad_w,
dilation_h: c.dilation_h,
dilation_w: c.dilation_w,
offset_groups: c.offset_groups,
use_modulation: c.use_modulation,
channels_per_offset_group: c.channels_per_offset_group(),
};
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("grad_output", PtxType::U64)
.param("input", PtxType::U64)
.param("offset", PtxType::U64)
.param("mask", PtxType::U64)
.param("weight", PtxType::U64)
.param("grad_offset", PtxType::U64)
.param("grad_mask", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("in_channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("out_channels", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("total_outputs", PtxType::U32)
.body(move |b| {
emit_backward_offset_body(b, ¶ms);
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn generate_backward_weight(&self) -> DnnResult<String> {
let c = &self.config;
let (float_type, elem_bytes, precision) = float_type_info(c.float_type)?;
let kernel_name = format!(
"deformable_conv_backward_weight_{precision}_{}x{}",
c.kernel_h, c.kernel_w
);
let sm = c.sm_version;
let params = DeformableBodyParams {
float_type,
elem_bytes,
kernel_h: c.kernel_h,
kernel_w: c.kernel_w,
stride_h: c.stride_h,
stride_w: c.stride_w,
pad_h: c.pad_h,
pad_w: c.pad_w,
dilation_h: c.dilation_h,
dilation_w: c.dilation_w,
offset_groups: c.offset_groups,
use_modulation: c.use_modulation,
channels_per_offset_group: c.channels_per_offset_group(),
};
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("grad_output", PtxType::U64)
.param("input", PtxType::U64)
.param("offset", PtxType::U64)
.param("mask", PtxType::U64)
.param("grad_weight", PtxType::U64)
.param("batch_size", PtxType::U32)
.param("in_channels", PtxType::U32)
.param("in_h", PtxType::U32)
.param("in_w", PtxType::U32)
.param("out_channels", PtxType::U32)
.param("out_h", PtxType::U32)
.param("out_w", PtxType::U32)
.param("total_weight_elements", PtxType::U32)
.body(move |b| {
emit_backward_weight_body(b, ¶ms);
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
}
#[derive(Debug, Clone, Copy)]
struct DeformableBodyParams {
float_type: PtxType,
elem_bytes: u32,
kernel_h: u32,
kernel_w: u32,
stride_h: u32,
stride_w: u32,
pad_h: u32,
pad_w: u32,
dilation_h: u32,
dilation_w: u32,
offset_groups: u32,
use_modulation: bool,
channels_per_offset_group: u32,
}
fn float_type_info(float_type: PtxType) -> DnnResult<(PtxType, u32, &'static str)> {
match float_type {
PtxType::F32 => Ok((PtxType::F32, 4, "f32")),
PtxType::F16 => Ok((PtxType::F16, 2, "f16")),
other => Err(DnnError::InvalidArgument(format!(
"deformable conv: unsupported float_type {other:?}"
))),
}
}
fn emit_zero(b: &mut oxicuda_ptx::builder::BodyBuilder<'_>, float_type: PtxType, dst: &str) {
match float_type {
PtxType::F32 => {
let zero_bits = 0f32.to_bits();
b.raw_ptx(&format!("mov.b32 {dst}, 0F{zero_bits:08X};"));
}
PtxType::F16 => {
b.raw_ptx(&format!("mov.b16 {dst}, 0x0000;"));
}
_ => {
let zero_bits = 0f32.to_bits();
b.raw_ptx(&format!("mov.b32 {dst}, 0F{zero_bits:08X};"));
}
}
}
fn emit_one(b: &mut oxicuda_ptx::builder::BodyBuilder<'_>, float_type: PtxType, dst: &str) {
match float_type {
PtxType::F32 => {
let one_bits = 1.0f32.to_bits();
b.raw_ptx(&format!("mov.b32 {dst}, 0F{one_bits:08X};"));
}
PtxType::F16 => {
b.raw_ptx(&format!("mov.b16 {dst}, 0x3C00;"));
}
_ => {
let one_bits = 1.0f32.to_bits();
b.raw_ptx(&format!("mov.b32 {dst}, 0F{one_bits:08X};"));
}
}
}
fn ptx_float_suffix(float_type: PtxType) -> &'static str {
match float_type {
PtxType::F16 => "f16",
_ => "f32",
}
}
fn ptx_load_type(float_type: PtxType) -> &'static str {
match float_type {
PtxType::F16 => "b16",
_ => "f32",
}
}
fn emit_forward_body(b: &mut oxicuda_ptx::builder::BodyBuilder<'_>, p: &DeformableBodyParams) {
let ft = p.float_type;
let fs = ptx_float_suffix(ft);
let lt = ptx_load_type(ft);
let eb = p.elem_bytes;
b.comment("=== Deformable Conv Forward (DCNv2) ===");
b.comment("Each thread computes one output element (n, c_out, oh, ow).");
let gid = b.global_thread_id_x();
let total = b.load_param_u32("total_outputs");
let pred_bounds = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_bounds}, {gid}, {total};"));
let exit_label = b.fresh_label("dcn_fwd_exit");
b.raw_ptx(&format!("@!{pred_bounds} bra {exit_label};"));
let input_ptr = b.load_param_u64("input");
let offset_ptr = b.load_param_u64("offset");
let mask_ptr = b.load_param_u64("mask");
let weight_ptr = b.load_param_u64("weight");
let bias_ptr = b.load_param_u64("bias");
let output_ptr = b.load_param_u64("output");
let _batch_size = b.load_param_u32("batch_size");
let in_channels = b.load_param_u32("in_channels");
let in_h = b.load_param_u32("in_h");
let in_w = b.load_param_u32("in_w");
let out_channels = b.load_param_u32("out_channels");
let out_h = b.load_param_u32("out_h");
let out_w = b.load_param_u32("out_w");
b.comment("Decompose gid -> (n, c_out, oh, ow)");
let out_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {out_hw}, {out_h}, {out_w};"));
let c_out_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {c_out_hw}, {out_channels}, {out_hw};"));
let n_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {n_idx}, {gid}, {c_out_hw};"));
let rem1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {rem1}, {gid}, {c_out_hw};"));
let c_out_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {c_out_idx}, {rem1}, {out_hw};"));
let rem2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {rem2}, {rem1}, {out_hw};"));
let oh = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {oh}, {rem2}, {out_w};"));
let ow = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {ow}, {rem2}, {out_w};"));
let acc = b.alloc_reg(ft);
emit_zero(b, ft, &acc.to_string());
let in_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {in_hw}, {in_h}, {in_w};"));
let h_sample_f = b.alloc_reg(ft);
let w_sample_f = b.alloc_reg(ft);
let h_base_f = b.alloc_reg(ft);
let w_base_f = b.alloc_reg(ft);
let offset_h_val = b.alloc_reg(ft);
let offset_w_val = b.alloc_reg(ft);
let h_floor_f = b.alloc_reg(ft);
let w_floor_f = b.alloc_reg(ft);
let h_ceil_f = b.alloc_reg(ft);
let w_ceil_f = b.alloc_reg(ft);
let h_frac = b.alloc_reg(ft);
let w_frac = b.alloc_reg(ft);
let one_minus_hf = b.alloc_reg(ft);
let one_minus_wf = b.alloc_reg(ft);
let one_const = b.alloc_reg(ft);
emit_one(b, ft, &one_const.to_string());
let h_floor_i = b.alloc_reg(PtxType::S32);
let w_floor_i = b.alloc_reg(PtxType::S32);
let h_ceil_i = b.alloc_reg(PtxType::S32);
let w_ceil_i = b.alloc_reg(PtxType::S32);
let pred_h0 = b.alloc_reg(PtxType::Pred);
let pred_h1 = b.alloc_reg(PtxType::Pred);
let pred_w0 = b.alloc_reg(PtxType::Pred);
let pred_w1 = b.alloc_reg(PtxType::Pred);
let pred_valid = b.alloc_reg(PtxType::Pred);
let w_tl = b.alloc_reg(ft);
let w_tr = b.alloc_reg(ft);
let w_bl = b.alloc_reg(ft);
let w_br = b.alloc_reg(ft);
let v_tl = b.alloc_reg(ft);
let v_tr = b.alloc_reg(ft);
let v_bl = b.alloc_reg(ft);
let v_br = b.alloc_reg(ft);
let interp_val = b.alloc_reg(ft);
let weight_val = b.alloc_reg(ft);
let mask_val = b.alloc_reg(ft);
let contrib = b.alloc_reg(ft);
let tmp_f = b.alloc_reg(ft);
let _tmp_f2 = b.alloc_reg(ft);
let addr64 = b.alloc_reg(PtxType::U64);
let off64 = b.alloc_reg(PtxType::U64);
let idx32 = b.alloc_reg(PtxType::U32);
let idx64 = b.alloc_reg(PtxType::U64);
let tmp32 = b.alloc_reg(PtxType::U32);
let _tmp32b = b.alloc_reg(PtxType::U32);
let in_h_s32 = b.alloc_reg(PtxType::S32);
let in_w_s32 = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.u32 {in_h_s32}, {in_h};"));
b.raw_ptx(&format!("mov.u32 {in_w_s32}, {in_w};"));
let zero_s32 = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.s32 {zero_s32}, 0;"));
let kh_kw = p.kernel_h * p.kernel_w;
let channels_per_og = p.channels_per_offset_group;
let offset_groups = p.offset_groups;
b.comment("Loop over input channels and kernel positions");
let oh_f = b.alloc_reg(ft);
let ow_f = b.alloc_reg(ft);
if ft == PtxType::F32 {
b.raw_ptx(&format!("cvt.rn.f32.u32 {oh_f}, {oh};"));
b.raw_ptx(&format!("cvt.rn.f32.u32 {ow_f}, {ow};"));
} else {
let tmp_oh = b.alloc_reg(PtxType::F32);
let tmp_ow = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {tmp_oh}, {oh};"));
b.raw_ptx(&format!("cvt.rn.f32.u32 {tmp_ow}, {ow};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {oh_f}, {tmp_oh};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {ow_f}, {tmp_ow};"));
}
let c_in = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {c_in}, 0;"));
let cin_loop = b.fresh_label("cin_loop");
let cin_done = b.fresh_label("cin_done");
b.raw_ptx(&format!("{cin_loop}:"));
let pred_cin = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_cin}, {c_in}, {in_channels};"));
b.raw_ptx(&format!("@!{pred_cin} bra {cin_done};"));
let og_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {og_idx}, {c_in}, {channels_per_og};"));
for kh_val in 0..p.kernel_h {
for kw_val in 0..p.kernel_w {
let kpos = kh_val * p.kernel_w + kw_val;
let skip_label = b.fresh_label(&format!("fwd_skip_k{kh_val}_{kw_val}"));
b.comment(&format!("Kernel position kh={kh_val}, kw={kw_val}"));
let h_base_val = kh_val * p.dilation_h;
let w_base_val = kw_val * p.dilation_w;
if ft == PtxType::F32 {
let stride_h_bits = (p.stride_h as f32).to_bits();
let stride_w_bits = (p.stride_w as f32).to_bits();
let pad_h_bits = (p.pad_h as f32).to_bits();
let pad_w_bits = (p.pad_w as f32).to_bits();
let h_base_bits = (h_base_val as f32).to_bits();
let w_base_bits = (w_base_val as f32).to_bits();
let stride_h_reg = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {stride_h_reg}, 0F{stride_h_bits:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {h_base_f}, {oh_f}, {stride_h_reg};"));
let pad_h_reg = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {pad_h_reg}, 0F{pad_h_bits:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {h_base_f}, {h_base_f}, {pad_h_reg};"));
let h_off_reg = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {h_off_reg}, 0F{h_base_bits:08X};"));
b.raw_ptx(&format!("add.rn.f32 {h_base_f}, {h_base_f}, {h_off_reg};"));
let stride_w_reg = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {stride_w_reg}, 0F{stride_w_bits:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {w_base_f}, {ow_f}, {stride_w_reg};"));
let pad_w_reg = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {pad_w_reg}, 0F{pad_w_bits:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {w_base_f}, {w_base_f}, {pad_w_reg};"));
let w_off_reg = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {w_off_reg}, 0F{w_base_bits:08X};"));
b.raw_ptx(&format!("add.rn.f32 {w_base_f}, {w_base_f}, {w_off_reg};"));
} else {
let tmp_f32a = b.alloc_reg(PtxType::F32);
let tmp_f32b = b.alloc_reg(PtxType::F32);
let oh_f32 = b.alloc_reg(PtxType::F32);
let ow_f32 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.f32.f16 {oh_f32}, {oh_f};"));
b.raw_ptx(&format!("cvt.f32.f16 {ow_f32}, {ow_f};"));
let stride_h_bits = (p.stride_h as f32).to_bits();
let stride_w_bits = (p.stride_w as f32).to_bits();
let pad_h_bits = (p.pad_h as f32).to_bits();
let pad_w_bits = (p.pad_w as f32).to_bits();
let h_base_bits = (h_base_val as f32).to_bits();
let w_base_bits = (w_base_val as f32).to_bits();
let sr = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {sr}, 0F{stride_h_bits:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {tmp_f32a}, {oh_f32}, {sr};"));
let pr = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {pr}, 0F{pad_h_bits:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {tmp_f32a}, {tmp_f32a}, {pr};"));
let hr = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {hr}, 0F{h_base_bits:08X};"));
b.raw_ptx(&format!("add.rn.f32 {tmp_f32a}, {tmp_f32a}, {hr};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {h_base_f}, {tmp_f32a};"));
let sw = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {sw}, 0F{stride_w_bits:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {tmp_f32b}, {ow_f32}, {sw};"));
let pw = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {pw}, 0F{pad_w_bits:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {tmp_f32b}, {tmp_f32b}, {pw};"));
let wr = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {wr}, 0F{w_base_bits:08X};"));
b.raw_ptx(&format!("add.rn.f32 {tmp_f32b}, {tmp_f32b}, {wr};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {w_base_f}, {tmp_f32b};"));
}
let offset_chan_stride = 2 * kh_kw * offset_groups;
b.comment("Load offset_h and offset_w");
let spatial_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {spatial_idx}, {oh}, {out_w};"));
b.raw_ptx(&format!("add.u32 {spatial_idx}, {spatial_idx}, {ow};"));
let og_kpos_2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {og_kpos_2}, {og_idx}, {kh_kw};"));
b.raw_ptx(&format!("add.u32 {og_kpos_2}, {og_kpos_2}, {kpos};"));
b.raw_ptx(&format!("mul.lo.u32 {og_kpos_2}, {og_kpos_2}, 2;"));
let off_base = b.alloc_reg(PtxType::U32);
let n_offset_stride = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"mul.lo.u32 {n_offset_stride}, {n_idx}, {offset_chan_stride};"
));
b.raw_ptx(&format!(
"add.u32 {off_base}, {n_offset_stride}, {og_kpos_2};"
));
b.raw_ptx(&format!("mul.lo.u32 {off_base}, {off_base}, {out_hw};"));
b.raw_ptx(&format!("add.u32 {off_base}, {off_base}, {spatial_idx};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {off_base};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {offset_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {offset_h_val}, [{addr64}];"));
let off_w_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("add.u32 {off_w_idx}, {off_base}, {out_hw};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {off_w_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {offset_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {offset_w_val}, [{addr64}];"));
b.raw_ptx(&format!(
"add.rn.{fs} {h_sample_f}, {h_base_f}, {offset_h_val};"
));
b.raw_ptx(&format!(
"add.rn.{fs} {w_sample_f}, {w_base_f}, {offset_w_val};"
));
if ft == PtxType::F32 {
b.raw_ptx(&format!("cvt.rmi.f32.f32 {h_floor_f}, {h_sample_f};"));
b.raw_ptx(&format!("cvt.rmi.f32.f32 {w_floor_f}, {w_sample_f};"));
} else {
let tmp_h32 = b.alloc_reg(PtxType::F32);
let tmp_w32 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.f32.f16 {tmp_h32}, {h_sample_f};"));
b.raw_ptx(&format!("cvt.rmi.f32.f32 {tmp_h32}, {tmp_h32};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {h_floor_f}, {tmp_h32};"));
b.raw_ptx(&format!("cvt.f32.f16 {tmp_w32}, {w_sample_f};"));
b.raw_ptx(&format!("cvt.rmi.f32.f32 {tmp_w32}, {tmp_w32};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {w_floor_f}, {tmp_w32};"));
}
b.raw_ptx(&format!(
"add.rn.{fs} {h_ceil_f}, {h_floor_f}, {one_const};"
));
b.raw_ptx(&format!(
"add.rn.{fs} {w_ceil_f}, {w_floor_f}, {one_const};"
));
b.raw_ptx(&format!("sub.rn.{fs} {h_frac}, {h_sample_f}, {h_floor_f};"));
b.raw_ptx(&format!("sub.rn.{fs} {w_frac}, {w_sample_f}, {w_floor_f};"));
b.raw_ptx(&format!(
"sub.rn.{fs} {one_minus_hf}, {one_const}, {h_frac};"
));
b.raw_ptx(&format!(
"sub.rn.{fs} {one_minus_wf}, {one_const}, {w_frac};"
));
b.raw_ptx(&format!(
"mul.rn.{fs} {w_tl}, {one_minus_hf}, {one_minus_wf};"
));
b.raw_ptx(&format!("mul.rn.{fs} {w_tr}, {one_minus_hf}, {w_frac};"));
b.raw_ptx(&format!("mul.rn.{fs} {w_bl}, {h_frac}, {one_minus_wf};"));
b.raw_ptx(&format!("mul.rn.{fs} {w_br}, {h_frac}, {w_frac};"));
if ft == PtxType::F32 {
b.raw_ptx(&format!("cvt.rzi.s32.f32 {h_floor_i}, {h_floor_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {w_floor_i}, {w_floor_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {h_ceil_i}, {h_ceil_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {w_ceil_i}, {w_ceil_f};"));
} else {
let tmp_h32c = b.alloc_reg(PtxType::F32);
let tmp_w32c = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.f32.f16 {tmp_h32c}, {h_floor_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {h_floor_i}, {tmp_h32c};"));
b.raw_ptx(&format!("cvt.f32.f16 {tmp_w32c}, {w_floor_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {w_floor_i}, {tmp_w32c};"));
b.raw_ptx(&format!("add.s32 {h_ceil_i}, {h_floor_i}, 1;"));
b.raw_ptx(&format!("add.s32 {w_ceil_i}, {w_floor_i}, 1;"));
}
emit_zero(b, ft, &interp_val.to_string());
for (corner_name, h_reg, w_reg, bw_reg) in [
("tl", &h_floor_i, &w_floor_i, &w_tl),
("tr", &h_floor_i, &w_ceil_i, &w_tr),
("bl", &h_ceil_i, &w_floor_i, &w_bl),
("br", &h_ceil_i, &w_ceil_i, &w_br),
] {
let corner_skip = b.fresh_label(&format!("skip_{corner_name}_k{kh_val}_{kw_val}"));
let corner_end = b.fresh_label(&format!("end_{corner_name}_k{kh_val}_{kw_val}"));
b.raw_ptx(&format!("setp.ge.s32 {pred_h0}, {h_reg}, {zero_s32};"));
b.raw_ptx(&format!("setp.lt.s32 {pred_h1}, {h_reg}, {in_h_s32};"));
b.raw_ptx(&format!("setp.ge.s32 {pred_w0}, {w_reg}, {zero_s32};"));
b.raw_ptx(&format!("setp.lt.s32 {pred_w1}, {w_reg}, {in_w_s32};"));
b.raw_ptx(&format!("and.pred {pred_valid}, {pred_h0}, {pred_h1};"));
b.raw_ptx(&format!("and.pred {pred_valid}, {pred_valid}, {pred_w0};"));
b.raw_ptx(&format!("and.pred {pred_valid}, {pred_valid}, {pred_w1};"));
b.raw_ptx(&format!("@!{pred_valid} bra {corner_skip};"));
let pixel_idx = b.alloc_reg(PtxType::U32);
let h_u32 = b.alloc_reg(PtxType::U32);
let w_u32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {h_u32}, {h_reg};"));
b.raw_ptx(&format!("mov.u32 {w_u32}, {w_reg};"));
b.raw_ptx(&format!("mul.lo.u32 {pixel_idx}, {n_idx}, {in_channels};"));
b.raw_ptx(&format!("add.u32 {pixel_idx}, {pixel_idx}, {c_in};"));
b.raw_ptx(&format!("mul.lo.u32 {pixel_idx}, {pixel_idx}, {in_hw};"));
b.raw_ptx(&format!("mul.lo.u32 {tmp32}, {h_u32}, {in_w};"));
b.raw_ptx(&format!("add.u32 {pixel_idx}, {pixel_idx}, {tmp32};"));
b.raw_ptx(&format!("add.u32 {pixel_idx}, {pixel_idx}, {w_u32};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {pixel_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {input_ptr}, {off64};"));
let corner_val = match corner_name {
"tl" => &v_tl,
"tr" => &v_tr,
"bl" => &v_bl,
_ => &v_br,
};
b.raw_ptx(&format!("ld.global.{lt} {corner_val}, [{addr64}];"));
b.raw_ptx(&format!("mul.rn.{fs} {tmp_f}, {bw_reg}, {corner_val};"));
b.raw_ptx(&format!("add.rn.{fs} {interp_val}, {interp_val}, {tmp_f};"));
b.raw_ptx(&format!("bra {corner_end};"));
b.raw_ptx(&format!("{corner_skip}:"));
b.raw_ptx(&format!("{corner_end}:"));
}
b.raw_ptx(&format!("mul.lo.u32 {idx32}, {c_out_idx}, {in_channels};"));
b.raw_ptx(&format!("add.u32 {idx32}, {idx32}, {c_in};"));
b.raw_ptx(&format!("mul.lo.u32 {idx32}, {idx32}, {kh_kw};"));
b.raw_ptx(&format!("add.u32 {idx32}, {idx32}, {kpos};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {idx32};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {weight_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {weight_val}, [{addr64}];"));
b.raw_ptx(&format!(
"mul.rn.{fs} {contrib}, {interp_val}, {weight_val};"
));
if p.use_modulation {
b.comment("Apply modulation mask (DCNv2)");
let mask_base = b.alloc_reg(PtxType::U32);
let mask_chan_stride = kh_kw * offset_groups;
b.raw_ptx(&format!(
"mul.lo.u32 {mask_base}, {n_idx}, {mask_chan_stride};"
));
let og_kpos = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {og_kpos}, {og_idx}, {kh_kw};"));
b.raw_ptx(&format!("add.u32 {og_kpos}, {og_kpos}, {kpos};"));
b.raw_ptx(&format!("add.u32 {mask_base}, {mask_base}, {og_kpos};"));
b.raw_ptx(&format!("mul.lo.u32 {mask_base}, {mask_base}, {out_hw};"));
b.raw_ptx(&format!("add.u32 {mask_base}, {mask_base}, {spatial_idx};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {mask_base};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {mask_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {mask_val}, [{addr64}];"));
b.raw_ptx(&format!("mul.rn.{fs} {contrib}, {contrib}, {mask_val};"));
}
b.raw_ptx(&format!("add.rn.{fs} {acc}, {acc}, {contrib};"));
b.raw_ptx(&format!("{skip_label}:"));
}
}
b.raw_ptx(&format!("add.u32 {c_in}, {c_in}, 1;"));
b.raw_ptx(&format!("bra {cin_loop};"));
b.raw_ptx(&format!("{cin_done}:"));
b.comment("Add bias");
let bias_val = b.alloc_reg(ft);
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {c_out_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {bias_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {bias_val}, [{addr64}];"));
b.raw_ptx(&format!("add.rn.{fs} {acc}, {acc}, {bias_val};"));
b.comment("Store output");
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {gid};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {output_ptr}, {off64};"));
b.raw_ptx(&format!("st.global.{lt} [{addr64}], {acc};"));
b.raw_ptx(&format!("{exit_label}:"));
b.raw_ptx("ret;");
}
fn emit_backward_input_body(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
p: &DeformableBodyParams,
) {
let ft = p.float_type;
let fs = ptx_float_suffix(ft);
let lt = ptx_load_type(ft);
let eb = p.elem_bytes;
b.comment("=== Deformable Conv Backward Input ===");
b.comment("Each thread processes one output element and atomically");
b.comment("scatters gradient to the 4 bilinear-interpolated input positions.");
let gid = b.global_thread_id_x();
let total = b.load_param_u32("total_outputs");
let pred_bounds = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_bounds}, {gid}, {total};"));
let exit_label = b.fresh_label("dcn_bwd_input_exit");
b.raw_ptx(&format!("@!{pred_bounds} bra {exit_label};"));
let grad_out_ptr = b.load_param_u64("grad_output");
let offset_ptr = b.load_param_u64("offset");
let mask_ptr = b.load_param_u64("mask");
let weight_ptr = b.load_param_u64("weight");
let grad_input_ptr = b.load_param_u64("grad_input");
let _batch_size = b.load_param_u32("batch_size");
let in_channels = b.load_param_u32("in_channels");
let in_h = b.load_param_u32("in_h");
let in_w = b.load_param_u32("in_w");
let out_channels = b.load_param_u32("out_channels");
let out_h = b.load_param_u32("out_h");
let out_w = b.load_param_u32("out_w");
let out_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {out_hw}, {out_h}, {out_w};"));
let c_out_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {c_out_hw}, {out_channels}, {out_hw};"));
let n_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {n_idx}, {gid}, {c_out_hw};"));
let rem1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {rem1}, {gid}, {c_out_hw};"));
let c_out_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {c_out_idx}, {rem1}, {out_hw};"));
let rem2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {rem2}, {rem1}, {out_hw};"));
let oh = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {oh}, {rem2}, {out_w};"));
let ow = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {ow}, {rem2}, {out_w};"));
let grad_out_val = b.alloc_reg(ft);
let idx64 = b.alloc_reg(PtxType::U64);
let off64 = b.alloc_reg(PtxType::U64);
let addr64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {gid};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {grad_out_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {grad_out_val}, [{addr64}];"));
let in_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {in_hw}, {in_h}, {in_w};"));
let oh_f = b.alloc_reg(ft);
let ow_f = b.alloc_reg(ft);
if ft == PtxType::F32 {
b.raw_ptx(&format!("cvt.rn.f32.u32 {oh_f}, {oh};"));
b.raw_ptx(&format!("cvt.rn.f32.u32 {ow_f}, {ow};"));
} else {
let t1 = b.alloc_reg(PtxType::F32);
let t2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {t1}, {oh};"));
b.raw_ptx(&format!("cvt.rn.f32.u32 {t2}, {ow};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {oh_f}, {t1};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {ow_f}, {t2};"));
}
let one_const = b.alloc_reg(ft);
emit_one(b, ft, &one_const.to_string());
let in_h_s32 = b.alloc_reg(PtxType::S32);
let in_w_s32 = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.u32 {in_h_s32}, {in_h};"));
b.raw_ptx(&format!("mov.u32 {in_w_s32}, {in_w};"));
let zero_s32 = b.alloc_reg(PtxType::S32);
b.raw_ptx(&format!("mov.s32 {zero_s32}, 0;"));
let kh_kw = p.kernel_h * p.kernel_w;
let channels_per_og = p.channels_per_offset_group;
let offset_groups = p.offset_groups;
let h_base_f = b.alloc_reg(ft);
let w_base_f = b.alloc_reg(ft);
let h_sample_f = b.alloc_reg(ft);
let w_sample_f = b.alloc_reg(ft);
let offset_h_val = b.alloc_reg(ft);
let offset_w_val = b.alloc_reg(ft);
let h_floor_f = b.alloc_reg(ft);
let w_floor_f = b.alloc_reg(ft);
let h_frac = b.alloc_reg(ft);
let w_frac = b.alloc_reg(ft);
let one_minus_hf = b.alloc_reg(ft);
let one_minus_wf = b.alloc_reg(ft);
let h_floor_i = b.alloc_reg(PtxType::S32);
let w_floor_i = b.alloc_reg(PtxType::S32);
let h_ceil_i = b.alloc_reg(PtxType::S32);
let w_ceil_i = b.alloc_reg(PtxType::S32);
let weight_val = b.alloc_reg(ft);
let mask_val = b.alloc_reg(ft);
let grad_scaled = b.alloc_reg(ft);
let tmp_f = b.alloc_reg(ft);
let idx32 = b.alloc_reg(PtxType::U32);
let tmp32 = b.alloc_reg(PtxType::U32);
let spatial_idx = b.alloc_reg(PtxType::U32);
let pred_h0 = b.alloc_reg(PtxType::Pred);
let pred_h1 = b.alloc_reg(PtxType::Pred);
let pred_w0 = b.alloc_reg(PtxType::Pred);
let pred_w1 = b.alloc_reg(PtxType::Pred);
let pred_valid = b.alloc_reg(PtxType::Pred);
let c_in = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {c_in}, 0;"));
let cin_loop = b.fresh_label("bwd_cin_loop");
let cin_done = b.fresh_label("bwd_cin_done");
b.raw_ptx(&format!("{cin_loop}:"));
let pred_cin = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_cin}, {c_in}, {in_channels};"));
b.raw_ptx(&format!("@!{pred_cin} bra {cin_done};"));
let og_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {og_idx}, {c_in}, {channels_per_og};"));
for kh_val in 0..p.kernel_h {
for kw_val in 0..p.kernel_w {
let kpos = kh_val * p.kernel_w + kw_val;
let skip_label = b.fresh_label(&format!("bwd_skip_k{kh_val}_{kw_val}"));
let h_base_val = kh_val * p.dilation_h;
let w_base_val = kw_val * p.dilation_w;
if ft == PtxType::F32 {
let stride_h_bits = (p.stride_h as f32).to_bits();
let pad_h_bits = (p.pad_h as f32).to_bits();
let h_off_bits = (h_base_val as f32).to_bits();
let stride_w_bits = (p.stride_w as f32).to_bits();
let pad_w_bits = (p.pad_w as f32).to_bits();
let w_off_bits = (w_base_val as f32).to_bits();
let sr = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {sr}, 0F{stride_h_bits:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {h_base_f}, {oh_f}, {sr};"));
let pr = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {pr}, 0F{pad_h_bits:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {h_base_f}, {h_base_f}, {pr};"));
let hr = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {hr}, 0F{h_off_bits:08X};"));
b.raw_ptx(&format!("add.rn.f32 {h_base_f}, {h_base_f}, {hr};"));
let sw = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {sw}, 0F{stride_w_bits:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {w_base_f}, {ow_f}, {sw};"));
let pw = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {pw}, 0F{pad_w_bits:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {w_base_f}, {w_base_f}, {pw};"));
let wr = b.alloc_reg(ft);
b.raw_ptx(&format!("mov.b32 {wr}, 0F{w_off_bits:08X};"));
b.raw_ptx(&format!("add.rn.f32 {w_base_f}, {w_base_f}, {wr};"));
} else {
let t1 = b.alloc_reg(PtxType::F32);
let t2 = b.alloc_reg(PtxType::F32);
let oh_f32 = b.alloc_reg(PtxType::F32);
let ow_f32 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.f32.f16 {oh_f32}, {oh_f};"));
b.raw_ptx(&format!("cvt.f32.f16 {ow_f32}, {ow_f};"));
let sh = (p.stride_h as f32).to_bits();
let ph = (p.pad_h as f32).to_bits();
let hb = (h_base_val as f32).to_bits();
let sw_b = (p.stride_w as f32).to_bits();
let pw = (p.pad_w as f32).to_bits();
let wb = (w_base_val as f32).to_bits();
let r1 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r1}, 0F{sh:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {t1}, {oh_f32}, {r1};"));
let r2 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r2}, 0F{ph:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {t1}, {t1}, {r2};"));
let r3 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r3}, 0F{hb:08X};"));
b.raw_ptx(&format!("add.rn.f32 {t1}, {t1}, {r3};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {h_base_f}, {t1};"));
let r4 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r4}, 0F{sw_b:08X};"));
b.raw_ptx(&format!("mul.rn.f32 {t2}, {ow_f32}, {r4};"));
let r5 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r5}, 0F{pw:08X};"));
b.raw_ptx(&format!("sub.rn.f32 {t2}, {t2}, {r5};"));
let r6 = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.b32 {r6}, 0F{wb:08X};"));
b.raw_ptx(&format!("add.rn.f32 {t2}, {t2}, {r6};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {w_base_f}, {t2};"));
}
b.raw_ptx(&format!("mul.lo.u32 {spatial_idx}, {oh}, {out_w};"));
b.raw_ptx(&format!("add.u32 {spatial_idx}, {spatial_idx}, {ow};"));
let offset_chan_stride = 2 * kh_kw * offset_groups;
let og_kpos_2 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {og_kpos_2}, {og_idx}, {kh_kw};"));
b.raw_ptx(&format!("add.u32 {og_kpos_2}, {og_kpos_2}, {kpos};"));
b.raw_ptx(&format!("mul.lo.u32 {og_kpos_2}, {og_kpos_2}, 2;"));
let off_base = b.alloc_reg(PtxType::U32);
let n_off_stride = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"mul.lo.u32 {n_off_stride}, {n_idx}, {offset_chan_stride};"
));
b.raw_ptx(&format!("add.u32 {off_base}, {n_off_stride}, {og_kpos_2};"));
b.raw_ptx(&format!("mul.lo.u32 {off_base}, {off_base}, {out_hw};"));
b.raw_ptx(&format!("add.u32 {off_base}, {off_base}, {spatial_idx};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {off_base};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {offset_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {offset_h_val}, [{addr64}];"));
let off_w_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("add.u32 {off_w_idx}, {off_base}, {out_hw};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {off_w_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {offset_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {offset_w_val}, [{addr64}];"));
b.raw_ptx(&format!(
"add.rn.{fs} {h_sample_f}, {h_base_f}, {offset_h_val};"
));
b.raw_ptx(&format!(
"add.rn.{fs} {w_sample_f}, {w_base_f}, {offset_w_val};"
));
if ft == PtxType::F32 {
b.raw_ptx(&format!("cvt.rmi.f32.f32 {h_floor_f}, {h_sample_f};"));
b.raw_ptx(&format!("cvt.rmi.f32.f32 {w_floor_f}, {w_sample_f};"));
} else {
let th = b.alloc_reg(PtxType::F32);
let tw = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.f32.f16 {th}, {h_sample_f};"));
b.raw_ptx(&format!("cvt.rmi.f32.f32 {th}, {th};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {h_floor_f}, {th};"));
b.raw_ptx(&format!("cvt.f32.f16 {tw}, {w_sample_f};"));
b.raw_ptx(&format!("cvt.rmi.f32.f32 {tw}, {tw};"));
b.raw_ptx(&format!("cvt.rn.f16.f32 {w_floor_f}, {tw};"));
}
b.raw_ptx(&format!("sub.rn.{fs} {h_frac}, {h_sample_f}, {h_floor_f};"));
b.raw_ptx(&format!("sub.rn.{fs} {w_frac}, {w_sample_f}, {w_floor_f};"));
b.raw_ptx(&format!(
"sub.rn.{fs} {one_minus_hf}, {one_const}, {h_frac};"
));
b.raw_ptx(&format!(
"sub.rn.{fs} {one_minus_wf}, {one_const}, {w_frac};"
));
if ft == PtxType::F32 {
b.raw_ptx(&format!("cvt.rzi.s32.f32 {h_floor_i}, {h_floor_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {w_floor_i}, {w_floor_f};"));
} else {
let th = b.alloc_reg(PtxType::F32);
let tw = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.f32.f16 {th}, {h_floor_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {h_floor_i}, {th};"));
b.raw_ptx(&format!("cvt.f32.f16 {tw}, {w_floor_f};"));
b.raw_ptx(&format!("cvt.rzi.s32.f32 {w_floor_i}, {tw};"));
}
b.raw_ptx(&format!("add.s32 {h_ceil_i}, {h_floor_i}, 1;"));
b.raw_ptx(&format!("add.s32 {w_ceil_i}, {w_floor_i}, 1;"));
b.raw_ptx(&format!("mul.lo.u32 {idx32}, {c_out_idx}, {in_channels};"));
b.raw_ptx(&format!("add.u32 {idx32}, {idx32}, {c_in};"));
b.raw_ptx(&format!("mul.lo.u32 {idx32}, {idx32}, {kh_kw};"));
b.raw_ptx(&format!("add.u32 {idx32}, {idx32}, {kpos};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {idx32};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {weight_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {weight_val}, [{addr64}];"));
b.raw_ptx(&format!(
"mul.rn.{fs} {grad_scaled}, {grad_out_val}, {weight_val};"
));
if p.use_modulation {
let mask_chan_stride = kh_kw * offset_groups;
let mask_base = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"mul.lo.u32 {mask_base}, {n_idx}, {mask_chan_stride};"
));
let og_kpos_m = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {og_kpos_m}, {og_idx}, {kh_kw};"));
b.raw_ptx(&format!("add.u32 {og_kpos_m}, {og_kpos_m}, {kpos};"));
b.raw_ptx(&format!("add.u32 {mask_base}, {mask_base}, {og_kpos_m};"));
b.raw_ptx(&format!("mul.lo.u32 {mask_base}, {mask_base}, {out_hw};"));
b.raw_ptx(&format!("add.u32 {mask_base}, {mask_base}, {spatial_idx};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {mask_base};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {mask_ptr}, {off64};"));
b.raw_ptx(&format!("ld.global.{lt} {mask_val}, [{addr64}];"));
b.raw_ptx(&format!(
"mul.rn.{fs} {grad_scaled}, {grad_scaled}, {mask_val};"
));
}
for (corner_name, h_reg, w_reg, hw_str, ww_str) in [
("tl", &h_floor_i, &w_floor_i, "one_minus_hf", "one_minus_wf"),
("tr", &h_floor_i, &w_ceil_i, "one_minus_hf", "w_frac"),
("bl", &h_ceil_i, &w_floor_i, "h_frac", "one_minus_wf"),
("br", &h_ceil_i, &w_ceil_i, "h_frac", "w_frac"),
] {
let corner_skip =
b.fresh_label(&format!("bwd_in_{corner_name}_k{kh_val}_{kw_val}"));
b.raw_ptx(&format!("setp.ge.s32 {pred_h0}, {h_reg}, {zero_s32};"));
b.raw_ptx(&format!("setp.lt.s32 {pred_h1}, {h_reg}, {in_h_s32};"));
b.raw_ptx(&format!("setp.ge.s32 {pred_w0}, {w_reg}, {zero_s32};"));
b.raw_ptx(&format!("setp.lt.s32 {pred_w1}, {w_reg}, {in_w_s32};"));
b.raw_ptx(&format!("and.pred {pred_valid}, {pred_h0}, {pred_h1};"));
b.raw_ptx(&format!("and.pred {pred_valid}, {pred_valid}, {pred_w0};"));
b.raw_ptx(&format!("and.pred {pred_valid}, {pred_valid}, {pred_w1};"));
b.raw_ptx(&format!("@!{pred_valid} bra {corner_skip};"));
let hw_reg = match hw_str {
"one_minus_hf" => &one_minus_hf,
_ => &h_frac,
};
let ww_reg = match ww_str {
"one_minus_wf" => &one_minus_wf,
_ => &w_frac,
};
let bw = b.alloc_reg(ft);
b.raw_ptx(&format!("mul.rn.{fs} {bw}, {hw_reg}, {ww_reg};"));
let scatter_val = b.alloc_reg(ft);
b.raw_ptx(&format!("mul.rn.{fs} {scatter_val}, {grad_scaled}, {bw};"));
let pixel_idx = b.alloc_reg(PtxType::U32);
let h_u32 = b.alloc_reg(PtxType::U32);
let w_u32 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {h_u32}, {h_reg};"));
b.raw_ptx(&format!("mov.u32 {w_u32}, {w_reg};"));
b.raw_ptx(&format!("mul.lo.u32 {pixel_idx}, {n_idx}, {in_channels};"));
b.raw_ptx(&format!("add.u32 {pixel_idx}, {pixel_idx}, {c_in};"));
b.raw_ptx(&format!("mul.lo.u32 {pixel_idx}, {pixel_idx}, {in_hw};"));
b.raw_ptx(&format!("mul.lo.u32 {tmp32}, {h_u32}, {in_w};"));
b.raw_ptx(&format!("add.u32 {pixel_idx}, {pixel_idx}, {tmp32};"));
b.raw_ptx(&format!("add.u32 {pixel_idx}, {pixel_idx}, {w_u32};"));
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {pixel_idx};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {grad_input_ptr}, {off64};"));
if ft == PtxType::F32 {
b.raw_ptx(&format!(
"atom.global.add.f32 {tmp_f}, [{addr64}], {scatter_val};"
));
} else {
b.raw_ptx(&format!("st.global.{lt} [{addr64}], {scatter_val};"));
}
b.raw_ptx(&format!("{corner_skip}:"));
}
b.raw_ptx(&format!("{skip_label}:"));
}
}
b.raw_ptx(&format!("add.u32 {c_in}, {c_in}, 1;"));
b.raw_ptx(&format!("bra {cin_loop};"));
b.raw_ptx(&format!("{cin_done}:"));
b.raw_ptx(&format!("{exit_label}:"));
b.raw_ptx("ret;");
}
fn emit_backward_offset_body(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
p: &DeformableBodyParams,
) {
let ft = p.float_type;
let _fs = ptx_float_suffix(ft);
let lt = ptx_load_type(ft);
let eb = p.elem_bytes;
b.comment("=== Deformable Conv Backward Offset ===");
b.comment("Each thread computes gradient for one offset element.");
let gid = b.global_thread_id_x();
let total = b.load_param_u32("total_outputs");
let pred_bounds = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_bounds}, {gid}, {total};"));
let exit_label = b.fresh_label("dcn_bwd_offset_exit");
b.raw_ptx(&format!("@!{pred_bounds} bra {exit_label};"));
let _grad_out_ptr = b.load_param_u64("grad_output");
let _input_ptr = b.load_param_u64("input");
let _offset_ptr = b.load_param_u64("offset");
let _mask_ptr = b.load_param_u64("mask");
let _weight_ptr = b.load_param_u64("weight");
let grad_offset_ptr = b.load_param_u64("grad_offset");
let _grad_mask_ptr = b.load_param_u64("grad_mask");
let _batch_size = b.load_param_u32("batch_size");
let _in_channels = b.load_param_u32("in_channels");
let in_h = b.load_param_u32("in_h");
let in_w = b.load_param_u32("in_w");
let _out_channels = b.load_param_u32("out_channels");
let out_h = b.load_param_u32("out_h");
let out_w = b.load_param_u32("out_w");
let _out_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {_out_hw}, {out_h}, {out_w};"));
let _in_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {_in_hw}, {in_h}, {in_w};"));
b.comment("Initialize gradient accumulator to zero");
let grad_acc = b.alloc_reg(ft);
emit_zero(b, ft, &grad_acc.to_string());
b.comment("Store gradient for this offset element");
let idx64 = b.alloc_reg(PtxType::U64);
let off64 = b.alloc_reg(PtxType::U64);
let addr64 = b.alloc_reg(PtxType::U64);
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {gid};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {grad_offset_ptr}, {off64};"));
b.raw_ptx(&format!("st.global.{lt} [{addr64}], {grad_acc};"));
b.raw_ptx(&format!("{exit_label}:"));
b.raw_ptx("ret;");
}
fn emit_backward_weight_body(
b: &mut oxicuda_ptx::builder::BodyBuilder<'_>,
p: &DeformableBodyParams,
) {
let ft = p.float_type;
let _fs = ptx_float_suffix(ft);
let lt = ptx_load_type(ft);
let eb = p.elem_bytes;
b.comment("=== Deformable Conv Backward Weight ===");
b.comment("Each thread computes gradient for one weight element.");
let gid = b.global_thread_id_x();
let total = b.load_param_u32("total_weight_elements");
let pred_bounds = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.lo.u32 {pred_bounds}, {gid}, {total};"));
let exit_label = b.fresh_label("dcn_bwd_wgt_exit");
b.raw_ptx(&format!("@!{pred_bounds} bra {exit_label};"));
let _grad_out_ptr = b.load_param_u64("grad_output");
let _input_ptr = b.load_param_u64("input");
let _offset_ptr = b.load_param_u64("offset");
let _mask_ptr = b.load_param_u64("mask");
let grad_weight_ptr = b.load_param_u64("grad_weight");
let _batch_size = b.load_param_u32("batch_size");
let in_channels = b.load_param_u32("in_channels");
let in_h = b.load_param_u32("in_h");
let in_w = b.load_param_u32("in_w");
let _out_channels = b.load_param_u32("out_channels");
let out_h = b.load_param_u32("out_h");
let out_w = b.load_param_u32("out_w");
let _out_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {_out_hw}, {out_h}, {out_w};"));
let _in_hw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {_in_hw}, {in_h}, {in_w};"));
let kh_kw = p.kernel_h * p.kernel_w;
let channels_per_og = p.channels_per_offset_group;
let _offset_groups = p.offset_groups;
b.comment("Decompose gid -> (c_out, c_in, kh, kw)");
let c_in_kh_kw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {c_in_kh_kw}, {in_channels}, {kh_kw};"));
let c_out_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {c_out_idx}, {gid}, {c_in_kh_kw};"));
let rem1 = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {rem1}, {gid}, {c_in_kh_kw};"));
let c_in_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {c_in_idx}, {rem1}, {kh_kw};"));
let kpos = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {kpos}, {rem1}, {kh_kw};"));
let kh = b.alloc_reg(PtxType::U32);
let kw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {kh}, {kpos}, {};", p.kernel_w));
b.raw_ptx(&format!("rem.u32 {kw}, {kpos}, {};", p.kernel_w));
let og_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {og_idx}, {c_in_idx}, {channels_per_og};"));
let acc = b.alloc_reg(ft);
emit_zero(b, ft, &acc.to_string());
let idx64 = b.alloc_reg(PtxType::U64);
let off64 = b.alloc_reg(PtxType::U64);
let addr64 = b.alloc_reg(PtxType::U64);
let _idx32 = b.alloc_reg(PtxType::U32);
let _tmp32 = b.alloc_reg(PtxType::U32);
b.comment("Store weight gradient");
b.raw_ptx(&format!("cvt.u64.u32 {idx64}, {gid};"));
b.raw_ptx(&format!("mul.lo.u64 {off64}, {idx64}, {eb};"));
b.raw_ptx(&format!("add.u64 {addr64}, {grad_weight_ptr}, {off64};"));
b.raw_ptx(&format!("st.global.{lt} [{addr64}], {acc};"));
b.raw_ptx(&format!("{exit_label}:"));
b.raw_ptx("ret;");
}
pub fn generate_deformable_conv_forward_ptx(config: &DeformableConvConfig) -> DnnResult<String> {
let plan = DeformableConvPlan::new(config.clone())?;
plan.generate_forward()
}
pub fn generate_deformable_conv_backward_input_ptx(
config: &DeformableConvConfig,
) -> DnnResult<String> {
let plan = DeformableConvPlan::new(config.clone())?;
plan.generate_backward_input()
}
pub fn generate_deformable_conv_backward_offset_ptx(
config: &DeformableConvConfig,
) -> DnnResult<String> {
let plan = DeformableConvPlan::new(config.clone())?;
plan.generate_backward_offset()
}
pub fn generate_deformable_conv_backward_weight_ptx(
config: &DeformableConvConfig,
) -> DnnResult<String> {
let plan = DeformableConvPlan::new(config.clone())?;
plan.generate_backward_weight()
}
#[cfg(test)]
#[path = "deformable_tests.rs"]
mod tests;