use std::fmt::Write as FmtWrite;
use crate::arch::SmVersion;
use crate::error::PtxGenError;
use crate::ir::PtxType;
#[derive(Debug, Clone)]
pub struct ConvolutionTemplate {
pub in_channels: u32,
pub out_channels: u32,
pub kernel_h: u32,
pub kernel_w: u32,
pub stride_h: u32,
pub stride_w: u32,
pub pad_h: u32,
pub pad_w: u32,
pub dilation_h: u32,
pub dilation_w: u32,
pub groups: u32,
pub sm_version: SmVersion,
pub float_type: PtxType,
}
impl ConvolutionTemplate {
const fn type_suffix(&self) -> &'static str {
match self.float_type {
PtxType::F16 => "f16",
PtxType::F64 => "f64",
_ => "f32",
}
}
const fn ty(&self) -> &'static str {
self.float_type.as_ptx_str()
}
const fn byte_size(&self) -> usize {
self.float_type.size_bytes()
}
const fn zero_lit(&self) -> &'static str {
match self.float_type {
PtxType::F64 => "0d0000000000000000",
_ => "0f00000000",
}
}
const fn channels_per_group(&self) -> u32 {
self.in_channels / self.groups
}
const fn out_channels_per_group(&self) -> u32 {
self.out_channels / self.groups
}
#[must_use]
pub const fn output_size(&self, in_h: u32, in_w: u32) -> (u32, u32) {
let out_h =
(in_h + 2 * self.pad_h - self.dilation_h * (self.kernel_h - 1) - 1) / self.stride_h + 1;
let out_w =
(in_w + 2 * self.pad_w - self.dilation_w * (self.kernel_w - 1) - 1) / self.stride_w + 1;
(out_h, out_w)
}
#[must_use]
pub fn kernel_name(&self, suffix: &str) -> String {
let ts = self.type_suffix();
let ic = self.in_channels;
let oc = self.out_channels;
let kh = self.kernel_h;
let kw = self.kernel_w;
let g = self.groups;
if g == 1 {
format!("conv2d_{suffix}_{ts}_ic{ic}_oc{oc}_k{kh}x{kw}")
} else if g == self.in_channels {
format!("conv2d_{suffix}_{ts}_dw{ic}_k{kh}x{kw}")
} else {
format!("conv2d_{suffix}_{ts}_ic{ic}_oc{oc}_k{kh}x{kw}_g{g}")
}
}
fn validate(&self) -> Result<(), PtxGenError> {
if !matches!(self.float_type, PtxType::F16 | PtxType::F32 | PtxType::F64) {
return Err(PtxGenError::InvalidType(format!(
"convolution requires F16, F32, or F64, got {}",
self.float_type.as_ptx_str()
)));
}
if self.in_channels == 0 {
return Err(PtxGenError::GenerationFailed(
"in_channels must be > 0".to_string(),
));
}
if self.out_channels == 0 {
return Err(PtxGenError::GenerationFailed(
"out_channels must be > 0".to_string(),
));
}
if self.kernel_h == 0 || self.kernel_w == 0 {
return Err(PtxGenError::GenerationFailed(
"kernel dimensions must be > 0".to_string(),
));
}
if self.stride_h == 0 || self.stride_w == 0 {
return Err(PtxGenError::GenerationFailed(
"stride must be > 0".to_string(),
));
}
if self.dilation_h == 0 || self.dilation_w == 0 {
return Err(PtxGenError::GenerationFailed(
"dilation must be > 0".to_string(),
));
}
if self.groups == 0 {
return Err(PtxGenError::GenerationFailed(
"groups must be > 0".to_string(),
));
}
if self.in_channels % self.groups != 0 {
return Err(PtxGenError::GenerationFailed(format!(
"in_channels ({}) must be divisible by groups ({})",
self.in_channels, self.groups
)));
}
if self.out_channels % self.groups != 0 {
return Err(PtxGenError::GenerationFailed(format!(
"out_channels ({}) must be divisible by groups ({})",
self.out_channels, self.groups
)));
}
Ok(())
}
fn write_header(&self, ptx: &mut String) -> Result<(), PtxGenError> {
writeln!(ptx, ".version {}", self.sm_version.ptx_version())
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".target {}", self.sm_version.as_ptx_str())
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".address_size 64").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
Ok(())
}
#[allow(clippy::too_many_lines)]
pub fn generate_im2col_kernel(&self) -> Result<String, PtxGenError> {
self.validate()?;
let ty = self.ty();
let byte_size = self.byte_size();
let kernel_name = self.kernel_name("im2col");
let cpg = self.channels_per_group();
let kh = self.kernel_h;
let kw = self.kernel_w;
let sh = self.stride_h;
let sw = self.stride_w;
let ph = self.pad_h;
let pw = self.pad_w;
let dh = self.dilation_h;
let dw = self.dilation_w;
let zero_lit = self.zero_lit();
let col_row_len = cpg * kh * kw;
let mut ptx = String::with_capacity(8192);
self.write_header(&mut ptx)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_col,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_batch_size,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_w,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_w").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<48>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<24>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg {ty} %val<4>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Thread and block indexing").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r2, %ntid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r3, %r1, %r2, %r0; // out_pixel_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, %ctaid.y; // batch_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load parameters").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_col];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r5, [%param_batch_size];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r6, [%param_in_h];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r7, [%param_in_w];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r8, [%param_out_h];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r9, [%param_out_w];").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Bounds check").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r10, %r8, %r9; // total_out_pixels = out_h * out_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r3, %r10;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $IM2COL_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // Decompose output pixel index into (out_y, out_x)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.u32 %r11, %r3, %r9; // out_y").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " rem.u32 %r12, %r3, %r9; // out_x").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Compute base input coordinates (signed)")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.s32 %r13, %r11, {sh};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub.s32 %r13, %r13, {ph}; // in_y_base")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.s32 %r14, %r12, {sw};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub.s32 %r14, %r14, {pw}; // in_x_base")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Input base address for batch element")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd2, %r4;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r15, %r6, %r7; // in_h * in_w = spatial_size"
)
.map_err(PtxGenError::FormatError)?;
let in_ch = self.in_channels;
writeln!(
ptx,
" mul.lo.u32 %r16, %r15, {in_ch}; // in_channels * spatial_size"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd3, %r16;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd3, %rd3, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u64 %rd4, %rd2, %rd3, %rd0; // input_batch_ptr"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Col output base address").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" cvt.u64.u32 %rd5, %r10; // total_out_pixels as u64"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u64 %rd6, %rd5, {col_row_len}; // total_out_pixels * col_row_len"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u64 %rd6, %rd6, {byte_size}; // batch stride in col"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u64 %rd7, %rd2, %rd6, %rd1; // col + batch_offset"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd8, %r3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd8, %rd8, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd7, %rd7, %rd8; // col_pixel_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u64 %rd9, %rd5, {byte_size}; // col_row_stride"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Im2col extraction loop: c, ky, kx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r17, 0; // c").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u64 %rd10, %rd7; // running col pointer")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_C_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r17, {cpg};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $IM2COL_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r17;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd12, %r15; // spatial_size")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, %rd12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd13, %rd4, %rd11; // channel_base_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r18, 0; // ky").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_KY_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p2, %r18, {kh};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p2 bra $IM2COL_KY_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r19, %r18, {dh}, %r13; // in_y")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r20, 0; // kx").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_KX_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p3, %r20, {kw};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p3 bra $IM2COL_KX_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r21, %r20, {dw}, %r14; // in_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p4, %r19, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.s32 %p5, %r19, %r6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p4, %p4, %p5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p5, %r21, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p4, %p4, %p5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.s32 %p5, %r21, %r7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p4, %p4, %p5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p4 bra $IM2COL_PAD;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.s32 %r22, %r19, %r7, %r21; // in_y * in_w + in_x"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd14, %r22;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd14, %rd14, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd15, %rd13, %rd14;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %val0, [%rd15];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $IM2COL_STORE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_PAD:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{ty} %val0, {zero_lit};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_STORE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd10], %val0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd10, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r20, %r20, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $IM2COL_KX_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_KX_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r18, %r18, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $IM2COL_KY_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_KY_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r17, %r17, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $IM2COL_C_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$IM2COL_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
#[allow(clippy::too_many_lines)]
pub fn generate_direct_conv_kernel(&self) -> Result<String, PtxGenError> {
self.validate()?;
let ty = self.ty();
let byte_size = self.byte_size();
let kernel_name = self.kernel_name("direct");
let cpg = self.channels_per_group();
let kh = self.kernel_h;
let kw = self.kernel_w;
let sh = self.stride_h;
let sw = self.stride_w;
let ph = self.pad_h;
let pw = self.pad_w;
let dh = self.dilation_h;
let dw = self.dilation_w;
let zero_lit = self.zero_lit();
let groups = self.groups;
let oc = self.out_channels;
let ocpg = self.out_channels_per_group();
let tile_size = 16_usize; let input_smem = tile_size * tile_size * byte_size;
let weight_smem = (kh as usize) * (kw as usize) * (cpg as usize) * byte_size;
let total_smem = input_smem + weight_smem;
let mut ptx = String::with_capacity(8192);
self.write_header(&mut ptx)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_weight,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_output,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_bias,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_batch_size,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_w,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_w").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<64>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<32>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg {ty} %f<16>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<8>;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" .shared .align {} .b8 smem_conv[{}];",
byte_size.max(4),
total_smem
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Thread and block indexing").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x; // thread_x").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %ctaid.x; // block_x")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mov.u32 %r2, %ctaid.y; // block_y (maps to out_y)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mov.u32 %r3, %ctaid.z; // combined batch*oc index"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, %ntid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r5, %r1, %r4, %r0; // out_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.u32 %r6, %r3, {oc}; // batch_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " rem.u32 %r7, %r3, {oc}; // oc_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load parameters").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_weight];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_output];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd3, [%param_bias];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r8, [%param_batch_size];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r9, [%param_in_h];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r10, [%param_in_w];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r11, [%param_out_h];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r12, [%param_out_w];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Bounds check").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r5, %r12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r2, %r11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p0, %p0, %p1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $DIRECT_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Determine group and channel range")
.map_err(PtxGenError::FormatError)?;
if groups > 1 {
writeln!(ptx, " div.u32 %r13, %r7, {ocpg}; // group_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r14, %r13, {cpg}; // c_start")
.map_err(PtxGenError::FormatError)?;
} else {
writeln!(ptx, " mov.u32 %r13, 0; // group_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r14, 0; // c_start").map_err(PtxGenError::FormatError)?;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Input base for batch element").map_err(PtxGenError::FormatError)?;
let in_ch = self.in_channels;
writeln!(ptx, " mul.lo.u32 %r15, %r9, %r10; // in_h * in_w")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r16, %r15, {in_ch}; // C * H * W")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd4, %r16;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd4, %rd4, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd5, %r6;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u64 %rd6, %rd5, %rd4, %rd0; // input_batch_ptr"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
let weight_per_oc = (cpg * kh * kw) as usize * byte_size;
writeln!(ptx, " // Weight base for output channel").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd7, %r7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd7, %rd7, {weight_per_oc};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd8, %rd1, %rd7; // weight_oc_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Initialize accumulator").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{ty} %f0, {zero_lit}; // acc").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Compute input base coordinates (signed)")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.s32 %r17, %r2, {sh};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub.s32 %r17, %r17, {ph}; // in_y_base")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.s32 %r18, %r5, {sw};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub.s32 %r18, %r18, {pw}; // in_x_base")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Convolution accumulation loop").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r19, 0; // c_local (relative to group)")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r20, 0; // weight_offset_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_C_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p2, %r19, {cpg};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p2 bra $DIRECT_BIAS;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r21, %r14, %r19; // c_global")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd9, %r21;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd10, %r15; // spatial_size")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, %rd10;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd11, %rd6, %rd9; // input_channel_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r22, 0; // ky").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_KY_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p3, %r22, {kh};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p3 bra $DIRECT_KY_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r23, %r22, {dh}, %r17; // in_y")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p4, %r23, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.s32 %p5, %r23, %r9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p4, %p4, %p5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r24, 0; // kx").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_KX_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p6, %r24, {kw};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p6 bra $DIRECT_KX_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p4 bra $DIRECT_SKIP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r25, %r24, {dw}, %r18; // in_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p7, %r25, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p7 bra $DIRECT_SKIP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.s32 %p7, %r25, %r10;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p7 bra $DIRECT_SKIP;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.s32 %r26, %r23, %r10, %r25; // in_y * in_w + in_x"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd12, %r26;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd12, %rd12, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd13, %rd11, %rd12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f1, [%rd13]; // input_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd14, %r20; // weight_offset_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd14, %rd14, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd15, %rd8, %rd14;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f2, [%rd15]; // weight_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma.rn{ty} %f0, %f1, %f2, %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_SKIP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r20, %r20, 1; // weight_offset_idx++")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r24, %r24, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $DIRECT_KX_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_KX_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r22, %r22, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $DIRECT_KY_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_KY_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r19, %r19, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $DIRECT_C_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_BIAS:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.eq.u64 %p2, %rd3, 0; // bias == null?")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p2 bra $DIRECT_STORE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd16, %r7; // oc_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd16, %rd16, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd17, %rd3, %rd16;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f3, [%rd17]; // bias_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add{ty} %f0, %f0, %f3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_STORE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Store output value").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r27, %r11, %r12; // out_h * out_w")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r28, %r27, {oc}; // oc * out_h * out_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r29, %r6, %r28, 0; // batch * out_plane"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r29, %r7, %r27, %r29; // + oc_idx * spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r29, %r2, %r12, %r29; // + out_y * out_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r29, %r29, %r5; // + out_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd18, %r29;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd18, %rd18, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd19, %rd2, %rd18;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd19], %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$DIRECT_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
#[allow(clippy::too_many_lines)]
pub fn generate_1x1_conv_kernel(&self) -> Result<String, PtxGenError> {
self.validate()?;
let ty = self.ty();
let byte_size = self.byte_size();
let kernel_name = self.kernel_name("1x1");
let cpg = self.channels_per_group();
let zero_lit = self.zero_lit();
let groups = self.groups;
let oc = self.out_channels;
let ocpg = self.out_channels_per_group();
let in_ch = self.in_channels;
let mut ptx = String::with_capacity(4096);
self.write_header(&mut ptx)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_weight,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_output,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_bias,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_batch_size,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_spatial_size").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<48>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<24>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg {ty} %f<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Thread indexing").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r2, %ntid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r3, %r1, %r2, %r0; // spatial_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, %ctaid.y; // combined (batch, oc)")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.u32 %r5, %r4, {oc}; // batch_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " rem.u32 %r6, %r4, {oc}; // oc_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_weight];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_output];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd3, [%param_bias];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r7, [%param_batch_size];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r8, [%param_spatial_size];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r3, %r8;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $CONV1X1_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
if groups > 1 {
writeln!(ptx, " div.u32 %r9, %r6, {ocpg}; // group_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r10, %r9, {cpg}; // c_start")
.map_err(PtxGenError::FormatError)?;
} else {
writeln!(ptx, " mov.u32 %r10, 0; // c_start").map_err(PtxGenError::FormatError)?;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Input base for batch").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r11, %r8, {in_ch}; // C * spatial")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd4, %r11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd4, %rd4, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd5, %r5;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u64 %rd6, %rd5, %rd4, %rd0; // input_batch_ptr"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
let weight_per_oc = (cpg as usize) * byte_size;
writeln!(ptx, " // Weight base for oc").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd7, %r6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd7, %rd7, {weight_per_oc};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd8, %rd1, %rd7; // weight_oc_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Dot product over channels_per_group")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{ty} %f0, {zero_lit}; // acc").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r12, 0; // c_local").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$CONV1X1_C_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r12, {cpg};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $CONV1X1_BIAS;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r13, %r10, %r12; // c_global")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r14, %r13, %r8, %r3; // c_global * spatial + idx"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd9, %r14;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd6, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f1, [%rd10]; // input_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd12, %rd8, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f2, [%rd12]; // weight_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma.rn{ty} %f0, %f1, %f2, %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r12, %r12, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $CONV1X1_C_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$CONV1X1_BIAS:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.eq.u64 %p2, %rd3, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p2 bra $CONV1X1_STORE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd13, %r6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd13, %rd13, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd14, %rd3, %rd13;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f3, [%rd14]; // bias_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add{ty} %f0, %f0, %f3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$CONV1X1_STORE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r15, %r8, {oc}; // oc * spatial")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r16, %r5, %r15, 0; // batch * (oc * spatial)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r16, %r6, %r8, %r16; // + oc_idx * spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r16, %r16, %r3; // + spatial_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd15, %r16;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd15, %rd15, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd16, %rd2, %rd15;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd16], %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$CONV1X1_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
#[allow(clippy::too_many_lines)]
pub fn generate_backward_data_kernel(&self) -> Result<String, PtxGenError> {
self.validate()?;
let ty = self.ty();
let byte_size = self.byte_size();
let kernel_name = self.kernel_name("bwd_data");
let cpg = self.channels_per_group();
let kh = self.kernel_h;
let kw = self.kernel_w;
let sh = self.stride_h;
let sw = self.stride_w;
let ph = self.pad_h;
let pw = self.pad_w;
let dh = self.dilation_h;
let dw = self.dilation_w;
let zero_lit = self.zero_lit();
let groups = self.groups;
let oc = self.out_channels;
let ocpg = self.out_channels_per_group();
let in_ch = self.in_channels;
let mut ptx = String::with_capacity(8192);
self.write_header(&mut ptx)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_grad_output,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_weight,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_grad_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_batch_size,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_w,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_w").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<64>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<32>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg {ty} %f<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Thread indexing").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r2, %ntid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r3, %r1, %r2, %r0; // in_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, %ctaid.y; // in_y").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mov.u32 %r5, %ctaid.z; // combined (batch, channel)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.u32 %r6, %r5, {in_ch}; // batch_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" rem.u32 %r7, %r5, {in_ch}; // c_idx (input channel)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_grad_output];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_weight];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_grad_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r8, [%param_batch_size];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r9, [%param_in_h];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r10, [%param_in_w];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r11, [%param_out_h];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r12, [%param_out_w];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r3, %r10; // in_x >= in_w?")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r4, %r9; // in_y >= in_h?")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p0, %p0, %p1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $BWD_DATA_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
if groups > 1 {
writeln!(ptx, " div.u32 %r13, %r7, {cpg}; // group_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" rem.u32 %r14, %r7, {cpg}; // c_local (within group)"
)
.map_err(PtxGenError::FormatError)?;
} else {
writeln!(ptx, " mov.u32 %r13, 0; // group_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r14, %r7; // c_local = c_idx")
.map_err(PtxGenError::FormatError)?;
}
writeln!(ptx, " mul.lo.u32 %r15, %r13, {ocpg}; // oc_start")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Grad output batch base").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r16, %r11, %r12; // out_spatial = out_h * out_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r17, %r16, {oc}; // oc * out_spatial")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd3, %r17;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd3, %rd3, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd4, %r6;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u64 %rd5, %rd4, %rd3, %rd0; // go_batch_ptr"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{ty} %f0, {zero_lit}; // grad_acc")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Loop over output channels and kernel positions")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r18, 0; // oc_local").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_OC_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p2, %r18, {ocpg};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p2 bra $BWD_DATA_STORE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r19, %r15, %r18; // oc_global")
.map_err(PtxGenError::FormatError)?;
let kernel_spatial = (kh * kw) as usize * byte_size;
writeln!(
ptx,
" mad.lo.u32 %r20, %r19, {cpg}, %r14; // oc_global * cpg + c_local"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd6, %r20;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd6, %rd6, {kernel_spatial};")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" add.u64 %rd7, %rd1, %rd6; // weight_ptr for this (oc, c)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd8, %r19;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd9, %r16; // out_spatial")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd8, %rd8, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd8, %rd8, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd5, %rd8; // go_channel_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r21, 0; // ky").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_KY_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p3, %r21, {kh};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p3 bra $BWD_DATA_KY_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r22, 0; // kx").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_KX_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p4, %r22, {kw};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p4 bra $BWD_DATA_KX_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Compute corresponding output position")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r23, %r21, {dh}, 0; // ky * dh")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.s32 %r24, %r4, {ph}; // in_y + pad_h")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" sub.s32 %r24, %r24, %r23; // in_y + pad_h - ky * dh"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p5, %r24, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p5 bra $BWD_DATA_KX_NEXT;").map_err(PtxGenError::FormatError)?;
if sh > 1 {
writeln!(ptx, " rem.u32 %r25, %r24, {sh};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ne.u32 %p5, %r25, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p5 bra $BWD_DATA_KX_NEXT;").map_err(PtxGenError::FormatError)?;
}
writeln!(ptx, " div.u32 %r26, %r24, {sh}; // out_y")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p5, %r26, %r11; // out_y >= out_h?")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p5 bra $BWD_DATA_KX_NEXT;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r27, %r22, {dw}, 0; // kx * dw")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.s32 %r28, %r3, {pw}; // in_x + pad_w")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" sub.s32 %r28, %r28, %r27; // in_x + pad_w - kx * dw"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p6, %r28, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p6 bra $BWD_DATA_KX_NEXT;").map_err(PtxGenError::FormatError)?;
if sw > 1 {
writeln!(ptx, " rem.u32 %r29, %r28, {sw};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ne.u32 %p6, %r29, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p6 bra $BWD_DATA_KX_NEXT;").map_err(PtxGenError::FormatError)?;
}
writeln!(ptx, " div.u32 %r30, %r28, {sw}; // out_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p6, %r30, %r12; // out_x >= out_w?")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p6 bra $BWD_DATA_KX_NEXT;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r31, %r26, %r12, %r30; // out_y * out_w + out_x"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r31;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd12, %rd10, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f1, [%rd12]; // go_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load weight at (ky, kx)").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r32, %r21, {kw}, %r22; // ky * kw + kx"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd13, %r32;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd13, %rd13, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd14, %rd7, %rd13;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f2, [%rd14]; // w_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma.rn{ty} %f0, %f1, %f2, %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_KX_NEXT:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r22, %r22, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $BWD_DATA_KX_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_KX_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r21, %r21, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $BWD_DATA_KY_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_KY_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r18, %r18, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $BWD_DATA_OC_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_STORE:").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r33, %r9, %r10; // in_spatial = in_h * in_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r34, %r33, {in_ch}; // C * in_spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r35, %r6, %r34, 0; // batch * plane")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r35, %r7, %r33, %r35; // + c * spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r35, %r4, %r10, %r35; // + in_y * in_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r35, %r35, %r3; // + in_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd15, %r35;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd15, %rd15, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd16, %rd2, %rd15;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd16], %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_DATA_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
#[allow(clippy::too_many_lines)]
pub fn generate_backward_filter_kernel(&self) -> Result<String, PtxGenError> {
self.validate()?;
let ty = self.ty();
let byte_size = self.byte_size();
let kernel_name = self.kernel_name("bwd_filter");
let cpg = self.channels_per_group();
let kh = self.kernel_h;
let kw = self.kernel_w;
let sh = self.stride_h;
let sw = self.stride_w;
let ph = self.pad_h;
let pw = self.pad_w;
let dh = self.dilation_h;
let dw = self.dilation_w;
let zero_lit = self.zero_lit();
let groups = self.groups;
let oc = self.out_channels;
let ocpg = self.out_channels_per_group();
let in_ch = self.in_channels;
let mut ptx = String::with_capacity(8192);
self.write_header(&mut ptx)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_grad_output,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_grad_weight,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_batch_size,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_in_w,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_h,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_out_w").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<64>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<32>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg {ty} %f<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<8>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
let weight_spatial = kh * kw;
writeln!(
ptx,
" // Thread indexing (one thread per weight element)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r2, %ntid.x;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r3, %r1, %r2, %r0; // spatial_idx (ky*kw flat)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, %ctaid.y; // combined (oc, c_local)")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r3, {weight_spatial};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $BWD_FILTER_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.u32 %r5, %r3, {kw}; // ky").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " rem.u32 %r6, %r3, {kw}; // kx").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.u32 %r7, %r4, {cpg}; // oc_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " rem.u32 %r8, %r4, {cpg}; // c_local")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r7, {oc};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $BWD_FILTER_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
if groups > 1 {
writeln!(ptx, " div.u32 %r9, %r7, {ocpg}; // group_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.u32 %r10, %r9, {cpg}, %r8; // c_global")
.map_err(PtxGenError::FormatError)?;
} else {
writeln!(ptx, " mov.u32 %r10, %r8; // c_global = c_local")
.map_err(PtxGenError::FormatError)?;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_grad_output];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_grad_weight];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r11, [%param_batch_size];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r12, [%param_in_h];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r13, [%param_in_w];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r14, [%param_out_h];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r15, [%param_out_w];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r16, %r12, %r13; // in_spatial = in_h * in_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r17, %r14, %r15; // out_spatial = out_h * out_w"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u32 %r18, %r16, {in_ch}; // C * in_spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r19, %r17, {oc}; // oc * out_spatial")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{ty} %f0, {zero_lit}; // grad_w_acc")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r20, 0; // batch").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_BATCH_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p2, %r20, %r11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p2 bra $BWD_FILTER_STORE;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r21, %r20, %r18, 0; // batch * C * in_spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r21, %r10, %r16, %r21; // + c_global * in_spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd3, %r21;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd3, %rd3, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd4, %rd0, %rd3; // input_ch_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r22, %r20, %r19, 0; // batch * oc * out_spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r22, %r7, %r17, %r22; // + oc_idx * out_spatial"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd5, %r22;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd5, %rd5, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd6, %rd1, %rd5; // go_oc_ptr")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r23, 0; // out_y").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_OY_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p3, %r23, %r14;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p3 bra $BWD_FILTER_OY_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r24, 0; // out_x").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_OX_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p4, %r24, %r15;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p4 bra $BWD_FILTER_OX_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r25, %r23, {sh}, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub.s32 %r25, %r25, {ph};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r25, %r5, {dh}, %r25; // in_y")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r26, %r24, {sw}, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub.s32 %r26, %r26, {pw};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mad.lo.s32 %r26, %r6, {dw}, %r26; // in_x")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p5, %r25, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.s32 %p6, %r25, %r12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p5, %p5, %p6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.lt.s32 %p6, %r26, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p5, %p5, %p6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.s32 %p6, %r26, %r13;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " or.pred %p5, %p5, %p6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p5 bra $BWD_FILTER_OX_NEXT;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.s32 %r27, %r25, %r13, %r26; // in_y * in_w + in_x"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd7, %r27;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd7, %rd7, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd8, %rd4, %rd7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f1, [%rd8]; // input_val")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mad.lo.u32 %r28, %r23, %r15, %r24; // out_y * out_w + out_x"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd9, %r28;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd6, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f2, [%rd10]; // go_val")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma.rn{ty} %f0, %f1, %f2, %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_OX_NEXT:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r24, %r24, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $BWD_FILTER_OX_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_OX_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r23, %r23, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $BWD_FILTER_OY_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_OY_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r20, %r20, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $BWD_FILTER_BATCH_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_STORE:").map_err(PtxGenError::FormatError)?;
let weight_per_oc_total = (cpg * kh * kw) as usize * byte_size;
writeln!(ptx, " // Store weight gradient").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r7; // oc_idx")
.map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" mul.lo.u64 %rd11, %rd11, {weight_per_oc_total}; // oc_idx * (cpg*kh*kw*bs)"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd12, %r8; // c_local")
.map_err(PtxGenError::FormatError)?;
let kernel_spatial_bytes = (kh * kw) as usize * byte_size;
writeln!(
ptx,
" mul.lo.u64 %rd12, %rd12, {kernel_spatial_bytes}; // c_local * kh*kw*bs"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd11, %rd11, %rd12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd13, %r3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd13, %rd13, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd11, %rd11, %rd13;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd14, %rd2, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd14], %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BWD_FILTER_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
}
#[cfg(test)]
#[path = "convolution_tests.rs"]
mod tests;