use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{FillMode, GpuFloat, MatrixDesc, MatrixDescMut, Transpose};
use super::syrk_tc;
#[allow(clippy::too_many_arguments)]
pub fn syr2k<T: GpuFloat>(
handle: &BlasHandle,
fill_mode: FillMode,
trans: Transpose,
alpha: T,
a: &MatrixDesc<T>,
b: &MatrixDesc<T>,
beta: T,
c: &mut MatrixDescMut<T>,
) -> BlasResult<()> {
if trans == Transpose::ConjTrans {
return Err(BlasError::InvalidArgument(
"SYR2K: use HER2K for conjugate-transpose".into(),
));
}
if c.rows != c.cols {
return Err(BlasError::InvalidDimension(format!(
"SYR2K: output C must be square, got {}x{}",
c.rows, c.cols
)));
}
let n = c.rows;
let (a_n, a_k) = match trans {
Transpose::NoTrans => (a.rows, a.cols),
Transpose::Trans | Transpose::ConjTrans => (a.cols, a.rows),
};
let (b_n, b_k) = match trans {
Transpose::NoTrans => (b.rows, b.cols),
Transpose::Trans | Transpose::ConjTrans => (b.cols, b.rows),
};
if a_n != n {
return Err(BlasError::DimensionMismatch(format!(
"SYR2K: op(A) has {a_n} rows but C is {n}x{n}"
)));
}
if b_n != n {
return Err(BlasError::DimensionMismatch(format!(
"SYR2K: op(B) has {b_n} rows but C is {n}x{n}"
)));
}
if a_k != b_k {
return Err(BlasError::DimensionMismatch(format!(
"SYR2K: op(A) has K={a_k} but op(B) has K={b_k}"
)));
}
if n == 0 {
return Ok(());
}
{
let sm = handle.sm_version();
if syrk_tc::is_tc_applicable(sm, n) && fill_mode != FillMode::Full {
let tile = syrk_tc::syrk_tc_tile_config(sm, n);
let config =
syrk_tc::SyrkTcConfig::new(tile.tile_m, tile.tile_n, tile.tile_k, sm, fill_mode);
let _tc_kernel = syrk_tc::generate_syrk_tc_ptx(&config);
}
}
let (trans_left, trans_right) = match trans {
Transpose::NoTrans => (Transpose::NoTrans, Transpose::Trans),
Transpose::Trans => (Transpose::Trans, Transpose::NoTrans),
Transpose::ConjTrans => unreachable!(),
};
super::gemm_api::gemm(handle, trans_left, trans_right, alpha, a, b, beta, c)?;
let one = T::gpu_one();
super::gemm_api::gemm(handle, trans_left, trans_right, alpha, b, a, one, c)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn syr2k_rejects_conj_trans() {
let err = BlasError::InvalidArgument("SYR2K: use HER2K".into());
assert!(err.to_string().contains("HER2K"));
}
#[test]
fn syr2k_validates_square_c() {
let err = BlasError::InvalidDimension("SYR2K: output C must be square, got 4x6".into());
assert!(err.to_string().contains("square"));
}
}