use crate::error::SolverError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TriangularSide {
Lower,
Upper,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiagonalType {
Unit,
NonUnit,
}
#[derive(Debug)]
pub struct TrsmResult {
pub x: Vec<f32>,
pub n: usize,
pub nrhs: usize,
}
pub fn trsm(
a: &[f32],
b: &[f32],
n: usize,
nrhs: usize,
side: TriangularSide,
diag: DiagonalType,
) -> Result<TrsmResult, SolverError> {
if a.len() != n * n {
return Err(SolverError::DimensionMismatch {
matrix_n: n,
rhs_len: a.len(),
});
}
if b.len() != n * nrhs {
return Err(SolverError::DimensionMismatch {
matrix_n: n,
rhs_len: b.len(),
});
}
let mut x = b.to_vec();
match side {
TriangularSide::Lower => forward_substitution(a, &mut x, n, nrhs, diag)?,
TriangularSide::Upper => back_substitution(a, &mut x, n, nrhs, diag)?,
}
Ok(TrsmResult { x, n, nrhs })
}
fn apply_diagonal(
a: &[f32],
n: usize,
i: usize,
sum: f32,
diag: DiagonalType,
) -> Result<f32, SolverError> {
match diag {
DiagonalType::Unit => Ok(sum),
DiagonalType::NonUnit => {
let d = a[i * n + i];
if d.abs() < f32::EPSILON {
return Err(SolverError::SingularMatrix(i));
}
Ok(sum / d)
}
}
}
fn forward_substitution(
a: &[f32],
x: &mut [f32],
n: usize,
nrhs: usize,
diag: DiagonalType,
) -> Result<(), SolverError> {
for col in 0..nrhs {
for i in 0..n {
let mut sum = x[i * nrhs + col];
for j in 0..i {
sum -= a[i * n + j] * x[j * nrhs + col];
}
x[i * nrhs + col] = apply_diagonal(a, n, i, sum, diag)?;
}
}
Ok(())
}
fn back_substitution(
a: &[f32],
x: &mut [f32],
n: usize,
nrhs: usize,
diag: DiagonalType,
) -> Result<(), SolverError> {
for col in 0..nrhs {
for i in (0..n).rev() {
let mut sum = x[i * nrhs + col];
for j in (i + 1)..n {
sum -= a[i * n + j] * x[j * nrhs + col];
}
x[i * nrhs + col] = apply_diagonal(a, n, i, sum, diag)?;
}
}
Ok(())
}