1use crate::DualNum;
3use nalgebra::allocator::Allocator;
4use nalgebra::{DefaultAllocator, Dim, OMatrix, OVector, U1};
5use num_traits::Float;
6use std::fmt;
7use std::iter::Product;
8use std::marker::PhantomData;
9
10#[derive(Debug)]
12pub struct LinAlgError();
13
14impl fmt::Display for LinAlgError {
15 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16 write!(f, "The matrix appears to be singular.")
17 }
18}
19
20impl std::error::Error for LinAlgError {}
21
22pub struct LU<T: DualNum<F>, F, D: Dim>
24where
25 DefaultAllocator: Allocator<D, D> + Allocator<D>,
26{
27 a: OMatrix<T, D, D>,
28 p: OVector<usize, D>,
29 p_count: usize,
30 f: PhantomData<F>,
31}
32
33impl<T: DualNum<F> + Copy, F: Float, D: Dim> LU<T, F, D>
34where
35 DefaultAllocator: Allocator<D, D> + Allocator<D>,
36{
37 pub fn new(mut a: OMatrix<T, D, D>) -> Result<Self, LinAlgError> {
38 let (n, _) = a.shape_generic();
39 let mut p = OVector::zeros_generic(n, U1);
40 let n = n.value();
41 let mut p_count = n;
42
43 for i in 0..n {
44 p[i] = i;
45 }
46
47 for i in 0..n {
48 let mut max_a = F::zero();
49 let mut imax = i;
50
51 for k in i..n {
52 let abs_a = a[(k, i)].abs();
53 if abs_a.re() > max_a {
54 max_a = abs_a.re();
55 imax = k;
56 }
57 }
58
59 if max_a.is_zero() {
60 return Err(LinAlgError());
61 }
62
63 if imax != i {
64 let j = p[i];
65 p[i] = p[imax];
66 p[imax] = j;
67
68 for j in 0..n {
69 let ptr = a[(i, j)];
70 a[(i, j)] = a[(imax, j)];
71 a[(imax, j)] = ptr;
72 }
73
74 p_count += 1;
75 }
76
77 for j in i + 1..n {
78 a[(j, i)] = a[(j, i)] / a[(i, i)];
79
80 for k in i + 1..n {
81 a[(j, k)] = a[(j, k)] - a[(j, i)] * a[(i, k)];
82 }
83 }
84 }
85 Ok(LU {
86 a,
87 p,
88 p_count,
89 f: PhantomData,
90 })
91 }
92
93 pub fn solve(&self, b: &OVector<T, D>) -> OVector<T, D> {
94 let (n, _) = b.shape_generic();
95 let mut x = OVector::zeros_generic(n, U1);
96 let n = n.value();
97
98 for i in 0..n {
99 x[i] = b[self.p[i]];
100
101 for k in 0..i {
102 x[i] = x[i] - self.a[(i, k)] * x[k];
103 }
104 }
105
106 for i in (0..n).rev() {
107 for k in i + 1..n {
108 x[i] = x[i] - self.a[(i, k)] * x[k];
109 }
110
111 x[i] /= self.a[(i, i)];
112 }
113
114 x
115 }
116
117 pub fn determinant(&self) -> T
118 where
119 T: Product,
120 {
121 let n = self.p.len();
122 let det = (0..n).map(|i| self.a[(i, i)]).product();
123
124 if (self.p_count - n).is_multiple_of(2) {
125 det
126 } else {
127 -det
128 }
129 }
130
131 pub fn inverse(&self) -> OMatrix<T, D, D> {
132 let (r, c) = self.a.shape_generic();
133 let n = self.p.len();
134 let mut ia = OMatrix::zeros_generic(r, c);
135
136 for j in 0..n {
137 for i in 0..n {
138 ia[(i, j)] = if self.p[i] == j { T::one() } else { T::zero() };
139
140 for k in 0..i {
141 ia[(i, j)] = ia[(i, j)] - self.a[(i, k)] * ia[(k, j)];
142 }
143 }
144
145 for i in (0..n).rev() {
146 for k in i + 1..n {
147 ia[(i, j)] = ia[(i, j)] - self.a[(i, k)] * ia[(k, j)];
148 }
149 ia[(i, j)] /= self.a[(i, i)];
150 }
151 }
152
153 ia
154 }
155}
156
157pub fn smallest_ev<T: DualNum<F> + Copy, F: Float, D: Dim>(
160 a: OMatrix<T, D, D>,
161) -> (T, OVector<T, D>)
162where
163 DefaultAllocator: Allocator<D, D> + Allocator<D>,
164{
165 let (r, _) = a.shape_generic();
166 let n = r.value();
167 if n == 1 {
168 (a[(0, 0)], OVector::from_element_generic(r, U1, T::one()))
169 } else if n == 2 {
170 let (a, b, c) = (a[(0, 0)], a[(0, 1)], a[(1, 1)]);
171 let l = (a + c - ((a - c).powi(2) + b * b * F::from(4.0).unwrap()).sqrt())
172 * F::from(0.5).unwrap();
173 let theta = (b + b).atan2(a - c) * F::from(0.5).unwrap();
174 let (s, c) = theta.sin_cos();
175 let mut u = OVector::from_fn_generic(r, U1, |i, _| [-s, c][i]);
176 if u[0].re() < F::zero() || u[0].re().is_zero() && u[1].re() < F::zero() {
177 u = -u;
178 }
179 (l, u)
180 } else {
181 let (e, vecs) = jacobi_eigenvalue(a, 200);
182 (e[0], vecs.column(0).into_owned())
183 }
184}
185
186pub fn jacobi_eigenvalue<T: DualNum<F> + Copy, F: Float, D: Dim>(
188 mut a: OMatrix<T, D, D>,
189 max_iter: usize,
190) -> (OVector<T, D>, OMatrix<T, D, D>)
191where
192 DefaultAllocator: Allocator<D, D> + Allocator<D>,
193{
194 let (r, c) = a.shape_generic();
195 let n = r.value();
196
197 let mut v = OMatrix::identity_generic(r, c);
198 let mut d = a.diagonal().to_owned();
199
200 let mut bw = d.clone();
201 let mut zw = OVector::zeros_generic(r, U1);
202
203 for it_num in 0..max_iter {
204 let mut thresh = F::zero();
205 for j in 0..n {
206 for i in 0..j {
207 thresh = thresh + a[(i, j)].re().powi(2);
208 }
209 }
210 thresh = thresh.sqrt() / F::from(n).unwrap();
211
212 if thresh.is_zero() {
213 break;
214 }
215
216 for p in 0..n {
217 for q in p + 1..n {
218 let gapq = a[(p, q)].abs() * F::from(10.0).unwrap();
219 let termp = gapq + d[p].abs();
220 let termq = gapq + d[q].abs();
221
222 if 4 < it_num && termp == d[p].abs() && termq == d[q].abs() {
223 a[(p, q)] = T::zero();
224 } else if thresh <= a[(p, q)].re().abs() {
225 let h = d[q] - d[p];
226 let term = h.abs() + gapq;
227
228 let t = if term == h.abs() {
229 a[(p, q)] / h
230 } else {
231 let theta = h * F::from(0.5).unwrap() / a[(p, q)];
232 let mut t = (theta.abs() + (theta * theta + F::one()).sqrt()).recip();
233 if theta.is_negative() {
234 t = -t;
235 }
236 t
237 };
238
239 let c = (t * t + F::one()).sqrt().recip();
240 let s = t * c;
241 let tau = s / (c + F::one());
242 let h = t * a[(p, q)];
243
244 zw[p] -= h;
245 zw[q] += h;
246 d[p] -= h;
247 d[q] += h;
248
249 a[(p, q)] = T::zero();
250
251 for j in 0..p {
252 let g = a[(j, p)];
253 let h = a[(j, q)];
254 a[(j, p)] = g - s * (h + g * tau);
255 a[(j, q)] = h + s * (g - h * tau);
256 }
257
258 for j in p + 1..q {
259 let g = a[(p, j)];
260 let h = a[(j, q)];
261 a[(p, j)] = g - s * (h + g * tau);
262 a[(j, q)] = h + s * (g - h * tau);
263 }
264
265 for j in q + 1..n {
266 let g = a[(p, j)];
267 let h = a[(q, j)];
268 a[(p, j)] = g - s * (h + g * tau);
269 a[(q, j)] = h + s * (g - h * tau);
270 }
271
272 for j in 0..n {
273 let g = v[(j, p)];
274 let h = v[(j, q)];
275 v[(j, p)] = g - s * (h + g * tau);
276 v[(j, q)] = h + s * (g - h * tau);
277 }
278 }
279 }
280 }
281
282 bw += &zw;
283 d = bw.clone();
284 zw.fill(T::zero());
285 }
286
287 for k in 0..n - 1 {
288 let mut m = k;
289
290 for l in k + 1..n {
291 if d[l].re() < d[m].re() {
292 m = l;
293 }
294 }
295
296 if m != k {
297 d.swap_rows(m, k);
298
299 for l in 0..n {
300 v.swap((l, m), (l, k));
301 }
302 }
303 }
304
305 (d, v)
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use crate::Dual64;
312 use approx::assert_abs_diff_eq;
313 use nalgebra::{dmatrix, dvector};
314
315 #[test]
316 fn test_solve_f64() {
317 let a = dmatrix![4.0, 3.0; 6.0, 3.0];
318 let b = dvector![10.0, 12.0];
319 let lu = LU::new(a).unwrap();
320 assert_eq!(lu.determinant(), -6.0);
321 assert_eq!(lu.solve(&b), dvector![1.0, 2.0]);
322 assert_eq!(
323 lu.inverse() * lu.determinant(),
324 dmatrix![3.0, -3.0; -6.0, 4.0]
325 );
326 }
327
328 #[test]
329 fn test_solve_dual64() {
330 let a = dmatrix![
331 Dual64::new(4.0, 3.0), Dual64::new(3.0, 3.0);
332 Dual64::new(6.0, 1.0), Dual64::new(3.0, 2.0)
333 ];
334 let b = dvector![Dual64::new(10.0, 20.0), Dual64::new(12.0, 20.0)];
335 let lu = LU::new(a).unwrap();
336 let det = lu.determinant();
337 assert_eq!((det.re, det.eps), (-6.0, -4.0));
338 let x = lu.solve(&b);
339 assert_eq!((x[0].re, x[0].eps, x[1].re, x[1].eps), (1.0, 2.0, 2.0, 1.0));
340 }
341
342 #[test]
343 fn test_eig_f64_2() {
344 let a = dmatrix![2.0, 2.0; 2.0, 5.0];
345 let (l, v) = jacobi_eigenvalue(a.clone(), 200);
346 let (l1, v1) = smallest_ev(a.clone());
347 let av = a * &v;
348 println!("{l} {v}");
349 println!("{l1} {v1}");
350 assert_abs_diff_eq!(av[(0, 0)], (l[0] * v[(0, 0)]), epsilon = 1e-14);
351 assert_abs_diff_eq!(av[(1, 0)], (l[0] * v[(1, 0)]), epsilon = 1e-14);
352 assert_abs_diff_eq!(av[(0, 1)], (l[1] * v[(0, 1)]), epsilon = 1e-14);
353 assert_abs_diff_eq!(av[(1, 1)], (l[1] * v[(1, 1)]), epsilon = 1e-14);
354 assert_abs_diff_eq!(l[0], l1, epsilon = 1e-14);
355 assert_abs_diff_eq!(v[(0, 0)], v1[0], epsilon = 1e-14);
356 assert_abs_diff_eq!(v[(1, 0)], v1[1], epsilon = 1e-14);
357 }
358
359 #[test]
360 fn test_eig_f64_zeros1() {
361 let a = dmatrix![1.0, 0.0; 0.0, 0.0];
362 let (l, v) = jacobi_eigenvalue(a.clone(), 200);
363 let (l1, v1) = smallest_ev(a.clone());
364 let av = a * &v;
365 println!("{l} {v}");
366 println!("{l1} {v1}");
367 assert_abs_diff_eq!(av[(0, 0)], (l[0] * v[(0, 0)]), epsilon = 1e-14);
368 assert_abs_diff_eq!(av[(1, 0)], (l[0] * v[(1, 0)]), epsilon = 1e-14);
369 assert_abs_diff_eq!(av[(0, 1)], (l[1] * v[(0, 1)]), epsilon = 1e-14);
370 assert_abs_diff_eq!(av[(1, 1)], (l[1] * v[(1, 1)]), epsilon = 1e-14);
371 assert_abs_diff_eq!(l[0], l1, epsilon = 1e-14);
372 assert_abs_diff_eq!(v[(0, 0)], v1[0], epsilon = 1e-14);
373 assert_abs_diff_eq!(v[(1, 0)], v1[1], epsilon = 1e-14);
374 }
375
376 #[test]
377 fn test_eig_f64_zeros2() {
378 let a = dmatrix![0.0, 0.0; 0.0, 1.0];
379 let (l, v) = jacobi_eigenvalue(a.clone(), 200);
380 let (l1, v1) = smallest_ev(a.clone());
381 let av = a * &v;
382 println!("{l} {v}");
383 println!("{l1} {v1}");
384 assert_abs_diff_eq!(av[(0, 0)], (l[0] * v[(0, 0)]), epsilon = 1e-14);
385 assert_abs_diff_eq!(av[(1, 0)], (l[0] * v[(1, 0)]), epsilon = 1e-14);
386 assert_abs_diff_eq!(av[(0, 1)], (l[1] * v[(0, 1)]), epsilon = 1e-14);
387 assert_abs_diff_eq!(av[(1, 1)], (l[1] * v[(1, 1)]), epsilon = 1e-14);
388 assert_abs_diff_eq!(l[0], l1, epsilon = 1e-14);
389 assert_abs_diff_eq!(v[(0, 0)], v1[0], epsilon = 1e-14);
390 assert_abs_diff_eq!(v[(1, 0)], v1[1], epsilon = 1e-14);
391 }
392
393 #[test]
394 fn test_eig_f64_3() {
395 let a = dmatrix![2.0, 2.0, 7.0; 2.0, 5.0, 9.0; 7.0, 9.0, 2.0];
396 let (l, v) = jacobi_eigenvalue(a.clone(), 200);
397 let av = a * &v;
398 println!("{l} {v}");
399 for i in 0..3 {
400 for j in 0..3 {
401 assert_abs_diff_eq!(av[(i, j)], (l[j] * v[(i, j)]), epsilon = 1e-14);
402 }
403 }
404 }
405
406 #[test]
407 fn test_eig_dual64() {
408 let a = dmatrix![
409 Dual64::new(2.0, 1.0), Dual64::new(2.0, 2.0);
410 Dual64::new(2.0, 2.0), Dual64::new(5.0, 3.0)
411 ];
412 let (l, v) = jacobi_eigenvalue(a.clone(), 200);
413 let (l1, v1) = smallest_ev(a.clone());
414 let av = a * &v;
415 println!("{l} {v}");
416 println!("{l1} {v1}");
417 assert_abs_diff_eq!(av[(0, 0)].re, (l[0] * v[(0, 0)]).re, epsilon = 1e-14);
418 assert_abs_diff_eq!(av[(1, 0)].re, (l[0] * v[(1, 0)]).re, epsilon = 1e-14);
419 assert_abs_diff_eq!(av[(0, 1)].re, (l[1] * v[(0, 1)]).re, epsilon = 1e-14);
420 assert_abs_diff_eq!(av[(1, 1)].re, (l[1] * v[(1, 1)]).re, epsilon = 1e-14);
421 assert_abs_diff_eq!(av[(0, 0)].eps, (l[0] * v[(0, 0)]).eps, epsilon = 1e-14);
422 assert_abs_diff_eq!(av[(1, 0)].eps, (l[0] * v[(1, 0)]).eps, epsilon = 1e-14);
423 assert_abs_diff_eq!(av[(0, 1)].eps, (l[1] * v[(0, 1)]).eps, epsilon = 1e-14);
424 assert_abs_diff_eq!(av[(1, 1)].eps, (l[1] * v[(1, 1)]).eps, epsilon = 1e-14);
425 assert_abs_diff_eq!(l[0].re, l1.re, epsilon = 1e-14);
426 assert_abs_diff_eq!(l[0].eps, l1.eps, epsilon = 1e-14);
427 assert_abs_diff_eq!(v[(0, 0)].re, v1[0].re, epsilon = 1e-14);
428 assert_abs_diff_eq!(v[(0, 0)].eps, v1[0].eps, epsilon = 1e-14);
429 assert_abs_diff_eq!(v[(1, 0)].re, v1[1].re, epsilon = 1e-14);
430 assert_abs_diff_eq!(v[(1, 0)].eps, v1[1].eps, epsilon = 1e-14);
431 }
432
433 #[test]
434 fn test_norm_f64() {
435 let v = dvector![3.0, 4.0];
436 assert_eq!(v.norm(), 5.0);
437 }
438
439 #[test]
440 fn test_norm_dual64() {
441 let v = dvector![Dual64::new(3.0, 1.0), Dual64::new(4.0, 3.0)];
442 println!("{}", v.norm());
443 assert_eq!(v.norm().re, 5.0);
444 assert_eq!(v.norm().eps, 3.0);
445 }
446}