1use super::inertia::Inertia;
12use super::pivot::{Block2x2, PivotType};
13
14#[derive(Debug, Clone)]
16pub enum PivotEntry {
17 OneByOne(f64),
19 TwoByTwo(Block2x2),
21 Delayed,
23}
24
25#[derive(Debug)]
49pub struct MixedDiagonal {
50 pivot_map: Vec<PivotType>,
51 diag: Vec<f64>,
52 off_diag: Vec<f64>,
53 n: usize,
54}
55
56impl MixedDiagonal {
57 pub fn new(n: usize) -> Self {
61 Self {
62 pivot_map: vec![PivotType::Delayed; n],
63 diag: vec![0.0; n],
64 off_diag: vec![0.0; n],
65 n,
66 }
67 }
68
69 pub fn set_1x1(&mut self, col: usize, value: f64) {
76 debug_assert!(
77 col < self.n,
78 "set_1x1: col {} out of bounds (n = {})",
79 col,
80 self.n
81 );
82 debug_assert!(
83 self.pivot_map[col] == PivotType::Delayed,
84 "set_1x1: col {} is already set ({:?})",
85 col,
86 self.pivot_map[col]
87 );
88 self.pivot_map[col] = PivotType::OneByOne;
89 self.diag[col] = value;
90 }
91
92 pub fn set_2x2(&mut self, block: Block2x2) {
101 let col = block.first_col;
102 debug_assert!(
103 col + 1 < self.n,
104 "set_2x2: first_col {} + 1 out of bounds (n = {})",
105 col,
106 self.n
107 );
108 debug_assert!(
109 self.pivot_map[col] == PivotType::Delayed,
110 "set_2x2: col {} is already set ({:?})",
111 col,
112 self.pivot_map[col]
113 );
114 debug_assert!(
115 self.pivot_map[col + 1] == PivotType::Delayed,
116 "set_2x2: col {} is already set ({:?})",
117 col + 1,
118 self.pivot_map[col + 1]
119 );
120 self.pivot_map[col] = PivotType::TwoByTwo { partner: col + 1 };
121 self.pivot_map[col + 1] = PivotType::TwoByTwo { partner: col };
122 self.diag[col] = block.a;
123 self.diag[col + 1] = block.c;
124 self.off_diag[col] = block.b;
125 }
126
127 pub fn dimension(&self) -> usize {
129 self.n
130 }
131
132 pub fn pivot_type(&self, col: usize) -> PivotType {
138 debug_assert!(
139 col < self.n,
140 "pivot_type: col {} out of bounds (n = {})",
141 col,
142 self.n
143 );
144 self.pivot_map[col]
145 }
146
147 pub fn diagonal_1x1(&self, col: usize) -> f64 {
153 debug_assert!(
154 self.pivot_map[col] == PivotType::OneByOne,
155 "diagonal_1x1: col {} is not OneByOne ({:?})",
156 col,
157 self.pivot_map[col]
158 );
159 self.diag[col]
160 }
161
162 pub fn diagonal_2x2(&self, first_col: usize) -> Block2x2 {
170 debug_assert!(
171 matches!(self.pivot_map[first_col], PivotType::TwoByTwo { partner } if partner > first_col),
172 "diagonal_2x2: col {} is not a 2x2 block owner ({:?})",
173 first_col,
174 self.pivot_map[first_col]
175 );
176 Block2x2 {
177 first_col,
178 a: self.diag[first_col],
179 b: self.off_diag[first_col],
180 c: self.diag[first_col + 1],
181 }
182 }
183
184 pub fn num_delayed(&self) -> usize {
186 self.pivot_map
187 .iter()
188 .filter(|p| **p == PivotType::Delayed)
189 .count()
190 }
191
192 pub fn num_1x1(&self) -> usize {
194 self.pivot_map
195 .iter()
196 .filter(|p| **p == PivotType::OneByOne)
197 .count()
198 }
199
200 pub fn grow(&mut self, new_n: usize) {
202 if new_n > self.n {
203 self.pivot_map.resize(new_n, PivotType::Delayed);
204 self.diag.resize(new_n, 0.0);
205 self.off_diag.resize(new_n, 0.0);
206 self.n = new_n;
207 }
208 }
209
210 pub fn truncate(&mut self, new_n: usize) {
212 debug_assert!(
213 new_n <= self.n,
214 "truncate: new_n {} > current n {}",
215 new_n,
216 self.n
217 );
218 self.pivot_map.truncate(new_n);
219 self.diag.truncate(new_n);
220 self.off_diag.truncate(new_n);
221 self.n = new_n;
222 }
223
224 pub fn copy_from_offset(&mut self, source: &MixedDiagonal, self_offset: usize, count: usize) {
235 debug_assert!(
236 self_offset + count <= self.n,
237 "copy_from_offset: self_offset {} + count {} > self.n {}",
238 self_offset,
239 count,
240 self.n
241 );
242 debug_assert!(
243 count <= source.n,
244 "copy_from_offset: count {} > source.n {}",
245 count,
246 source.n
247 );
248
249 let mut col = 0;
250 while col < count {
251 match source.pivot_map[col] {
252 PivotType::OneByOne => {
253 self.pivot_map[self_offset + col] = PivotType::OneByOne;
254 self.diag[self_offset + col] = source.diag[col];
255 col += 1;
256 }
257 PivotType::TwoByTwo { .. } if col + 1 < count => {
258 let dest = self_offset + col;
259 self.pivot_map[dest] = PivotType::TwoByTwo { partner: dest + 1 };
260 self.pivot_map[dest + 1] = PivotType::TwoByTwo { partner: dest };
261 self.diag[dest] = source.diag[col];
262 self.diag[dest + 1] = source.diag[col + 1];
263 self.off_diag[dest] = source.off_diag[col];
264 col += 2;
265 }
266 PivotType::Delayed => {
267 col += 1;
269 }
270 _ => {
271 col += 1;
273 }
274 }
275 }
276 }
277
278 pub fn iter_pivots(&self) -> PivotIter<'_> {
283 PivotIter { d: self, col: 0 }
284 }
285
286 pub fn num_2x2_pairs(&self) -> usize {
288 self.pivot_map
289 .iter()
290 .enumerate()
291 .filter(|(i, p)| matches!(p, PivotType::TwoByTwo { partner } if *partner > *i))
292 .count()
293 }
294
295 pub fn solve_in_place(&self, x: &mut [f64]) {
313 debug_assert_eq!(
314 x.len(),
315 self.n,
316 "solve_in_place: x.len() = {} != n = {}",
317 x.len(),
318 self.n
319 );
320 debug_assert!(
321 self.num_delayed() == 0,
322 "solve_in_place: {} delayed columns remain",
323 self.num_delayed()
324 );
325
326 let mut col = 0;
327 while col < self.n {
328 match self.pivot_map[col] {
329 PivotType::OneByOne => {
330 let d = self.diag[col];
331 if d == 0.0 {
332 x[col] = 0.0;
334 } else {
335 x[col] /= d;
336 }
337 col += 1;
338 }
339 PivotType::TwoByTwo { partner } => {
340 if partner > col {
341 let a = self.diag[col];
342 let b = self.off_diag[col];
343 let c = self.diag[partner];
344 let det = a * c - b * b;
345 if det == 0.0 {
346 x[col] = 0.0;
348 x[partner] = 0.0;
349 } else {
350 let r1 = x[col];
351 let r2 = x[partner];
352 x[col] = (c * r1 - b * r2) / det;
354 x[partner] = (a * r2 - b * r1) / det;
355 }
356 }
357 col += 1;
359 }
360 PivotType::Delayed => {
361 unreachable!("solve_in_place: delayed column at {}", col);
362 }
363 }
364 }
365 }
366
367 pub fn compute_inertia(&self) -> Inertia {
392 debug_assert!(
393 self.num_delayed() == 0,
394 "compute_inertia: {} delayed columns remain",
395 self.num_delayed()
396 );
397
398 let mut positive = 0usize;
399 let mut negative = 0usize;
400 let mut zero = 0usize;
401
402 let mut col = 0;
403 while col < self.n {
404 match self.pivot_map[col] {
405 PivotType::OneByOne => {
406 let d = self.diag[col];
407 if d > 0.0 {
408 positive += 1;
409 } else if d < 0.0 {
410 negative += 1;
411 } else {
412 zero += 1;
413 }
414 col += 1;
415 }
416 PivotType::TwoByTwo { partner } => {
417 if partner > col {
418 let a = self.diag[col];
420 let b = self.off_diag[col];
421 let c = self.diag[partner];
422 let det = a * c - b * b;
423 let trace = a + c;
424
425 if det > 0.0 {
426 if trace > 0.0 {
427 positive += 2;
428 } else {
429 negative += 2;
431 }
432 } else if det < 0.0 {
433 positive += 1;
434 negative += 1;
435 } else {
436 if trace > 0.0 {
438 positive += 1;
439 zero += 1;
440 } else if trace < 0.0 {
441 negative += 1;
442 zero += 1;
443 } else {
444 zero += 2;
445 }
446 }
447 }
448 col += 1;
449 }
450 PivotType::Delayed => {
451 unreachable!("compute_inertia: delayed column at {}", col);
452 }
453 }
454 }
455
456 Inertia {
457 positive,
458 negative,
459 zero,
460 }
461 }
462}
463
464pub struct PivotIter<'a> {
469 d: &'a MixedDiagonal,
470 col: usize,
471}
472
473impl<'a> Iterator for PivotIter<'a> {
474 type Item = (usize, PivotEntry);
475
476 fn next(&mut self) -> Option<Self::Item> {
477 if self.col >= self.d.n {
478 return None;
479 }
480 let col = self.col;
481 match self.d.pivot_map[col] {
482 PivotType::OneByOne => {
483 self.col += 1;
484 Some((col, PivotEntry::OneByOne(self.d.diag[col])))
485 }
486 PivotType::TwoByTwo { partner } if partner > col => {
487 self.col += 2;
488 Some((
489 col,
490 PivotEntry::TwoByTwo(Block2x2 {
491 first_col: col,
492 a: self.d.diag[col],
493 b: self.d.off_diag[col],
494 c: self.d.diag[col + 1],
495 }),
496 ))
497 }
498 PivotType::TwoByTwo { .. } => {
499 self.col += 1;
501 self.next()
502 }
503 PivotType::Delayed => {
504 self.col += 1;
505 Some((col, PivotEntry::Delayed))
506 }
507 }
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use crate::symmetric::pivot::{Block2x2, PivotType};
515
516 #[test]
519 fn new_creates_all_delayed() {
520 let diag = MixedDiagonal::new(5);
521 assert_eq!(diag.dimension(), 5);
522 for col in 0..5 {
523 assert_eq!(diag.pivot_type(col), PivotType::Delayed);
524 }
525 assert_eq!(diag.num_delayed(), 5);
526 assert_eq!(diag.num_1x1(), 0);
527 assert_eq!(diag.num_2x2_pairs(), 0);
528 }
529
530 #[test]
531 fn set_1x1_marks_correct_pivot_type() {
532 let mut diag = MixedDiagonal::new(4);
533 diag.set_1x1(0, 3.5);
534 diag.set_1x1(2, -1.0);
535
536 assert_eq!(diag.pivot_type(0), PivotType::OneByOne);
537 assert_eq!(diag.pivot_type(1), PivotType::Delayed);
538 assert_eq!(diag.pivot_type(2), PivotType::OneByOne);
539 assert_eq!(diag.pivot_type(3), PivotType::Delayed);
540
541 assert_eq!(diag.diagonal_1x1(0), 3.5);
542 assert_eq!(diag.diagonal_1x1(2), -1.0);
543
544 assert_eq!(diag.num_1x1(), 2);
545 assert_eq!(diag.num_delayed(), 2);
546 }
547
548 #[test]
549 fn set_2x2_marks_both_columns() {
550 let mut diag = MixedDiagonal::new(6);
551 let block = Block2x2 {
552 first_col: 2,
553 a: 2.0,
554 b: 0.5,
555 c: -3.0,
556 };
557 diag.set_2x2(block);
558
559 assert_eq!(diag.pivot_type(2), PivotType::TwoByTwo { partner: 3 });
560 assert_eq!(diag.pivot_type(3), PivotType::TwoByTwo { partner: 2 });
561 assert_eq!(diag.diagonal_2x2(2), block);
562 assert_eq!(diag.num_2x2_pairs(), 1);
563 assert_eq!(diag.num_delayed(), 4);
564 }
565
566 #[test]
567 fn mixed_pivots_correct_counts() {
568 let mut diag = MixedDiagonal::new(6);
569 diag.set_2x2(Block2x2 {
570 first_col: 0,
571 a: 2.0,
572 b: 0.5,
573 c: -3.0,
574 });
575 diag.set_1x1(2, 4.0);
576 diag.set_1x1(3, -1.0);
577 diag.set_1x1(4, 7.0);
578 diag.set_1x1(5, 2.0);
579
580 assert_eq!(diag.num_2x2_pairs(), 1);
581 assert_eq!(diag.num_1x1(), 4);
582 assert_eq!(diag.num_delayed(), 0);
583 assert_eq!(diag.dimension(), 6);
584 }
585
586 #[test]
589 fn solve_all_1x1() {
590 let mut diag = MixedDiagonal::new(4);
594 diag.set_1x1(0, 2.0);
595 diag.set_1x1(1, 4.0);
596 diag.set_1x1(2, -1.0);
597 diag.set_1x1(3, 5.0);
598
599 let mut x = vec![6.0, 12.0, -3.0, 20.0];
600 let b = x.clone();
601 diag.solve_in_place(&mut x);
602
603 assert_eq!(x, vec![3.0, 3.0, 3.0, 4.0]);
604
605 let dx: Vec<f64> = vec![2.0 * x[0], 4.0 * x[1], -x[2], 5.0 * x[3]];
607 let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
608 let norm_diff: f64 = dx
609 .iter()
610 .zip(b.iter())
611 .map(|(d, bi)| (d - bi).powi(2))
612 .sum::<f64>()
613 .sqrt();
614 assert!(norm_diff / norm_b < 1e-14);
615 }
616
617 #[test]
618 fn solve_all_2x2() {
619 let mut diag = MixedDiagonal::new(2);
624 diag.set_2x2(Block2x2 {
625 first_col: 0,
626 a: 2.0,
627 b: 0.5,
628 c: -3.0,
629 });
630
631 let b = vec![4.5, -0.5];
632 let mut x = b.clone();
633 diag.solve_in_place(&mut x);
634
635 let dx0 = 2.0 * x[0] + 0.5 * x[1];
637 let dx1 = 0.5 * x[0] + (-3.0) * x[1];
638 let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
639 let norm_diff = ((dx0 - b[0]).powi(2) + (dx1 - b[1]).powi(2)).sqrt();
640 assert!(
641 norm_diff / norm_b < 1e-14,
642 "relative error: {:.2e}",
643 norm_diff / norm_b
644 );
645 }
646
647 #[test]
648 fn solve_mixed_1x1_and_2x2() {
649 let mut diag = MixedDiagonal::new(6);
651 diag.set_2x2(Block2x2 {
652 first_col: 0,
653 a: 2.0,
654 b: 0.5,
655 c: -3.0,
656 });
657 diag.set_1x1(2, 4.0);
658 diag.set_1x1(3, -1.0);
659 diag.set_1x1(4, 7.0);
660 diag.set_1x1(5, 2.0);
661
662 let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
663 let mut x = b.clone();
664 diag.solve_in_place(&mut x);
665
666 let dx0 = 2.0 * x[0] + 0.5 * x[1];
669 let dx1 = 0.5 * x[0] + (-3.0) * x[1];
670 let dx2 = 4.0 * x[2];
672 let dx3 = -x[3];
673 let dx4 = 7.0 * x[4];
674 let dx5 = 2.0 * x[5];
675
676 let dx = [dx0, dx1, dx2, dx3, dx4, dx5];
677 let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
678 let norm_diff: f64 = dx
679 .iter()
680 .zip(b.iter())
681 .map(|(d, bi)| (d - bi).powi(2))
682 .sum::<f64>()
683 .sqrt();
684 assert!(
685 norm_diff / norm_b < 1e-14,
686 "relative error: {:.2e}",
687 norm_diff / norm_b
688 );
689 }
690
691 #[test]
692 fn solve_dimension_0_is_noop() {
693 let diag = MixedDiagonal::new(0);
694 let mut x: Vec<f64> = vec![];
695 diag.solve_in_place(&mut x);
696 assert!(x.is_empty());
697 }
698
699 #[test]
702 fn dimension_0() {
703 let diag = MixedDiagonal::new(0);
704 assert_eq!(diag.dimension(), 0);
705 assert_eq!(diag.num_delayed(), 0);
706 assert_eq!(diag.num_1x1(), 0);
707 assert_eq!(diag.num_2x2_pairs(), 0);
708 }
709
710 #[test]
711 fn dimension_1_single_1x1() {
712 let mut diag = MixedDiagonal::new(1);
713 diag.set_1x1(0, 5.0);
714 assert_eq!(diag.pivot_type(0), PivotType::OneByOne);
715 assert_eq!(diag.diagonal_1x1(0), 5.0);
716 assert_eq!(diag.num_1x1(), 1);
717 assert_eq!(diag.num_delayed(), 0);
718 }
719
720 #[test]
721 fn dimension_2_single_2x2() {
722 let mut diag = MixedDiagonal::new(2);
723 let block = Block2x2 {
724 first_col: 0,
725 a: 1.0,
726 b: 0.0,
727 c: 1.0,
728 };
729 diag.set_2x2(block);
730 assert_eq!(diag.num_2x2_pairs(), 1);
731 assert_eq!(diag.num_delayed(), 0);
732 }
733
734 #[test]
735 fn all_2x2_even_n() {
736 let mut diag = MixedDiagonal::new(4);
737 diag.set_2x2(Block2x2 {
738 first_col: 0,
739 a: 1.0,
740 b: 0.0,
741 c: 1.0,
742 });
743 diag.set_2x2(Block2x2 {
744 first_col: 2,
745 a: 2.0,
746 b: 0.5,
747 c: 3.0,
748 });
749 assert_eq!(diag.num_2x2_pairs(), 2);
750 assert_eq!(diag.num_delayed(), 0);
751 }
752
753 #[test]
754 #[should_panic]
755 fn solve_panics_on_delayed_columns() {
756 let mut diag = MixedDiagonal::new(3);
757 diag.set_1x1(0, 1.0);
758 let mut x = vec![1.0, 2.0, 3.0];
760 diag.solve_in_place(&mut x); }
762
763 #[test]
764 #[should_panic]
765 fn set_2x2_at_last_column_odd_n_panics() {
766 let mut diag = MixedDiagonal::new(3);
767 diag.set_2x2(Block2x2 {
769 first_col: 2,
770 a: 1.0,
771 b: 0.0,
772 c: 1.0,
773 });
774 }
775
776 #[test]
779 fn inertia_all_positive_1x1() {
780 let mut diag = MixedDiagonal::new(4);
781 for i in 0..4 {
782 diag.set_1x1(i, (i + 1) as f64);
783 }
784 let inertia = diag.compute_inertia();
785 assert_eq!(
786 inertia,
787 Inertia {
788 positive: 4,
789 negative: 0,
790 zero: 0
791 }
792 );
793 }
794
795 #[test]
796 fn inertia_mixed_sign_1x1() {
797 let mut diag = MixedDiagonal::new(5);
798 diag.set_1x1(0, 3.0); diag.set_1x1(1, -2.0); diag.set_1x1(2, 1.0); diag.set_1x1(3, -0.5); diag.set_1x1(4, 0.0); let inertia = diag.compute_inertia();
804 assert_eq!(
805 inertia,
806 Inertia {
807 positive: 2,
808 negative: 2,
809 zero: 1
810 }
811 );
812 }
813
814 #[test]
815 fn inertia_2x2_det_negative_one_plus_one_minus() {
816 let mut diag = MixedDiagonal::new(2);
818 diag.set_2x2(Block2x2 {
819 first_col: 0,
820 a: 2.0,
821 b: 0.5,
822 c: -3.0,
823 });
824 let inertia = diag.compute_inertia();
825 assert_eq!(
826 inertia,
827 Inertia {
828 positive: 1,
829 negative: 1,
830 zero: 0
831 }
832 );
833 }
834
835 #[test]
836 fn inertia_2x2_det_positive_trace_positive() {
837 let mut diag = MixedDiagonal::new(2);
839 diag.set_2x2(Block2x2 {
840 first_col: 0,
841 a: 5.0,
842 b: 1.0,
843 c: 3.0,
844 });
845 let inertia = diag.compute_inertia();
846 assert_eq!(
847 inertia,
848 Inertia {
849 positive: 2,
850 negative: 0,
851 zero: 0
852 }
853 );
854 }
855
856 #[test]
857 fn inertia_2x2_det_positive_trace_negative() {
858 let mut diag = MixedDiagonal::new(2);
860 diag.set_2x2(Block2x2 {
861 first_col: 0,
862 a: -5.0,
863 b: 1.0,
864 c: -3.0,
865 });
866 let inertia = diag.compute_inertia();
867 assert_eq!(
868 inertia,
869 Inertia {
870 positive: 0,
871 negative: 2,
872 zero: 0
873 }
874 );
875 }
876
877 #[test]
878 fn inertia_mixed_1x1_and_2x2() {
879 let mut diag = MixedDiagonal::new(6);
884 diag.set_2x2(Block2x2 {
885 first_col: 0,
886 a: 2.0,
887 b: 0.5,
888 c: -3.0,
889 });
890 diag.set_1x1(2, 4.0);
891 diag.set_1x1(3, -1.0);
892 diag.set_1x1(4, 7.0);
893 diag.set_1x1(5, 2.0);
894 let inertia = diag.compute_inertia();
895 assert_eq!(
896 inertia,
897 Inertia {
898 positive: 4,
899 negative: 2,
900 zero: 0
901 }
902 );
903 }
904
905 #[test]
906 fn scale_test_n_10000() {
907 let n = 10_000;
909 let mut diag = MixedDiagonal::new(n);
910
911 let mut col = 0;
914 while col < n {
915 if col + 1 < n && col % 3 != 2 {
916 diag.set_2x2(Block2x2 {
917 first_col: col,
918 a: 2.0 + (col as f64) * 0.001,
919 b: 0.1,
920 c: 3.0 + (col as f64) * 0.001,
921 });
922 col += 2;
923 } else {
924 diag.set_1x1(col, 1.0 + (col as f64) * 0.001);
925 col += 1;
926 }
927 }
928
929 assert_eq!(diag.num_delayed(), 0);
930 assert_eq!(diag.dimension(), n);
931
932 let b: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
934 let mut x = b.clone();
935 diag.solve_in_place(&mut x);
936
937 let mut dx = vec![0.0; n];
939 for i in 0..n {
940 match diag.pivot_type(i) {
941 PivotType::OneByOne => {
942 dx[i] = diag.diagonal_1x1(i) * x[i];
943 }
944 PivotType::TwoByTwo { partner } => {
945 if i < partner {
946 let block = diag.diagonal_2x2(i);
947 dx[i] = block.a * x[i] + block.b * x[partner];
948 dx[partner] = block.b * x[i] + block.c * x[partner];
949 }
950 }
952 PivotType::Delayed => unreachable!(),
953 }
954 }
955
956 let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
957 let norm_diff: f64 = dx
958 .iter()
959 .zip(b.iter())
960 .map(|(d, bi)| (d - bi).powi(2))
961 .sum::<f64>()
962 .sqrt();
963 let rel_err = norm_diff / norm_b;
964 assert!(
965 rel_err < 1e-14,
966 "scale test: relative error {:.2e} exceeds 1e-14",
967 rel_err
968 );
969 }
970}