funspace/chebyshev/linalg/
pdma.rs1use num_traits::Zero;
5use std::clone::Clone;
6use std::ops::{Add, Div, Mul, Sub};
7
8#[allow(clippy::many_single_char_names)]
18#[inline]
19pub fn pdma<T1, T2>(l2: &[T1], l1: &[T1], d0: &[T1], u1: &[T1], u2: &[T1], rhs: &mut [T2])
20where
21    T1: Mul<Output = T1>
22        + Sub<Output = T1>
23        + Div<Output = T1>
24        + Add<Output = T1>
25        + Zero
26        + Clone
27        + Copy,
28    T2: Mul<Output = T2>
29        + Sub<Output = T2>
30        + Div<Output = T2>
31        + Add<Output = T2>
32        + Add<T1, Output = T2>
33        + Mul<T1, Output = T2>
34        + Div<T1, Output = T2>
35        + Sub<T1, Output = T2>
36        + Zero
37        + Clone
38        + Copy,
39{
40    let n = rhs.len();
41    assert!(n > 3, "Error in pdma: Size = {} too small!", n);
42
43    let mut al = vec![T1::zero(); n - 1];
44    let mut be = vec![T1::zero(); n - 2];
45    let mut ze = vec![T2::zero(); n];
46    let mut ga = vec![T1::zero(); n];
47    let mut mu = vec![T1::zero(); n];
48
49    unsafe {
50        *mu.get_unchecked_mut(0) = *d0.get_unchecked(0);
51        *al.get_unchecked_mut(0) = *u1.get_unchecked(0) / *mu.get_unchecked(0);
52        *be.get_unchecked_mut(0) = *u2.get_unchecked(0) / *mu.get_unchecked(0);
53        *ze.get_unchecked_mut(0) = *rhs.get_unchecked(0) / *mu.get_unchecked(0);
54
55        *ga.get_unchecked_mut(1) = *l1.get_unchecked(0);
56        *mu.get_unchecked_mut(1) =
57            *d0.get_unchecked(1) - *al.get_unchecked(0) * *ga.get_unchecked(1);
58        *al.get_unchecked_mut(1) = (*u1.get_unchecked(1)
59            - *be.get_unchecked(0) * *ga.get_unchecked(1))
60            / *mu.get_unchecked(1);
61        *be.get_unchecked_mut(1) = *u2.get_unchecked(1) / *mu.get_unchecked(1);
62        *ze.get_unchecked_mut(1) = (*rhs.get_unchecked(1)
63            - *ze.get_unchecked(0) * *ga.get_unchecked(1))
64            / *mu.get_unchecked(1);
65
66        for i in 2..n - 2 {
67            *ga.get_unchecked_mut(i) =
68                *l1.get_unchecked(i - 1) - *al.get_unchecked(i - 2) * *l2.get_unchecked(i - 2);
69            *mu.get_unchecked_mut(i) = *d0.get_unchecked(i)
70                - *be.get_unchecked(i - 2) * *l2.get_unchecked(i - 2)
71                - *al.get_unchecked(i - 1) * *ga.get_unchecked(i);
72            *al.get_unchecked_mut(i) = (*u1.get_unchecked(i)
73                - *be.get_unchecked(i - 1) * *ga.get_unchecked(i))
74                / *mu.get_unchecked(i);
75            *be.get_unchecked_mut(i) = *u2.get_unchecked(i) / *mu.get_unchecked(i);
76            *ze.get_unchecked_mut(i) = (*rhs.get_unchecked(i)
77                - *ze.get_unchecked(i - 2) * *l2.get_unchecked(i - 2)
78                - *ze.get_unchecked(i - 1) * *ga.get_unchecked(i))
79                / *mu.get_unchecked(i);
80        }
81
82        *ga.get_unchecked_mut(n - 2) =
83            *l1.get_unchecked(n - 3) - *al.get_unchecked(n - 4) * *l2.get_unchecked(n - 4);
84        *mu.get_unchecked_mut(n - 2) = *d0.get_unchecked(n - 2)
85            - *be.get_unchecked(n - 4) * *l2.get_unchecked(n - 4)
86            - *al.get_unchecked(n - 3) * *ga.get_unchecked(n - 2);
87        *al.get_unchecked_mut(n - 2) = (*u1.get_unchecked(n - 2)
88            - *be.get_unchecked(n - 3) * *ga.get_unchecked(n - 2))
89            / *mu.get_unchecked(n - 2);
90        *ze.get_unchecked_mut(n - 2) = (*rhs.get_unchecked(n - 2)
91            - *ze.get_unchecked(n - 4) * *l2.get_unchecked(n - 4)
92            - *ze.get_unchecked(n - 3) * *ga.get_unchecked(n - 2))
93            / *mu.get_unchecked(n - 2);
94
95        *ga.get_unchecked_mut(n - 1) =
96            *l1.get_unchecked(n - 2) - *al.get_unchecked(n - 3) * *l2.get_unchecked(n - 3);
97        *mu.get_unchecked_mut(n - 1) = *d0.get_unchecked(n - 1)
98            - *be.get_unchecked(n - 3) * *l2.get_unchecked(n - 3)
99            - *al.get_unchecked(n - 2) * *ga.get_unchecked(n - 1);
100        *ze.get_unchecked_mut(n - 1) = (*rhs.get_unchecked(n - 1)
101            - *ze.get_unchecked(n - 3) * *l2.get_unchecked(n - 3)
102            - *ze.get_unchecked(n - 2) * *ga.get_unchecked(n - 1))
103            / *mu.get_unchecked(n - 1);
104
105        *rhs.get_unchecked_mut(n - 1) = *ze.get_unchecked(n - 1);
107        *rhs.get_unchecked_mut(n - 2) =
108            *ze.get_unchecked(n - 2) - *rhs.get_unchecked(n - 1) * *al.get_unchecked(n - 2);
109
110        for i in (0..n - 2).rev() {
111            *rhs.get_unchecked_mut(i) = *ze.get_unchecked(i)
112                - *rhs.get_unchecked(i + 1) * *al.get_unchecked(i)
113                - *rhs.get_unchecked(i + 2) * *be.get_unchecked(i);
114        }
115    }
116}
117
118#[allow(clippy::many_single_char_names, dead_code)]
128#[inline]
129pub fn pdma_checked<T1, T2>(l2: &[T1], l1: &[T1], d0: &[T1], u1: &[T1], u2: &[T1], rhs: &mut [T2])
130where
131    T1: Mul<Output = T1>
132        + Sub<Output = T1>
133        + Div<Output = T1>
134        + Add<Output = T1>
135        + Zero
136        + Clone
137        + Copy,
138    T2: Mul<Output = T2>
139        + Sub<Output = T2>
140        + Div<Output = T2>
141        + Add<Output = T2>
142        + Add<T1, Output = T2>
143        + Mul<T1, Output = T2>
144        + Div<T1, Output = T2>
145        + Sub<T1, Output = T2>
146        + Zero
147        + Clone
148        + Copy,
149{
150    let n = rhs.len();
151    assert!(n > 3, "Error in pdma: Size = {} too small!", n);
152
153    let mut al = vec![T1::zero(); n - 1];
154    let mut be = vec![T1::zero(); n - 2];
155    let mut ze = vec![T2::zero(); n];
156    let mut ga = vec![T1::zero(); n];
157    let mut mu = vec![T1::zero(); n];
158
159    mu[0] = d0[0];
160    al[0] = u1[0] / mu[0];
161    be[0] = u2[0] / mu[0];
162    ze[0] = rhs[0] / mu[0];
163
164    ga[1] = l1[0];
165    mu[1] = d0[1] - al[0] * ga[1];
166    al[1] = (u1[1] - be[0] * ga[1]) / mu[1];
167    be[1] = u2[1] / mu[1];
168    ze[1] = (rhs[1] - ze[0] * ga[1]) / mu[1];
169
170    for i in 2..n - 2 {
171        ga[i] = l1[i - 1] - al[i - 2] * l2[i - 2];
172        mu[i] = d0[i] - be[i - 2] * l2[i - 2] - al[i - 1] * ga[i];
173        al[i] = (u1[i] - be[i - 1] * ga[i]) / mu[i];
174        be[i] = u2[i] / mu[i];
175        ze[i] = (rhs[i] - ze[i - 2] * l2[i - 2] - ze[i - 1] * ga[i]) / mu[i];
176    }
177
178    ga[n - 2] = l1[n - 3] - al[n - 4] * l2[n - 4];
179    mu[n - 2] = d0[n - 2] - be[n - 4] * l2[n - 4] - al[n - 3] * ga[n - 2];
180    al[n - 2] = (u1[n - 2] - be[n - 3] * ga[n - 2]) / mu[n - 2];
181    ze[n - 2] = (rhs[n - 2] - ze[n - 4] * l2[n - 4] - ze[n - 3] * ga[n - 2]) / mu[n - 2];
182
183    ga[n - 1] = l1[n - 2] - al[n - 3] * l2[n - 3];
184    mu[n - 1] = d0[n - 1] - be[n - 3] * l2[n - 3] - al[n - 2] * ga[n - 1];
185    ze[n - 1] = (rhs[n - 1] - ze[n - 3] * l2[n - 3] - ze[n - 2] * ga[n - 1]) / mu[n - 1];
186
187    rhs[n - 1] = ze[n - 1];
189    rhs[n - 2] = ze[n - 2] - rhs[n - 1] * al[n - 2];
190
191    for i in (0..n - 2).rev() {
192        rhs[i] = ze[i] - rhs[i + 1] * al[i] - rhs[i + 2] * be[i];
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use ndarray::{Array1, Array2};
200    #[test]
201    fn test_pdma() {
203        let n = 8;
204        let b = (1..n + 1).map(|x| x as f64).collect::<Vec<f64>>();
205        let l2 = (1..n - 1).map(|x| 1.5 * x as f64).collect::<Vec<f64>>();
207        let l1 = (1..n).map(|x| -2.5 * x as f64).collect::<Vec<f64>>();
208        let d0 = (1..n + 1).map(|x| 1.0 * x as f64).collect::<Vec<f64>>();
209        let u1 = (1..n).map(|x| 3.5 * x as f64).collect::<Vec<f64>>();
210        let u2 = (1..n - 1).map(|x| -0.5 * x as f64).collect::<Vec<f64>>();
211        let mut mat = Array2::<f64>::zeros((n, n));
213        for i in 0..n {
214            mat[[i, i]] = d0[i];
215        }
216        for i in 0..n - 1 {
217            mat[[i + 1, i]] = l1[i];
218        }
219        for i in 0..n - 2 {
220            mat[[i + 2, i]] = l2[i];
221        }
222        for i in 0..n - 1 {
223            mat[[i, i + 1]] = u1[i];
224        }
225        for i in 0..n - 2 {
226            mat[[i, i + 2]] = u2[i];
227        }
228
229        let mut rhs = b.clone();
231        pdma(&l2, &l1, &d0, &u1, &u2, &mut rhs);
232        let b2 = mat.dot(&Array1::from_vec(rhs));
234        for (v1, v2) in b.iter().zip(b2.iter()) {
235            assert!((v1 - v2).abs() < 1e-6, "PDMA failed, {} /= {}.", v1, v2);
236        }
237
238        let mut rhs = b.clone();
240        pdma_checked(&l2, &l1, &d0, &u1, &u2, &mut rhs);
241        let b2 = mat.dot(&Array1::from_vec(rhs));
243        for (v1, v2) in b.iter().zip(b2.iter()) {
244            assert!((v1 - v2).abs() < 1e-6, "PDMA failed, {} /= {}.", v1, v2);
245        }
246    }
247}