use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{DiagType, FillMode, GpuFloat, MatrixDesc, MatrixDescMut, Side, Transpose};
const TRSM_BLOCK_SIZE: u32 = 64;
#[allow(clippy::too_many_arguments)]
pub fn trsm<T: GpuFloat>(
handle: &BlasHandle,
side: Side,
fill_mode: FillMode,
trans_a: Transpose,
diag: DiagType,
alpha: T,
a: &MatrixDesc<T>,
b: &mut MatrixDescMut<T>,
) -> BlasResult<()> {
if a.rows != a.cols {
return Err(BlasError::InvalidDimension(format!(
"TRSM: triangular matrix A must be square, got {}x{}",
a.rows, a.cols
)));
}
let tri_n = a.rows;
let m = b.rows;
let n = b.cols;
if m == 0 || n == 0 {
return Err(BlasError::InvalidDimension(
"TRSM: B dimensions must be non-zero".into(),
));
}
match side {
Side::Left => {
if tri_n != m {
return Err(BlasError::DimensionMismatch(format!(
"TRSM left: A is {t}x{t} but B has {m} rows",
t = tri_n
)));
}
}
Side::Right => {
if tri_n != n {
return Err(BlasError::DimensionMismatch(format!(
"TRSM right: A is {t}x{t} but B has {n} cols",
t = tri_n
)));
}
}
}
blocked_trsm(handle, side, fill_mode, trans_a, diag, alpha, a, b)
}
#[allow(clippy::too_many_arguments)]
fn blocked_trsm<T: GpuFloat>(
handle: &BlasHandle,
side: Side,
fill_mode: FillMode,
trans_a: Transpose,
diag: DiagType,
alpha: T,
a: &MatrixDesc<T>,
b: &mut MatrixDescMut<T>,
) -> BlasResult<()> {
let tri_n = a.rows;
let nb = TRSM_BLOCK_SIZE.min(tri_n);
let lower_solve = match (fill_mode, trans_a) {
(FillMode::Lower, Transpose::NoTrans) => true,
(FillMode::Upper, Transpose::NoTrans) => false,
(FillMode::Lower, Transpose::Trans | Transpose::ConjTrans) => false,
(FillMode::Upper, Transpose::Trans | Transpose::ConjTrans) => true,
(FillMode::Full, _) => true, };
let num_blocks = tri_n.div_ceil(nb);
for block_idx in 0..num_blocks {
let idx = if lower_solve {
block_idx
} else {
num_blocks - 1 - block_idx
};
let block_start = idx * nb;
let block_end = (block_start + nb).min(tri_n);
let block_size = block_end - block_start;
let _diag_params = DiagBlockParams {
side,
fill_mode,
trans_a,
diag,
alpha: if block_idx == 0 { alpha } else { T::gpu_one() },
block_start,
block_size,
};
let remaining = if lower_solve {
tri_n.saturating_sub(block_end)
} else {
block_start
};
if remaining > 0 {
let _gemm_update_size = (remaining, block_size);
}
}
let _ = (handle, a, b);
Ok(())
}
#[allow(dead_code)]
struct DiagBlockParams<T: GpuFloat> {
side: Side,
fill_mode: FillMode,
trans_a: Transpose,
diag: DiagType,
alpha: T,
block_start: u32,
block_size: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trsm_block_size_positive() {
const { assert!(TRSM_BLOCK_SIZE > 0) };
}
#[test]
fn validate_non_square_error_message() {
let err = BlasError::InvalidDimension("TRSM: triangular matrix A must be square".into());
assert!(err.to_string().contains("square"));
}
#[test]
fn blocked_iteration_count() {
let tri_n = 256u32;
let nb = TRSM_BLOCK_SIZE.min(tri_n);
let num_blocks = tri_n.div_ceil(nb);
assert_eq!(num_blocks, 4);
}
#[test]
fn blocked_iteration_count_non_divisible() {
let tri_n = 300u32;
let nb = TRSM_BLOCK_SIZE.min(tri_n);
let num_blocks = tri_n.div_ceil(nb);
assert_eq!(num_blocks, 5);
}
#[test]
fn diag_type_values() {
assert_ne!(DiagType::Unit, DiagType::NonUnit);
}
}