use std::fmt::Write as FmtWrite;
use crate::arch::SmVersion;
use crate::error::PtxGenError;
use crate::ir::PtxType;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BnMode {
Training,
Inference,
}
impl BnMode {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Training => "train",
Self::Inference => "infer",
}
}
}
pub struct BatchNormTemplate {
pub precision: PtxType,
pub mode: BnMode,
pub channels: u32,
pub spatial_size: u32,
pub epsilon: f32,
pub block_size: u32,
}
impl BatchNormTemplate {
#[must_use]
pub const fn new(
precision: PtxType,
mode: BnMode,
channels: u32,
spatial_size: u32,
epsilon: f32,
block_size: u32,
) -> Self {
Self {
precision,
mode,
channels,
spatial_size,
epsilon,
block_size,
}
}
#[must_use]
pub const fn with_precision(mut self, precision: PtxType) -> Self {
self.precision = precision;
self
}
#[must_use]
pub const fn with_mode(mut self, mode: BnMode) -> Self {
self.mode = mode;
self
}
#[must_use]
pub const fn with_channels(mut self, channels: u32) -> Self {
self.channels = channels;
self
}
#[must_use]
pub const fn with_spatial_size(mut self, spatial_size: u32) -> Self {
self.spatial_size = spatial_size;
self
}
#[must_use]
pub const fn with_epsilon(mut self, epsilon: f32) -> Self {
self.epsilon = epsilon;
self
}
#[must_use]
pub const fn with_block_size(mut self, block_size: u32) -> Self {
self.block_size = block_size;
self
}
#[must_use]
pub fn kernel_name(&self) -> String {
let type_str = self.precision.as_ptx_str().trim_start_matches('.');
format!(
"batch_norm_{}_{}_c{}_s{}_bs{}",
self.mode.as_str(),
type_str,
self.channels,
self.spatial_size,
self.block_size,
)
}
fn validate(&self) -> Result<(), PtxGenError> {
if !matches!(self.precision, PtxType::F32 | PtxType::F64) {
return Err(PtxGenError::InvalidType(format!(
"batch_norm requires F32 or F64, got {}",
self.precision.as_ptx_str()
)));
}
if self.block_size < 32 || !self.block_size.is_power_of_two() {
return Err(PtxGenError::GenerationFailed(format!(
"block_size must be a power of 2 >= 32, got {}",
self.block_size
)));
}
if self.channels == 0 {
return Err(PtxGenError::GenerationFailed(
"channels must be > 0".to_string(),
));
}
if self.spatial_size == 0 {
return Err(PtxGenError::GenerationFailed(
"spatial_size must be > 0".to_string(),
));
}
if self.epsilon <= 0.0 {
return Err(PtxGenError::GenerationFailed(format!(
"epsilon must be > 0, got {}",
self.epsilon
)));
}
if self.block_size > 1024 {
return Err(PtxGenError::GenerationFailed(format!(
"block_size {} exceeds maximum of 1024",
self.block_size
)));
}
Ok(())
}
pub fn generate(&self, sm: SmVersion) -> Result<String, PtxGenError> {
self.validate()?;
match self.mode {
BnMode::Training => self.generate_training(sm),
BnMode::Inference => self.generate_inference(sm),
}
}
#[allow(clippy::too_many_lines)]
fn generate_training(&self, sm: SmVersion) -> Result<String, PtxGenError> {
let ty = self.precision.as_ptx_str();
let byte_size = self.precision.size_bytes();
let kernel_name = self.kernel_name();
let block_size = self.block_size;
let spatial_size = self.spatial_size;
let smem_bytes = (block_size as usize) * byte_size;
let zero_lit = match self.precision {
PtxType::F64 => "0d0000000000000000",
_ => "0f00000000",
};
let eps_hex = format!("0f{:08X}", self.epsilon.to_bits());
let mut ptx = String::with_capacity(8192);
writeln!(ptx, ".version {}", sm.ptx_version()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".target {}", sm.as_ptx_str()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".address_size 64").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_output,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_gamma,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_beta,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_batch_count").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .maxntid {block_size}, 1, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<24>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<20>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .f32 %f<24>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<8>;").map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" .shared .align {} .b8 smem_bn[{}];",
byte_size.max(4),
smem_bytes
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Thread and block indexing").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load parameters").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_output];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_gamma];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd3, [%param_beta];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r2, [%param_batch_count];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // Compute total elements per channel across batch"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u32 %r3, %r2, {spatial_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u64 %rd4, smem_bn;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd5, %r0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd5, %rd5, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd6, %rd4, %rd5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
let channel_stride = spatial_size as usize * byte_size;
let sample_stride_elems = self.channels as usize * spatial_size as usize;
let sample_stride_bytes = sample_stride_elems * byte_size;
writeln!(ptx, " // Pass 1: Compute channel mean via reduction")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{ty} %f0, {zero_lit};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd7, %r1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd7, %rd7, {channel_stride};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd8, %rd0, %rd7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$MEAN_BATCH_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r4, %r2;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $MEAN_BATCH_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd9, %r4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, {sample_stride_bytes};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd8, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r5, %r0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$MEAN_SPATIAL_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r5, {spatial_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $MEAN_SPATIAL_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd12, %rd10, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f1, [%rd12];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add{ty} %f0, %f0, %f1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r5, %r5, {block_size};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $MEAN_SPATIAL_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$MEAN_SPATIAL_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r4, %r4, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $MEAN_BATCH_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$MEAN_BATCH_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.shared{ty} [%rd6], %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bar.sync 0;").map_err(PtxGenError::FormatError)?;
self.emit_shared_reduction(&mut ptx, ty, byte_size, "add", "MEAN_RED")?;
writeln!(ptx, " ld.shared{ty} %f2, [%rd4];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.rn{ty}.u32 %f3, %r3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.rn{ty} %f2, %f2, %f3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Pass 2: Compute channel variance via reduction")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov{ty} %f4, {zero_lit};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$VAR_BATCH_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r4, %r2;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $VAR_BATCH_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd9, %r4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, {sample_stride_bytes};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd8, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r5, %r0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$VAR_SPATIAL_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r5, {spatial_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $VAR_SPATIAL_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd12, %rd10, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f5, [%rd12];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub{ty} %f5, %f5, %f2;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma{ty} %f4, %f5, %f5, %f4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r5, %r5, {block_size};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $VAR_SPATIAL_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$VAR_SPATIAL_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r4, %r4, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $VAR_BATCH_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$VAR_BATCH_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.shared{ty} [%rd6], %f4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bar.sync 0;").map_err(PtxGenError::FormatError)?;
self.emit_shared_reduction(&mut ptx, ty, byte_size, "add", "VAR_RED")?;
writeln!(ptx, " ld.shared{ty} %f6, [%rd4];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " div.rn{ty} %f6, %f6, %f3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add{ty} %f7, %f6, {eps_hex};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sqrt.rn{ty} %f7, %f7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " rcp.approx{ty} %f7, %f7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load gamma and beta for channel")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd13, %r1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd13, %rd13, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd14, %rd2, %rd13;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f8, [%rd14];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd15, %rd3, %rd13;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f9, [%rd15];").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Pass 3: Normalize and write output")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd16, %rd1, %rd7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$NORM_BATCH_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r4, %r2;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $BN_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd9, %r4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd9, %rd9, {sample_stride_bytes};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd8, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd17, %rd16, %rd9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r5, %r0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$NORM_SPATIAL_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r5, {spatial_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $NORM_SPATIAL_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd12, %rd10, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f10, [%rd12];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub{ty} %f10, %f10, %f2;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul{ty} %f10, %f10, %f7;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma{ty} %f10, %f8, %f10, %f9;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd18, %rd17, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd18], %f10;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r5, %r5, {block_size};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $NORM_SPATIAL_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$NORM_SPATIAL_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r4, %r4, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $NORM_BATCH_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BN_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
#[allow(clippy::too_many_lines)]
fn generate_inference(&self, sm: SmVersion) -> Result<String, PtxGenError> {
let ty = self.precision.as_ptx_str();
let byte_size = self.precision.size_bytes();
let kernel_name = self.kernel_name();
let block_size = self.block_size;
let spatial_size = self.spatial_size;
let eps_hex = format!("0f{:08X}", self.epsilon.to_bits());
let mut ptx = String::with_capacity(4096);
writeln!(ptx, ".version {}", sm.ptx_version()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".target {}", sm.as_ptx_str()).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".address_size 64").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, ".visible .entry {kernel_name}(").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_input,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_output,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_gamma,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_beta,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_running_mean,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u64 %param_running_var,").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .param .u32 %param_batch_count").map_err(PtxGenError::FormatError)?;
writeln!(ptx, ")").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "{{").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .maxntid {block_size}, 1, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b32 %r<20>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .b64 %rd<20>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .f32 %f<16>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " .reg .pred %p<4>;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Thread and block indexing").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r0, %tid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r1, %ctaid.x;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Load parameters").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd0, [%param_input];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd1, [%param_output];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd2, [%param_gamma];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd3, [%param_beta];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd4, [%param_running_mean];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u64 %rd5, [%param_running_var];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.param.u32 %r2, [%param_batch_count];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // Load running stats, gamma, beta for this channel"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd6, %r1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd6, %rd6, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd7, %rd4, %rd6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f0, [%rd7];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd8, %rd5, %rd6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f1, [%rd8];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd9, %rd2, %rd6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f2, [%rd9];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd10, %rd3, %rd6;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f3, [%rd10];").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(ptx, " // Compute inv_stddev = 1/sqrt(var + eps)")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add{ty} %f4, %f1, {eps_hex};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sqrt.rn{ty} %f4, %f4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " rcp.approx{ty} %f4, %f4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
let channel_stride = spatial_size as usize * byte_size;
let sample_stride_elems = self.channels as usize * spatial_size as usize;
let sample_stride_bytes = sample_stride_elems * byte_size;
writeln!(ptx, " // Compute channel base address").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd11, %r1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd11, %rd11, {channel_stride};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd12, %rd0, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd13, %rd1, %rd11;").map_err(PtxGenError::FormatError)?;
writeln!(ptx).map_err(PtxGenError::FormatError)?;
writeln!(
ptx,
" // Apply normalization across batch and spatial dims"
)
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r3, 0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$INF_BATCH_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p0, %r3, %r2;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p0 bra $BN_INF_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd14, %r3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd14, %rd14, {sample_stride_bytes};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd15, %rd12, %rd14;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd16, %rd13, %rd14;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mov.u32 %r4, %r0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$INF_SPATIAL_LOOP:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " setp.ge.u32 %p1, %r4, {spatial_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @%p1 bra $INF_SPATIAL_DONE;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " cvt.u64.u32 %rd17, %r4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul.lo.u64 %rd17, %rd17, {byte_size};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd18, %rd15, %rd17;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.global{ty} %f5, [%rd18];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " sub{ty} %f5, %f5, %f0;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " mul{ty} %f5, %f5, %f4;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " fma{ty} %f5, %f2, %f5, %f3;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u64 %rd19, %rd16, %rd17;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.global{ty} [%rd19], %f5;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r4, %r4, {block_size};").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $INF_SPATIAL_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$INF_SPATIAL_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " add.u32 %r3, %r3, 1;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bra $INF_BATCH_LOOP;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$BN_INF_DONE:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ret;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "}}").map_err(PtxGenError::FormatError)?;
Ok(ptx)
}
fn emit_shared_reduction(
&self,
ptx: &mut String,
ty: &str,
byte_size: usize,
combine_op: &str,
label_prefix: &str,
) -> Result<(), PtxGenError> {
let mut stride = self.block_size / 2;
while stride > 0 {
writeln!(ptx, " setp.lt.u32 %p2, %r0, {stride};")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " @!%p2 bra $SKIP_{label_prefix}_{stride};")
.map_err(PtxGenError::FormatError)?;
let partner_off = stride as usize * byte_size;
writeln!(ptx, " ld.shared{ty} %f11, [%rd6+{partner_off}];")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " ld.shared{ty} %f12, [%rd6];").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " {combine_op}{ty} %f12, %f12, %f11;")
.map_err(PtxGenError::FormatError)?;
writeln!(ptx, " st.shared{ty} [%rd6], %f12;").map_err(PtxGenError::FormatError)?;
writeln!(ptx, "$SKIP_{label_prefix}_{stride}:").map_err(PtxGenError::FormatError)?;
writeln!(ptx, " bar.sync 0;").map_err(PtxGenError::FormatError)?;
stride /= 2;
}
writeln!(ptx).map_err(PtxGenError::FormatError)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arch::SmVersion;
#[test]
fn kernel_name_training() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 256);
assert_eq!(t.kernel_name(), "batch_norm_train_f32_c64_s1024_bs256");
}
#[test]
fn kernel_name_inference() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Inference, 128, 256, 1e-5, 128);
assert_eq!(t.kernel_name(), "batch_norm_infer_f32_c128_s256_bs128");
}
#[test]
fn kernel_name_f64() {
let t = BatchNormTemplate::new(PtxType::F64, BnMode::Training, 32, 512, 1e-5, 256);
assert_eq!(t.kernel_name(), "batch_norm_train_f64_c32_s512_bs256");
}
#[test]
fn invalid_precision_u32() {
let t = BatchNormTemplate::new(PtxType::U32, BnMode::Training, 64, 1024, 1e-5, 256);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_precision_f16() {
let t = BatchNormTemplate::new(PtxType::F16, BnMode::Training, 64, 1024, 1e-5, 256);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_block_size_not_pow2() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 100);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_block_size_too_small() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 16);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_block_size_too_large() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 2048);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_channels_zero() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 0, 1024, 1e-5, 256);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_spatial_size_zero() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 0, 1e-5, 256);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_epsilon_zero() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 0.0, 256);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn invalid_epsilon_negative() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, -1e-5, 256);
assert!(t.generate(SmVersion::Sm80).is_err());
}
#[test]
fn generate_training_f32() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 256);
let ptx = t
.generate(SmVersion::Sm80)
.expect("should generate training BN");
assert!(ptx.contains(".entry batch_norm_train_f32_c64_s1024_bs256"));
assert!(ptx.contains(".shared"));
assert!(ptx.contains("bar.sync 0"));
assert!(ptx.contains("sqrt.rn.f32"));
assert!(ptx.contains("rcp.approx.f32"));
assert!(ptx.contains("fma.f32"));
assert!(ptx.contains("%param_gamma"));
assert!(ptx.contains("%param_beta"));
}
#[test]
fn generate_inference_f32() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Inference, 64, 1024, 1e-5, 256);
let ptx = t
.generate(SmVersion::Sm80)
.expect("should generate inference BN");
assert!(ptx.contains(".entry batch_norm_infer_f32_c64_s1024_bs256"));
assert!(ptx.contains("%param_running_mean"));
assert!(ptx.contains("%param_running_var"));
assert!(ptx.contains("sqrt.rn.f32"));
assert!(ptx.contains("rcp.approx.f32"));
assert!(ptx.contains("fma.f32"));
}
#[test]
fn generate_training_f64() {
let t = BatchNormTemplate::new(PtxType::F64, BnMode::Training, 32, 512, 1e-5, 128);
let ptx = t
.generate(SmVersion::Sm80)
.expect("should generate f64 training BN");
assert!(ptx.contains("batch_norm_train_f64"));
assert!(ptx.contains("fma.f64"));
}
#[test]
fn generate_inference_f64() {
let t = BatchNormTemplate::new(PtxType::F64, BnMode::Inference, 32, 512, 1e-5, 128);
let ptx = t
.generate(SmVersion::Sm80)
.expect("should generate f64 inference BN");
assert!(ptx.contains("batch_norm_infer_f64"));
}
#[test]
fn generate_small_block() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 16, 64, 1e-5, 32);
let ptx = t
.generate(SmVersion::Sm80)
.expect("should generate with block_size=32");
assert!(ptx.contains("batch_norm_train_f32_c16_s64_bs32"));
}
#[test]
fn generate_different_sm_versions() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 256);
let ptx_75 = t
.generate(SmVersion::Sm75)
.expect("should generate for Sm75");
let ptx_90 = t
.generate(SmVersion::Sm90)
.expect("should generate for Sm90");
assert!(ptx_75.contains("sm_75"));
assert!(ptx_90.contains("sm_90"));
}
#[test]
fn builder_pattern() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 256)
.with_precision(PtxType::F64)
.with_mode(BnMode::Inference)
.with_channels(128)
.with_spatial_size(512)
.with_epsilon(1e-6)
.with_block_size(128);
assert_eq!(t.kernel_name(), "batch_norm_infer_f64_c128_s512_bs128");
}
#[test]
fn bn_mode_as_str() {
assert_eq!(BnMode::Training.as_str(), "train");
assert_eq!(BnMode::Inference.as_str(), "infer");
}
#[test]
fn training_has_reduction_phases() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Training, 64, 1024, 1e-5, 256);
let ptx = t.generate(SmVersion::Sm80).expect("should generate");
assert!(ptx.contains("Pass 1") || ptx.contains("mean"));
assert!(ptx.contains("Pass 2") || ptx.contains("variance"));
assert!(ptx.contains("Pass 3") || ptx.contains("Normalize"));
}
#[test]
fn inference_no_shared_memory() {
let t = BatchNormTemplate::new(PtxType::F32, BnMode::Inference, 64, 1024, 1e-5, 256);
let ptx = t.generate(SmVersion::Sm80).expect("should generate");
assert!(!ptx.contains(".shared"));
}
}