use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{FillMode, GpuFloat, MatrixDesc, MatrixDescMut, Side, Transpose};
#[allow(clippy::too_many_arguments)]
pub fn symm<T: GpuFloat>(
handle: &BlasHandle,
side: Side,
fill_mode: FillMode,
alpha: T,
a: &MatrixDesc<T>,
b: &MatrixDesc<T>,
beta: T,
c: &mut MatrixDescMut<T>,
) -> BlasResult<()> {
if a.rows != a.cols {
return Err(BlasError::InvalidDimension(format!(
"symmetric matrix A must be square, got {}x{}",
a.rows, a.cols
)));
}
let sym_n = a.rows;
match side {
Side::Left => {
if sym_n != b.rows {
return Err(BlasError::DimensionMismatch(format!(
"SYMM left: A is {s}x{s} but B has {} rows",
b.rows,
s = sym_n
)));
}
if c.rows != sym_n || c.cols != b.cols {
return Err(BlasError::DimensionMismatch(format!(
"SYMM left: C should be {}x{}, got {}x{}",
sym_n, b.cols, c.rows, c.cols
)));
}
}
Side::Right => {
if sym_n != b.cols {
return Err(BlasError::DimensionMismatch(format!(
"SYMM right: A is {s}x{s} but B has {} cols",
b.cols,
s = sym_n
)));
}
if c.rows != b.rows || c.cols != sym_n {
return Err(BlasError::DimensionMismatch(format!(
"SYMM right: C should be {}x{}, got {}x{}",
b.rows, sym_n, c.rows, c.cols
)));
}
}
}
let _ = fill_mode;
let (trans_a, trans_b) = match side {
Side::Left => (Transpose::NoTrans, Transpose::NoTrans),
Side::Right => (Transpose::NoTrans, Transpose::NoTrans),
};
let (left, right) = match side {
Side::Left => (a, b),
Side::Right => (b, a),
};
super::gemm_api::gemm(handle, trans_a, trans_b, alpha, left, right, beta, c)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn symm_validates_square() {
let err = BlasError::InvalidDimension("symmetric matrix A must be square, got 3x5".into());
assert!(err.to_string().contains("square"));
}
#[test]
fn side_enum_values() {
assert_ne!(Side::Left, Side::Right);
}
#[test]
fn fill_mode_enum_values() {
assert_ne!(FillMode::Upper, FillMode::Lower);
}
}