differential-equations 0.6.0

A Rust library for solving differential equations.
Documentation
//! Schur complement helpers for block systems used in IRK solvers.

use crate::traits::{Real, State};

use super::Matrix;

/// Solve the 2x2 block system using the (explicit) Schur complement:
/// [A B; C D] [x;y] = [r;s]
/// Returns (x, y).
///
/// Notes:
/// - This forms the dense Schur complement S = D - C A^{-1} B explicitly.
///   For small per-stage blocks (common in IRK), this is acceptable and simple.
/// - For larger blocks, prefer an operator-based approach that applies S without forming it.
pub fn schur_complement<T: Real, V: State<T>>(
    a: &Matrix<T>,
    b: &Matrix<T>,
    c: &Matrix<T>,
    d: &Matrix<T>,
    r: V,
    s: V,
) -> Result<(V, V), crate::linalg::LinalgError> {
    let n = a.n;
    assert_eq!(b.n, n, "block size mismatch: B");
    assert_eq!(c.n, n, "block size mismatch: C");
    assert_eq!(d.n, n, "block size mismatch: D");
    assert_eq!(r.len(), n, "rhs r size mismatch");
    assert_eq!(s.len(), n, "rhs s size mismatch");

    // Helper: solve with A and D using existing dense LU path
    let solve_a = |rhs: V| a.lin_solve(rhs);

    // Build dense Schur complement S = D - C A^{-1} B, as a dense Full matrix
    // We'll fill column-by-column using basis vectors e_j.
    let mut s_dense = Matrix::zeros(n, n);
    for j in 0..n {
        // e_j
        let mut e = r.zeros_like();
        e.set_component(j, T::one());
        // u = B e_j
        let u = b.mul_state::<V>(&e);
        // v = A^{-1} u
        let v = solve_a(u)?;
        // z = C v
        let z = c.mul_state::<V>(&v);
        // column j of S is (D e_j - z)
        let d_ej = d.mul_state::<V>(&e);
        let diff = d_ej.minus(&z);
        for i in 0..n {
            s_dense[(i, j)] = diff.get_component(i);
        }
    }

    // Compute w = s - C A^{-1} r
    let ar = solve_a(r.clone())?;
    let car = c.mul_state::<V>(&ar);
    let w = s.minus(&car);

    // Solve S y = w
    let y = s_dense.lin_solve(w)?;

    // Back-substitute for x: A x = r - B y
    let by = b.mul_state::<V>(&y);
    let rhs_x = r.minus(&by);
    let x = solve_a(rhs_x)?;

    Ok((x, y))
}

#[cfg(all(test, feature = "nalgebra"))]
mod tests {
    use super::{Matrix, schur_complement};
    use nalgebra::Vector2;

    fn approx_eq(a: f64, b: f64) {
        assert!((a - b).abs() < 1e-12, "{} != {}", a, b);
    }

    #[test]
    fn schur_trivial_identity_blocks() {
        let a: Matrix<f64> = Matrix::identity(2);
        let d: Matrix<f64> = Matrix::identity(2);
        let b: Matrix<f64> = Matrix::zeros(2, 2);
        let c: Matrix<f64> = Matrix::zeros(2, 2);

        let x_true = Vector2::new(1.0, -2.0);
        let y_true = Vector2::new(3.0, 4.0);

        // r = A x + B y = x; s = C x + D y = y
        let r = a.mul_state(&x_true);
        let s = d.mul_state(&y_true);

        let (x, y) = schur_complement(&a, &b, &c, &d, r, s).unwrap();
        approx_eq(x.x, x_true.x);
        approx_eq(x.y, x_true.y);
        approx_eq(y.x, y_true.x);
        approx_eq(y.y, y_true.y);
    }

    #[test]
    fn schur_mixed_blocks_small_dense() {
        // Choose small invertible A and and simple B, C
        let a: Matrix<f64> = Matrix::from_vec(2, 2, vec![3.0, 1.0, 2.0, 4.0]).unwrap();
        let d: Matrix<f64> = Matrix::from_vec(2, 2, vec![2.0, 0.5, 1.0, 3.0]).unwrap();
        let b: Matrix<f64> = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]).unwrap();
        let c: Matrix<f64> = Matrix::from_vec(2, 2, vec![0.5, 0.0, 0.0, 0.5]).unwrap();

        let x_true = Vector2::new(1.0, -2.0);
        let y_true = Vector2::new(3.0, 4.0);

        // r = A x + B y, s = C x + D y
        let r = {
            let ax = a.mul_state(&x_true);
            let by = b.mul_state(&y_true);
            Vector2::new(ax.x + by.x, ax.y + by.y)
        };
        let s = {
            let cx = c.mul_state(&x_true);
            let dy = d.mul_state(&y_true);
            Vector2::new(cx.x + dy.x, cx.y + dy.y)
        };

        let (x, y) = schur_complement(&a, &b, &c, &d, r, s).unwrap();
        approx_eq(x.x, x_true.x);
        approx_eq(x.y, x_true.y);
        approx_eq(y.x, y_true.x);
        approx_eq(y.y, y_true.y);
    }
}