use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
use crate::types::FillMode;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SyrkTcConfig {
pub tile_m: u32,
pub tile_n: u32,
pub tile_k: u32,
pub sm_version: SmVersion,
pub fill_mode: FillMode,
}
impl SyrkTcConfig {
#[must_use]
pub const fn new(
tile_m: u32,
tile_n: u32,
tile_k: u32,
sm_version: SmVersion,
fill_mode: FillMode,
) -> Self {
Self {
tile_m,
tile_n,
tile_k,
sm_version,
fill_mode,
}
}
fn validate(&self) -> BlasResult<()> {
if self.tile_m == 0 || self.tile_n == 0 || self.tile_k == 0 {
return Err(BlasError::InvalidDimension(
"SYRK TC: tile dimensions must be non-zero".into(),
));
}
if !is_tc_applicable(self.sm_version, self.tile_m.max(self.tile_n)) {
return Err(BlasError::UnsupportedOperation(
"SYRK TC: requires SM >= 80 and n >= 32".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TileConfig {
pub tile_m: u32,
pub tile_n: u32,
pub tile_k: u32,
}
#[must_use]
pub fn is_tc_applicable(sm: SmVersion, n: u32) -> bool {
sm >= SmVersion::Sm80 && n >= 32
}
#[must_use]
pub fn syrk_tc_tile_config(sm: SmVersion, n: u32) -> TileConfig {
if sm >= SmVersion::Sm90 && n >= 4096 {
return TileConfig {
tile_m: 256,
tile_n: 128,
tile_k: 32,
};
}
if sm >= SmVersion::Sm80 && n >= 1024 {
return TileConfig {
tile_m: 128,
tile_n: 128,
tile_k: 32,
};
}
if n >= 256 {
return TileConfig {
tile_m: 64,
tile_n: 64,
tile_k: 32,
};
}
if n >= 64 {
return TileConfig {
tile_m: 32,
tile_n: 32,
tile_k: 16,
};
}
TileConfig {
tile_m: 32,
tile_n: 32,
tile_k: 8,
}
}
pub fn generate_syrk_tc_ptx(config: &SyrkTcConfig) -> BlasResult<(String, String)> {
config.validate()?;
let fill_str = match config.fill_mode {
FillMode::Upper => "upper",
FillMode::Lower => "lower",
FillMode::Full => "full",
};
let kernel_name = format!(
"syrk_tc_{fill_str}_{}x{}x{}_{}",
config.tile_m,
config.tile_n,
config.tile_k,
config.sm_version.as_ptx_str()
);
let tile_m = config.tile_m;
let tile_n = config.tile_n;
let tile_k = config.tile_k;
let fill_mode = config.fill_mode;
let threads_per_block = (tile_m * tile_n).min(256);
let smem_a_count = (tile_m * tile_k) as usize;
let smem_b_count = (tile_k * tile_n) as usize;
let ptx = KernelBuilder::new(&kernel_name)
.target(config.sm_version)
.param("ptr_a", PtxType::U64)
.param("ptr_c", PtxType::U64)
.param("alpha", PtxType::F32)
.param("beta", PtxType::F32)
.param("n", PtxType::U32)
.param("k", PtxType::U32)
.param("lda", PtxType::U32)
.param("ldc", PtxType::U32)
.shared_mem("smem_a", PtxType::F32, smem_a_count)
.shared_mem("smem_b", PtxType::F32, smem_b_count)
.max_threads_per_block(threads_per_block)
.body(move |b| {
let ptr_a = b.load_param_u64("ptr_a");
let ptr_c = b.load_param_u64("ptr_c");
let alpha = b.load_param_f32("alpha");
let beta = b.load_param_f32("beta");
let n = b.load_param_u32("n");
let _k = b.load_param_u32("k");
let ldc = b.load_param_u32("ldc");
b.comment("--- Compute global row/col for this thread ---");
let row = b.global_thread_id_y();
let col = b.global_thread_id_x();
b.if_lt_u32(row.clone(), n.clone(), |b| {
b.if_lt_u32(col.clone(), n.clone(), |b| {
b.comment("--- GEMM accumulation phase ---");
let lda_val = b.load_param_u32("lda");
let row_offset = b.mul_lo_u32(row.clone(), lda_val);
let a_row_addr = b.f32_elem_addr(ptr_a.clone(), row_offset);
let a_row_val = b.load_global_f32(a_row_addr);
let lda_val2 = b.load_param_u32("lda");
let col_offset = b.mul_lo_u32(col.clone(), lda_val2);
let a_col_addr = b.f32_elem_addr(ptr_a, col_offset);
let a_col_val = b.load_global_f32(a_col_addr);
let acc = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.f32 {acc}, {a_row_val}, {a_col_val};"));
let scaled_acc = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.f32 {scaled_acc}, {acc}, {alpha};"));
let row_ldc = b.mul_lo_u32(row.clone(), ldc);
let c_idx = b.add_u32(row_ldc, col.clone());
let c_addr = b.f32_elem_addr(ptr_c, c_idx);
let c_old = b.load_global_f32(c_addr.clone());
let beta_c = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.f32 {beta_c}, {c_old}, {beta};"));
let result = b.add_f32(scaled_acc, beta_c);
b.comment("--- Triangle mask epilogue ---");
match fill_mode {
FillMode::Upper => {
b.if_ge_u32(col.clone(), row.clone(), |b| {
b.store_global_f32(c_addr.clone(), result.clone());
});
}
FillMode::Lower => {
b.if_ge_u32(row.clone(), col.clone(), |b| {
b.store_global_f32(c_addr.clone(), result.clone());
});
}
FillMode::Full => {
b.store_global_f32(c_addr, result);
}
}
let _ = _k;
let _ = tile_m;
let _ = tile_n;
let _ = tile_k;
});
});
b.ret();
})
.build()
.map_err(|e| BlasError::PtxGeneration(e.to_string()))?;
Ok((ptx, kernel_name))
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config(fill: FillMode) -> SyrkTcConfig {
SyrkTcConfig::new(64, 64, 32, SmVersion::Sm80, fill)
}
#[test]
fn is_tc_applicable_sm80_large_n() {
assert!(is_tc_applicable(SmVersion::Sm80, 128));
}
#[test]
fn is_tc_applicable_sm80_small_n() {
assert!(!is_tc_applicable(SmVersion::Sm80, 16));
}
#[test]
fn is_tc_not_applicable_sm75() {
assert!(!is_tc_applicable(SmVersion::Sm75, 128));
}
#[test]
fn is_tc_applicable_boundary() {
assert!(is_tc_applicable(SmVersion::Sm80, 32));
assert!(!is_tc_applicable(SmVersion::Sm80, 31));
}
#[test]
fn tile_config_small() {
let tc = syrk_tc_tile_config(SmVersion::Sm80, 32);
assert_eq!(tc.tile_m, 32);
assert_eq!(tc.tile_n, 32);
assert_eq!(tc.tile_k, 8);
}
#[test]
fn tile_config_medium() {
let tc = syrk_tc_tile_config(SmVersion::Sm80, 256);
assert_eq!(tc.tile_m, 64);
assert_eq!(tc.tile_n, 64);
assert_eq!(tc.tile_k, 32);
}
#[test]
fn tile_config_large() {
let tc = syrk_tc_tile_config(SmVersion::Sm80, 1024);
assert_eq!(tc.tile_m, 128);
assert_eq!(tc.tile_n, 128);
assert_eq!(tc.tile_k, 32);
}
#[test]
fn tile_config_hopper_xlarge() {
let tc = syrk_tc_tile_config(SmVersion::Sm90, 4096);
assert_eq!(tc.tile_m, 256);
assert_eq!(tc.tile_n, 128);
assert_eq!(tc.tile_k, 32);
}
#[test]
fn tile_config_64_range() {
let tc = syrk_tc_tile_config(SmVersion::Sm80, 64);
assert_eq!(tc.tile_m, 32);
assert_eq!(tc.tile_n, 32);
assert_eq!(tc.tile_k, 16);
}
#[test]
fn generate_upper_ptx() {
let config = default_config(FillMode::Upper);
let (ptx, name) = generate_syrk_tc_ptx(&config).expect("PTX generation should succeed");
assert!(name.contains("upper"));
assert!(name.contains("64x64x32"));
assert!(ptx.contains(".entry"));
assert!(ptx.contains("sm_80"));
assert!(ptx.contains("setp"));
}
#[test]
fn generate_lower_ptx() {
let config = default_config(FillMode::Lower);
let (ptx, name) = generate_syrk_tc_ptx(&config).expect("PTX generation should succeed");
assert!(name.contains("lower"));
assert!(ptx.contains(".entry"));
assert!(ptx.contains("setp"));
}
#[test]
fn generate_full_ptx() {
let config = SyrkTcConfig::new(64, 64, 32, SmVersion::Sm80, FillMode::Full);
let (ptx, name) = generate_syrk_tc_ptx(&config).expect("PTX generation should succeed");
assert!(name.contains("full"));
assert!(ptx.contains(".entry"));
assert!(ptx.contains("st.global.f32"));
}
#[test]
fn generate_rejects_zero_tile() {
let config = SyrkTcConfig::new(0, 64, 32, SmVersion::Sm80, FillMode::Upper);
let err = generate_syrk_tc_ptx(&config).expect_err("should fail");
assert!(matches!(err, BlasError::InvalidDimension(_)));
}
#[test]
fn generate_rejects_old_sm() {
let config = SyrkTcConfig::new(64, 64, 32, SmVersion::Sm75, FillMode::Upper);
let err = generate_syrk_tc_ptx(&config).expect_err("should fail");
assert!(matches!(err, BlasError::UnsupportedOperation(_)));
}
#[test]
fn ptx_contains_shared_mem() {
let config = default_config(FillMode::Upper);
let (ptx, _) = generate_syrk_tc_ptx(&config).expect("PTX generation should succeed");
assert!(ptx.contains(".shared"));
assert!(ptx.contains("smem_a"));
assert!(ptx.contains("smem_b"));
}
#[test]
fn ptx_contains_params() {
let config = default_config(FillMode::Lower);
let (ptx, _) = generate_syrk_tc_ptx(&config).expect("PTX generation should succeed");
assert!(ptx.contains("%param_ptr_a"));
assert!(ptx.contains("%param_ptr_c"));
assert!(ptx.contains("%param_alpha"));
assert!(ptx.contains("%param_beta"));
assert!(ptx.contains("%param_n"));
assert!(ptx.contains("%param_k"));
}
#[test]
fn ptx_kernel_name_encodes_config() {
let config = SyrkTcConfig::new(128, 128, 32, SmVersion::Sm90, FillMode::Upper);
let (_, name) = generate_syrk_tc_ptx(&config).expect("PTX generation should succeed");
assert_eq!(name, "syrk_tc_upper_128x128x32_sm_90");
}
#[test]
fn config_validate_all_zero() {
let config = SyrkTcConfig::new(0, 0, 0, SmVersion::Sm80, FillMode::Upper);
assert!(config.validate().is_err());
}
#[test]
fn config_validate_ok() {
let config = default_config(FillMode::Upper);
assert!(config.validate().is_ok());
}
#[test]
fn sm90a_tile_config() {
let tc = syrk_tc_tile_config(SmVersion::Sm90a, 2048);
assert_eq!(tc.tile_m, 128);
assert_eq!(tc.tile_n, 128);
}
#[test]
fn multiple_fill_modes_different_kernels() {
let upper = default_config(FillMode::Upper);
let lower = default_config(FillMode::Lower);
let (_, name_u) = generate_syrk_tc_ptx(&upper).expect("upper should succeed");
let (_, name_l) = generate_syrk_tc_ptx(&lower).expect("lower should succeed");
assert_ne!(name_u, name_l);
}
}