Skip to main content

lox_core/math/linear_algebra/
tridiagonal.rs

1// SPDX-FileCopyrightText: 2024 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5//! Tridiagonal matrix representation and solver.
6
7use std::ops::Index;
8
9use thiserror::Error;
10
11/// Error returned when the diagonal dimensions of a [`Tridiagonal`] matrix are inconsistent.
12#[derive(Clone, Debug, Error, Eq, PartialEq)]
13#[error("lengths of `dl` and `du` must be `d.len() - 1 = {0}` but was {1} and {2}")]
14pub struct LoxTridiagonalError(usize, usize, usize);
15
16type Idx = (usize, usize);
17
18/// A tridiagonal matrix representation
19#[derive(Clone, Debug, PartialEq)]
20pub struct Tridiagonal<'a> {
21    dl: &'a [f64],
22    d: &'a [f64],
23    du: &'a [f64],
24}
25
26impl<'a> Tridiagonal<'a> {
27    /// Creates a new tridiagonal matrix from lower, main, and upper diagonals.
28    pub fn new(dl: &'a [f64], d: &'a [f64], du: &'a [f64]) -> Result<Self, LoxTridiagonalError> {
29        let n = d.len();
30        if (dl.len() != n - 1 || du.len() != n - 1)
31            && !(d.is_empty() && dl.is_empty() && du.is_empty())
32        {
33            return Err(LoxTridiagonalError(n - 1, dl.len(), du.len()));
34        }
35        Ok(Self { dl, d, du })
36    }
37
38    /// Returns the shape `(n, n)` of the matrix.
39    pub fn shape(&self) -> (usize, usize) {
40        (self.d.len(), self.d.len())
41    }
42
43    /// Solves the tridiagonal system `Ax = d` using the Thomas algorithm.
44    pub fn solve(&self, d: &[f64]) -> Vec<f64> {
45        let n = self.d.len();
46        let a = self.dl;
47        let b = self.d;
48        let c = self.du;
49
50        let mut w = vec![0.0; n - 1];
51        let mut g = vec![0.0; n];
52        let mut p = vec![0.0; n];
53
54        w[0] = c[0] / b[0];
55        g[0] = d[0] / b[0];
56
57        for i in 1..n - 1 {
58            w[i] = c[i] / (b[i] - a[i - 1] * w[i - 1]);
59        }
60        for i in 1..n {
61            g[i] = (d[i] - a[i - 1] * g[i - 1]) / (b[i] - a[i - 1] * w[i - 1]);
62        }
63        p[n - 1] = g[n - 1];
64        for i in (1..n).rev() {
65            p[i - 1] = g[i - 1] - w[i - 1] * p[i];
66        }
67
68        p
69    }
70}
71
72impl Index<Idx> for Tridiagonal<'_> {
73    type Output = f64;
74
75    fn index(&self, (i, j): Idx) -> &Self::Output {
76        let n = self.d.len();
77        if i >= n {
78            panic!("row index out of bounds: the number of rows is {n} but the index is {i}")
79        }
80        if j >= n {
81            panic!("column index out of bounds: the number of columns is {n} but the index is {j}")
82        }
83        if i == j {
84            &self.d[i]
85        } else if i == j + 1 {
86            &self.dl[j]
87        } else if i + 1 == j {
88            &self.du[i]
89        } else {
90            &0.0
91        }
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use lox_test_utils::assert_approx_eq;
98
99    use super::*;
100
101    #[test]
102    fn test_tridiagonal() {
103        let du = vec![1.0, 2.0];
104        let d = vec![3.0, 4.0, 5.0];
105        let dl = vec![6.0, 7.0];
106        let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
107
108        assert_eq!(tri.shape(), (3, 3));
109
110        assert_eq!(&tri[(0, 0)], &3.0);
111        assert_eq!(&tri[(1, 0)], &6.0);
112        assert_eq!(&tri[(2, 0)], &0.0);
113        assert_eq!(&tri[(0, 1)], &1.0);
114        assert_eq!(&tri[(1, 1)], &4.0);
115        assert_eq!(&tri[(2, 1)], &7.0);
116        assert_eq!(&tri[(0, 2)], &0.0);
117        assert_eq!(&tri[(1, 2)], &2.0);
118        assert_eq!(&tri[(2, 2)], &5.0);
119    }
120
121    #[test]
122    fn test_tridiagonal_error() {
123        let du = vec![1.0, 2.0];
124        let d = vec![3.0, 4.0, 5.0];
125        let dl = vec![6.0];
126        let tri = Tridiagonal::new(&dl, &d, &du);
127
128        assert_eq!(tri, Err(LoxTridiagonalError(2, 1, 2)));
129    }
130
131    #[test]
132    fn test_tridiagonal_solve() {
133        let du: Vec<f64> = vec![1.0, 2.0];
134        let d: Vec<f64> = vec![3.0, 4.0, 5.0];
135        let dl: Vec<f64> = vec![6.0, 7.0];
136        let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
137
138        let b = vec![1.0, 2.0, 3.0];
139        let x = tri.solve(&b);
140        let exp = [-0.1666666666666666, 1.5, -1.5];
141
142        assert_approx_eq!(x[0], exp[0], rtol <= 1e-14);
143        assert_approx_eq!(x[1], exp[1], rtol <= 1e-14);
144        assert_approx_eq!(x[2], exp[2], rtol <= 1e-14);
145    }
146
147    #[test]
148    #[should_panic(expected = "row index out of bounds")]
149    fn test_tridiagonal_invalid_row() {
150        let du = vec![1.0, 2.0];
151        let d = vec![3.0, 4.0, 5.0];
152        let dl = vec![6.0, 7.0];
153        let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
154        let _x = tri[(3, 0)];
155    }
156
157    #[test]
158    #[should_panic(expected = "column index out of bounds")]
159    fn test_tridiagonal_invalid_column() {
160        let du = vec![1.0, 2.0];
161        let d = vec![3.0, 4.0, 5.0];
162        let dl = vec![6.0, 7.0];
163        let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
164        let _x = tri[(0, 3)];
165    }
166}