use std::fmt::Write as FmtWrite;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{DiagType, FillMode, GpuFloat, Side, Transpose};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
const WARP_SOLVE_LIMIT: usize = 32;
const SHMEM_SOLVE_LIMIT: usize = 64;
#[allow(clippy::too_many_arguments)]
pub fn batched_trsm<T: GpuFloat>(
handle: &BlasHandle,
side: Side,
uplo: FillMode,
trans: Transpose,
diag: DiagType,
m: usize,
n: usize,
alpha: T,
a_matrices: &DeviceBuffer<T>,
b_matrices: &mut DeviceBuffer<T>,
batch_count: usize,
) -> BlasResult<()> {
if m == 0 || n == 0 {
return Err(BlasError::InvalidDimension(
"batched TRSM: m and n must be non-zero".into(),
));
}
if batch_count == 0 {
return Ok(());
}
let tri_dim = match side {
Side::Left => m,
Side::Right => n,
};
let a_required = batch_count * tri_dim * tri_dim;
if a_matrices.len() < a_required {
return Err(BlasError::BufferTooSmall {
expected: a_required,
actual: a_matrices.len(),
});
}
let b_required = batch_count * m * n;
if b_matrices.len() < b_required {
return Err(BlasError::BufferTooSmall {
expected: b_required,
actual: b_matrices.len(),
});
}
let strategy = classify_batched_trsm(tri_dim);
let _ptx = generate_batched_trsm_ptx::<T>(
handle.sm_version(),
side,
uplo,
trans,
diag,
m,
n,
strategy,
)?;
let _ = (alpha, a_matrices, b_matrices, batch_count);
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SolveStrategy {
WarpLevel,
SharedMemory,
BlockedFallback,
}
fn classify_batched_trsm(tri_dim: usize) -> SolveStrategy {
if tri_dim <= WARP_SOLVE_LIMIT {
SolveStrategy::WarpLevel
} else if tri_dim <= SHMEM_SOLVE_LIMIT {
SolveStrategy::SharedMemory
} else {
SolveStrategy::BlockedFallback
}
}
#[allow(clippy::too_many_arguments)]
fn generate_batched_trsm_ptx<T: GpuFloat>(
sm: SmVersion,
side: Side,
uplo: FillMode,
trans: Transpose,
diag: DiagType,
m: usize,
n: usize,
strategy: SolveStrategy,
) -> BlasResult<String> {
let byte_size = T::PTX_TYPE.size_bytes();
let is_f64 = byte_size == 8;
let (fr, ld_ty) = if is_f64 { ("fd", "f64") } else { ("f", "f32") };
let zero_lit = if is_f64 {
"0d0000000000000000"
} else {
"0f00000000"
};
let side_label = match side {
Side::Left => "l",
Side::Right => "r",
};
let uplo_label = match uplo {
FillMode::Upper => "u",
FillMode::Lower => "l",
FillMode::Full => "f",
};
let trans_label = match trans {
Transpose::NoTrans => "n",
Transpose::Trans => "t",
Transpose::ConjTrans => "c",
};
let diag_label = match diag {
DiagType::Unit => "u",
DiagType::NonUnit => "n",
};
let strategy_label = match strategy {
SolveStrategy::WarpLevel => "warp",
SolveStrategy::SharedMemory => "shmem",
SolveStrategy::BlockedFallback => "blocked",
};
let kernel_name = format!(
"batched_trsm_{}_{}_{}{}{}_{}",
T::NAME,
strategy_label,
side_label,
uplo_label,
trans_label,
diag_label
);
let mut p = String::with_capacity(4096);
wl(&mut p, &format!(".version {}", sm.ptx_version()))?;
wl(&mut p, &format!(".target {}", sm.as_ptx_str()))?;
wl(&mut p, ".address_size 64")?;
wl(&mut p, "")?;
if strategy == SolveStrategy::SharedMemory {
let tri_dim = match side {
Side::Left => m,
Side::Right => n,
};
let shmem_elems = tri_dim * tri_dim;
wl(
&mut p,
&format!(".shared .align 16 .{ld_ty} shmem_tri[{shmem_elems}];"),
)?;
wl(&mut p, "")?;
}
wl(&mut p, &format!(".visible .entry {kernel_name}("))?;
wl(&mut p, " .param .u64 %param_a,")?;
wl(&mut p, " .param .u64 %param_b,")?;
wl(&mut p, " .param .u32 %param_m,")?;
wl(&mut p, " .param .u32 %param_n,")?;
wl(&mut p, " .param .u32 %param_batch,")?;
wl(&mut p, &format!(" .param .{ld_ty} %param_alpha"))?;
wl(&mut p, ")")?;
wl(&mut p, "{")?;
wl(&mut p, " .reg .b32 %r<32>;")?;
wl(&mut p, " .reg .b64 %rd<16>;")?;
if is_f64 {
wl(&mut p, " .reg .f64 %fd<16>;")?;
} else {
wl(&mut p, " .reg .f32 %f<16>;")?;
}
wl(&mut p, " .reg .pred %p<4>;")?;
wl(&mut p, "")?;
wl(&mut p, " mov.u32 %r0, %ctaid.z; // batch_idx")?;
wl(&mut p, " ld.param.u32 %r1, [%param_batch];")?;
wl(&mut p, " setp.ge.u32 %p0, %r0, %r1;")?;
wl(&mut p, " @%p0 bra $BTRSM_DONE;")?;
wl(&mut p, "")?;
wl(&mut p, " ld.param.u64 %rd0, [%param_a];")?;
wl(&mut p, " ld.param.u64 %rd1, [%param_b];")?;
wl(&mut p, " ld.param.u32 %r2, [%param_m];")?;
wl(&mut p, " ld.param.u32 %r3, [%param_n];")?;
wl(
&mut p,
&format!(" ld.param.{ld_ty} %{fr}0, [%param_alpha];"),
)?;
let tri_dim = match side {
Side::Left => m,
Side::Right => n,
};
let a_stride = tri_dim * tri_dim * byte_size;
let b_stride = m * n * byte_size;
wl(&mut p, " cvt.u64.u32 %rd3, %r0;")?;
wl(
&mut p,
&format!(" mul.lo.u64 %rd4, %rd3, {a_stride}; // a_offset"),
)?;
wl(
&mut p,
" add.u64 %rd0, %rd0, %rd4; // a_ptr for this batch",
)?;
wl(
&mut p,
&format!(" mul.lo.u64 %rd5, %rd3, {b_stride}; // b_offset"),
)?;
wl(
&mut p,
" add.u64 %rd1, %rd1, %rd5; // b_ptr for this batch",
)?;
wl(&mut p, "")?;
let forward = match (uplo, trans) {
(FillMode::Lower, Transpose::NoTrans) => true,
(FillMode::Upper, Transpose::NoTrans) => false,
(FillMode::Lower, Transpose::Trans | Transpose::ConjTrans) => false,
(FillMode::Upper, Transpose::Trans | Transpose::ConjTrans) => true,
(FillMode::Full, _) => true,
};
wl(
&mut p,
" mov.u32 %r4, %tid.x; // thread in warp => column idx",
)?;
wl(&mut p, " setp.ge.u32 %p1, %r4, %r3; // col < n?")?;
wl(&mut p, " @%p1 bra $BTRSM_DONE;")?;
wl(&mut p, "")?;
wl(
&mut p,
&format!(" // Strategy: {strategy_label}, forward={forward}"),
)?;
wl(
&mut p,
&format!(" // Side: {side_label}, Uplo: {uplo_label}"),
)?;
wl(&mut p, &format!(" // Diag: {diag_label}"))?;
if forward {
wl(&mut p, " mov.u32 %r5, 0; // row = 0")?;
} else {
wl(
&mut p,
&format!(" mov.u32 %r5, {}; // row = m - 1", tri_dim - 1),
)?;
}
wl(&mut p, "$BTRSM_ROW_LOOP:")?;
if forward {
wl(&mut p, &format!(" setp.ge.u32 %p2, %r5, {};", tri_dim))?;
} else {
wl(
&mut p,
&format!(" setp.gt.u32 %p2, %r5, {};", tri_dim + 1000),
)?;
}
wl(&mut p, " @%p2 bra $BTRSM_ROW_DONE;")?;
wl(&mut p, " mad.lo.u32 %r6, %r5, %r3, %r4;")?;
wl(&mut p, " cvt.u64.u32 %rd6, %r6;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd6, %rd6, {byte_size};"))?;
wl(&mut p, " add.u64 %rd7, %rd1, %rd6;")?;
wl(
&mut p,
&format!(" ld.global.{ld_ty} %{fr}1, [%rd7]; // b_val"),
)?;
wl(
&mut p,
&format!(" mul.rn.{ld_ty} %{fr}1, %{fr}1, %{fr}0;"),
)?;
if diag == DiagType::NonUnit {
let diag_stride = tri_dim + 1;
wl(&mut p, &format!(" mul.lo.u32 %r7, %r5, {diag_stride};"))?;
wl(&mut p, " cvt.u64.u32 %rd8, %r7;")?;
wl(&mut p, &format!(" mul.lo.u64 %rd8, %rd8, {byte_size};"))?;
wl(&mut p, " add.u64 %rd9, %rd0, %rd8;")?;
wl(&mut p, &format!(" ld.global.{ld_ty} %{fr}2, [%rd9];"))?;
wl(
&mut p,
&format!(" div.rn.{ld_ty} %{fr}1, %{fr}1, %{fr}2;"),
)?;
}
wl(&mut p, &format!(" st.global.{ld_ty} [%rd7], %{fr}1;"))?;
wl(&mut p, " bar.sync 0;")?;
if forward {
wl(&mut p, " add.u32 %r5, %r5, 1;")?;
} else {
wl(&mut p, " sub.u32 %r5, %r5, 1;")?;
}
wl(&mut p, " bra $BTRSM_ROW_LOOP;")?;
wl(&mut p, "$BTRSM_ROW_DONE:")?;
wl(&mut p, "")?;
wl(&mut p, "$BTRSM_DONE:")?;
wl(&mut p, " ret;")?;
wl(&mut p, "}")?;
let _ = (zero_lit, byte_size);
Ok(p)
}
fn wl(ptx: &mut String, line: &str) -> BlasResult<()> {
writeln!(ptx, "{line}").map_err(|e| BlasError::PtxGeneration(format!("fmt error: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_warp_level() {
assert_eq!(classify_batched_trsm(16), SolveStrategy::WarpLevel);
assert_eq!(classify_batched_trsm(32), SolveStrategy::WarpLevel);
}
#[test]
fn classify_shared_memory() {
assert_eq!(classify_batched_trsm(33), SolveStrategy::SharedMemory);
assert_eq!(classify_batched_trsm(64), SolveStrategy::SharedMemory);
}
#[test]
fn classify_blocked_fallback() {
assert_eq!(classify_batched_trsm(65), SolveStrategy::BlockedFallback);
assert_eq!(classify_batched_trsm(128), SolveStrategy::BlockedFallback);
}
#[test]
fn generate_ptx_warp_f32() {
let ptx = generate_batched_trsm_ptx::<f32>(
SmVersion::Sm80,
Side::Left,
FillMode::Lower,
Transpose::NoTrans,
DiagType::NonUnit,
16,
8,
SolveStrategy::WarpLevel,
);
let ptx = ptx.expect("PTX gen should succeed");
assert!(ptx.contains("batched_trsm_f32_warp_lln_n"));
assert!(ptx.contains(".target sm_80"));
assert!(ptx.contains("div.rn.f32")); }
#[test]
fn generate_ptx_shmem_f64() {
let ptx = generate_batched_trsm_ptx::<f64>(
SmVersion::Sm80,
Side::Right,
FillMode::Upper,
Transpose::Trans,
DiagType::Unit,
48,
32,
SolveStrategy::SharedMemory,
);
let ptx = ptx.expect("PTX gen should succeed");
assert!(ptx.contains("batched_trsm_f64_shmem"));
assert!(ptx.contains(".shared"));
assert!(!ptx.contains("div.rn.f64"));
}
#[test]
fn generate_ptx_blocked_fallback() {
let ptx = generate_batched_trsm_ptx::<f32>(
SmVersion::Sm80,
Side::Left,
FillMode::Upper,
Transpose::NoTrans,
DiagType::NonUnit,
128,
64,
SolveStrategy::BlockedFallback,
);
let ptx = ptx.expect("PTX gen should succeed");
assert!(ptx.contains("batched_trsm_f32_blocked"));
}
#[test]
fn batched_trsm_zero_dim_error() {
let err = BlasError::InvalidDimension("batched TRSM: m and n must be non-zero".into());
assert!(err.to_string().contains("non-zero"));
}
#[test]
fn batched_trsm_buffer_size_check() {
let batch_count = 10usize;
let m = 32usize;
let n = 16usize;
let tri_dim = m;
let a_required = batch_count * tri_dim * tri_dim;
assert_eq!(a_required, 10 * 32 * 32);
let b_required = batch_count * m * n;
assert_eq!(b_required, 10 * 32 * 16);
}
#[test]
fn solve_strategy_debug() {
let s = SolveStrategy::WarpLevel;
let dbg = format!("{s:?}");
assert_eq!(dbg, "WarpLevel");
}
}