lox_core/math/linear_algebra/
tridiagonal.rs1use std::ops::Index;
8
9use thiserror::Error;
10
11#[derive(Clone, Debug, Error, Eq, PartialEq)]
13#[error("lengths of `dl` and `du` must be `d.len() - 1 = {0}` but was {1} and {2}")]
14pub struct LoxTridiagonalError(usize, usize, usize);
15
16type Idx = (usize, usize);
17
18#[derive(Clone, Debug, PartialEq)]
20pub struct Tridiagonal<'a> {
21 dl: &'a [f64],
22 d: &'a [f64],
23 du: &'a [f64],
24}
25
26impl<'a> Tridiagonal<'a> {
27 pub fn new(dl: &'a [f64], d: &'a [f64], du: &'a [f64]) -> Result<Self, LoxTridiagonalError> {
29 let n = d.len();
30 if (dl.len() != n - 1 || du.len() != n - 1)
31 && !(d.is_empty() && dl.is_empty() && du.is_empty())
32 {
33 return Err(LoxTridiagonalError(n - 1, dl.len(), du.len()));
34 }
35 Ok(Self { dl, d, du })
36 }
37
38 pub fn shape(&self) -> (usize, usize) {
40 (self.d.len(), self.d.len())
41 }
42
43 pub fn solve(&self, d: &[f64]) -> Vec<f64> {
45 let n = self.d.len();
46 let a = self.dl;
47 let b = self.d;
48 let c = self.du;
49
50 let mut w = vec![0.0; n - 1];
51 let mut g = vec![0.0; n];
52 let mut p = vec![0.0; n];
53
54 w[0] = c[0] / b[0];
55 g[0] = d[0] / b[0];
56
57 for i in 1..n - 1 {
58 w[i] = c[i] / (b[i] - a[i - 1] * w[i - 1]);
59 }
60 for i in 1..n {
61 g[i] = (d[i] - a[i - 1] * g[i - 1]) / (b[i] - a[i - 1] * w[i - 1]);
62 }
63 p[n - 1] = g[n - 1];
64 for i in (1..n).rev() {
65 p[i - 1] = g[i - 1] - w[i - 1] * p[i];
66 }
67
68 p
69 }
70}
71
72impl Index<Idx> for Tridiagonal<'_> {
73 type Output = f64;
74
75 fn index(&self, (i, j): Idx) -> &Self::Output {
76 let n = self.d.len();
77 if i >= n {
78 panic!("row index out of bounds: the number of rows is {n} but the index is {i}")
79 }
80 if j >= n {
81 panic!("column index out of bounds: the number of columns is {n} but the index is {j}")
82 }
83 if i == j {
84 &self.d[i]
85 } else if i == j + 1 {
86 &self.dl[j]
87 } else if i + 1 == j {
88 &self.du[i]
89 } else {
90 &0.0
91 }
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use lox_test_utils::assert_approx_eq;
98
99 use super::*;
100
101 #[test]
102 fn test_tridiagonal() {
103 let du = vec![1.0, 2.0];
104 let d = vec![3.0, 4.0, 5.0];
105 let dl = vec![6.0, 7.0];
106 let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
107
108 assert_eq!(tri.shape(), (3, 3));
109
110 assert_eq!(&tri[(0, 0)], &3.0);
111 assert_eq!(&tri[(1, 0)], &6.0);
112 assert_eq!(&tri[(2, 0)], &0.0);
113 assert_eq!(&tri[(0, 1)], &1.0);
114 assert_eq!(&tri[(1, 1)], &4.0);
115 assert_eq!(&tri[(2, 1)], &7.0);
116 assert_eq!(&tri[(0, 2)], &0.0);
117 assert_eq!(&tri[(1, 2)], &2.0);
118 assert_eq!(&tri[(2, 2)], &5.0);
119 }
120
121 #[test]
122 fn test_tridiagonal_error() {
123 let du = vec![1.0, 2.0];
124 let d = vec![3.0, 4.0, 5.0];
125 let dl = vec![6.0];
126 let tri = Tridiagonal::new(&dl, &d, &du);
127
128 assert_eq!(tri, Err(LoxTridiagonalError(2, 1, 2)));
129 }
130
131 #[test]
132 fn test_tridiagonal_solve() {
133 let du: Vec<f64> = vec![1.0, 2.0];
134 let d: Vec<f64> = vec![3.0, 4.0, 5.0];
135 let dl: Vec<f64> = vec![6.0, 7.0];
136 let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
137
138 let b = vec![1.0, 2.0, 3.0];
139 let x = tri.solve(&b);
140 let exp = [-0.1666666666666666, 1.5, -1.5];
141
142 assert_approx_eq!(x[0], exp[0], rtol <= 1e-14);
143 assert_approx_eq!(x[1], exp[1], rtol <= 1e-14);
144 assert_approx_eq!(x[2], exp[2], rtol <= 1e-14);
145 }
146
147 #[test]
148 #[should_panic(expected = "row index out of bounds")]
149 fn test_tridiagonal_invalid_row() {
150 let du = vec![1.0, 2.0];
151 let d = vec![3.0, 4.0, 5.0];
152 let dl = vec![6.0, 7.0];
153 let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
154 let _x = tri[(3, 0)];
155 }
156
157 #[test]
158 #[should_panic(expected = "column index out of bounds")]
159 fn test_tridiagonal_invalid_column() {
160 let du = vec![1.0, 2.0];
161 let d = vec![3.0, 4.0, 5.0];
162 let dl = vec![6.0, 7.0];
163 let tri = Tridiagonal::new(&dl, &d, &du).expect("should be valid");
164 let _x = tri[(0, 3)];
165 }
166}