1use num_bigint::{BigInt, Sign};
32use num_rational::BigRational;
33use num_traits::ToPrimitive;
34
35use crate::LaError;
36use crate::matrix::Matrix;
37use crate::vector::Vector;
38
39fn validate_finite<const D: usize>(m: &Matrix<D>) -> Result<(), LaError> {
44 for r in 0..D {
45 for c in 0..D {
46 if !m.rows[r][c].is_finite() {
47 return Err(LaError::NonFinite { col: c });
48 }
49 }
50 }
51 Ok(())
52}
53
54fn validate_finite_vec<const D: usize>(v: &Vector<D>) -> Result<(), LaError> {
59 for (i, &x) in v.data.iter().enumerate() {
60 if !x.is_finite() {
61 return Err(LaError::NonFinite { col: i });
62 }
63 }
64 Ok(())
65}
66
67fn f64_to_bigrational(x: f64) -> BigRational {
75 BigRational::from_float(x).expect("non-finite matrix entry in exact determinant")
76}
77
78fn bareiss_det<const D: usize>(m: &Matrix<D>) -> BigRational {
83 if D == 0 {
84 return BigRational::from_integer(BigInt::from(1));
85 }
86 if D == 1 {
87 return f64_to_bigrational(m.rows[0][0]);
88 }
89
90 let mut a: Vec<Vec<BigRational>> = Vec::with_capacity(D);
92 for r in 0..D {
93 let mut row = Vec::with_capacity(D);
94 for c in 0..D {
95 row.push(f64_to_bigrational(m.rows[r][c]));
96 }
97 a.push(row);
98 }
99
100 let zero = BigRational::from_integer(BigInt::from(0));
101 let mut prev_pivot = BigRational::from_integer(BigInt::from(1));
102 let mut sign: i8 = 1;
103
104 for k in 0..D {
105 if a[k][k] == zero {
107 let mut found = false;
108 for i in (k + 1)..D {
109 if a[i][k] != zero {
110 a.swap(k, i);
111 sign = -sign;
112 found = true;
113 break;
114 }
115 }
116 if !found {
117 return zero;
119 }
120 }
121
122 for i in (k + 1)..D {
124 for j in (k + 1)..D {
125 a[i][j] = (&a[k][k] * &a[i][j] - &a[i][k] * &a[k][j]) / &prev_pivot;
127 }
128 a[i][k] = zero.clone();
131 }
132
133 prev_pivot = a[k][k].clone();
134 }
135
136 let det = &a[D - 1][D - 1];
137 if sign < 0 { -det } else { det.clone() }
138}
139
140fn gauss_solve<const D: usize>(m: &Matrix<D>, b: &Vector<D>) -> Result<Vec<BigRational>, LaError> {
150 if D == 0 {
151 return Ok(Vec::new());
152 }
153
154 let zero = BigRational::from_integer(BigInt::from(0));
155
156 let mut aug: Vec<Vec<BigRational>> = Vec::with_capacity(D);
158 for r in 0..D {
159 let mut row = Vec::with_capacity(D + 1);
160 for c in 0..D {
161 row.push(f64_to_bigrational(m.rows[r][c]));
162 }
163 row.push(f64_to_bigrational(b.data[r]));
164 aug.push(row);
165 }
166
167 for k in 0..D {
169 if aug[k][k] == zero {
171 if let Some(swap_row) = ((k + 1)..D).find(|&i| aug[i][k] != zero) {
172 aug.swap(k, swap_row);
173 } else {
174 return Err(LaError::Singular { pivot_col: k });
175 }
176 }
177
178 let pivot = aug[k][k].clone();
180 for i in (k + 1)..D {
181 if aug[i][k] != zero {
182 let factor = &aug[i][k] / &pivot;
183 #[allow(clippy::needless_range_loop)]
186 for j in (k + 1)..=D {
187 let term = &factor * &aug[k][j];
188 aug[i][j] -= term;
189 }
190 aug[i][k] = zero.clone();
191 }
192 }
193 }
194
195 let mut x: Vec<BigRational> = vec![zero; D];
197 for i in (0..D).rev() {
198 let mut sum = aug[i][D].clone();
199 for j in (i + 1)..D {
200 sum -= &aug[i][j] * &x[j];
201 }
202 x[i] = sum / &aug[i][i];
203 }
204
205 Ok(x)
206}
207
208impl<const D: usize> Matrix<D> {
209 #[inline]
235 pub fn det_exact(&self) -> Result<BigRational, LaError> {
236 validate_finite(self)?;
237 Ok(bareiss_det(self))
238 }
239
240 #[inline]
261 pub fn det_exact_f64(&self) -> Result<f64, LaError> {
262 let exact = self.det_exact()?;
263 let val = exact.to_f64().unwrap_or(f64::INFINITY);
264 if val.is_finite() {
265 Ok(val)
266 } else {
267 Err(LaError::Overflow)
268 }
269 }
270
271 #[inline]
301 pub fn solve_exact(&self, b: Vector<D>) -> Result<[BigRational; D], LaError> {
302 validate_finite(self)?;
303 validate_finite_vec(&b)?;
304 let solution = gauss_solve(self, &b)?;
305 Ok(std::array::from_fn(|i| solution[i].clone()))
306 }
307
308 #[inline]
333 pub fn solve_exact_f64(&self, b: Vector<D>) -> Result<Vector<D>, LaError> {
334 let exact = self.solve_exact(b)?;
335 let mut result = [0.0f64; D];
336 for (i, val) in exact.iter().enumerate() {
337 let f = val.to_f64().unwrap_or(f64::INFINITY);
338 if !f.is_finite() {
339 return Err(LaError::Overflow);
340 }
341 result[i] = f;
342 }
343 Ok(Vector::new(result))
344 }
345
346 #[inline]
381 pub fn det_sign_exact(&self) -> Result<i8, LaError> {
382 validate_finite(self)?;
383
384 if let (Some(det_f64), Some(err)) = (self.det_direct(), self.det_errbound()) {
386 if det_f64.is_finite() {
391 if det_f64 > err {
392 return Ok(1);
393 }
394 if det_f64 < -err {
395 return Ok(-1);
396 }
397 }
398 }
399
400 let det = bareiss_det(self);
402 Ok(match det.numer().sign() {
403 Sign::Plus => 1,
404 Sign::Minus => -1,
405 Sign::NoSign => 0,
406 })
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::DEFAULT_PIVOT_TOL;
414
415 use pastey::paste;
416
417 macro_rules! gen_det_exact_tests {
422 ($d:literal) => {
423 paste! {
424 #[test]
425 fn [<det_exact_identity_ $d d>]() {
426 let det = Matrix::<$d>::identity().det_exact().unwrap();
427 assert_eq!(det, BigRational::from_integer(BigInt::from(1)));
428 }
429
430 #[test]
431 fn [<det_exact_err_on_nan_ $d d>]() {
432 let mut m = Matrix::<$d>::identity();
433 m.set(0, 0, f64::NAN);
434 assert_eq!(m.det_exact(), Err(LaError::NonFinite { col: 0 }));
435 }
436
437 #[test]
438 fn [<det_exact_err_on_inf_ $d d>]() {
439 let mut m = Matrix::<$d>::identity();
440 m.set(0, 0, f64::INFINITY);
441 assert_eq!(m.det_exact(), Err(LaError::NonFinite { col: 0 }));
442 }
443 }
444 };
445 }
446
447 gen_det_exact_tests!(2);
448 gen_det_exact_tests!(3);
449 gen_det_exact_tests!(4);
450 gen_det_exact_tests!(5);
451
452 macro_rules! gen_det_exact_f64_tests {
453 ($d:literal) => {
454 paste! {
455 #[test]
456 fn [<det_exact_f64_identity_ $d d>]() {
457 let det = Matrix::<$d>::identity().det_exact_f64().unwrap();
458 assert!((det - 1.0).abs() <= f64::EPSILON);
459 }
460
461 #[test]
462 fn [<det_exact_f64_err_on_nan_ $d d>]() {
463 let mut m = Matrix::<$d>::identity();
464 m.set(0, 0, f64::NAN);
465 assert_eq!(m.det_exact_f64(), Err(LaError::NonFinite { col: 0 }));
466 }
467 }
468 };
469 }
470
471 gen_det_exact_f64_tests!(2);
472 gen_det_exact_f64_tests!(3);
473 gen_det_exact_f64_tests!(4);
474 gen_det_exact_f64_tests!(5);
475
476 macro_rules! gen_det_exact_f64_agrees_with_det_direct {
479 ($d:literal) => {
480 paste! {
481 #[test]
482 #[allow(clippy::cast_precision_loss)]
483 fn [<det_exact_f64_agrees_with_det_direct_ $d d>]() {
484 let mut rows = [[0.0f64; $d]; $d];
486 for r in 0..$d {
487 for c in 0..$d {
488 rows[r][c] = if r == c {
489 (r as f64) + f64::from($d) + 1.0
490 } else {
491 0.1 / ((r + c + 1) as f64)
492 };
493 }
494 }
495 let m = Matrix::<$d>::from_rows(rows);
496 let exact = m.det_exact_f64().unwrap();
497 let direct = m.det_direct().unwrap();
498 let eps = direct.abs().mul_add(1e-12, 1e-12);
499 assert!((exact - direct).abs() <= eps);
500 }
501 }
502 };
503 }
504
505 gen_det_exact_f64_agrees_with_det_direct!(2);
506 gen_det_exact_f64_agrees_with_det_direct!(3);
507 gen_det_exact_f64_agrees_with_det_direct!(4);
508
509 #[test]
510 fn det_sign_exact_d0_is_positive() {
511 assert_eq!(Matrix::<0>::zero().det_sign_exact().unwrap(), 1);
512 }
513
514 #[test]
515 fn det_sign_exact_d1_positive() {
516 let m = Matrix::<1>::from_rows([[42.0]]);
517 assert_eq!(m.det_sign_exact().unwrap(), 1);
518 }
519
520 #[test]
521 fn det_sign_exact_d1_negative() {
522 let m = Matrix::<1>::from_rows([[-3.5]]);
523 assert_eq!(m.det_sign_exact().unwrap(), -1);
524 }
525
526 #[test]
527 fn det_sign_exact_d1_zero() {
528 let m = Matrix::<1>::from_rows([[0.0]]);
529 assert_eq!(m.det_sign_exact().unwrap(), 0);
530 }
531
532 #[test]
533 fn det_sign_exact_identity_2d() {
534 assert_eq!(Matrix::<2>::identity().det_sign_exact().unwrap(), 1);
535 }
536
537 #[test]
538 fn det_sign_exact_identity_3d() {
539 assert_eq!(Matrix::<3>::identity().det_sign_exact().unwrap(), 1);
540 }
541
542 #[test]
543 fn det_sign_exact_identity_4d() {
544 assert_eq!(Matrix::<4>::identity().det_sign_exact().unwrap(), 1);
545 }
546
547 #[test]
548 fn det_sign_exact_identity_5d() {
549 assert_eq!(Matrix::<5>::identity().det_sign_exact().unwrap(), 1);
550 }
551
552 #[test]
553 fn det_sign_exact_singular_duplicate_rows() {
554 let m = Matrix::<3>::from_rows([
555 [1.0, 2.0, 3.0],
556 [4.0, 5.0, 6.0],
557 [1.0, 2.0, 3.0], ]);
559 assert_eq!(m.det_sign_exact().unwrap(), 0);
560 }
561
562 #[test]
563 fn det_sign_exact_singular_linear_combination() {
564 let m = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [5.0, 7.0, 9.0]]);
566 assert_eq!(m.det_sign_exact().unwrap(), 0);
567 }
568
569 #[test]
570 fn det_sign_exact_negative_det_row_swap() {
571 let m = Matrix::<3>::from_rows([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]);
573 assert_eq!(m.det_sign_exact().unwrap(), -1);
574 }
575
576 #[test]
577 fn det_sign_exact_negative_det_known() {
578 let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
580 assert_eq!(m.det_sign_exact().unwrap(), -1);
581 }
582
583 #[test]
584 fn det_sign_exact_agrees_with_det_for_spd() {
585 let m = Matrix::<3>::from_rows([[4.0, 2.0, 0.0], [2.0, 5.0, 1.0], [0.0, 1.0, 3.0]]);
587 assert_eq!(m.det_sign_exact().unwrap(), 1);
588 assert!(m.det(DEFAULT_PIVOT_TOL).unwrap() > 0.0);
589 }
590
591 #[test]
598 fn det_sign_exact_near_singular_perturbation() {
599 let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); let m = Matrix::<3>::from_rows([
601 [1.0 + perturbation, 2.0, 3.0],
602 [4.0, 5.0, 6.0],
603 [7.0, 8.0, 9.0],
604 ]);
605 assert_eq!(m.det_sign_exact().unwrap(), -1);
607 }
608
609 #[test]
613 fn det_sign_exact_fast_filter_positive_4x4() {
614 let m = Matrix::<4>::from_rows([
615 [2.0, 1.0, 0.0, 0.0],
616 [1.0, 3.0, 1.0, 0.0],
617 [0.0, 1.0, 4.0, 1.0],
618 [0.0, 0.0, 1.0, 5.0],
619 ]);
620 assert_eq!(m.det_sign_exact().unwrap(), 1);
622 }
623
624 #[test]
625 fn det_sign_exact_fast_filter_negative_4x4() {
626 let m = Matrix::<4>::from_rows([
628 [1.0, 3.0, 1.0, 0.0],
629 [2.0, 1.0, 0.0, 0.0],
630 [0.0, 1.0, 4.0, 1.0],
631 [0.0, 0.0, 1.0, 5.0],
632 ]);
633 assert_eq!(m.det_sign_exact().unwrap(), -1);
634 }
635
636 #[test]
637 fn det_sign_exact_subnormal_entries() {
638 let tiny = 5e-324_f64; assert!(tiny.is_subnormal());
641
642 let m = Matrix::<2>::from_rows([[tiny, 0.0], [0.0, tiny]]);
643 assert_eq!(m.det_sign_exact().unwrap(), 1);
645 }
646
647 #[test]
648 fn det_sign_exact_returns_err_on_nan() {
649 let m = Matrix::<2>::from_rows([[f64::NAN, 0.0], [0.0, 1.0]]);
650 assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 0 }));
651 }
652
653 #[test]
654 fn det_sign_exact_returns_err_on_infinity() {
655 let m = Matrix::<2>::from_rows([[f64::INFINITY, 0.0], [0.0, 1.0]]);
656 assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 0 }));
657 }
658
659 #[test]
660 fn det_sign_exact_returns_err_on_nan_5x5() {
661 let mut m = Matrix::<5>::identity();
663 m.set(2, 3, f64::NAN);
664 assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 3 }));
665 }
666
667 #[test]
668 fn det_sign_exact_returns_err_on_infinity_5x5() {
669 let mut m = Matrix::<5>::identity();
670 m.set(0, 0, f64::INFINITY);
671 assert_eq!(m.det_sign_exact(), Err(LaError::NonFinite { col: 0 }));
672 }
673
674 #[test]
675 fn det_sign_exact_pivot_needed_5x5() {
676 let m = Matrix::<5>::from_rows([
679 [0.0, 1.0, 0.0, 0.0, 0.0],
680 [1.0, 0.0, 0.0, 0.0, 0.0],
681 [0.0, 0.0, 1.0, 0.0, 0.0],
682 [0.0, 0.0, 0.0, 1.0, 0.0],
683 [0.0, 0.0, 0.0, 0.0, 1.0],
684 ]);
685 assert_eq!(m.det_sign_exact().unwrap(), -1);
686 }
687
688 #[test]
689 fn det_sign_exact_5x5_known() {
690 let m = Matrix::<5>::from_rows([
692 [0.0, 1.0, 0.0, 0.0, 0.0],
693 [1.0, 0.0, 0.0, 0.0, 0.0],
694 [0.0, 0.0, 0.0, 1.0, 0.0],
695 [0.0, 0.0, 1.0, 0.0, 0.0],
696 [0.0, 0.0, 0.0, 0.0, 1.0],
697 ]);
698 assert_eq!(m.det_sign_exact().unwrap(), 1);
700 }
701
702 #[test]
707 fn det_errbound_d0_is_zero() {
708 assert_eq!(Matrix::<0>::zero().det_errbound(), Some(0.0));
709 }
710
711 #[test]
712 fn det_errbound_d1_is_zero() {
713 assert_eq!(Matrix::<1>::from_rows([[42.0]]).det_errbound(), Some(0.0));
714 }
715
716 #[test]
717 fn det_errbound_d2_positive() {
718 let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
719 let bound = m.det_errbound().unwrap();
720 assert!(bound > 0.0);
721 assert!(crate::ERR_COEFF_2.mul_add(-10.0, bound).abs() < 1e-30);
723 }
724
725 #[test]
726 fn det_errbound_d3_positive() {
727 let m = Matrix::<3>::identity();
728 let bound = m.det_errbound().unwrap();
729 assert!(bound > 0.0);
730 }
731
732 #[test]
733 fn det_errbound_d3_non_identity() {
734 let m = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]]);
736 let bound = m.det_errbound().unwrap();
737 assert!(bound > 0.0);
738 }
739
740 #[test]
741 fn det_errbound_d4_positive() {
742 let m = Matrix::<4>::identity();
743 let bound = m.det_errbound().unwrap();
744 assert!(bound > 0.0);
745 }
746
747 #[test]
748 fn det_errbound_d4_non_identity() {
749 let m = Matrix::<4>::from_rows([
751 [1.0, 0.0, 0.0, 0.0],
752 [0.0, 2.0, 0.0, 0.0],
753 [0.0, 0.0, 3.0, 0.0],
754 [0.0, 0.0, 0.0, 4.0],
755 ]);
756 let bound = m.det_errbound().unwrap();
757 assert!(bound > 0.0);
758 }
759
760 #[test]
761 fn det_errbound_d5_is_none() {
762 assert_eq!(Matrix::<5>::identity().det_errbound(), None);
763 }
764
765 #[test]
766 fn bareiss_det_d0_is_one() {
767 let det = bareiss_det(&Matrix::<0>::zero());
768 assert_eq!(det, BigRational::from_integer(BigInt::from(1)));
769 }
770
771 #[test]
772 fn bareiss_det_d1_returns_entry() {
773 let det = bareiss_det(&Matrix::<1>::from_rows([[7.0]]));
774 assert_eq!(det, f64_to_bigrational(7.0));
775 }
776
777 #[test]
778 fn bareiss_det_d3_with_pivoting() {
779 let m = Matrix::<3>::from_rows([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]);
781 let det = bareiss_det(&m);
782 assert_eq!(det, BigRational::from_integer(BigInt::from(-1)));
784 }
785
786 #[test]
787 fn bareiss_det_singular_all_zeros_in_column() {
788 let m = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]);
790 let det = bareiss_det(&m);
791 assert_eq!(det, BigRational::from_integer(BigInt::from(0)));
792 }
793
794 #[test]
795 fn det_sign_exact_overflow_determinant_finite_entries() {
796 let big = f64::MAX / 2.0;
800 assert!(big.is_finite());
801 let m = Matrix::<3>::from_rows([[0.0, 0.0, 1.0], [big, 0.0, 1.0], [0.0, big, 1.0]]);
802 assert_eq!(m.det_sign_exact().unwrap(), 1);
804 }
805
806 #[test]
811 fn det_exact_d0_is_one() {
812 let det = Matrix::<0>::zero().det_exact().unwrap();
813 assert_eq!(det, BigRational::from_integer(BigInt::from(1)));
814 }
815
816 #[test]
817 fn det_exact_known_2x2() {
818 let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
820 let det = m.det_exact().unwrap();
821 assert_eq!(det, BigRational::from_integer(BigInt::from(-2)));
822 }
823
824 #[test]
825 fn det_exact_singular_returns_zero() {
826 let m = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
828 let det = m.det_exact().unwrap();
829 assert_eq!(det, BigRational::from_integer(BigInt::from(0)));
830 }
831
832 #[test]
833 fn det_exact_near_singular_perturbation() {
834 let perturbation = f64::from_bits(0x3CD0_0000_0000_0000); let m = Matrix::<3>::from_rows([
837 [1.0 + perturbation, 2.0, 3.0],
838 [4.0, 5.0, 6.0],
839 [7.0, 8.0, 9.0],
840 ]);
841 let det = m.det_exact().unwrap();
842 let expected = BigRational::new(BigInt::from(-3), BigInt::from(1u64 << 50));
844 assert_eq!(det, expected);
845 }
846
847 #[test]
848 fn det_exact_5x5_permutation() {
849 let m = Matrix::<5>::from_rows([
851 [0.0, 1.0, 0.0, 0.0, 0.0],
852 [1.0, 0.0, 0.0, 0.0, 0.0],
853 [0.0, 0.0, 1.0, 0.0, 0.0],
854 [0.0, 0.0, 0.0, 1.0, 0.0],
855 [0.0, 0.0, 0.0, 0.0, 1.0],
856 ]);
857 let det = m.det_exact().unwrap();
858 assert_eq!(det, BigRational::from_integer(BigInt::from(-1)));
859 }
860
861 #[test]
866 fn det_exact_f64_known_2x2() {
867 let m = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
868 let det = m.det_exact_f64().unwrap();
869 assert!((det - (-2.0)).abs() <= f64::EPSILON);
870 }
871
872 #[test]
873 fn det_exact_f64_overflow_returns_err() {
874 let big = f64::MAX / 2.0;
876 let m = Matrix::<3>::from_rows([[0.0, 0.0, 1.0], [big, 0.0, 1.0], [0.0, big, 1.0]]);
877 assert_eq!(m.det_exact_f64(), Err(LaError::Overflow));
879 }
880
881 fn arbitrary_rhs<const D: usize>() -> Vector<D> {
887 let values = [1.0, -2.5, 3.0, 0.25, -4.0];
888 let mut arr = [0.0f64; D];
889 for (dst, src) in arr.iter_mut().zip(values.iter()) {
890 *dst = *src;
891 }
892 Vector::<D>::new(arr)
893 }
894
895 macro_rules! gen_solve_exact_tests {
896 ($d:literal) => {
897 paste! {
898 #[test]
899 fn [<solve_exact_identity_ $d d>]() {
900 let a = Matrix::<$d>::identity();
901 let b = arbitrary_rhs::<$d>();
902 let x = a.solve_exact(b).unwrap();
903 for (i, xi) in x.iter().enumerate() {
904 assert_eq!(*xi, f64_to_bigrational(b.data[i]));
905 }
906 }
907
908 #[test]
909 fn [<solve_exact_err_on_nan_matrix_ $d d>]() {
910 let mut a = Matrix::<$d>::identity();
911 a.set(0, 0, f64::NAN);
912 let b = arbitrary_rhs::<$d>();
913 assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: 0 }));
914 }
915
916 #[test]
917 fn [<solve_exact_err_on_inf_matrix_ $d d>]() {
918 let mut a = Matrix::<$d>::identity();
919 a.set(0, 0, f64::INFINITY);
920 let b = arbitrary_rhs::<$d>();
921 assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: 0 }));
922 }
923
924 #[test]
925 fn [<solve_exact_err_on_nan_vector_ $d d>]() {
926 let a = Matrix::<$d>::identity();
927 let mut b_arr = [1.0f64; $d];
928 b_arr[0] = f64::NAN;
929 let b = Vector::<$d>::new(b_arr);
930 assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: 0 }));
931 }
932
933 #[test]
934 fn [<solve_exact_err_on_inf_vector_ $d d>]() {
935 let a = Matrix::<$d>::identity();
936 let mut b_arr = [1.0f64; $d];
937 b_arr[$d - 1] = f64::INFINITY;
938 let b = Vector::<$d>::new(b_arr);
939 assert_eq!(a.solve_exact(b), Err(LaError::NonFinite { col: $d - 1 }));
940 }
941
942 #[test]
943 fn [<solve_exact_singular_ $d d>]() {
944 let a = Matrix::<$d>::zero();
946 let b = arbitrary_rhs::<$d>();
947 assert_eq!(a.solve_exact(b), Err(LaError::Singular { pivot_col: 0 }));
948 }
949 }
950 };
951 }
952
953 gen_solve_exact_tests!(2);
954 gen_solve_exact_tests!(3);
955 gen_solve_exact_tests!(4);
956 gen_solve_exact_tests!(5);
957
958 macro_rules! gen_solve_exact_f64_tests {
959 ($d:literal) => {
960 paste! {
961 #[test]
962 fn [<solve_exact_f64_identity_ $d d>]() {
963 let a = Matrix::<$d>::identity();
964 let b = arbitrary_rhs::<$d>();
965 let x = a.solve_exact_f64(b).unwrap().into_array();
966 for i in 0..$d {
967 assert!((x[i] - b.data[i]).abs() <= f64::EPSILON);
968 }
969 }
970
971 #[test]
972 fn [<solve_exact_f64_err_on_nan_ $d d>]() {
973 let mut a = Matrix::<$d>::identity();
974 a.set(0, 0, f64::NAN);
975 let b = arbitrary_rhs::<$d>();
976 assert_eq!(a.solve_exact_f64(b), Err(LaError::NonFinite { col: 0 }));
977 }
978 }
979 };
980 }
981
982 gen_solve_exact_f64_tests!(2);
983 gen_solve_exact_f64_tests!(3);
984 gen_solve_exact_f64_tests!(4);
985 gen_solve_exact_f64_tests!(5);
986
987 macro_rules! gen_solve_exact_f64_agrees_with_lu {
990 ($d:literal) => {
991 paste! {
992 #[test]
993 #[allow(clippy::cast_precision_loss)]
994 fn [<solve_exact_f64_agrees_with_lu_ $d d>]() {
995 let mut rows = [[0.0f64; $d]; $d];
997 for r in 0..$d {
998 for c in 0..$d {
999 rows[r][c] = if r == c {
1000 (r as f64) + f64::from($d) + 1.0
1001 } else {
1002 0.1 / ((r + c + 1) as f64)
1003 };
1004 }
1005 }
1006 let a = Matrix::<$d>::from_rows(rows);
1007 let b = arbitrary_rhs::<$d>();
1008 let exact = a.solve_exact_f64(b).unwrap().into_array();
1009 let lu_sol = a.lu(DEFAULT_PIVOT_TOL).unwrap()
1010 .solve_vec(b).unwrap().into_array();
1011 for i in 0..$d {
1012 let eps = lu_sol[i].abs().mul_add(1e-12, 1e-12);
1013 assert!((exact[i] - lu_sol[i]).abs() <= eps);
1014 }
1015 }
1016 }
1017 };
1018 }
1019
1020 gen_solve_exact_f64_agrees_with_lu!(2);
1021 gen_solve_exact_f64_agrees_with_lu!(3);
1022 gen_solve_exact_f64_agrees_with_lu!(4);
1023 gen_solve_exact_f64_agrees_with_lu!(5);
1024
1025 #[test]
1030 fn solve_exact_d0_returns_empty() {
1031 let a = Matrix::<0>::zero();
1032 let b = Vector::<0>::zero();
1033 let x = a.solve_exact(b).unwrap();
1034 assert!(x.is_empty());
1035 }
1036
1037 #[test]
1038 fn solve_exact_known_2x2() {
1039 let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
1041 let b = Vector::<2>::new([5.0, 11.0]);
1042 let x = a.solve_exact(b).unwrap();
1043 assert_eq!(x[0], BigRational::from_integer(BigInt::from(1)));
1044 assert_eq!(x[1], BigRational::from_integer(BigInt::from(2)));
1045 }
1046
1047 #[test]
1048 fn solve_exact_pivoting_needed() {
1049 let a = Matrix::<3>::from_rows([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]);
1051 let b = Vector::<3>::new([2.0, 3.0, 4.0]);
1052 let x = a.solve_exact(b).unwrap();
1053 assert_eq!(x[0], f64_to_bigrational(3.0));
1055 assert_eq!(x[1], f64_to_bigrational(2.0));
1056 assert_eq!(x[2], f64_to_bigrational(4.0));
1057 }
1058
1059 #[test]
1060 fn solve_exact_fractional_result() {
1061 let a = Matrix::<2>::from_rows([[2.0, 1.0], [1.0, 3.0]]);
1063 let b = Vector::<2>::new([1.0, 1.0]);
1064 let x = a.solve_exact(b).unwrap();
1065 assert_eq!(x[0], BigRational::new(BigInt::from(2), BigInt::from(5)));
1066 assert_eq!(x[1], BigRational::new(BigInt::from(1), BigInt::from(5)));
1067 }
1068
1069 #[test]
1070 fn solve_exact_singular_duplicate_rows() {
1071 let a = Matrix::<3>::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]);
1072 let b = Vector::<3>::new([1.0, 2.0, 3.0]);
1073 assert!(matches!(a.solve_exact(b), Err(LaError::Singular { .. })));
1074 }
1075
1076 #[test]
1077 fn solve_exact_5x5_permutation() {
1078 let a = Matrix::<5>::from_rows([
1080 [0.0, 1.0, 0.0, 0.0, 0.0],
1081 [1.0, 0.0, 0.0, 0.0, 0.0],
1082 [0.0, 0.0, 1.0, 0.0, 0.0],
1083 [0.0, 0.0, 0.0, 1.0, 0.0],
1084 [0.0, 0.0, 0.0, 0.0, 1.0],
1085 ]);
1086 let b = Vector::<5>::new([10.0, 20.0, 30.0, 40.0, 50.0]);
1087 let x = a.solve_exact(b).unwrap();
1088 assert_eq!(x[0], f64_to_bigrational(20.0));
1089 assert_eq!(x[1], f64_to_bigrational(10.0));
1090 assert_eq!(x[2], f64_to_bigrational(30.0));
1091 assert_eq!(x[3], f64_to_bigrational(40.0));
1092 assert_eq!(x[4], f64_to_bigrational(50.0));
1093 }
1094
1095 #[test]
1100 fn solve_exact_f64_known_2x2() {
1101 let a = Matrix::<2>::from_rows([[1.0, 2.0], [3.0, 4.0]]);
1102 let b = Vector::<2>::new([5.0, 11.0]);
1103 let x = a.solve_exact_f64(b).unwrap().into_array();
1104 assert!((x[0] - 1.0).abs() <= f64::EPSILON);
1105 assert!((x[1] - 2.0).abs() <= f64::EPSILON);
1106 }
1107
1108 #[test]
1109 fn solve_exact_f64_overflow_returns_err() {
1110 let big = f64::MAX / 2.0;
1113 let a = Matrix::<2>::from_rows([[1.0 / big, 0.0], [0.0, 1.0 / big]]);
1114 let b = Vector::<2>::new([big, big]);
1115 assert_eq!(a.solve_exact_f64(b), Err(LaError::Overflow));
1116 }
1117
1118 #[test]
1123 fn gauss_solve_d0_returns_empty() {
1124 let a = Matrix::<0>::zero();
1125 let b = Vector::<0>::zero();
1126 assert_eq!(gauss_solve(&a, &b).unwrap().len(), 0);
1127 }
1128
1129 #[test]
1130 fn gauss_solve_d1() {
1131 let a = Matrix::<1>::from_rows([[2.0]]);
1132 let b = Vector::<1>::new([6.0]);
1133 let x = gauss_solve(&a, &b).unwrap();
1134 assert_eq!(x[0], f64_to_bigrational(3.0));
1135 }
1136
1137 #[test]
1138 fn gauss_solve_singular_column_all_zero() {
1139 let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]);
1140 let b = Vector::<3>::new([1.0, 2.0, 3.0]);
1141 assert_eq!(gauss_solve(&a, &b), Err(LaError::Singular { pivot_col: 1 }));
1142 }
1143
1144 #[test]
1149 fn validate_finite_vec_ok() {
1150 assert!(validate_finite_vec(&Vector::<3>::new([1.0, 2.0, 3.0])).is_ok());
1151 }
1152
1153 #[test]
1154 fn validate_finite_vec_err_on_nan() {
1155 assert_eq!(
1156 validate_finite_vec(&Vector::<2>::new([f64::NAN, 1.0])),
1157 Err(LaError::NonFinite { col: 0 })
1158 );
1159 }
1160
1161 #[test]
1162 fn validate_finite_vec_err_on_inf() {
1163 assert_eq!(
1164 validate_finite_vec(&Vector::<2>::new([1.0, f64::NEG_INFINITY])),
1165 Err(LaError::NonFinite { col: 1 })
1166 );
1167 }
1168
1169 #[test]
1174 fn validate_finite_ok_for_finite() {
1175 assert!(validate_finite(&Matrix::<3>::identity()).is_ok());
1176 }
1177
1178 #[test]
1179 fn validate_finite_err_on_nan() {
1180 let mut m = Matrix::<2>::identity();
1181 m.set(1, 0, f64::NAN);
1182 assert_eq!(validate_finite(&m), Err(LaError::NonFinite { col: 0 }));
1183 }
1184
1185 #[test]
1186 fn validate_finite_err_on_inf() {
1187 let mut m = Matrix::<2>::identity();
1188 m.set(0, 1, f64::NEG_INFINITY);
1189 assert_eq!(validate_finite(&m), Err(LaError::NonFinite { col: 1 }));
1190 }
1191}