use std::fmt::Write as FmtWrite;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SplitKConfig {
pub split_factor: u32,
pub k_per_split: u32,
pub k_total: u32,
}
impl SplitKConfig {
pub fn new(k: u32, split_factor: u32) -> Self {
let factor = split_factor.clamp(2, k.max(2));
let k_per_split = k.div_ceil(factor);
Self {
split_factor: factor,
k_per_split,
k_total: k,
}
}
pub fn partition_range(&self, partition_idx: u32) -> (u32, u32) {
let start = partition_idx * self.k_per_split;
let end = (start + self.k_per_split).min(self.k_total);
(start, end)
}
pub fn workspace_elements(&self, m: u32, n: u32) -> u64 {
u64::from(m) * u64::from(n) * u64::from(self.split_factor)
}
}
pub fn generate_splitk_reduction_kernel(
target: SmVersion,
acc_type: PtxType,
split_factor: u32,
) -> BlasResult<(String, String)> {
if !matches!(acc_type, PtxType::F32 | PtxType::F64) {
return Err(BlasError::PtxGeneration(format!(
"split-K reduction requires F32 or F64 accumulator, got {}",
acc_type.as_ptx_str()
)));
}
let ty = acc_type.as_ptx_str();
let byte_size = acc_type.size_bytes();
let kernel_name = format!(
"splitk_reduce_{}_x{}",
ty.trim_start_matches('.'),
split_factor
);
let mut ptx = String::with_capacity(4096);
write_line(&mut ptx, &format!(".version {}", target.ptx_version()))?;
write_line(&mut ptx, &format!(".target {}", target.as_ptx_str()))?;
write_line(&mut ptx, ".address_size 64")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, &format!(".visible .entry {kernel_name}("))?;
write_line(&mut ptx, " .param .u64 %param_ws,")?;
write_line(&mut ptx, " .param .u64 %param_c,")?;
write_line(&mut ptx, " .param .u32 %param_mn,")?;
write_line(&mut ptx, &format!(" .param {ty} %param_alpha,"))?;
write_line(&mut ptx, &format!(" .param {ty} %param_beta"))?;
write_line(&mut ptx, ")")?;
write_line(&mut ptx, "{")?;
write_line(&mut ptx, " .reg .b32 %r<16>;")?;
write_line(&mut ptx, " .reg .b64 %rd<16>;")?;
write_line(&mut ptx, " .reg .f32 %f<16>;")?;
write_line(&mut ptx, " .reg .pred %p<4>;")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " mov.u32 %r0, %tid.x;")?;
write_line(&mut ptx, " mov.u32 %r1, %ctaid.x;")?;
write_line(&mut ptx, " mov.u32 %r2, %ntid.x;")?;
write_line(
&mut ptx,
" mad.lo.u32 %r3, %r1, %r2, %r0; // global idx",
)?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " ld.param.u32 %r4, [%param_mn];")?;
write_line(&mut ptx, " setp.ge.u32 %p0, %r3, %r4;")?;
write_line(&mut ptx, " @%p0 bra $REDUCE_DONE;")?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " ld.param.u64 %rd0, [%param_ws];")?;
write_line(&mut ptx, " ld.param.u64 %rd1, [%param_c];")?;
write_line(&mut ptx, &format!(" ld.param{ty} %f8, [%param_alpha];"))?;
write_line(&mut ptx, &format!(" ld.param{ty} %f9, [%param_beta];"))?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, " cvt.u64.u32 %rd2, %r3;")?;
write_line(
&mut ptx,
&format!(" mul.lo.u64 %rd2, %rd2, {byte_size};"),
)?;
write_line(&mut ptx, " cvt.u64.u32 %rd3, %r4;")?;
write_line(
&mut ptx,
&format!(" mul.lo.u64 %rd3, %rd3, {byte_size}; // partition stride"),
)?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, &format!(" mov{ty} %f0, 0f00000000; // acc"))?;
write_line(&mut ptx, " add.u64 %rd4, %rd0, %rd2; // ws + offset")?;
for _ in 0..split_factor {
write_line(&mut ptx, &format!(" ld.global{ty} %f1, [%rd4];"))?;
write_line(&mut ptx, &format!(" add{ty} %f0, %f0, %f1;"))?;
write_line(&mut ptx, " add.u64 %rd4, %rd4, %rd3;")?;
}
write_line(&mut ptx, "")?;
write_line(&mut ptx, " add.u64 %rd5, %rd1, %rd2; // c + offset")?;
write_line(&mut ptx, &format!(" ld.global{ty} %f2, [%rd5];"))?;
write_line(&mut ptx, &format!(" mul{ty} %f0, %f0, %f8;"))?;
write_line(&mut ptx, &format!(" fma.rn{ty} %f0, %f9, %f2, %f0;"))?;
write_line(&mut ptx, &format!(" st.global{ty} [%rd5], %f0;"))?;
write_line(&mut ptx, "")?;
write_line(&mut ptx, "$REDUCE_DONE:")?;
write_line(&mut ptx, " ret;")?;
write_line(&mut ptx, "}")?;
Ok((kernel_name, ptx))
}
fn write_line(ptx: &mut String, line: &str) -> BlasResult<()> {
writeln!(ptx, "{line}").map_err(|e| BlasError::PtxGeneration(format!("fmt error: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn splitk_config_basic() {
let cfg = SplitKConfig::new(1024, 4);
assert_eq!(cfg.split_factor, 4);
assert_eq!(cfg.k_per_split, 256);
assert_eq!(cfg.partition_range(0), (0, 256));
assert_eq!(cfg.partition_range(3), (768, 1024));
}
#[test]
fn splitk_config_uneven() {
let cfg = SplitKConfig::new(1000, 3);
assert_eq!(cfg.split_factor, 3);
assert_eq!(cfg.k_per_split, 334);
assert_eq!(cfg.partition_range(2), (668, 1000));
}
#[test]
fn splitk_config_clamp_low() {
let cfg = SplitKConfig::new(100, 1);
assert_eq!(cfg.split_factor, 2); }
#[test]
fn splitk_workspace_size() {
let cfg = SplitKConfig::new(1024, 4);
assert_eq!(cfg.workspace_elements(64, 64), 64 * 64 * 4);
}
#[test]
fn generate_reduction_f32() {
let (name, ptx) = generate_splitk_reduction_kernel(SmVersion::Sm80, PtxType::F32, 4)
.expect("reduction kernel generation failed");
assert_eq!(name, "splitk_reduce_f32_x4");
assert!(ptx.contains(".entry splitk_reduce_f32_x4"));
assert!(ptx.contains("$REDUCE_DONE"));
}
#[test]
fn generate_reduction_invalid_type() {
let result = generate_splitk_reduction_kernel(SmVersion::Sm80, PtxType::U32, 4);
assert!(result.is_err());
}
}