use faer::Par;
use faer::dyn_stack::{MemStack, StackReq};
use faer::linalg::{temp_mat_scratch, temp_mat_zeroed};
use faer::mat::AsMatMut;
use faer::matrix_free::LinOp;
use faer::prelude::{Reborrow, ReborrowMut};
use faer::{MatMut, MatRef};
use faer_traits::ComplexField;
use super::{BlockPrecondError, BlockSplit2, Precond};
#[derive(Debug)]
pub struct SchurPrecond2<AInv, SInv, B, C> {
split: BlockSplit2,
ainv: AInv,
sinv: SInv,
b: B,
c: C,
}
impl<AInv, SInv, B, C> SchurPrecond2<AInv, SInv, B, C> {
pub fn new<T>(
split: BlockSplit2,
ainv: AInv,
sinv: SInv,
b: B,
c: C,
) -> Result<Self, BlockPrecondError>
where
T: ComplexField,
AInv: Precond<T>,
SInv: Precond<T>,
B: LinOp<T>,
C: LinOp<T>,
{
validate_dims("ainv", ainv.nrows(), ainv.ncols(), split.n0, split.n0)?;
validate_dims("sinv", sinv.nrows(), sinv.ncols(), split.n1, split.n1)?;
validate_dims("b", b.nrows(), b.ncols(), split.n0, split.n1)?;
validate_dims("c", c.nrows(), c.ncols(), split.n1, split.n0)?;
Ok(Self {
split,
ainv,
sinv,
b,
c,
})
}
#[inline]
#[must_use]
pub fn split(&self) -> BlockSplit2 {
self.split
}
}
impl<T, AInv, SInv, B, C> LinOp<T> for SchurPrecond2<AInv, SInv, B, C>
where
T: ComplexField + Copy,
AInv: Precond<T>,
SInv: Precond<T>,
B: LinOp<T>,
C: LinOp<T>,
{
fn apply_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
StackReq::all_of(&[
temp_mat_scratch::<T>(self.split.n1, rhs_ncols),
temp_mat_scratch::<T>(self.split.n0, rhs_ncols),
self.ainv.apply_in_place_scratch(rhs_ncols, par),
self.sinv.apply_in_place_scratch(rhs_ncols, par),
self.b.apply_scratch(rhs_ncols, par),
self.c.apply_scratch(rhs_ncols, par),
])
}
fn nrows(&self) -> usize {
self.split.total_dim()
}
fn ncols(&self) -> usize {
self.split.total_dim()
}
fn apply(&self, mut out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack) {
assert_eq!(rhs.nrows(), self.ncols());
assert_eq!(out.nrows(), self.nrows());
assert_eq!(out.ncols(), rhs.ncols());
out.rb_mut().copy_from(rhs);
self.apply_in_place(out, par, stack);
}
fn conj_apply(
&self,
mut out: MatMut<'_, T>,
rhs: MatRef<'_, T>,
par: Par,
stack: &mut MemStack,
) {
assert_eq!(rhs.nrows(), self.ncols());
assert_eq!(out.nrows(), self.nrows());
assert_eq!(out.ncols(), rhs.ncols());
out.rb_mut().copy_from(rhs);
self.conj_apply_in_place(out, par, stack);
}
}
impl<T, AInv, SInv, B, C> Precond<T> for SchurPrecond2<AInv, SInv, B, C>
where
T: ComplexField + Copy,
AInv: Precond<T>,
SInv: Precond<T>,
B: LinOp<T>,
C: LinOp<T>,
{
fn apply_in_place_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
<Self as LinOp<T>>::apply_scratch(self, rhs_ncols, par)
}
fn apply_in_place(&self, mut rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
assert_eq!(rhs.nrows(), self.nrows());
let rhs_ncols = rhs.ncols();
{
let (mut rhs0, mut rhs1) = rhs.rb_mut().split_at_row_mut(self.split.n0);
self.ainv.apply_in_place(rhs0.rb_mut(), par, stack);
let (mut tmp_c, stack) = temp_mat_zeroed::<T, _, _>(self.split.n1, rhs_ncols, stack);
self.c.apply(tmp_c.as_mat_mut(), rhs0.rb(), par, stack);
subtract_in_place(rhs1.rb_mut(), tmp_c.as_mat_mut().as_ref());
self.sinv.apply_in_place(rhs1.rb_mut(), par, stack);
}
{
let (_, rhs1) = rhs.rb_mut().split_at_row_mut(self.split.n0);
let (mut tmp_b, stack) = temp_mat_zeroed::<T, _, _>(self.split.n0, rhs_ncols, stack);
self.b.apply(tmp_b.as_mat_mut(), rhs1.rb(), par, stack);
self.ainv.apply_in_place(tmp_b.as_mat_mut(), par, stack);
let (mut rhs0, _) = rhs.rb_mut().split_at_row_mut(self.split.n0);
subtract_in_place(rhs0.rb_mut(), tmp_b.as_mat_mut().as_ref());
}
}
fn conj_apply_in_place(&self, mut rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
assert_eq!(rhs.nrows(), self.nrows());
let rhs_ncols = rhs.ncols();
{
let (mut rhs0, mut rhs1) = rhs.rb_mut().split_at_row_mut(self.split.n0);
self.ainv.conj_apply_in_place(rhs0.rb_mut(), par, stack);
let (mut tmp_c, stack) = temp_mat_zeroed::<T, _, _>(self.split.n1, rhs_ncols, stack);
self.c.conj_apply(tmp_c.as_mat_mut(), rhs0.rb(), par, stack);
subtract_in_place(rhs1.rb_mut(), tmp_c.as_mat_mut().as_ref());
self.sinv.conj_apply_in_place(rhs1.rb_mut(), par, stack);
}
{
let (_, rhs1) = rhs.rb_mut().split_at_row_mut(self.split.n0);
let (mut tmp_b, stack) = temp_mat_zeroed::<T, _, _>(self.split.n0, rhs_ncols, stack);
self.b.conj_apply(tmp_b.as_mat_mut(), rhs1.rb(), par, stack);
self.ainv
.conj_apply_in_place(tmp_b.as_mat_mut(), par, stack);
let (mut rhs0, _) = rhs.rb_mut().split_at_row_mut(self.split.n0);
subtract_in_place(rhs0.rb_mut(), tmp_b.as_mat_mut().as_ref());
}
}
}
fn validate_dims(
which: &'static str,
actual_nrows: usize,
actual_ncols: usize,
expected_nrows: usize,
expected_ncols: usize,
) -> Result<(), BlockPrecondError> {
if actual_nrows != expected_nrows || actual_ncols != expected_ncols {
return Err(BlockPrecondError::DimensionMismatch {
which,
expected_nrows,
expected_ncols,
actual_nrows,
actual_ncols,
});
}
Ok(())
}
fn subtract_in_place<T: ComplexField + Copy>(mut lhs: MatMut<'_, T>, rhs: MatRef<'_, T>) {
assert_eq!(lhs.nrows(), rhs.nrows());
assert_eq!(lhs.ncols(), rhs.ncols());
for col in 0..lhs.ncols() {
for row in 0..lhs.nrows() {
lhs[(row, col)] -= rhs[(row, col)];
}
}
}
#[cfg(test)]
mod test {
use super::SchurPrecond2;
use crate::sparse::precond::DiagonalPrecond;
use crate::sparse::{BlockPrecondError, BlockSplit2, Precond};
use faer::dyn_stack::{MemBuffer, MemStack, StackReq};
use faer::matrix_free::LinOp;
use faer::{Mat, MatMut, MatRef, Par};
#[derive(Clone, Debug)]
struct DenseBlockOp {
data: Mat<f64>,
}
impl DenseBlockOp {
fn new(nrows: usize, ncols: usize, values: &[f64]) -> Self {
assert_eq!(values.len(), nrows * ncols);
let data = Mat::from_fn(nrows, ncols, |i, j| values[i + nrows * j]);
Self { data }
}
}
impl LinOp<f64> for DenseBlockOp {
fn apply_scratch(&self, _rhs_ncols: usize, _par: Par) -> StackReq {
StackReq::EMPTY
}
fn nrows(&self) -> usize {
self.data.nrows()
}
fn ncols(&self) -> usize {
self.data.ncols()
}
fn apply(
&self,
mut out: MatMut<'_, f64>,
rhs: MatRef<'_, f64>,
_par: Par,
_stack: &mut MemStack,
) {
for col in 0..out.ncols() {
for row in 0..out.nrows() {
out[(row, col)] = 0.0;
}
}
for col in 0..rhs.ncols() {
for k in 0..self.ncols() {
let rhs_value = rhs[(k, col)];
for row in 0..self.nrows() {
out[(row, col)] += self.data[(row, k)] * rhs_value;
}
}
}
}
fn conj_apply(
&self,
out: MatMut<'_, f64>,
rhs: MatRef<'_, f64>,
par: Par,
stack: &mut MemStack,
) {
self.apply(out, rhs, par, stack);
}
}
#[test]
fn schur_preconditioner_matches_exact_block_solve() {
let split = BlockSplit2::new(2, 1);
let ainv = DiagonalPrecond::from_inverse_diagonal(&[0.5, 1.0 / 3.0]);
let sinv = DiagonalPrecond::from_inverse_diagonal(&[0.25]);
let b = DenseBlockOp::new(2, 1, &[1.0, 2.0]);
let c = DenseBlockOp::new(1, 2, &[2.0, 3.0]);
let precond = SchurPrecond2::new::<f64>(split, ainv, sinv, b, c).unwrap();
let mut rhs = Mat::from_fn(3, 1, |i, _| [5.0, 7.0, 8.0][i]);
let mut buffer = MemBuffer::new(precond.apply_in_place_scratch(1, Par::Seq));
let stack = MemStack::new(&mut buffer);
precond.apply_in_place(rhs.as_mut(), Par::Seq, stack);
assert!((rhs[(0, 0)] - 3.0).abs() < 1.0e-12);
assert!((rhs[(1, 0)] - 3.0).abs() < 1.0e-12);
assert!((rhs[(2, 0)] + 1.0).abs() < 1.0e-12);
}
#[test]
fn schur_preconditioner_conjugate_matches_forward_apply_for_real_nonscalar_blocks() {
let split = BlockSplit2::new(2, 2);
let ainv = DiagonalPrecond::from_inverse_diagonal(&[0.5, 0.25]);
let sinv = DiagonalPrecond::from_inverse_diagonal(&[0.2, 0.5]);
let b = DenseBlockOp::new(2, 2, &[1.0, 3.0, 2.0, 4.0]);
let c = DenseBlockOp::new(2, 2, &[2.0, 1.0, 0.0, 5.0]);
let precond = SchurPrecond2::new::<f64>(split, ainv, sinv, b, c).unwrap();
let rhs = Mat::from_fn(4, 1, |i, _| [1.0, -2.0, 3.0, 4.0][i]);
let mut expected = rhs.clone();
let mut out = rhs.clone();
let mut buffer = MemBuffer::new(precond.apply_in_place_scratch(1, Par::Seq));
let stack = MemStack::new(&mut buffer);
precond.apply_in_place(expected.as_mut(), Par::Seq, stack);
let stack = MemStack::new(&mut buffer);
precond.conj_apply_in_place(out.as_mut(), Par::Seq, stack);
for row in 0..4 {
assert!((out[(row, 0)] - expected[(row, 0)]).abs() < 1.0e-12);
}
}
#[test]
fn schur_preconditioner_rejects_dimension_mismatch() {
let split = BlockSplit2::new(2, 1);
let ainv = DiagonalPrecond::from_inverse_diagonal(&[0.5, 1.0 / 3.0]);
let sinv = DiagonalPrecond::from_inverse_diagonal(&[0.25]);
let b = DenseBlockOp::new(3, 1, &[1.0, 2.0, 3.0]);
let c = DenseBlockOp::new(1, 2, &[2.0, 3.0]);
assert!(matches!(
SchurPrecond2::new::<f64>(split, ainv, sinv, b, c),
Err(BlockPrecondError::DimensionMismatch { which: "b", .. })
));
}
}