Skip to main content

trueno_solve/
trsm.rs

1//! Triangular solve: AX = B (TRSM — cuBLAS parity).
2//!
3//! Solves for X where A is an n×n triangular matrix and B is n×nrhs.
4//! Supports lower/upper triangular, unit/non-unit diagonal.
5
6use crate::error::SolverError;
7
8/// Whether the triangular matrix is lower or upper.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum TriangularSide {
11    /// Lower triangular (L[i][j] = 0 for j > i).
12    Lower,
13    /// Upper triangular (U[i][j] = 0 for j < i).
14    Upper,
15}
16
17/// Whether the diagonal is unit (all ones) or general.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum DiagonalType {
20    /// Diagonal elements are 1 (not stored/used).
21    Unit,
22    /// Diagonal elements are general (read from matrix).
23    NonUnit,
24}
25
26/// Triangular solve result.
27#[derive(Debug)]
28pub struct TrsmResult {
29    /// Solution matrix X, stored row-major n×nrhs.
30    pub x: Vec<f32>,
31    /// Number of rows/columns of A.
32    pub n: usize,
33    /// Number of right-hand sides (columns of B).
34    pub nrhs: usize,
35}
36
37/// Solve AX = B where A is triangular.
38///
39/// # Errors
40///
41/// Returns error if dimensions don't match or A has zero diagonal.
42pub fn trsm(
43    a: &[f32],
44    b: &[f32],
45    n: usize,
46    nrhs: usize,
47    side: TriangularSide,
48    diag: DiagonalType,
49) -> Result<TrsmResult, SolverError> {
50    if a.len() != n * n {
51        return Err(SolverError::DimensionMismatch {
52            matrix_n: n,
53            rhs_len: a.len(),
54        });
55    }
56    if b.len() != n * nrhs {
57        return Err(SolverError::DimensionMismatch {
58            matrix_n: n,
59            rhs_len: b.len(),
60        });
61    }
62
63    let mut x = b.to_vec();
64
65    match side {
66        TriangularSide::Lower => forward_substitution(a, &mut x, n, nrhs, diag)?,
67        TriangularSide::Upper => back_substitution(a, &mut x, n, nrhs, diag)?,
68    }
69
70    Ok(TrsmResult { x, n, nrhs })
71}
72
73/// Apply diagonal scaling (unit or non-unit) to a substitution result.
74fn apply_diagonal(
75    a: &[f32],
76    n: usize,
77    i: usize,
78    sum: f32,
79    diag: DiagonalType,
80) -> Result<f32, SolverError> {
81    match diag {
82        DiagonalType::Unit => Ok(sum),
83        DiagonalType::NonUnit => {
84            let d = a[i * n + i];
85            if d.abs() < f32::EPSILON {
86                return Err(SolverError::SingularMatrix(i));
87            }
88            Ok(sum / d)
89        }
90    }
91}
92
93/// Forward substitution for lower triangular systems.
94fn forward_substitution(
95    a: &[f32],
96    x: &mut [f32],
97    n: usize,
98    nrhs: usize,
99    diag: DiagonalType,
100) -> Result<(), SolverError> {
101    for col in 0..nrhs {
102        for i in 0..n {
103            let mut sum = x[i * nrhs + col];
104            for j in 0..i {
105                sum -= a[i * n + j] * x[j * nrhs + col];
106            }
107            x[i * nrhs + col] = apply_diagonal(a, n, i, sum, diag)?;
108        }
109    }
110    Ok(())
111}
112
113/// Back substitution for upper triangular systems.
114fn back_substitution(
115    a: &[f32],
116    x: &mut [f32],
117    n: usize,
118    nrhs: usize,
119    diag: DiagonalType,
120) -> Result<(), SolverError> {
121    for col in 0..nrhs {
122        for i in (0..n).rev() {
123            let mut sum = x[i * nrhs + col];
124            for j in (i + 1)..n {
125                sum -= a[i * n + j] * x[j * nrhs + col];
126            }
127            x[i * nrhs + col] = apply_diagonal(a, n, i, sum, diag)?;
128        }
129    }
130    Ok(())
131}