use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams};
use oxicuda_ptx::ir::PtxType;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{FillMode, GpuFloat, MatrixDesc, MatrixDescMut, Transpose};
use super::syrk_tc;
pub fn syrk<T: GpuFloat>(
handle: &BlasHandle,
fill_mode: FillMode,
trans: Transpose,
alpha: T,
a: &MatrixDesc<T>,
beta: T,
c: &mut MatrixDescMut<T>,
) -> BlasResult<()> {
if trans == Transpose::ConjTrans {
return Err(BlasError::InvalidArgument(
"SYRK: use HERK for conjugate-transpose; ConjTrans is not valid here".into(),
));
}
if c.rows != c.cols {
return Err(BlasError::InvalidDimension(format!(
"SYRK: output C must be square, got {}x{}",
c.rows, c.cols
)));
}
let n = c.rows;
let (a_n, a_k) = match trans {
Transpose::NoTrans => (a.rows, a.cols),
Transpose::Trans | Transpose::ConjTrans => (a.cols, a.rows),
};
if a_n != n {
return Err(BlasError::DimensionMismatch(format!(
"SYRK: op(A) has {a_n} rows but C is {n}x{n}"
)));
}
if n == 0 {
return Ok(()); }
{
let sm = handle.sm_version();
let tc_eligible = syrk_tc::is_tc_applicable(sm, n)
&& fill_mode != FillMode::Full
&& T::PTX_TYPE == PtxType::F32;
if tc_eligible {
let tile = syrk_tc::syrk_tc_tile_config(sm, n);
let config =
syrk_tc::SyrkTcConfig::new(tile.tile_m, tile.tile_n, tile.tile_k, sm, fill_mode);
if let Ok((ptx, kernel_name)) = syrk_tc::generate_syrk_tc_ptx(&config) {
if let Ok(module) = Module::from_ptx(&ptx) {
let module = Arc::new(module);
let kernel =
Kernel::from_module(Arc::clone(&module), &kernel_name).map_err(|e| {
BlasError::LaunchFailed(format!("SYRK TC: kernel lookup failed: {e}"))
})?;
let grid_x = n.div_ceil(tile.tile_n);
let grid_y = n.div_ceil(tile.tile_m);
let threads_per_block = (tile.tile_m * tile.tile_n).min(256);
let params = LaunchParams::new(
Dim3::new(grid_x, grid_y, 1),
Dim3::new(threads_per_block, 1, 1),
);
let alpha_f32 = f32::from_bits(alpha.to_bits_u64() as u32);
let beta_f32 = f32::from_bits(beta.to_bits_u64() as u32);
let args = (a.ptr, c.ptr, alpha_f32, beta_f32, n, a_k, a.ld, c.ld);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| {
BlasError::LaunchFailed(format!("SYRK TC: launch failed: {e}"))
})?;
return Ok(());
}
}
}
}
let (trans_left, trans_right) = match trans {
Transpose::NoTrans => (Transpose::NoTrans, Transpose::Trans),
Transpose::Trans => (Transpose::Trans, Transpose::NoTrans),
Transpose::ConjTrans => unreachable!(), };
super::gemm_api::gemm(handle, trans_left, trans_right, alpha, a, a, beta, c)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn syrk_rejects_conj_trans() {
let err = BlasError::InvalidArgument("SYRK: use HERK".into());
assert!(err.to_string().contains("HERK"));
}
#[test]
fn syrk_validates_square_c() {
let err = BlasError::InvalidDimension("SYRK: output C must be square, got 3x5".into());
assert!(err.to_string().contains("square"));
}
#[test]
fn trans_choices() {
let (tl, tr) = match Transpose::NoTrans {
Transpose::NoTrans => (Transpose::NoTrans, Transpose::Trans),
_ => (Transpose::Trans, Transpose::NoTrans),
};
assert_eq!(tl, Transpose::NoTrans);
assert_eq!(tr, Transpose::Trans);
}
}