use super::inertia::Inertia;
use super::pivot::{Block2x2, PivotType};
#[derive(Debug, Clone)]
pub enum PivotEntry {
OneByOne(f64),
TwoByTwo(Block2x2),
Delayed,
}
#[derive(Debug)]
pub struct MixedDiagonal {
pivot_map: Vec<PivotType>,
diag: Vec<f64>,
off_diag: Vec<f64>,
n: usize,
}
impl MixedDiagonal {
pub fn new(n: usize) -> Self {
Self {
pivot_map: vec![PivotType::Delayed; n],
diag: vec![0.0; n],
off_diag: vec![0.0; n],
n,
}
}
pub fn set_1x1(&mut self, col: usize, value: f64) {
debug_assert!(
col < self.n,
"set_1x1: col {} out of bounds (n = {})",
col,
self.n
);
debug_assert!(
self.pivot_map[col] == PivotType::Delayed,
"set_1x1: col {} is already set ({:?})",
col,
self.pivot_map[col]
);
self.pivot_map[col] = PivotType::OneByOne;
self.diag[col] = value;
}
pub fn set_2x2(&mut self, block: Block2x2) {
let col = block.first_col;
debug_assert!(
col + 1 < self.n,
"set_2x2: first_col {} + 1 out of bounds (n = {})",
col,
self.n
);
debug_assert!(
self.pivot_map[col] == PivotType::Delayed,
"set_2x2: col {} is already set ({:?})",
col,
self.pivot_map[col]
);
debug_assert!(
self.pivot_map[col + 1] == PivotType::Delayed,
"set_2x2: col {} is already set ({:?})",
col + 1,
self.pivot_map[col + 1]
);
self.pivot_map[col] = PivotType::TwoByTwo { partner: col + 1 };
self.pivot_map[col + 1] = PivotType::TwoByTwo { partner: col };
self.diag[col] = block.a;
self.diag[col + 1] = block.c;
self.off_diag[col] = block.b;
}
pub fn dimension(&self) -> usize {
self.n
}
pub fn pivot_type(&self, col: usize) -> PivotType {
debug_assert!(
col < self.n,
"pivot_type: col {} out of bounds (n = {})",
col,
self.n
);
self.pivot_map[col]
}
pub fn diagonal_1x1(&self, col: usize) -> f64 {
debug_assert!(
self.pivot_map[col] == PivotType::OneByOne,
"diagonal_1x1: col {} is not OneByOne ({:?})",
col,
self.pivot_map[col]
);
self.diag[col]
}
pub fn diagonal_2x2(&self, first_col: usize) -> Block2x2 {
debug_assert!(
matches!(self.pivot_map[first_col], PivotType::TwoByTwo { partner } if partner > first_col),
"diagonal_2x2: col {} is not a 2x2 block owner ({:?})",
first_col,
self.pivot_map[first_col]
);
Block2x2 {
first_col,
a: self.diag[first_col],
b: self.off_diag[first_col],
c: self.diag[first_col + 1],
}
}
pub fn num_delayed(&self) -> usize {
self.pivot_map
.iter()
.filter(|p| **p == PivotType::Delayed)
.count()
}
pub fn num_1x1(&self) -> usize {
self.pivot_map
.iter()
.filter(|p| **p == PivotType::OneByOne)
.count()
}
pub fn grow(&mut self, new_n: usize) {
if new_n > self.n {
self.pivot_map.resize(new_n, PivotType::Delayed);
self.diag.resize(new_n, 0.0);
self.off_diag.resize(new_n, 0.0);
self.n = new_n;
}
}
pub fn truncate(&mut self, new_n: usize) {
debug_assert!(
new_n <= self.n,
"truncate: new_n {} > current n {}",
new_n,
self.n
);
self.pivot_map.truncate(new_n);
self.diag.truncate(new_n);
self.off_diag.truncate(new_n);
self.n = new_n;
}
pub fn copy_from_offset(&mut self, source: &MixedDiagonal, self_offset: usize, count: usize) {
debug_assert!(
self_offset + count <= self.n,
"copy_from_offset: self_offset {} + count {} > self.n {}",
self_offset,
count,
self.n
);
debug_assert!(
count <= source.n,
"copy_from_offset: count {} > source.n {}",
count,
source.n
);
let mut col = 0;
while col < count {
match source.pivot_map[col] {
PivotType::OneByOne => {
self.pivot_map[self_offset + col] = PivotType::OneByOne;
self.diag[self_offset + col] = source.diag[col];
col += 1;
}
PivotType::TwoByTwo { .. } if col + 1 < count => {
let dest = self_offset + col;
self.pivot_map[dest] = PivotType::TwoByTwo { partner: dest + 1 };
self.pivot_map[dest + 1] = PivotType::TwoByTwo { partner: dest };
self.diag[dest] = source.diag[col];
self.diag[dest + 1] = source.diag[col + 1];
self.off_diag[dest] = source.off_diag[col];
col += 2;
}
PivotType::Delayed => {
col += 1;
}
_ => {
col += 1;
}
}
}
}
pub fn iter_pivots(&self) -> PivotIter<'_> {
PivotIter { d: self, col: 0 }
}
pub fn num_2x2_pairs(&self) -> usize {
self.pivot_map
.iter()
.enumerate()
.filter(|(i, p)| matches!(p, PivotType::TwoByTwo { partner } if *partner > *i))
.count()
}
pub fn solve_in_place(&self, x: &mut [f64]) {
debug_assert_eq!(
x.len(),
self.n,
"solve_in_place: x.len() = {} != n = {}",
x.len(),
self.n
);
debug_assert!(
self.num_delayed() == 0,
"solve_in_place: {} delayed columns remain",
self.num_delayed()
);
let mut col = 0;
while col < self.n {
match self.pivot_map[col] {
PivotType::OneByOne => {
let d = self.diag[col];
if d == 0.0 {
x[col] = 0.0;
} else {
x[col] /= d;
}
col += 1;
}
PivotType::TwoByTwo { partner } => {
if partner > col {
let a = self.diag[col];
let b = self.off_diag[col];
let c = self.diag[partner];
let det = a * c - b * b;
if det == 0.0 {
x[col] = 0.0;
x[partner] = 0.0;
} else {
let r1 = x[col];
let r2 = x[partner];
x[col] = (c * r1 - b * r2) / det;
x[partner] = (a * r2 - b * r1) / det;
}
}
col += 1;
}
PivotType::Delayed => {
unreachable!("solve_in_place: delayed column at {}", col);
}
}
}
}
pub fn compute_inertia(&self) -> Inertia {
debug_assert!(
self.num_delayed() == 0,
"compute_inertia: {} delayed columns remain",
self.num_delayed()
);
let mut positive = 0usize;
let mut negative = 0usize;
let mut zero = 0usize;
let mut col = 0;
while col < self.n {
match self.pivot_map[col] {
PivotType::OneByOne => {
let d = self.diag[col];
if d > 0.0 {
positive += 1;
} else if d < 0.0 {
negative += 1;
} else {
zero += 1;
}
col += 1;
}
PivotType::TwoByTwo { partner } => {
if partner > col {
let a = self.diag[col];
let b = self.off_diag[col];
let c = self.diag[partner];
let det = a * c - b * b;
let trace = a + c;
if det > 0.0 {
if trace > 0.0 {
positive += 2;
} else {
negative += 2;
}
} else if det < 0.0 {
positive += 1;
negative += 1;
} else {
if trace > 0.0 {
positive += 1;
zero += 1;
} else if trace < 0.0 {
negative += 1;
zero += 1;
} else {
zero += 2;
}
}
}
col += 1;
}
PivotType::Delayed => {
unreachable!("compute_inertia: delayed column at {}", col);
}
}
}
Inertia {
positive,
negative,
zero,
}
}
}
pub struct PivotIter<'a> {
d: &'a MixedDiagonal,
col: usize,
}
impl<'a> Iterator for PivotIter<'a> {
type Item = (usize, PivotEntry);
fn next(&mut self) -> Option<Self::Item> {
if self.col >= self.d.n {
return None;
}
let col = self.col;
match self.d.pivot_map[col] {
PivotType::OneByOne => {
self.col += 1;
Some((col, PivotEntry::OneByOne(self.d.diag[col])))
}
PivotType::TwoByTwo { partner } if partner > col => {
self.col += 2;
Some((
col,
PivotEntry::TwoByTwo(Block2x2 {
first_col: col,
a: self.d.diag[col],
b: self.d.off_diag[col],
c: self.d.diag[col + 1],
}),
))
}
PivotType::TwoByTwo { .. } => {
self.col += 1;
self.next()
}
PivotType::Delayed => {
self.col += 1;
Some((col, PivotEntry::Delayed))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symmetric::pivot::{Block2x2, PivotType};
#[test]
fn new_creates_all_delayed() {
let diag = MixedDiagonal::new(5);
assert_eq!(diag.dimension(), 5);
for col in 0..5 {
assert_eq!(diag.pivot_type(col), PivotType::Delayed);
}
assert_eq!(diag.num_delayed(), 5);
assert_eq!(diag.num_1x1(), 0);
assert_eq!(diag.num_2x2_pairs(), 0);
}
#[test]
fn set_1x1_marks_correct_pivot_type() {
let mut diag = MixedDiagonal::new(4);
diag.set_1x1(0, 3.5);
diag.set_1x1(2, -1.0);
assert_eq!(diag.pivot_type(0), PivotType::OneByOne);
assert_eq!(diag.pivot_type(1), PivotType::Delayed);
assert_eq!(diag.pivot_type(2), PivotType::OneByOne);
assert_eq!(diag.pivot_type(3), PivotType::Delayed);
assert_eq!(diag.diagonal_1x1(0), 3.5);
assert_eq!(diag.diagonal_1x1(2), -1.0);
assert_eq!(diag.num_1x1(), 2);
assert_eq!(diag.num_delayed(), 2);
}
#[test]
fn set_2x2_marks_both_columns() {
let mut diag = MixedDiagonal::new(6);
let block = Block2x2 {
first_col: 2,
a: 2.0,
b: 0.5,
c: -3.0,
};
diag.set_2x2(block);
assert_eq!(diag.pivot_type(2), PivotType::TwoByTwo { partner: 3 });
assert_eq!(diag.pivot_type(3), PivotType::TwoByTwo { partner: 2 });
assert_eq!(diag.diagonal_2x2(2), block);
assert_eq!(diag.num_2x2_pairs(), 1);
assert_eq!(diag.num_delayed(), 4);
}
#[test]
fn mixed_pivots_correct_counts() {
let mut diag = MixedDiagonal::new(6);
diag.set_2x2(Block2x2 {
first_col: 0,
a: 2.0,
b: 0.5,
c: -3.0,
});
diag.set_1x1(2, 4.0);
diag.set_1x1(3, -1.0);
diag.set_1x1(4, 7.0);
diag.set_1x1(5, 2.0);
assert_eq!(diag.num_2x2_pairs(), 1);
assert_eq!(diag.num_1x1(), 4);
assert_eq!(diag.num_delayed(), 0);
assert_eq!(diag.dimension(), 6);
}
#[test]
fn solve_all_1x1() {
let mut diag = MixedDiagonal::new(4);
diag.set_1x1(0, 2.0);
diag.set_1x1(1, 4.0);
diag.set_1x1(2, -1.0);
diag.set_1x1(3, 5.0);
let mut x = vec![6.0, 12.0, -3.0, 20.0];
let b = x.clone();
diag.solve_in_place(&mut x);
assert_eq!(x, vec![3.0, 3.0, 3.0, 4.0]);
let dx: Vec<f64> = vec![2.0 * x[0], 4.0 * x[1], -x[2], 5.0 * x[3]];
let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
let norm_diff: f64 = dx
.iter()
.zip(b.iter())
.map(|(d, bi)| (d - bi).powi(2))
.sum::<f64>()
.sqrt();
assert!(norm_diff / norm_b < 1e-14);
}
#[test]
fn solve_all_2x2() {
let mut diag = MixedDiagonal::new(2);
diag.set_2x2(Block2x2 {
first_col: 0,
a: 2.0,
b: 0.5,
c: -3.0,
});
let b = vec![4.5, -0.5];
let mut x = b.clone();
diag.solve_in_place(&mut x);
let dx0 = 2.0 * x[0] + 0.5 * x[1];
let dx1 = 0.5 * x[0] + (-3.0) * x[1];
let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
let norm_diff = ((dx0 - b[0]).powi(2) + (dx1 - b[1]).powi(2)).sqrt();
assert!(
norm_diff / norm_b < 1e-14,
"relative error: {:.2e}",
norm_diff / norm_b
);
}
#[test]
fn solve_mixed_1x1_and_2x2() {
let mut diag = MixedDiagonal::new(6);
diag.set_2x2(Block2x2 {
first_col: 0,
a: 2.0,
b: 0.5,
c: -3.0,
});
diag.set_1x1(2, 4.0);
diag.set_1x1(3, -1.0);
diag.set_1x1(4, 7.0);
diag.set_1x1(5, 2.0);
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut x = b.clone();
diag.solve_in_place(&mut x);
let dx0 = 2.0 * x[0] + 0.5 * x[1];
let dx1 = 0.5 * x[0] + (-3.0) * x[1];
let dx2 = 4.0 * x[2];
let dx3 = -x[3];
let dx4 = 7.0 * x[4];
let dx5 = 2.0 * x[5];
let dx = [dx0, dx1, dx2, dx3, dx4, dx5];
let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
let norm_diff: f64 = dx
.iter()
.zip(b.iter())
.map(|(d, bi)| (d - bi).powi(2))
.sum::<f64>()
.sqrt();
assert!(
norm_diff / norm_b < 1e-14,
"relative error: {:.2e}",
norm_diff / norm_b
);
}
#[test]
fn solve_dimension_0_is_noop() {
let diag = MixedDiagonal::new(0);
let mut x: Vec<f64> = vec![];
diag.solve_in_place(&mut x);
assert!(x.is_empty());
}
#[test]
fn dimension_0() {
let diag = MixedDiagonal::new(0);
assert_eq!(diag.dimension(), 0);
assert_eq!(diag.num_delayed(), 0);
assert_eq!(diag.num_1x1(), 0);
assert_eq!(diag.num_2x2_pairs(), 0);
}
#[test]
fn dimension_1_single_1x1() {
let mut diag = MixedDiagonal::new(1);
diag.set_1x1(0, 5.0);
assert_eq!(diag.pivot_type(0), PivotType::OneByOne);
assert_eq!(diag.diagonal_1x1(0), 5.0);
assert_eq!(diag.num_1x1(), 1);
assert_eq!(diag.num_delayed(), 0);
}
#[test]
fn dimension_2_single_2x2() {
let mut diag = MixedDiagonal::new(2);
let block = Block2x2 {
first_col: 0,
a: 1.0,
b: 0.0,
c: 1.0,
};
diag.set_2x2(block);
assert_eq!(diag.num_2x2_pairs(), 1);
assert_eq!(diag.num_delayed(), 0);
}
#[test]
fn all_2x2_even_n() {
let mut diag = MixedDiagonal::new(4);
diag.set_2x2(Block2x2 {
first_col: 0,
a: 1.0,
b: 0.0,
c: 1.0,
});
diag.set_2x2(Block2x2 {
first_col: 2,
a: 2.0,
b: 0.5,
c: 3.0,
});
assert_eq!(diag.num_2x2_pairs(), 2);
assert_eq!(diag.num_delayed(), 0);
}
#[test]
#[should_panic]
fn solve_panics_on_delayed_columns() {
let mut diag = MixedDiagonal::new(3);
diag.set_1x1(0, 1.0);
let mut x = vec![1.0, 2.0, 3.0];
diag.solve_in_place(&mut x); }
#[test]
#[should_panic]
fn set_2x2_at_last_column_odd_n_panics() {
let mut diag = MixedDiagonal::new(3);
diag.set_2x2(Block2x2 {
first_col: 2,
a: 1.0,
b: 0.0,
c: 1.0,
});
}
#[test]
fn inertia_all_positive_1x1() {
let mut diag = MixedDiagonal::new(4);
for i in 0..4 {
diag.set_1x1(i, (i + 1) as f64);
}
let inertia = diag.compute_inertia();
assert_eq!(
inertia,
Inertia {
positive: 4,
negative: 0,
zero: 0
}
);
}
#[test]
fn inertia_mixed_sign_1x1() {
let mut diag = MixedDiagonal::new(5);
diag.set_1x1(0, 3.0); diag.set_1x1(1, -2.0); diag.set_1x1(2, 1.0); diag.set_1x1(3, -0.5); diag.set_1x1(4, 0.0); let inertia = diag.compute_inertia();
assert_eq!(
inertia,
Inertia {
positive: 2,
negative: 2,
zero: 1
}
);
}
#[test]
fn inertia_2x2_det_negative_one_plus_one_minus() {
let mut diag = MixedDiagonal::new(2);
diag.set_2x2(Block2x2 {
first_col: 0,
a: 2.0,
b: 0.5,
c: -3.0,
});
let inertia = diag.compute_inertia();
assert_eq!(
inertia,
Inertia {
positive: 1,
negative: 1,
zero: 0
}
);
}
#[test]
fn inertia_2x2_det_positive_trace_positive() {
let mut diag = MixedDiagonal::new(2);
diag.set_2x2(Block2x2 {
first_col: 0,
a: 5.0,
b: 1.0,
c: 3.0,
});
let inertia = diag.compute_inertia();
assert_eq!(
inertia,
Inertia {
positive: 2,
negative: 0,
zero: 0
}
);
}
#[test]
fn inertia_2x2_det_positive_trace_negative() {
let mut diag = MixedDiagonal::new(2);
diag.set_2x2(Block2x2 {
first_col: 0,
a: -5.0,
b: 1.0,
c: -3.0,
});
let inertia = diag.compute_inertia();
assert_eq!(
inertia,
Inertia {
positive: 0,
negative: 2,
zero: 0
}
);
}
#[test]
fn inertia_mixed_1x1_and_2x2() {
let mut diag = MixedDiagonal::new(6);
diag.set_2x2(Block2x2 {
first_col: 0,
a: 2.0,
b: 0.5,
c: -3.0,
});
diag.set_1x1(2, 4.0);
diag.set_1x1(3, -1.0);
diag.set_1x1(4, 7.0);
diag.set_1x1(5, 2.0);
let inertia = diag.compute_inertia();
assert_eq!(
inertia,
Inertia {
positive: 4,
negative: 2,
zero: 0
}
);
}
#[test]
fn scale_test_n_10000() {
let n = 10_000;
let mut diag = MixedDiagonal::new(n);
let mut col = 0;
while col < n {
if col + 1 < n && col % 3 != 2 {
diag.set_2x2(Block2x2 {
first_col: col,
a: 2.0 + (col as f64) * 0.001,
b: 0.1,
c: 3.0 + (col as f64) * 0.001,
});
col += 2;
} else {
diag.set_1x1(col, 1.0 + (col as f64) * 0.001);
col += 1;
}
}
assert_eq!(diag.num_delayed(), 0);
assert_eq!(diag.dimension(), n);
let b: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
let mut x = b.clone();
diag.solve_in_place(&mut x);
let mut dx = vec![0.0; n];
for i in 0..n {
match diag.pivot_type(i) {
PivotType::OneByOne => {
dx[i] = diag.diagonal_1x1(i) * x[i];
}
PivotType::TwoByTwo { partner } => {
if i < partner {
let block = diag.diagonal_2x2(i);
dx[i] = block.a * x[i] + block.b * x[partner];
dx[partner] = block.b * x[i] + block.c * x[partner];
}
}
PivotType::Delayed => unreachable!(),
}
}
let norm_b: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
let norm_diff: f64 = dx
.iter()
.zip(b.iter())
.map(|(d, bi)| (d - bi).powi(2))
.sum::<f64>()
.sqrt();
let rel_err = norm_diff / norm_b;
assert!(
rel_err < 1e-14,
"scale test: relative error {:.2e} exceeds 1e-14",
rel_err
);
}
}