1use crate::error::SolverError;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TriangularSide {
11 Lower,
13 Upper,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum DiagonalType {
20 Unit,
22 NonUnit,
24}
25
26#[derive(Debug)]
28pub struct TrsmResult {
29 pub x: Vec<f32>,
31 pub n: usize,
33 pub nrhs: usize,
35}
36
37pub fn trsm(
43 a: &[f32],
44 b: &[f32],
45 n: usize,
46 nrhs: usize,
47 side: TriangularSide,
48 diag: DiagonalType,
49) -> Result<TrsmResult, SolverError> {
50 if a.len() != n * n {
51 return Err(SolverError::DimensionMismatch {
52 matrix_n: n,
53 rhs_len: a.len(),
54 });
55 }
56 if b.len() != n * nrhs {
57 return Err(SolverError::DimensionMismatch {
58 matrix_n: n,
59 rhs_len: b.len(),
60 });
61 }
62
63 let mut x = b.to_vec();
64
65 match side {
66 TriangularSide::Lower => forward_substitution(a, &mut x, n, nrhs, diag)?,
67 TriangularSide::Upper => back_substitution(a, &mut x, n, nrhs, diag)?,
68 }
69
70 Ok(TrsmResult { x, n, nrhs })
71}
72
73fn apply_diagonal(
75 a: &[f32],
76 n: usize,
77 i: usize,
78 sum: f32,
79 diag: DiagonalType,
80) -> Result<f32, SolverError> {
81 match diag {
82 DiagonalType::Unit => Ok(sum),
83 DiagonalType::NonUnit => {
84 let d = a[i * n + i];
85 if d.abs() < f32::EPSILON {
86 return Err(SolverError::SingularMatrix(i));
87 }
88 Ok(sum / d)
89 }
90 }
91}
92
93fn forward_substitution(
95 a: &[f32],
96 x: &mut [f32],
97 n: usize,
98 nrhs: usize,
99 diag: DiagonalType,
100) -> Result<(), SolverError> {
101 for col in 0..nrhs {
102 for i in 0..n {
103 let mut sum = x[i * nrhs + col];
104 for j in 0..i {
105 sum -= a[i * n + j] * x[j * nrhs + col];
106 }
107 x[i * nrhs + col] = apply_diagonal(a, n, i, sum, diag)?;
108 }
109 }
110 Ok(())
111}
112
113fn back_substitution(
115 a: &[f32],
116 x: &mut [f32],
117 n: usize,
118 nrhs: usize,
119 diag: DiagonalType,
120) -> Result<(), SolverError> {
121 for col in 0..nrhs {
122 for i in (0..n).rev() {
123 let mut sum = x[i * nrhs + col];
124 for j in (i + 1)..n {
125 sum -= a[i * n + j] * x[j * nrhs + col];
126 }
127 x[i * nrhs + col] = apply_diagonal(a, n, i, sum, diag)?;
128 }
129 }
130 Ok(())
131}