use std::fmt::Write as FmtWrite;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{GpuFloat, Transpose};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StreamKConfig {
pub tile_m: usize,
pub tile_n: usize,
pub tile_k: usize,
pub sm_count: usize,
}
impl StreamKConfig {
pub fn new(tile_m: usize, tile_n: usize, tile_k: usize, sm_count: usize) -> BlasResult<Self> {
if tile_m == 0 || tile_n == 0 || tile_k == 0 {
return Err(BlasError::InvalidArgument(
"Stream-K: tile dimensions must be non-zero".into(),
));
}
if sm_count == 0 {
return Err(BlasError::InvalidArgument(
"Stream-K: sm_count must be non-zero".into(),
));
}
Ok(Self {
tile_m,
tile_n,
tile_k,
sm_count,
})
}
pub fn total_tiles(&self, m: usize, n: usize) -> usize {
let tiles_m = m.div_ceil(self.tile_m);
let tiles_n = n.div_ceil(self.tile_n);
tiles_m * tiles_n
}
pub fn iters_per_tile(&self, k: usize) -> usize {
k.div_ceil(self.tile_k)
}
pub fn total_iters(&self, m: usize, n: usize, k: usize) -> usize {
self.total_tiles(m, n) * self.iters_per_tile(k)
}
pub fn cta_count(&self, m: usize, n: usize) -> usize {
let total = self.total_tiles(m, n);
total.min(self.sm_count)
}
pub fn iters_per_cta(&self, m: usize, n: usize, k: usize) -> (usize, usize) {
let total = self.total_iters(m, n, k);
let ctas = self.cta_count(m, n);
if ctas == 0 {
return (0, 0);
}
let base = total / ctas;
let remainder = total % ctas;
(base, remainder)
}
}
#[allow(clippy::too_many_arguments)]
pub fn stream_k_gemm<T: GpuFloat>(
handle: &BlasHandle,
config: &StreamKConfig,
transa: Transpose,
transb: Transpose,
m: usize,
n: usize,
k: usize,
alpha: T,
a: &DeviceBuffer<T>,
lda: usize,
b: &DeviceBuffer<T>,
ldb: usize,
beta: T,
c: &mut DeviceBuffer<T>,
ldc: usize,
) -> BlasResult<()> {
if m == 0 || n == 0 || k == 0 {
return Err(BlasError::InvalidDimension(
"Stream-K GEMM: all dimensions must be non-zero".into(),
));
}
validate_stream_k_ld(transa, m, k, lda, "A")?;
validate_stream_k_ld(transb, k, n, ldb, "B")?;
if ldc < m {
return Err(BlasError::InvalidDimension(format!(
"Stream-K GEMM: ldc ({ldc}) < m ({m})"
)));
}
let a_required = match transa {
Transpose::NoTrans => lda * k,
Transpose::Trans | Transpose::ConjTrans => lda * m,
};
if a.len() < a_required {
return Err(BlasError::BufferTooSmall {
expected: a_required,
actual: a.len(),
});
}
let b_required = match transb {
Transpose::NoTrans => ldb * n,
Transpose::Trans | Transpose::ConjTrans => ldb * k,
};
if b.len() < b_required {
return Err(BlasError::BufferTooSmall {
expected: b_required,
actual: b.len(),
});
}
let c_required = ldc * n;
if c.len() < c_required {
return Err(BlasError::BufferTooSmall {
expected: c_required,
actual: c.len(),
});
}
let _ptx = generate_stream_k_ptx::<T>(handle.sm_version(), config, transa, transb, m, n, k)?;
let _ = (alpha, beta, a, b, c, lda, ldb, ldc);
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn generate_stream_k_ptx<T: GpuFloat>(
sm: SmVersion,
config: &StreamKConfig,
transa: Transpose,
transb: Transpose,
m: usize,
n: usize,
k: usize,
) -> BlasResult<String> {
let byte_size = T::PTX_TYPE.size_bytes();
let is_f64 = byte_size == 8;
let (fr, ld_ty) = if is_f64 { ("fd", "f64") } else { ("f", "f32") };
let zero_lit = if is_f64 {
"0d0000000000000000"
} else {
"0f00000000"
};
let ta = trans_label(transa);
let tb = trans_label(transb);
let kernel_name = format!("stream_k_gemm_{}_{ta}_{tb}", T::NAME);
let total_tiles = config.total_tiles(m, n);
let iters_per_tile = config.iters_per_tile(k);
let total_iters = config.total_iters(m, n, k);
let cta_count = config.cta_count(m, n);
let tiles_n = n.div_ceil(config.tile_n);
let mut p = String::with_capacity(8192);
wl(&mut p, &format!(".version {}", sm.ptx_version()))?;
wl(&mut p, &format!(".target {}", sm.as_ptx_str()))?;
wl(&mut p, ".address_size 64")?;
wl(&mut p, "")?;
wl(&mut p, &format!(".visible .entry {kernel_name}("))?;
wl(&mut p, " .param .u64 %param_a,")?;
wl(&mut p, " .param .u64 %param_b,")?;
wl(&mut p, " .param .u64 %param_c,")?;
wl(&mut p, " .param .u32 %param_m,")?;
wl(&mut p, " .param .u32 %param_n,")?;
wl(&mut p, " .param .u32 %param_k,")?;
wl(&mut p, " .param .u32 %param_lda,")?;
wl(&mut p, " .param .u32 %param_ldb,")?;
wl(&mut p, " .param .u32 %param_ldc,")?;
wl(&mut p, &format!(" .param .{ld_ty} %param_alpha,"))?;
wl(&mut p, &format!(" .param .{ld_ty} %param_beta"))?;
wl(&mut p, ")")?;
wl(&mut p, "{")?;
wl(&mut p, " .reg .b32 %r<48>;")?;
wl(&mut p, " .reg .b64 %rd<24>;")?;
if is_f64 {
wl(&mut p, " .reg .f64 %fd<16>;")?;
} else {
wl(&mut p, " .reg .f32 %f<16>;")?;
}
wl(&mut p, " .reg .pred %p<8>;")?;
wl(&mut p, "")?;
wl(&mut p, " mov.u32 %r0, %ctaid.x; // cta_id")?;
wl(&mut p, "")?;
wl(&mut p, &format!(" // total_iters = {total_iters}"))?;
wl(&mut p, &format!(" // cta_count = {cta_count}"))?;
wl(&mut p, &format!(" // iters_per_tile = {iters_per_tile}"))?;
wl(&mut p, &format!(" // total_tiles = {total_tiles}"))?;
wl(&mut p, &format!(" // tiles_n = {tiles_n}"))?;
let (base_iters, remainder) = match (
total_iters.checked_div(cta_count),
total_iters.checked_rem(cta_count),
) {
(Some(b), Some(r)) => (b, r),
_ => (0, 0),
};
wl(
&mut p,
&format!(" mov.u32 %r1, {}; // base_iters", base_iters),
)?;
wl(
&mut p,
&format!(" mov.u32 %r2, {}; // remainder", remainder),
)?;
wl(&mut p, " setp.lt.u32 %p0, %r0, %r2;")?;
wl(&mut p, " @%p0 bra $SK_EXTRA;")?;
wl(
&mut p,
&format!(" mov.u32 %r3, {}; // base+1", base_iters + 1),
)?;
wl(
&mut p,
" mul.lo.u32 %r4, %r2, %r3; // remainder * (base+1)",
)?;
wl(&mut p, " sub.u32 %r5, %r0, %r2; // cta_id - remainder")?;
wl(&mut p, " mad.lo.u32 %r6, %r5, %r1, %r4; // iter_start")?;
wl(&mut p, " mov.u32 %r7, %r1; // iter_count = base")?;
wl(&mut p, " bra $SK_COMPUTE;")?;
wl(&mut p, "$SK_EXTRA:")?;
wl(&mut p, " add.u32 %r3, %r1, 1; // base + 1")?;
wl(&mut p, " mul.lo.u32 %r6, %r0, %r3; // iter_start")?;
wl(&mut p, " mov.u32 %r7, %r3; // iter_count")?;
wl(&mut p, "$SK_COMPUTE:")?;
wl(&mut p, "")?;
wl(&mut p, " ld.param.u64 %rd0, [%param_a];")?;
wl(&mut p, " ld.param.u64 %rd1, [%param_b];")?;
wl(&mut p, " ld.param.u64 %rd2, [%param_c];")?;
wl(&mut p, " ld.param.u32 %r8, [%param_m];")?;
wl(&mut p, " ld.param.u32 %r9, [%param_n];")?;
wl(&mut p, " ld.param.u32 %r10, [%param_k];")?;
wl(&mut p, " ld.param.u32 %r20, [%param_lda];")?;
wl(&mut p, " ld.param.u32 %r21, [%param_ldb];")?;
wl(&mut p, " ld.param.u32 %r22, [%param_ldc];")?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}8, [%param_alpha];"),
)?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}9, [%param_beta];"),
)?;
wl(&mut p, "")?;
wl(&mut p, " mov.u32 %r11, 0; // local_iter")?;
wl(&mut p, "$SK_ITER_LOOP:")?;
wl(&mut p, " setp.ge.u32 %p1, %r11, %r7;")?;
wl(&mut p, " @%p1 bra $SK_ITER_DONE;")?;
wl(&mut p, "")?;
wl(&mut p, " add.u32 %r12, %r6, %r11; // global_iter")?;
wl(
&mut p,
&format!(" mov.u32 %r13, {}; // iters_per_tile", iters_per_tile),
)?;
wl(&mut p, " div.u32 %r14, %r12, %r13; // tile_idx")?;
wl(&mut p, " rem.u32 %r15, %r12, %r13; // k_slice")?;
wl(
&mut p,
&format!(" mov.u32 %r16, {}; // tiles_n", tiles_n),
)?;
wl(&mut p, " div.u32 %r17, %r14, %r16; // tile_row")?;
wl(&mut p, " rem.u32 %r18, %r14, %r16; // tile_col")?;
wl(&mut p, " mov.u32 %r19, %tid.x; // thread within block")?;
wl(
&mut p,
&format!(" mov.{ld_ty} %{fr}0, {zero_lit}; // acc"),
)?;
wl(
&mut p,
&format!(
" rem.u32 %r23, %r19, {}; // thread_row = tid % tile_m",
config.tile_m
),
)?;
wl(
&mut p,
&format!(
" div.u32 %r24, %r19, {}; // thread_col = tid / tile_m",
config.tile_m
),
)?;
wl(
&mut p,
&format!(
" mad.lo.u32 %r25, %r17, {}, %r23; // row_m",
config.tile_m
),
)?;
wl(
&mut p,
&format!(
" mad.lo.u32 %r26, %r18, {}, %r24; // col_n",
config.tile_n
),
)?;
wl(&mut p, " setp.ge.u32 %p2, %r25, %r8;")?;
wl(&mut p, " setp.ge.u32 %p3, %r26, %r9;")?;
wl(&mut p, " or.pred %p4, %p2, %p3;")?;
wl(&mut p, " @%p4 bra $SK_ITER_STEP_DONE;")?;
wl(
&mut p,
&format!(
" mul.lo.u32 %r27, %r15, {}; // k_base = k_slice * tile_k",
config.tile_k
),
)?;
for kk in 0..config.tile_k {
wl(
&mut p,
&format!(" add.u32 %r28, %r27, {kk}; // k_global"),
)?;
wl(&mut p, " setp.ge.u32 %p5, %r28, %r10;")?;
wl(&mut p, &format!(" @%p5 bra $SK_SKIP_FMA_{kk};"))?;
wl(
&mut p,
" mad.lo.u32 %r29, %r25, %r20, %r28; // A linear index",
)?;
wl(
&mut p,
&format!(" mul.wide.u32 %rd10, %r29, {byte_size}; // A byte offset"),
)?;
wl(&mut p, " add.u64 %rd11, %rd0, %rd10; // A addr")?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}1, [%rd11]; // a_val"),
)?;
wl(
&mut p,
" mad.lo.u32 %r30, %r28, %r21, %r26; // B linear index",
)?;
wl(
&mut p,
&format!(" mul.wide.u32 %rd12, %r30, {byte_size}; // B byte offset"),
)?;
wl(&mut p, " add.u64 %rd13, %rd1, %rd12; // B addr")?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}2, [%rd13]; // b_val"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}0, %{fr}1, %{fr}2, %{fr}0; // acc += A * B"),
)?;
wl(&mut p, &format!("$SK_SKIP_FMA_{kk}:"))?;
}
wl(
&mut p,
" mad.lo.u32 %r31, %r25, %r22, %r26; // C linear index",
)?;
wl(
&mut p,
&format!(" mul.wide.u32 %rd14, %r31, {byte_size}; // C byte offset"),
)?;
wl(&mut p, " add.u64 %rd15, %rd2, %rd14; // C addr")?;
wl(
&mut p,
&format!(" mul.rn.{ld_ty} %{fr}3, %{fr}8, %{fr}0; // alpha * acc"),
)?;
wl(
&mut p,
&format!(" atom.add.{ld_ty} [%rd15], %{fr}3; // partial tile reduction"),
)?;
wl(&mut p, "$SK_ITER_STEP_DONE:")?;
wl(&mut p, "")?;
wl(&mut p, " add.u32 %r11, %r11, 1;")?;
wl(&mut p, " bra $SK_ITER_LOOP;")?;
wl(&mut p, "$SK_ITER_DONE:")?;
wl(&mut p, "")?;
wl(
&mut p,
&format!(
" rem.u32 %r23, %r19, {}; // thread_row = tid % tile_m",
config.tile_m
),
)?;
wl(
&mut p,
&format!(
" div.u32 %r24, %r19, {}; // thread_col = tid / tile_m",
config.tile_m
),
)?;
wl(
&mut p,
&format!(
" mad.lo.u32 %r25, %r17, {}, %r23; // row_m",
config.tile_m
),
)?;
wl(
&mut p,
&format!(
" mad.lo.u32 %r26, %r18, {}, %r24; // col_n",
config.tile_n
),
)?;
wl(&mut p, " setp.ge.u32 %p2, %r25, %r8;")?;
wl(&mut p, " setp.ge.u32 %p3, %r26, %r9;")?;
wl(&mut p, " or.pred %p4, %p2, %p3;")?;
wl(&mut p, " @%p4 bra $SK_RET;")?;
wl(
&mut p,
" mad.lo.u32 %r31, %r25, %r22, %r26; // C linear index",
)?;
wl(
&mut p,
&format!(" mul.wide.u32 %rd14, %r31, {byte_size}; // C byte offset"),
)?;
wl(&mut p, " add.u64 %rd15, %rd2, %rd14; // C addr")?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}4, [%rd15]; // old C"),
)?;
wl(
&mut p,
&format!(" mul.rn.{ld_ty} %{fr}3, %{fr}8, %{fr}0; // alpha * acc"),
)?;
wl(
&mut p,
&format!(" fma.rn.{ld_ty} %{fr}5, %{fr}9, %{fr}4, %{fr}3; // beta * C + alpha * acc"),
)?;
wl(
&mut p,
&format!(" st.global.{ld_ty} [%rd15], %{fr}5; // final C write"),
)?;
wl(&mut p, "")?;
wl(&mut p, "$SK_RET:")?;
wl(&mut p, " ret;")?;
wl(&mut p, "}")?;
Ok(p)
}
fn validate_stream_k_ld(
trans: Transpose,
rows: usize,
cols: usize,
ld: usize,
name: &str,
) -> BlasResult<()> {
let min_ld = match trans {
Transpose::NoTrans => rows,
Transpose::Trans | Transpose::ConjTrans => cols,
};
if ld < min_ld {
return Err(BlasError::InvalidDimension(format!(
"Stream-K GEMM: ld{name} ({ld}) < required ({min_ld})"
)));
}
Ok(())
}
fn trans_label(t: Transpose) -> &'static str {
match t {
Transpose::NoTrans => "n",
Transpose::Trans => "t",
Transpose::ConjTrans => "c",
}
}
fn wl(ptx: &mut String, line: &str) -> BlasResult<()> {
writeln!(ptx, "{line}").map_err(|e| BlasError::PtxGeneration(format!("fmt error: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(sm_count: usize) -> StreamKConfig {
StreamKConfig {
tile_m: 128,
tile_n: 128,
tile_k: 32,
sm_count,
}
}
#[test]
fn config_new_validates() {
assert!(StreamKConfig::new(128, 128, 32, 80).is_ok());
assert!(StreamKConfig::new(0, 128, 32, 80).is_err());
assert!(StreamKConfig::new(128, 0, 32, 80).is_err());
assert!(StreamKConfig::new(128, 128, 0, 80).is_err());
assert!(StreamKConfig::new(128, 128, 32, 0).is_err());
}
#[test]
fn total_tiles_basic() {
let cfg = make_config(80);
assert_eq!(cfg.total_tiles(256, 256), 4); assert_eq!(cfg.total_tiles(512, 256), 8);
assert_eq!(cfg.total_tiles(129, 128), 2); }
#[test]
fn iters_per_tile_basic() {
let cfg = make_config(80);
assert_eq!(cfg.iters_per_tile(256), 8); assert_eq!(cfg.iters_per_tile(33), 2); assert_eq!(cfg.iters_per_tile(32), 1);
}
#[test]
fn total_iters_consistency() {
let cfg = make_config(80);
let (m, n, k) = (512, 512, 1024);
let expected = cfg.total_tiles(m, n) * cfg.iters_per_tile(k);
assert_eq!(cfg.total_iters(m, n, k), expected);
}
#[test]
fn cta_count_limited_by_tiles() {
let cfg = make_config(80);
assert_eq!(cfg.cta_count(256, 256), 4);
}
#[test]
fn cta_count_limited_by_sm() {
let cfg = make_config(10);
assert_eq!(cfg.cta_count(1024, 1024), 10);
}
#[test]
fn iters_per_cta_balanced() {
let cfg = make_config(4);
let (m, n, k) = (512, 512, 128);
let total = cfg.total_iters(m, n, k);
let (base, rem) = cfg.iters_per_cta(m, n, k);
let ctas = cfg.cta_count(m, n);
assert_eq!(base * (ctas - rem) + (base + 1) * rem, total);
}
#[test]
fn generate_stream_k_ptx_f32() {
let cfg = make_config(80);
let ptx = generate_stream_k_ptx::<f32>(
SmVersion::Sm90,
&cfg,
Transpose::NoTrans,
Transpose::NoTrans,
512,
512,
256,
);
let ptx = ptx.expect("PTX gen should succeed");
assert!(ptx.contains(".entry stream_k_gemm_f32_n_n"));
assert!(ptx.contains(".target sm_90"));
assert!(ptx.contains("$SK_ITER_LOOP"));
assert!(ptx.contains("atom.add"));
}
#[test]
fn generate_stream_k_ptx_f64_trans() {
let cfg = make_config(40);
let ptx = generate_stream_k_ptx::<f64>(
SmVersion::Sm80,
&cfg,
Transpose::Trans,
Transpose::NoTrans,
256,
256,
128,
);
let ptx = ptx.expect("PTX gen should succeed");
assert!(ptx.contains("stream_k_gemm_f64_t_n"));
assert!(ptx.contains("div.u32"));
}
#[test]
fn validate_stream_k_ld_ok() {
assert!(validate_stream_k_ld(Transpose::NoTrans, 64, 32, 64, "A").is_ok());
assert!(validate_stream_k_ld(Transpose::Trans, 64, 32, 32, "A").is_ok());
}
#[test]
fn validate_stream_k_ld_err() {
assert!(validate_stream_k_ld(Transpose::NoTrans, 64, 32, 32, "A").is_err());
}
#[test]
fn stream_k_tile_distribution_balanced_for_large_square() {
let cfg = StreamKConfig::new(128, 128, 32, 108).expect("valid config");
let (m, n, k) = (4096usize, 4096usize, 4096usize);
let total_tiles = cfg.total_tiles(m, n);
assert_eq!(total_tiles, 1024, "4096/128 * 4096/128 = 32 * 32 = 1024");
let (base_iters, remainder) = cfg.iters_per_cta(m, n, k);
let cta_count = cfg.cta_count(m, n);
let max_iters_any_cta = if remainder > 0 {
base_iters + 1
} else {
base_iters
};
let min_iters_any_cta = base_iters;
assert!(
max_iters_any_cta - min_iters_any_cta <= 1,
"Stream-K must balance within 1 iteration: max={} min={}",
max_iters_any_cta,
min_iters_any_cta
);
assert_eq!(cta_count, 108, "108 SMs are all active for 1024 tiles");
}
#[test]
fn stream_k_superior_to_split_k_in_tail_wave_scenario() {
let cfg = StreamKConfig::new(128, 128, 32, 10).expect("valid config");
let (m, n, k) = (384usize, 512usize, 512usize);
let (base_iters, remainder) = cfg.iters_per_cta(m, n, k);
let total_iters = cfg.total_iters(m, n, k);
let cta_count = cfg.cta_count(m, n);
let stream_k_max_imbalance = if remainder > 0 { 1usize } else { 0 };
let total_tiles = cfg.total_tiles(m, n);
let tiles_per_sm_max = total_tiles.div_ceil(cfg.sm_count);
let tiles_per_sm_min = total_tiles / cfg.sm_count;
let iters_per_tile = cfg.iters_per_tile(k);
let split_k_max_imbalance = (tiles_per_sm_max - tiles_per_sm_min) * iters_per_tile;
assert!(
stream_k_max_imbalance <= split_k_max_imbalance,
"Stream-K imbalance ({}) must be <= Split-K imbalance ({})",
stream_k_max_imbalance,
split_k_max_imbalance
);
assert_eq!(
base_iters * (cta_count - remainder) + (base_iters + 1) * remainder,
total_iters,
"Stream-K must account for all iterations"
);
}
#[test]
fn stream_k_ptx_contains_atomic_k_reduction() {
let cfg = StreamKConfig::new(128, 128, 32, 80).expect("valid config");
let ptx = generate_stream_k_ptx::<f32>(
SmVersion::Sm90,
&cfg,
Transpose::NoTrans,
Transpose::NoTrans,
512,
512,
1024,
)
.expect("PTX generation must succeed");
assert!(
ptx.contains("atom.add"),
"Stream-K PTX must document atomic K-reduction via atom.add"
);
}
#[test]
fn stream_k_cta_count_and_remainder_property() {
let cfg = StreamKConfig::new(128, 128, 32, 108).expect("valid config");
let (m, n, k) = (256usize, 256usize, 256usize);
let cta_count = cfg.cta_count(m, n);
let total_tiles = cfg.total_tiles(m, n);
assert_eq!(cta_count, total_tiles.min(108));
let (base_iters, remainder) = cfg.iters_per_cta(m, n, k);
let total_iters = cfg.total_iters(m, n, k);
assert_eq!(
base_iters * (cta_count - remainder) + (base_iters + 1) * remainder,
total_iters,
"Work conservation must hold"
);
}
}