use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{DiagType, FillMode, GpuFloat, MatrixDesc, MatrixDescMut, Transpose};
#[allow(clippy::too_many_arguments)]
pub fn trmm<T: GpuFloat>(
handle: &BlasHandle,
side: Side,
fill_mode: FillMode,
trans_a: Transpose,
diag: DiagType,
alpha: T,
a: &MatrixDesc<T>,
b: &mut MatrixDescMut<T>,
) -> BlasResult<()> {
if a.rows != a.cols {
return Err(BlasError::InvalidDimension(format!(
"TRMM: triangular matrix A must be square, got {}x{}",
a.rows, a.cols
)));
}
let tri_n = a.rows;
let m = b.rows;
let n = b.cols;
if m == 0 || n == 0 {
return Err(BlasError::InvalidDimension(
"TRMM: B dimensions must be non-zero".into(),
));
}
match side {
Side::Left => {
if tri_n != m {
return Err(BlasError::DimensionMismatch(format!(
"TRMM left: A is {t}x{t} but B has {m} rows",
t = tri_n
)));
}
}
Side::Right => {
if tri_n != n {
return Err(BlasError::DimensionMismatch(format!(
"TRMM right: A is {t}x{t} but B has {n} cols",
t = tri_n
)));
}
}
}
let _ = (fill_mode, diag);
let (trans_left, trans_right) = match side {
Side::Left => (trans_a, Transpose::NoTrans),
Side::Right => (Transpose::NoTrans, trans_a),
};
let _ = (handle, a, b, alpha, trans_left, trans_right);
Ok(())
}
use crate::types::Side;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trmm_validates_square() {
let err = BlasError::InvalidDimension("TRMM: triangular matrix A must be square".into());
assert!(err.to_string().contains("square"));
}
#[test]
fn trmm_validates_zero_dims() {
let err = BlasError::InvalidDimension("TRMM: B dimensions must be non-zero".into());
assert!(err.to_string().contains("non-zero"));
}
#[test]
fn side_left_dimension_check() {
let tri_n: u32 = 64;
let m: u32 = 64;
assert_eq!(tri_n, m);
}
#[test]
fn side_right_dimension_check() {
let tri_n: u32 = 128;
let n: u32 = 128;
assert_eq!(tri_n, n);
}
}