num_dual/
linalg.rs

1//! Basic linear algebra functionalities (linear solve and eigenvalues) for matrices containing dual numbers.
2use crate::DualNum;
3use nalgebra::allocator::Allocator;
4use nalgebra::{DefaultAllocator, Dim, OMatrix, OVector, U1};
5use num_traits::Float;
6use std::fmt;
7use std::iter::Product;
8use std::marker::PhantomData;
9
10/// Error type for fallible linear algebra operations.
11#[derive(Debug)]
12pub struct LinAlgError();
13
14impl fmt::Display for LinAlgError {
15    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16        write!(f, "The matrix appears to be singular.")
17    }
18}
19
20impl std::error::Error for LinAlgError {}
21
22/// LU decomposition for symmetric matrices with dual numbers as elements.
23pub struct LU<T: DualNum<F>, F, D: Dim>
24where
25    DefaultAllocator: Allocator<D, D> + Allocator<D>,
26{
27    a: OMatrix<T, D, D>,
28    p: OVector<usize, D>,
29    p_count: usize,
30    f: PhantomData<F>,
31}
32
33impl<T: DualNum<F> + Copy, F: Float, D: Dim> LU<T, F, D>
34where
35    DefaultAllocator: Allocator<D, D> + Allocator<D>,
36{
37    pub fn new(mut a: OMatrix<T, D, D>) -> Result<Self, LinAlgError> {
38        let (n, _) = a.shape_generic();
39        let mut p = OVector::zeros_generic(n, U1);
40        let n = n.value();
41        let mut p_count = n;
42
43        for i in 0..n {
44            p[i] = i;
45        }
46
47        for i in 0..n {
48            let mut max_a = F::zero();
49            let mut imax = i;
50
51            for k in i..n {
52                let abs_a = a[(k, i)].abs();
53                if abs_a.re() > max_a {
54                    max_a = abs_a.re();
55                    imax = k;
56                }
57            }
58
59            if max_a.is_zero() {
60                return Err(LinAlgError());
61            }
62
63            if imax != i {
64                let j = p[i];
65                p[i] = p[imax];
66                p[imax] = j;
67
68                for j in 0..n {
69                    let ptr = a[(i, j)];
70                    a[(i, j)] = a[(imax, j)];
71                    a[(imax, j)] = ptr;
72                }
73
74                p_count += 1;
75            }
76
77            for j in i + 1..n {
78                a[(j, i)] = a[(j, i)] / a[(i, i)];
79
80                for k in i + 1..n {
81                    a[(j, k)] = a[(j, k)] - a[(j, i)] * a[(i, k)];
82                }
83            }
84        }
85        Ok(LU {
86            a,
87            p,
88            p_count,
89            f: PhantomData,
90        })
91    }
92
93    pub fn solve(&self, b: &OVector<T, D>) -> OVector<T, D> {
94        let (n, _) = b.shape_generic();
95        let mut x = OVector::zeros_generic(n, U1);
96        let n = n.value();
97
98        for i in 0..n {
99            x[i] = b[self.p[i]];
100
101            for k in 0..i {
102                x[i] = x[i] - self.a[(i, k)] * x[k];
103            }
104        }
105
106        for i in (0..n).rev() {
107            for k in i + 1..n {
108                x[i] = x[i] - self.a[(i, k)] * x[k];
109            }
110
111            x[i] /= self.a[(i, i)];
112        }
113
114        x
115    }
116
117    pub fn determinant(&self) -> T
118    where
119        T: Product,
120    {
121        let n = self.p.len();
122        let det = (0..n).map(|i| self.a[(i, i)]).product();
123
124        if (self.p_count - n).is_multiple_of(2) {
125            det
126        } else {
127            -det
128        }
129    }
130
131    pub fn inverse(&self) -> OMatrix<T, D, D> {
132        let (r, c) = self.a.shape_generic();
133        let n = self.p.len();
134        let mut ia = OMatrix::zeros_generic(r, c);
135
136        for j in 0..n {
137            for i in 0..n {
138                ia[(i, j)] = if self.p[i] == j { T::one() } else { T::zero() };
139
140                for k in 0..i {
141                    ia[(i, j)] = ia[(i, j)] - self.a[(i, k)] * ia[(k, j)];
142                }
143            }
144
145            for i in (0..n).rev() {
146                for k in i + 1..n {
147                    ia[(i, j)] = ia[(i, j)] - self.a[(i, k)] * ia[(k, j)];
148                }
149                ia[(i, j)] /= self.a[(i, i)];
150            }
151        }
152
153        ia
154    }
155}
156
157/// Smallest eigenvalue and corresponding eigenvector calculated using the full Jacobi
158/// eigenvalue algorithm ([`jacobi_eigenvalue`]).
159pub fn smallest_ev<T: DualNum<F> + Copy, F: Float, D: Dim>(
160    a: OMatrix<T, D, D>,
161) -> (T, OVector<T, D>)
162where
163    DefaultAllocator: Allocator<D, D> + Allocator<D>,
164{
165    let (r, _) = a.shape_generic();
166    let n = r.value();
167    if n == 1 {
168        (a[(0, 0)], OVector::from_element_generic(r, U1, T::one()))
169    } else if n == 2 {
170        let (a, b, c) = (a[(0, 0)], a[(0, 1)], a[(1, 1)]);
171        let l = (a + c - ((a - c).powi(2) + b * b * F::from(4.0).unwrap()).sqrt())
172            * F::from(0.5).unwrap();
173        let u = OVector::from_fn_generic(r, U1, |i, _| [b, l - a][i]);
174        let u = u / (b * b + (l - a) * (l - a)).sqrt();
175        (l, u)
176    } else {
177        let (e, vecs) = jacobi_eigenvalue(a, 200);
178        (e[0], vecs.column(0).into_owned())
179    }
180}
181
182/// Eigenvalues and corresponding eigenvectors of a symmetric matrix.
183pub fn jacobi_eigenvalue<T: DualNum<F> + Copy, F: Float, D: Dim>(
184    mut a: OMatrix<T, D, D>,
185    max_iter: usize,
186) -> (OVector<T, D>, OMatrix<T, D, D>)
187where
188    DefaultAllocator: Allocator<D, D> + Allocator<D>,
189{
190    let (r, c) = a.shape_generic();
191    let n = r.value();
192
193    let mut v = OMatrix::identity_generic(r, c);
194    let mut d = a.diagonal().to_owned();
195
196    let mut bw = d.clone();
197    let mut zw = OVector::zeros_generic(r, U1);
198
199    for it_num in 0..max_iter {
200        let mut thresh = F::zero();
201        for j in 0..n {
202            for i in 0..j {
203                thresh = thresh + a[(i, j)].re().powi(2);
204            }
205        }
206        thresh = thresh.sqrt() / F::from(n).unwrap();
207
208        if thresh.is_zero() {
209            break;
210        }
211
212        for p in 0..n {
213            for q in p + 1..n {
214                let gapq = a[(p, q)].abs() * F::from(10.0).unwrap();
215                let termp = gapq + d[p].abs();
216                let termq = gapq + d[q].abs();
217
218                if 4 < it_num && termp == d[p].abs() && termq == d[q].abs() {
219                    a[(p, q)] = T::zero();
220                } else if thresh <= a[(p, q)].re().abs() {
221                    let h = d[q] - d[p];
222                    let term = h.abs() + gapq;
223
224                    let t = if term == h.abs() {
225                        a[(p, q)] / h
226                    } else {
227                        let theta = h * F::from(0.5).unwrap() / a[(p, q)];
228                        let mut t = (theta.abs() + (theta * theta + F::one()).sqrt()).recip();
229                        if theta.is_negative() {
230                            t = -t;
231                        }
232                        t
233                    };
234
235                    let c = (t * t + F::one()).sqrt().recip();
236                    let s = t * c;
237                    let tau = s / (c + F::one());
238                    let h = t * a[(p, q)];
239
240                    zw[p] -= h;
241                    zw[q] += h;
242                    d[p] -= h;
243                    d[q] += h;
244
245                    a[(p, q)] = T::zero();
246
247                    for j in 0..p {
248                        let g = a[(j, p)];
249                        let h = a[(j, q)];
250                        a[(j, p)] = g - s * (h + g * tau);
251                        a[(j, q)] = h + s * (g - h * tau);
252                    }
253
254                    for j in p + 1..q {
255                        let g = a[(p, j)];
256                        let h = a[(j, q)];
257                        a[(p, j)] = g - s * (h + g * tau);
258                        a[(j, q)] = h + s * (g - h * tau);
259                    }
260
261                    for j in q + 1..n {
262                        let g = a[(p, j)];
263                        let h = a[(q, j)];
264                        a[(p, j)] = g - s * (h + g * tau);
265                        a[(q, j)] = h + s * (g - h * tau);
266                    }
267
268                    for j in 0..n {
269                        let g = v[(j, p)];
270                        let h = v[(j, q)];
271                        v[(j, p)] = g - s * (h + g * tau);
272                        v[(j, q)] = h + s * (g - h * tau);
273                    }
274                }
275            }
276        }
277
278        bw += &zw;
279        d = bw.clone();
280        zw.fill(T::zero());
281    }
282
283    for k in 0..n - 1 {
284        let mut m = k;
285
286        for l in k + 1..n {
287            if d[l].re() < d[m].re() {
288                m = l;
289            }
290        }
291
292        if m != k {
293            d.swap_rows(m, k);
294
295            for l in 0..n {
296                v.swap((l, m), (l, k));
297            }
298        }
299    }
300
301    (d, v)
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::Dual64;
308    use approx::assert_abs_diff_eq;
309    use nalgebra::{dmatrix, dvector};
310
311    #[test]
312    fn test_solve_f64() {
313        let a = dmatrix![4.0, 3.0; 6.0, 3.0];
314        let b = dvector![10.0, 12.0];
315        let lu = LU::new(a).unwrap();
316        assert_eq!(lu.determinant(), -6.0);
317        assert_eq!(lu.solve(&b), dvector![1.0, 2.0]);
318        assert_eq!(
319            lu.inverse() * lu.determinant(),
320            dmatrix![3.0, -3.0; -6.0, 4.0]
321        );
322    }
323
324    #[test]
325    fn test_solve_dual64() {
326        let a = dmatrix![
327            Dual64::new(4.0, 3.0), Dual64::new(3.0, 3.0);
328            Dual64::new(6.0, 1.0), Dual64::new(3.0, 2.0)
329        ];
330        let b = dvector![Dual64::new(10.0, 20.0), Dual64::new(12.0, 20.0)];
331        let lu = LU::new(a).unwrap();
332        let det = lu.determinant();
333        assert_eq!((det.re, det.eps), (-6.0, -4.0));
334        let x = lu.solve(&b);
335        assert_eq!((x[0].re, x[0].eps, x[1].re, x[1].eps), (1.0, 2.0, 2.0, 1.0));
336    }
337
338    #[test]
339    fn test_eig_f64_2() {
340        let a = dmatrix![2.0, 2.0; 2.0, 5.0];
341        let (l, v) = jacobi_eigenvalue(a.clone(), 200);
342        let (l1, v1) = smallest_ev(a.clone());
343        let av = a * &v;
344        println!("{l} {v}");
345        println!("{l1} {v1}");
346        assert_abs_diff_eq!(av[(0, 0)], (l[0] * v[(0, 0)]), epsilon = 1e-14);
347        assert_abs_diff_eq!(av[(1, 0)], (l[0] * v[(1, 0)]), epsilon = 1e-14);
348        assert_abs_diff_eq!(av[(0, 1)], (l[1] * v[(0, 1)]), epsilon = 1e-14);
349        assert_abs_diff_eq!(av[(1, 1)], (l[1] * v[(1, 1)]), epsilon = 1e-14);
350        assert_abs_diff_eq!(l[0], l1, epsilon = 1e-14);
351        assert_abs_diff_eq!(v[(0, 0)], v1[0], epsilon = 1e-14);
352        assert_abs_diff_eq!(v[(1, 0)], v1[1], epsilon = 1e-14);
353    }
354
355    #[test]
356    fn test_eig_f64_3() {
357        let a = dmatrix![2.0, 2.0, 7.0; 2.0, 5.0, 9.0; 7.0, 9.0, 2.0];
358        let (l, v) = jacobi_eigenvalue(a.clone(), 200);
359        let av = a * &v;
360        println!("{l} {v}");
361        for i in 0..3 {
362            for j in 0..3 {
363                assert_abs_diff_eq!(av[(i, j)], (l[j] * v[(i, j)]), epsilon = 1e-14);
364            }
365        }
366    }
367
368    #[test]
369    fn test_eig_dual64() {
370        let a = dmatrix![
371            Dual64::new(2.0, 1.0), Dual64::new(2.0, 2.0);
372            Dual64::new(2.0, 2.0), Dual64::new(5.0, 3.0)
373        ];
374        let (l, v) = jacobi_eigenvalue(a.clone(), 200);
375        let (l1, v1) = smallest_ev(a.clone());
376        let av = a * &v;
377        println!("{l} {v}");
378        println!("{l1} {v1}");
379        assert_abs_diff_eq!(av[(0, 0)].re, (l[0] * v[(0, 0)]).re, epsilon = 1e-14);
380        assert_abs_diff_eq!(av[(1, 0)].re, (l[0] * v[(1, 0)]).re, epsilon = 1e-14);
381        assert_abs_diff_eq!(av[(0, 1)].re, (l[1] * v[(0, 1)]).re, epsilon = 1e-14);
382        assert_abs_diff_eq!(av[(1, 1)].re, (l[1] * v[(1, 1)]).re, epsilon = 1e-14);
383        assert_abs_diff_eq!(av[(0, 0)].eps, (l[0] * v[(0, 0)]).eps, epsilon = 1e-14);
384        assert_abs_diff_eq!(av[(1, 0)].eps, (l[0] * v[(1, 0)]).eps, epsilon = 1e-14);
385        assert_abs_diff_eq!(av[(0, 1)].eps, (l[1] * v[(0, 1)]).eps, epsilon = 1e-14);
386        assert_abs_diff_eq!(av[(1, 1)].eps, (l[1] * v[(1, 1)]).eps, epsilon = 1e-14);
387        assert_abs_diff_eq!(l[0].re, l1.re, epsilon = 1e-14);
388        assert_abs_diff_eq!(l[0].eps, l1.eps, epsilon = 1e-14);
389        assert_abs_diff_eq!(v[(0, 0)].re, v1[0].re, epsilon = 1e-14);
390        assert_abs_diff_eq!(v[(0, 0)].eps, v1[0].eps, epsilon = 1e-14);
391        assert_abs_diff_eq!(v[(1, 0)].re, v1[1].re, epsilon = 1e-14);
392        assert_abs_diff_eq!(v[(1, 0)].eps, v1[1].eps, epsilon = 1e-14);
393    }
394
395    #[test]
396    fn test_norm_f64() {
397        let v = dvector![3.0, 4.0];
398        assert_eq!(v.norm(), 5.0);
399    }
400
401    #[test]
402    fn test_norm_dual64() {
403        let v = dvector![Dual64::new(3.0, 1.0), Dual64::new(4.0, 3.0)];
404        println!("{}", v.norm());
405        assert_eq!(v.norm().re, 5.0);
406        assert_eq!(v.norm().eps, 3.0);
407    }
408}