use crate::arch::SmVersion;
use crate::error::PtxGenError;
use crate::ir::PtxType;
pub const MAX_BROADCAST_RANK: usize = 8;
pub struct BroadcastTemplate {
pub precision: PtxType,
pub target: SmVersion,
}
impl BroadcastTemplate {
#[must_use]
pub const fn new(precision: PtxType, target: SmVersion) -> Self {
Self { precision, target }
}
#[must_use]
pub fn kernel_name(&self) -> String {
let type_str = self.precision.as_ptx_str().trim_start_matches('.');
format!("broadcast_axes_{type_str}")
}
pub fn generate(&self) -> Result<String, PtxGenError> {
self.validate_precision()?;
Ok(self.generate_raw_ptx())
}
fn validate_precision(&self) -> Result<(), PtxGenError> {
if !matches!(
self.precision,
PtxType::F16 | PtxType::BF16 | PtxType::F32 | PtxType::F64
) {
return Err(PtxGenError::InvalidType(format!(
"broadcast_axes requires F16, BF16, F32, or F64, got {}",
self.precision.as_ptx_str()
)));
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn generate_raw_ptx(&self) -> String {
let kernel_name = self.kernel_name();
let ptx_ver = self.target.ptx_version();
let sm_target = self.target.as_ptx_str();
let byte_size = self.precision.size_bytes();
let (float_reg, ld_ty, st_ty) = float_reg_info(self.precision);
let dim_body = Self::generate_dim_accumulation();
format!(
".version {ptx_ver}\n\
.target {sm_target}\n\
.address_size 64\n\
\n\
.visible .entry {kernel_name}(\n\
.param .u64 %param_src_ptr,\n\
.param .u64 %param_dst_ptr,\n\
.param .u32 %param_rank,\n\
.param .u32 %param_ds0,\n\
.param .u32 %param_ds1,\n\
.param .u32 %param_ds2,\n\
.param .u32 %param_ds3,\n\
.param .u32 %param_ds4,\n\
.param .u32 %param_ds5,\n\
.param .u32 %param_ds6,\n\
.param .u32 %param_ds7,\n\
.param .u32 %param_dst_s0,\n\
.param .u32 %param_dst_s1,\n\
.param .u32 %param_dst_s2,\n\
.param .u32 %param_dst_s3,\n\
.param .u32 %param_dst_s4,\n\
.param .u32 %param_dst_s5,\n\
.param .u32 %param_dst_s6,\n\
.param .u32 %param_dst_s7,\n\
.param .u32 %param_ss0,\n\
.param .u32 %param_ss1,\n\
.param .u32 %param_ss2,\n\
.param .u32 %param_ss3,\n\
.param .u32 %param_ss4,\n\
.param .u32 %param_ss5,\n\
.param .u32 %param_ss6,\n\
.param .u32 %param_ss7,\n\
.param .u32 %param_n_dst\n\
)\n\
{{\n\
.reg .u32 %tid, %n_dst, %rank;\n\
.reg .u64 %src_ptr, %dst_ptr, %off;\n\
.reg .pred %p_oob;\n\
.reg .u32 %flat_src;\n\
.reg .u32 %q, %r;\n\
.reg .u32 %ds0, %ds1, %ds2, %ds3, %ds4, %ds5, %ds6, %ds7;\n\
.reg .u32 %dst_s0, %dst_s1, %dst_s2, %dst_s3, %dst_s4, %dst_s5, %dst_s6, %dst_s7;\n\
.reg .u32 %ss0, %ss1, %ss2, %ss3, %ss4, %ss5, %ss6, %ss7;\n\
.reg .pred %p_d0, %p_d1, %p_d2, %p_d3, %p_d4, %p_d5, %p_d6, %p_d7;\n\
.reg .u32 %contrib;\n\
.reg .{float_reg} %val;\n\
\n\
// Compute global thread id\n\
mov.u32 %tid, %ctaid.x;\n\
mad.lo.u32 %tid, %ntid.x, %tid, %tid.x;\n\
\n\
// Bounds check\n\
ld.param.u32 %n_dst, [%param_n_dst];\n\
setp.ge.u32 %p_oob, %tid, %n_dst;\n\
@%p_oob bra done;\n\
\n\
// Load rank and all shape/stride arrays\n\
ld.param.u32 %rank, [%param_rank];\n\
ld.param.u32 %ds0, [%param_ds0];\n\
ld.param.u32 %ds1, [%param_ds1];\n\
ld.param.u32 %ds2, [%param_ds2];\n\
ld.param.u32 %ds3, [%param_ds3];\n\
ld.param.u32 %ds4, [%param_ds4];\n\
ld.param.u32 %ds5, [%param_ds5];\n\
ld.param.u32 %ds6, [%param_ds6];\n\
ld.param.u32 %ds7, [%param_ds7];\n\
ld.param.u32 %dst_s0, [%param_dst_s0];\n\
ld.param.u32 %dst_s1, [%param_dst_s1];\n\
ld.param.u32 %dst_s2, [%param_dst_s2];\n\
ld.param.u32 %dst_s3, [%param_dst_s3];\n\
ld.param.u32 %dst_s4, [%param_dst_s4];\n\
ld.param.u32 %dst_s5, [%param_dst_s5];\n\
ld.param.u32 %dst_s6, [%param_dst_s6];\n\
ld.param.u32 %dst_s7, [%param_dst_s7];\n\
ld.param.u32 %ss0, [%param_ss0];\n\
ld.param.u32 %ss1, [%param_ss1];\n\
ld.param.u32 %ss2, [%param_ss2];\n\
ld.param.u32 %ss3, [%param_ss3];\n\
ld.param.u32 %ss4, [%param_ss4];\n\
ld.param.u32 %ss5, [%param_ss5];\n\
ld.param.u32 %ss6, [%param_ss6];\n\
ld.param.u32 %ss7, [%param_ss7];\n\
\n\
// Compute flat_src via stride-zero trick (unrolled 8 dims)\n\
mov.u32 %flat_src, 0;\n\
{dim_body}\n\
// Load src[flat_src]\n\
ld.param.u64 %src_ptr, [%param_src_ptr];\n\
cvt.u64.u32 %off, %flat_src;\n\
mul.lo.u64 %off, %off, {byte_size};\n\
add.u64 %src_ptr, %src_ptr, %off;\n\
ld.global{ld_ty} %val, [%src_ptr];\n\
\n\
// Store dst[tid]\n\
ld.param.u64 %dst_ptr, [%param_dst_ptr];\n\
cvt.u64.u32 %off, %tid;\n\
mul.lo.u64 %off, %off, {byte_size};\n\
add.u64 %dst_ptr, %dst_ptr, %off;\n\
st.global{st_ty} [%dst_ptr], %val;\n\
\n\
done:\n\
ret;\n\
}}\n"
)
}
fn generate_dim_accumulation() -> String {
use std::fmt::Write as _;
let mut body = String::with_capacity(2048);
let dim_names = ["0", "1", "2", "3", "4", "5", "6", "7"];
for (d, &dn) in dim_names.iter().enumerate() {
let _ = writeln!(body, " setp.gt.u32 %p_d{dn}, %rank, {d};");
let _ = writeln!(body, " @%p_d{dn} div.u32 %q, %tid, %dst_s{dn};");
let _ = writeln!(body, " @%p_d{dn} rem.u32 %r, %q, %ds{dn};");
let _ = writeln!(body, " @%p_d{dn} mul.lo.u32 %contrib, %r, %ss{dn};");
let _ = writeln!(
body,
" @%p_d{dn} add.u32 %flat_src, %flat_src, %contrib;"
);
}
body
}
}
const fn float_reg_info(precision: PtxType) -> (&'static str, &'static str, &'static str) {
match precision {
PtxType::F16 | PtxType::BF16 => ("b16", ".b16", ".b16"),
PtxType::F32 => ("f32", ".f32", ".f32"),
PtxType::F64 => ("f64", ".f64", ".f64"),
_ => ("b32", ".b32", ".b32"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arch::SmVersion;
#[test]
fn broadcast_kernel_name_f32() {
let t = BroadcastTemplate::new(PtxType::F32, SmVersion::Sm80);
assert_eq!(t.kernel_name(), "broadcast_axes_f32");
}
#[test]
fn broadcast_kernel_name_f16() {
let t = BroadcastTemplate::new(PtxType::F16, SmVersion::Sm80);
assert_eq!(t.kernel_name(), "broadcast_axes_f16");
}
#[test]
fn broadcast_invalid_precision_rejected() {
let t = BroadcastTemplate::new(PtxType::U32, SmVersion::Sm80);
assert!(t.generate().is_err());
}
#[test]
fn broadcast_generates_valid_ptx_headers_f32() {
let t = BroadcastTemplate::new(PtxType::F32, SmVersion::Sm80);
let ptx = t.generate().expect("broadcast_axes_f32 generation failed");
assert!(
ptx.contains(".version 7.0"),
"must have correct PTX version"
);
assert!(ptx.contains(".target sm_80"), "must have correct target");
assert!(
ptx.contains(".address_size 64"),
"must have 64-bit addressing"
);
assert!(
ptx.contains(".entry broadcast_axes_f32"),
"must have correct kernel name"
);
}
#[test]
fn broadcast_ptx_contains_stride_zero_trick_f32() {
let t = BroadcastTemplate::new(PtxType::F32, SmVersion::Sm80);
let ptx = t.generate().expect("broadcast_axes_f32 generation failed");
assert!(ptx.contains("ld.global.f32"), "must have global load");
assert!(ptx.contains("st.global.f32"), "must have global store");
assert!(ptx.contains("div.u32"), "must divide for coord extraction");
assert!(ptx.contains("rem.u32"), "must mod for coord extraction");
assert!(
ptx.contains("setp.gt.u32 %p_d0"),
"must guard dim 0 with predicate"
);
assert!(
ptx.contains("setp.gt.u32 %p_d7"),
"must guard dim 7 with predicate"
);
assert!(
ptx.contains("add.u32 %flat_src"),
"must accumulate flat_src"
);
}
#[test]
fn broadcast_ptx_contains_all_param_slots() {
let t = BroadcastTemplate::new(PtxType::F32, SmVersion::Sm80);
let ptx = t.generate().expect("broadcast_axes_f32 generation failed");
assert!(ptx.contains("%param_src_ptr"), "must have src_ptr param");
assert!(ptx.contains("%param_dst_ptr"), "must have dst_ptr param");
assert!(ptx.contains("%param_rank"), "must have rank param");
for d in 0..MAX_BROADCAST_RANK {
assert!(
ptx.contains(&format!("%param_ds{d}")),
"must have ds{d} param"
);
assert!(
ptx.contains(&format!("%param_dst_s{d}")),
"must have dst_s{d} param"
);
assert!(
ptx.contains(&format!("%param_ss{d}")),
"must have ss{d} param"
);
}
assert!(ptx.contains("%param_n_dst"), "must have n_dst param");
}
#[test]
fn broadcast_generates_for_f64() {
let t = BroadcastTemplate::new(PtxType::F64, SmVersion::Sm90);
let ptx = t.generate().expect("broadcast_axes_f64 generation failed");
assert!(
ptx.contains("broadcast_axes_f64"),
"must have f64 kernel name"
);
assert!(ptx.contains("ld.global.f64"), "must have f64 load");
assert!(ptx.contains("st.global.f64"), "must have f64 store");
}
#[test]
fn broadcast_max_rank_is_eight() {
assert_eq!(MAX_BROADCAST_RANK, 8);
}
}