use crate::traits::{Real, State};
use super::Matrix;
pub fn schur_complement<T: Real, V: State<T>>(
a: &Matrix<T>,
b: &Matrix<T>,
c: &Matrix<T>,
d: &Matrix<T>,
r: V,
s: V,
) -> (V, V) {
let n = a.n;
assert_eq!(b.n, n, "block size mismatch: B");
assert_eq!(c.n, n, "block size mismatch: C");
assert_eq!(d.n, n, "block size mismatch: D");
assert_eq!(r.len(), n, "rhs r size mismatch");
assert_eq!(s.len(), n, "rhs s size mismatch");
let solve_a = |rhs: V| a.lin_solve(rhs).unwrap();
let mut s_dense = Matrix::zeros(n, n);
for j in 0..n {
let mut e = V::zeros();
e.set(j, T::one());
let u = b.mul_state(&e);
let v = solve_a(u);
let z = c.mul_state(&v);
let d_ej = d.mul_state(&{
let mut e2 = V::zeros();
e2.set(j, T::one());
e2
});
for i in 0..n {
let val = d_ej.get(i) - z.get(i);
s_dense[(i, j)] = val;
}
}
let ar = solve_a(r);
let car = c.mul_state(&ar);
let mut w = V::zeros();
for i in 0..n {
w.set(i, s.get(i) - car.get(i));
}
let y = s_dense.lin_solve(w).unwrap();
let by = b.mul_state(&y);
let mut rhs_x = V::zeros();
for i in 0..n {
rhs_x.set(i, r.get(i) - by.get(i));
}
let x = solve_a(rhs_x);
(x, y)
}
#[cfg(test)]
mod tests {
use super::{Matrix, schur_complement};
use nalgebra::Vector2;
fn approx_eq(a: f64, b: f64) {
assert!((a - b).abs() < 1e-12, "{} != {}", a, b);
}
#[test]
fn schur_trivial_identity_blocks() {
let a: Matrix<f64> = Matrix::identity(2);
let d: Matrix<f64> = Matrix::identity(2);
let b: Matrix<f64> = Matrix::zeros(2, 2);
let c: Matrix<f64> = Matrix::zeros(2, 2);
let x_true = Vector2::new(1.0, -2.0);
let y_true = Vector2::new(3.0, 4.0);
let r = a.mul_state(&x_true);
let s = d.mul_state(&y_true);
let (x, y) = schur_complement(&a, &b, &c, &d, r, s);
approx_eq(x.x, x_true.x);
approx_eq(x.y, x_true.y);
approx_eq(y.x, y_true.x);
approx_eq(y.y, y_true.y);
}
#[test]
fn schur_mixed_blocks_small_dense() {
let a: Matrix<f64> = Matrix::from_vec(2, 2, vec![3.0, 1.0, 2.0, 4.0]);
let d: Matrix<f64> = Matrix::from_vec(2, 2, vec![2.0, 0.5, 1.0, 3.0]);
let b: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]);
let c: Matrix<f64> = Matrix::from_vec(2, 2, vec![0.5, 0.0, 0.0, 0.5]);
let x_true = Vector2::new(1.0, -2.0);
let y_true = Vector2::new(3.0, 4.0);
let r = {
let ax = a.mul_state(&x_true);
let by = b.mul_state(&y_true);
Vector2::new(ax.x + by.x, ax.y + by.y)
};
let s = {
let cx = c.mul_state(&x_true);
let dy = d.mul_state(&y_true);
Vector2::new(cx.x + dy.x, cx.y + dy.y)
};
let (x, y) = schur_complement(&a, &b, &c, &d, r, s);
approx_eq(x.x, x_true.x);
approx_eq(x.y, x_true.y);
approx_eq(y.x, y_true.x);
approx_eq(y.y, y_true.y);
}
}