Skip to main content

rivrs_sparse/symmetric/
diagonal.rs

1//! Mixed 1x1/2x2 block diagonal storage for the D factor in LDL^T.
2//!
3//! Provides [`MixedDiagonal`], the data structure representing the D factor in
4//! P^T A P = L D L^T where D contains a mix of 1x1 scalar pivots and 2x2
5//! symmetric Bunch-Kaufman pivot blocks. Supports incremental construction,
6//! query, solve, and inertia computation.
7//!
8//! Reference: Hogg, Duff & Lopez (2020), "A New Sparse LDL^T Solver Using
9//! A Posteriori Threshold Pivoting", SIAM J. Sci. Comput. 42(4), Section 3.
10
11use super::inertia::Inertia;
12use super::pivot::{Block2x2, PivotType};
13
14/// A single pivot entry from a [`MixedDiagonal`], yielded by [`MixedDiagonal::iter_pivots`].
15#[derive(Debug, Clone)]
16pub enum PivotEntry {
17    /// 1x1 pivot with scalar diagonal value.
18    OneByOne(f64),
19    /// 2x2 Bunch-Kaufman pivot block (yielded only at the lower-indexed column).
20    TwoByTwo(Block2x2),
21    /// Unresolved pivot.
22    Delayed,
23}
24
25/// The D factor in P^T A P = L D L^T with mixed 1x1 and 2x2 blocks.
26///
27/// Stores a block diagonal matrix where each block is either a 1x1 scalar
28/// pivot or a 2x2 symmetric Bunch-Kaufman pivot block. The structure is built
29/// incrementally during factorization (each column starts as [`PivotType::Delayed`]
30/// and is set to 1x1 or 2x2 as pivots are decided).
31///
32/// # Storage layout
33///
34/// Uses parallel arrays indexed by column for O(1) access during solve:
35/// - `pivot_map[col]`: pivot classification per column
36/// - `diag[col]`: diagonal value — for 1x1 pivots the scalar value, for 2x2
37///   blocks the `a` value at the owner column and `c` at the partner column
38/// - `off_diag[col]`: off-diagonal `b` for 2x2 blocks (at the owner column);
39///   0.0 at 1x1 and delayed columns
40///
41/// [`Block2x2`] serves as the API type for input ([`set_2x2`](Self::set_2x2))
42/// and output ([`diagonal_2x2`](Self::diagonal_2x2)) but is not the storage format.
43///
44/// # References
45///
46/// - Hogg, Duff & Lopez (2020), Section 3: mixed diagonal D storage in APTP
47/// - Bunch & Kaufman (1977): 2x2 pivot block structure
48#[derive(Debug)]
49pub struct MixedDiagonal {
50    pivot_map: Vec<PivotType>,
51    diag: Vec<f64>,
52    off_diag: Vec<f64>,
53    n: usize,
54}
55
56impl MixedDiagonal {
57    /// Create a new `MixedDiagonal` of dimension `n`.
58    ///
59    /// All columns start as [`PivotType::Delayed`] (unset).
60    pub fn new(n: usize) -> Self {
61        Self {
62            pivot_map: vec![PivotType::Delayed; n],
63            diag: vec![0.0; n],
64            off_diag: vec![0.0; n],
65            n,
66        }
67    }
68
69    /// Set column `col` as a 1x1 pivot with the given diagonal value.
70    ///
71    /// # Panics (debug only)
72    ///
73    /// - `col >= n` (bounds check)
74    /// - Column is not currently [`PivotType::Delayed`] (cannot overwrite a set pivot)
75    pub fn set_1x1(&mut self, col: usize, value: f64) {
76        debug_assert!(
77            col < self.n,
78            "set_1x1: col {} out of bounds (n = {})",
79            col,
80            self.n
81        );
82        debug_assert!(
83            self.pivot_map[col] == PivotType::Delayed,
84            "set_1x1: col {} is already set ({:?})",
85            col,
86            self.pivot_map[col]
87        );
88        self.pivot_map[col] = PivotType::OneByOne;
89        self.diag[col] = value;
90    }
91
92    /// Set a 2x2 pivot block starting at `block.first_col`.
93    ///
94    /// Marks both `first_col` and `first_col + 1` as [`PivotType::TwoByTwo`].
95    ///
96    /// # Panics (debug only)
97    ///
98    /// - `first_col + 1 >= n` (block must fit)
99    /// - Either column is not currently [`PivotType::Delayed`]
100    pub fn set_2x2(&mut self, block: Block2x2) {
101        let col = block.first_col;
102        debug_assert!(
103            col + 1 < self.n,
104            "set_2x2: first_col {} + 1 out of bounds (n = {})",
105            col,
106            self.n
107        );
108        debug_assert!(
109            self.pivot_map[col] == PivotType::Delayed,
110            "set_2x2: col {} is already set ({:?})",
111            col,
112            self.pivot_map[col]
113        );
114        debug_assert!(
115            self.pivot_map[col + 1] == PivotType::Delayed,
116            "set_2x2: col {} is already set ({:?})",
117            col + 1,
118            self.pivot_map[col + 1]
119        );
120        self.pivot_map[col] = PivotType::TwoByTwo { partner: col + 1 };
121        self.pivot_map[col + 1] = PivotType::TwoByTwo { partner: col };
122        self.diag[col] = block.a;
123        self.diag[col + 1] = block.c;
124        self.off_diag[col] = block.b;
125    }
126
127    /// Matrix dimension.
128    pub fn dimension(&self) -> usize {
129        self.n
130    }
131
132    /// Pivot type for the given column.
133    ///
134    /// # Panics (debug only)
135    ///
136    /// `col >= n`
137    pub fn pivot_type(&self, col: usize) -> PivotType {
138        debug_assert!(
139            col < self.n,
140            "pivot_type: col {} out of bounds (n = {})",
141            col,
142            self.n
143        );
144        self.pivot_map[col]
145    }
146
147    /// Diagonal value for a 1x1 pivot column.
148    ///
149    /// # Panics (debug only)
150    ///
151    /// Column is not [`PivotType::OneByOne`].
152    pub fn diagonal_1x1(&self, col: usize) -> f64 {
153        debug_assert!(
154            self.pivot_map[col] == PivotType::OneByOne,
155            "diagonal_1x1: col {} is not OneByOne ({:?})",
156            col,
157            self.pivot_map[col]
158        );
159        self.diag[col]
160    }
161
162    /// Block data for a 2x2 pivot (by the lower-indexed column).
163    ///
164    /// Constructs a [`Block2x2`] from the inline parallel arrays. This is O(1).
165    ///
166    /// # Panics (debug only)
167    ///
168    /// Column is not the owner (lower-indexed) of a [`PivotType::TwoByTwo`] block.
169    pub fn diagonal_2x2(&self, first_col: usize) -> Block2x2 {
170        debug_assert!(
171            matches!(self.pivot_map[first_col], PivotType::TwoByTwo { partner } if partner > first_col),
172            "diagonal_2x2: col {} is not a 2x2 block owner ({:?})",
173            first_col,
174            self.pivot_map[first_col]
175        );
176        Block2x2 {
177            first_col,
178            a: self.diag[first_col],
179            b: self.off_diag[first_col],
180            c: self.diag[first_col + 1],
181        }
182    }
183
184    /// Number of columns still marked as [`PivotType::Delayed`].
185    pub fn num_delayed(&self) -> usize {
186        self.pivot_map
187            .iter()
188            .filter(|p| **p == PivotType::Delayed)
189            .count()
190    }
191
192    /// Number of 1x1 pivots.
193    pub fn num_1x1(&self) -> usize {
194        self.pivot_map
195            .iter()
196            .filter(|p| **p == PivotType::OneByOne)
197            .count()
198    }
199
200    /// Extend to accommodate `new_n` entries. New entries are initialized as Delayed.
201    pub fn grow(&mut self, new_n: usize) {
202        if new_n > self.n {
203            self.pivot_map.resize(new_n, PivotType::Delayed);
204            self.diag.resize(new_n, 0.0);
205            self.off_diag.resize(new_n, 0.0);
206            self.n = new_n;
207        }
208    }
209
210    /// Shrink to `new_n` entries, discarding trailing entries.
211    pub fn truncate(&mut self, new_n: usize) {
212        debug_assert!(
213            new_n <= self.n,
214            "truncate: new_n {} > current n {}",
215            new_n,
216            self.n
217        );
218        self.pivot_map.truncate(new_n);
219        self.diag.truncate(new_n);
220        self.off_diag.truncate(new_n);
221        self.n = new_n;
222    }
223
224    /// Copy `count` pivot entries from `source` (starting at index 0) into `self`
225    /// (starting at `self_offset`).
226    ///
227    /// Handles 1x1, 2x2, and Delayed pivots. For 2x2 pivots, both columns of
228    /// the pair are copied as a unit.
229    ///
230    /// # Panics (debug only)
231    ///
232    /// - `self_offset + count > self.n`
233    /// - `count > source.n`
234    pub fn copy_from_offset(&mut self, source: &MixedDiagonal, self_offset: usize, count: usize) {
235        debug_assert!(
236            self_offset + count <= self.n,
237            "copy_from_offset: self_offset {} + count {} > self.n {}",
238            self_offset,
239            count,
240            self.n
241        );
242        debug_assert!(
243            count <= source.n,
244            "copy_from_offset: count {} > source.n {}",
245            count,
246            source.n
247        );
248
249        let mut col = 0;
250        while col < count {
251            match source.pivot_map[col] {
252                PivotType::OneByOne => {
253                    self.pivot_map[self_offset + col] = PivotType::OneByOne;
254                    self.diag[self_offset + col] = source.diag[col];
255                    col += 1;
256                }
257                PivotType::TwoByTwo { .. } if col + 1 < count => {
258                    let dest = self_offset + col;
259                    self.pivot_map[dest] = PivotType::TwoByTwo { partner: dest + 1 };
260                    self.pivot_map[dest + 1] = PivotType::TwoByTwo { partner: dest };
261                    self.diag[dest] = source.diag[col];
262                    self.diag[dest + 1] = source.diag[col + 1];
263                    self.off_diag[dest] = source.off_diag[col];
264                    col += 2;
265                }
266                PivotType::Delayed => {
267                    // Leave as Delayed (already the default)
268                    col += 1;
269                }
270                _ => {
271                    // Second column of a 2x2 pair, or 2x2 at boundary — skip
272                    col += 1;
273                }
274            }
275        }
276    }
277
278    /// Iterate over pivot entries, yielding `(col_index, PivotEntry)` pairs.
279    ///
280    /// Advances by 1 for 1x1 and Delayed pivots, by 2 for 2x2 pivots
281    /// (yielding the block only at the lower-indexed column and skipping the partner).
282    pub fn iter_pivots(&self) -> PivotIter<'_> {
283        PivotIter { d: self, col: 0 }
284    }
285
286    /// Number of 2x2 pivot pairs.
287    pub fn num_2x2_pairs(&self) -> usize {
288        self.pivot_map
289            .iter()
290            .enumerate()
291            .filter(|(i, p)| matches!(p, PivotType::TwoByTwo { partner } if *partner > *i))
292            .count()
293    }
294
295    /// Solve D x = b in place, where `x` initially contains the right-hand side b.
296    ///
297    /// For 1x1 pivots: `x[i] /= d[i]`.
298    /// For 2x2 blocks: solves `[[a, b], [b, c]] * [x1, x2]^T = [r1, r2]^T`
299    /// via Cramer's rule (analytical 2x2 inverse).
300    ///
301    /// # Panics (debug only)
302    ///
303    /// - Any column is still [`PivotType::Delayed`]
304    /// - Any 1x1 pivot value is zero
305    /// - Any 2x2 block has zero determinant
306    /// - `x.len() != n`
307    ///
308    /// # References
309    ///
310    /// - Cramer's rule for 2x2 symmetric systems
311    /// - Hogg, Duff & Lopez (2020), Section 3: D-solve in APTP context
312    pub fn solve_in_place(&self, x: &mut [f64]) {
313        debug_assert_eq!(
314            x.len(),
315            self.n,
316            "solve_in_place: x.len() = {} != n = {}",
317            x.len(),
318            self.n
319        );
320        debug_assert!(
321            self.num_delayed() == 0,
322            "solve_in_place: {} delayed columns remain",
323            self.num_delayed()
324        );
325
326        let mut col = 0;
327        while col < self.n {
328            match self.pivot_map[col] {
329                PivotType::OneByOne => {
330                    let d = self.diag[col];
331                    if d == 0.0 {
332                        // Zero pivot: set solution component to zero
333                        x[col] = 0.0;
334                    } else {
335                        x[col] /= d;
336                    }
337                    col += 1;
338                }
339                PivotType::TwoByTwo { partner } => {
340                    if partner > col {
341                        let a = self.diag[col];
342                        let b = self.off_diag[col];
343                        let c = self.diag[partner];
344                        let det = a * c - b * b;
345                        if det == 0.0 {
346                            // Zero-determinant 2x2 block: set both components to zero
347                            x[col] = 0.0;
348                            x[partner] = 0.0;
349                        } else {
350                            let r1 = x[col];
351                            let r2 = x[partner];
352                            // Cramer's rule: [[a,b],[b,c]]^-1 = (1/det) * [[c,-b],[-b,a]]
353                            x[col] = (c * r1 - b * r2) / det;
354                            x[partner] = (a * r2 - b * r1) / det;
355                        }
356                    }
357                    // Skip partner column if we've already processed this pair
358                    col += 1;
359                }
360                PivotType::Delayed => {
361                    unreachable!("solve_in_place: delayed column at {}", col);
362                }
363            }
364        }
365    }
366
367    /// Compute eigenvalue sign counts from stored pivots.
368    ///
369    /// For 1x1 pivots, the sign of the diagonal value determines the eigenvalue
370    /// sign. For 2x2 blocks, the trace and determinant classify the eigenvalue
371    /// signs without computing actual eigenvalues:
372    ///
373    /// | Condition | Eigenvalue signs |
374    /// |-----------|-----------------|
375    /// | det > 0, trace > 0 | both positive |
376    /// | det > 0, trace < 0 | both negative |
377    /// | det < 0 | one positive, one negative |
378    /// | det = 0, trace > 0 | one positive, one zero |
379    /// | det = 0, trace < 0 | one negative, one zero |
380    /// | det = 0, trace = 0 | both zero |
381    ///
382    /// # Panics (debug only)
383    ///
384    /// Any column is still [`PivotType::Delayed`].
385    ///
386    /// # References
387    ///
388    /// - Standard eigenvalue sign classification from trace/determinant
389    /// - Hogg, Duff & Lopez (2020), Section 2: inertia in APTP context
390    /// - Bunch & Kaufman (1977): inertia from pivot classifications
391    pub fn compute_inertia(&self) -> Inertia {
392        debug_assert!(
393            self.num_delayed() == 0,
394            "compute_inertia: {} delayed columns remain",
395            self.num_delayed()
396        );
397
398        let mut positive = 0usize;
399        let mut negative = 0usize;
400        let mut zero = 0usize;
401
402        let mut col = 0;
403        while col < self.n {
404            match self.pivot_map[col] {
405                PivotType::OneByOne => {
406                    let d = self.diag[col];
407                    if d > 0.0 {
408                        positive += 1;
409                    } else if d < 0.0 {
410                        negative += 1;
411                    } else {
412                        zero += 1;
413                    }
414                    col += 1;
415                }
416                PivotType::TwoByTwo { partner } => {
417                    if partner > col {
418                        // Read a, b, c directly from parallel arrays — O(1).
419                        let a = self.diag[col];
420                        let b = self.off_diag[col];
421                        let c = self.diag[partner];
422                        let det = a * c - b * b;
423                        let trace = a + c;
424
425                        if det > 0.0 {
426                            if trace > 0.0 {
427                                positive += 2;
428                            } else {
429                                // trace < 0 (trace == 0 impossible with det > 0 for real symmetric)
430                                negative += 2;
431                            }
432                        } else if det < 0.0 {
433                            positive += 1;
434                            negative += 1;
435                        } else {
436                            // det == 0
437                            if trace > 0.0 {
438                                positive += 1;
439                                zero += 1;
440                            } else if trace < 0.0 {
441                                negative += 1;
442                                zero += 1;
443                            } else {
444                                zero += 2;
445                            }
446                        }
447                    }
448                    col += 1;
449                }
450                PivotType::Delayed => {
451                    unreachable!("compute_inertia: delayed column at {}", col);
452                }
453            }
454        }
455
456        Inertia {
457            positive,
458            negative,
459            zero,
460        }
461    }
462}
463
464/// Iterator over pivot entries in a [`MixedDiagonal`].
465///
466/// Yields `(col_index, PivotEntry)` pairs, advancing by 2 for 2x2 pivots
467/// (yielding only at the lower-indexed column).
468pub struct PivotIter<'a> {
469    d: &'a MixedDiagonal,
470    col: usize,
471}
472
473impl<'a> Iterator for PivotIter<'a> {
474    type Item = (usize, PivotEntry);
475
476    fn next(&mut self) -> Option<Self::Item> {
477        if self.col >= self.d.n {
478            return None;
479        }
480        let col = self.col;
481        match self.d.pivot_map[col] {
482            PivotType::OneByOne => {
483                self.col += 1;
484                Some((col, PivotEntry::OneByOne(self.d.diag[col])))
485            }
486            PivotType::TwoByTwo { partner } if partner > col => {
487                self.col += 2;
488                Some((
489                    col,
490                    PivotEntry::TwoByTwo(Block2x2 {
491                        first_col: col,
492                        a: self.d.diag[col],
493                        b: self.d.off_diag[col],
494                        c: self.d.diag[col + 1],
495                    }),
496                ))
497            }
498            PivotType::TwoByTwo { .. } => {
499                // Second column of a 2x2 pair — skip
500                self.col += 1;
501                self.next()
502            }
503            PivotType::Delayed => {
504                self.col += 1;
505                Some((col, PivotEntry::Delayed))
506            }
507        }
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use crate::symmetric::pivot::{Block2x2, PivotType};
515
516    // ---- MixedDiagonal construction and query ----
517
518    #[test]
519    fn new_creates_all_delayed() {
520        let diag = MixedDiagonal::new(5);
521        assert_eq!(diag.dimension(), 5);
522        for col in 0..5 {
523            assert_eq!(diag.pivot_type(col), PivotType::Delayed);
524        }
525        assert_eq!(diag.num_delayed(), 5);
526        assert_eq!(diag.num_1x1(), 0);
527        assert_eq!(diag.num_2x2_pairs(), 0);
528    }
529
530    #[test]
531    fn set_1x1_marks_correct_pivot_type() {
532        let mut diag = MixedDiagonal::new(4);
533        diag.set_1x1(0, 3.5);
534        diag.set_1x1(2, -1.0);
535
536        assert_eq!(diag.pivot_type(0), PivotType::OneByOne);
537        assert_eq!(diag.pivot_type(1), PivotType::Delayed);
538        assert_eq!(diag.pivot_type(2), PivotType::OneByOne);
539        assert_eq!(diag.pivot_type(3), PivotType::Delayed);
540
541        assert_eq!(diag.diagonal_1x1(0), 3.5);
542        assert_eq!(diag.diagonal_1x1(2), -1.0);
543
544        assert_eq!(diag.num_1x1(), 2);
545        assert_eq!(diag.num_delayed(), 2);
546    }
547
548    #[test]
549    fn set_2x2_marks_both_columns() {
550        let mut diag = MixedDiagonal::new(6);
551        let block = Block2x2 {
552            first_col: 2,
553            a: 2.0,
554            b: 0.5,
555            c: -3.0,
556        };
557        diag.set_2x2(block);
558
559        assert_eq!(diag.pivot_type(2), PivotType::TwoByTwo { partner: 3 });
560        assert_eq!(diag.pivot_type(3), PivotType::TwoByTwo { partner: 2 });
561        assert_eq!(diag.diagonal_2x2(2), block);
562        assert_eq!(diag.num_2x2_pairs(), 1);
563        assert_eq!(diag.num_delayed(), 4);
564    }
565
566    #[test]
567    fn mixed_pivots_correct_counts() {
568        let mut diag = MixedDiagonal::new(6);
569        diag.set_2x2(Block2x2 {
570            first_col: 0,
571            a: 2.0,
572            b: 0.5,
573            c: -3.0,
574        });
575        diag.set_1x1(2, 4.0);
576        diag.set_1x1(3, -1.0);
577        diag.set_1x1(4, 7.0);
578        diag.set_1x1(5, 2.0);
579
580        assert_eq!(diag.num_2x2_pairs(), 1);
581        assert_eq!(diag.num_1x1(), 4);
582        assert_eq!(diag.num_delayed(), 0);
583        assert_eq!(diag.dimension(), 6);
584    }
585
586    // ---- solve_in_place ----
587
588    #[test]
589    fn solve_all_1x1() {
590        // D = diag(2, 4, -1, 5)
591        // b = [6, 12, -3, 20]
592        // x = [3, 3, 3, 4]
593        let mut diag = MixedDiagonal::new(4);
594        diag.set_1x1(0, 2.0);
595        diag.set_1x1(1, 4.0);
596        diag.set_1x1(2, -1.0);
597        diag.set_1x1(3, 5.0);
598
599        let mut x = vec![6.0, 12.0, -3.0, 20.0];
600        let b = x.clone();
601        diag.solve_in_place(&mut x);
602
603        assert_eq!(x, vec![3.0, 3.0, 3.0, 4.0]);
604
605        // Verify D*x = b (relative error)
606        let dx: Vec<f64> = vec![2.0 * x[0], 4.0 * x[1], -x[2], 5.0 * x[3]];
607        let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
608        let norm_diff: f64 = dx
609            .iter()
610            .zip(b.iter())
611            .map(|(d, bi)| (d - bi).powi(2))
612            .sum::<f64>()
613            .sqrt();
614        assert!(norm_diff / norm_b < 1e-14);
615    }
616
617    #[test]
618    fn solve_all_2x2() {
619        // D = [[2, 0.5], [0.5, -3]]  (one 2x2 block)
620        // b = [4.5, -0.5] → det = 2*(-3) - 0.5^2 = -6.25
621        // x1 = ((-3)*4.5 - 0.5*(-0.5)) / (-6.25) = (-13.5 + 0.25) / (-6.25) = -13.25/-6.25 = 2.12
622        // x2 = (2*(-0.5) - 0.5*4.5) / (-6.25) = (-1 - 2.25) / (-6.25) = -3.25/-6.25 = 0.52
623        let mut diag = MixedDiagonal::new(2);
624        diag.set_2x2(Block2x2 {
625            first_col: 0,
626            a: 2.0,
627            b: 0.5,
628            c: -3.0,
629        });
630
631        let b = vec![4.5, -0.5];
632        let mut x = b.clone();
633        diag.solve_in_place(&mut x);
634
635        // Verify: D*x should equal b
636        let dx0 = 2.0 * x[0] + 0.5 * x[1];
637        let dx1 = 0.5 * x[0] + (-3.0) * x[1];
638        let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
639        let norm_diff = ((dx0 - b[0]).powi(2) + (dx1 - b[1]).powi(2)).sqrt();
640        assert!(
641            norm_diff / norm_b < 1e-14,
642            "relative error: {:.2e}",
643            norm_diff / norm_b
644        );
645    }
646
647    #[test]
648    fn solve_mixed_1x1_and_2x2() {
649        // Dimension 6: 2x2 block at [0,1], then 1x1 at [2,3,4,5]
650        let mut diag = MixedDiagonal::new(6);
651        diag.set_2x2(Block2x2 {
652            first_col: 0,
653            a: 2.0,
654            b: 0.5,
655            c: -3.0,
656        });
657        diag.set_1x1(2, 4.0);
658        diag.set_1x1(3, -1.0);
659        diag.set_1x1(4, 7.0);
660        diag.set_1x1(5, 2.0);
661
662        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
663        let mut x = b.clone();
664        diag.solve_in_place(&mut x);
665
666        // Verify: reconstruct D*x and compare to b
667        // 2x2 block: D*x at [0,1]
668        let dx0 = 2.0 * x[0] + 0.5 * x[1];
669        let dx1 = 0.5 * x[0] + (-3.0) * x[1];
670        // 1x1 blocks: D*x at [2,3,4,5]
671        let dx2 = 4.0 * x[2];
672        let dx3 = -x[3];
673        let dx4 = 7.0 * x[4];
674        let dx5 = 2.0 * x[5];
675
676        let dx = [dx0, dx1, dx2, dx3, dx4, dx5];
677        let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
678        let norm_diff: f64 = dx
679            .iter()
680            .zip(b.iter())
681            .map(|(d, bi)| (d - bi).powi(2))
682            .sum::<f64>()
683            .sqrt();
684        assert!(
685            norm_diff / norm_b < 1e-14,
686            "relative error: {:.2e}",
687            norm_diff / norm_b
688        );
689    }
690
691    #[test]
692    fn solve_dimension_0_is_noop() {
693        let diag = MixedDiagonal::new(0);
694        let mut x: Vec<f64> = vec![];
695        diag.solve_in_place(&mut x);
696        assert!(x.is_empty());
697    }
698
699    // ---- Edge case tests ----
700
701    #[test]
702    fn dimension_0() {
703        let diag = MixedDiagonal::new(0);
704        assert_eq!(diag.dimension(), 0);
705        assert_eq!(diag.num_delayed(), 0);
706        assert_eq!(diag.num_1x1(), 0);
707        assert_eq!(diag.num_2x2_pairs(), 0);
708    }
709
710    #[test]
711    fn dimension_1_single_1x1() {
712        let mut diag = MixedDiagonal::new(1);
713        diag.set_1x1(0, 5.0);
714        assert_eq!(diag.pivot_type(0), PivotType::OneByOne);
715        assert_eq!(diag.diagonal_1x1(0), 5.0);
716        assert_eq!(diag.num_1x1(), 1);
717        assert_eq!(diag.num_delayed(), 0);
718    }
719
720    #[test]
721    fn dimension_2_single_2x2() {
722        let mut diag = MixedDiagonal::new(2);
723        let block = Block2x2 {
724            first_col: 0,
725            a: 1.0,
726            b: 0.0,
727            c: 1.0,
728        };
729        diag.set_2x2(block);
730        assert_eq!(diag.num_2x2_pairs(), 1);
731        assert_eq!(diag.num_delayed(), 0);
732    }
733
734    #[test]
735    fn all_2x2_even_n() {
736        let mut diag = MixedDiagonal::new(4);
737        diag.set_2x2(Block2x2 {
738            first_col: 0,
739            a: 1.0,
740            b: 0.0,
741            c: 1.0,
742        });
743        diag.set_2x2(Block2x2 {
744            first_col: 2,
745            a: 2.0,
746            b: 0.5,
747            c: 3.0,
748        });
749        assert_eq!(diag.num_2x2_pairs(), 2);
750        assert_eq!(diag.num_delayed(), 0);
751    }
752
753    #[test]
754    #[should_panic]
755    fn solve_panics_on_delayed_columns() {
756        let mut diag = MixedDiagonal::new(3);
757        diag.set_1x1(0, 1.0);
758        // columns 1, 2 still delayed
759        let mut x = vec![1.0, 2.0, 3.0];
760        diag.solve_in_place(&mut x); // should panic in debug mode
761    }
762
763    #[test]
764    #[should_panic]
765    fn set_2x2_at_last_column_odd_n_panics() {
766        let mut diag = MixedDiagonal::new(3);
767        // first_col = 2, but 2+1 = 3 which is NOT < 3, so debug-assert should fire
768        diag.set_2x2(Block2x2 {
769            first_col: 2,
770            a: 1.0,
771            b: 0.0,
772            c: 1.0,
773        });
774    }
775
776    // ---- compute_inertia unit tests ----
777
778    #[test]
779    fn inertia_all_positive_1x1() {
780        let mut diag = MixedDiagonal::new(4);
781        for i in 0..4 {
782            diag.set_1x1(i, (i + 1) as f64);
783        }
784        let inertia = diag.compute_inertia();
785        assert_eq!(
786            inertia,
787            Inertia {
788                positive: 4,
789                negative: 0,
790                zero: 0
791            }
792        );
793    }
794
795    #[test]
796    fn inertia_mixed_sign_1x1() {
797        let mut diag = MixedDiagonal::new(5);
798        diag.set_1x1(0, 3.0); // +
799        diag.set_1x1(1, -2.0); // -
800        diag.set_1x1(2, 1.0); // +
801        diag.set_1x1(3, -0.5); // -
802        diag.set_1x1(4, 0.0); // zero
803        let inertia = diag.compute_inertia();
804        assert_eq!(
805            inertia,
806            Inertia {
807                positive: 2,
808                negative: 2,
809                zero: 1
810            }
811        );
812    }
813
814    #[test]
815    fn inertia_2x2_det_negative_one_plus_one_minus() {
816        // [[2, 0.5], [0.5, -3]] → det = -6.25 < 0 → one +, one -
817        let mut diag = MixedDiagonal::new(2);
818        diag.set_2x2(Block2x2 {
819            first_col: 0,
820            a: 2.0,
821            b: 0.5,
822            c: -3.0,
823        });
824        let inertia = diag.compute_inertia();
825        assert_eq!(
826            inertia,
827            Inertia {
828                positive: 1,
829                negative: 1,
830                zero: 0
831            }
832        );
833    }
834
835    #[test]
836    fn inertia_2x2_det_positive_trace_positive() {
837        // [[5, 1], [1, 3]] → det = 15-1 = 14 > 0, trace = 8 > 0 → both positive
838        let mut diag = MixedDiagonal::new(2);
839        diag.set_2x2(Block2x2 {
840            first_col: 0,
841            a: 5.0,
842            b: 1.0,
843            c: 3.0,
844        });
845        let inertia = diag.compute_inertia();
846        assert_eq!(
847            inertia,
848            Inertia {
849                positive: 2,
850                negative: 0,
851                zero: 0
852            }
853        );
854    }
855
856    #[test]
857    fn inertia_2x2_det_positive_trace_negative() {
858        // [[-5, 1], [1, -3]] → det = 15-1 = 14 > 0, trace = -8 < 0 → both negative
859        let mut diag = MixedDiagonal::new(2);
860        diag.set_2x2(Block2x2 {
861            first_col: 0,
862            a: -5.0,
863            b: 1.0,
864            c: -3.0,
865        });
866        let inertia = diag.compute_inertia();
867        assert_eq!(
868            inertia,
869            Inertia {
870                positive: 0,
871                negative: 2,
872                zero: 0
873            }
874        );
875    }
876
877    #[test]
878    fn inertia_mixed_1x1_and_2x2() {
879        // quickstart.md example:
880        // 2x2 [[2, 0.5], [0.5, -3]] → det < 0 → 1+, 1-
881        // 1x1: 4(+), -1(-), 7(+), 2(+)
882        // Total: 4+, 2-
883        let mut diag = MixedDiagonal::new(6);
884        diag.set_2x2(Block2x2 {
885            first_col: 0,
886            a: 2.0,
887            b: 0.5,
888            c: -3.0,
889        });
890        diag.set_1x1(2, 4.0);
891        diag.set_1x1(3, -1.0);
892        diag.set_1x1(4, 7.0);
893        diag.set_1x1(5, 2.0);
894        let inertia = diag.compute_inertia();
895        assert_eq!(
896            inertia,
897            Inertia {
898                positive: 4,
899                negative: 2,
900                zero: 0
901            }
902        );
903    }
904
905    #[test]
906    fn scale_test_n_10000() {
907        // MixedDiagonal at n=10,000 with random mixed pivot pattern
908        let n = 10_000;
909        let mut diag = MixedDiagonal::new(n);
910
911        // Alternate: 2x2 block, then 1x1, then 2x2, then 1x1...
912        // Pattern: pairs at (0,1), 1x1 at 2, pairs at (3,4), 1x1 at 5, ...
913        let mut col = 0;
914        while col < n {
915            if col + 1 < n && col % 3 != 2 {
916                diag.set_2x2(Block2x2 {
917                    first_col: col,
918                    a: 2.0 + (col as f64) * 0.001,
919                    b: 0.1,
920                    c: 3.0 + (col as f64) * 0.001,
921                });
922                col += 2;
923            } else {
924                diag.set_1x1(col, 1.0 + (col as f64) * 0.001);
925                col += 1;
926            }
927        }
928
929        assert_eq!(diag.num_delayed(), 0);
930        assert_eq!(diag.dimension(), n);
931
932        // Solve round-trip: set b = [1, 2, 3, ...], solve D*x = b, verify D*x ≈ b
933        let b: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
934        let mut x = b.clone();
935        diag.solve_in_place(&mut x);
936
937        // Reconstruct D*x
938        let mut dx = vec![0.0; n];
939        for i in 0..n {
940            match diag.pivot_type(i) {
941                PivotType::OneByOne => {
942                    dx[i] = diag.diagonal_1x1(i) * x[i];
943                }
944                PivotType::TwoByTwo { partner } => {
945                    if i < partner {
946                        let block = diag.diagonal_2x2(i);
947                        dx[i] = block.a * x[i] + block.b * x[partner];
948                        dx[partner] = block.b * x[i] + block.c * x[partner];
949                    }
950                    // Skip partner column (already handled)
951                }
952                PivotType::Delayed => unreachable!(),
953            }
954        }
955
956        let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
957        let norm_diff: f64 = dx
958            .iter()
959            .zip(b.iter())
960            .map(|(d, bi)| (d - bi).powi(2))
961            .sum::<f64>()
962            .sqrt();
963        let rel_err = norm_diff / norm_b;
964        assert!(
965            rel_err < 1e-14,
966            "scale test: relative error {:.2e} exceeds 1e-14",
967            rel_err
968        );
969    }
970}