use crate::{
assert, join_raw,
mul::triangular::{self, BlockStructure},
solve, ComplexField, MatMut, MatRef, Parallelism,
};
use reborrow::*;
unsafe fn invert_lower_triangular_impl_small<E: ComplexField>(
mut dst: MatMut<'_, E>,
src: MatRef<'_, E>,
) {
let m = dst.nrows();
let src = {
#[inline(always)]
|i: usize, j: usize| src.read_unchecked(i, j)
};
match m {
0 => {}
1 => dst.write_unchecked(0, 0, src(0, 0).faer_inv()),
2 => {
let dst00 = src(0, 0).faer_inv();
let dst11 = src(1, 1).faer_inv();
let dst10 = (dst11.faer_mul(src(1, 0)).faer_mul(dst00)).faer_neg();
dst.write_unchecked(0, 0, dst00);
dst.write_unchecked(1, 1, dst11);
dst.write_unchecked(1, 0, dst10);
}
_ => unreachable!(),
}
}
unsafe fn invert_unit_lower_triangular_impl_small<E: ComplexField>(
mut dst: MatMut<'_, E>,
src: MatRef<'_, E>,
) {
let m = dst.nrows();
let src = |i: usize, j: usize| src.read_unchecked(i, j);
match m {
0 | 1 => {}
2 => {
dst.write_unchecked(1, 0, src(1, 0).faer_neg());
}
_ => unreachable!(),
}
}
unsafe fn invert_lower_triangular_impl<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
parallelism: Parallelism,
) {
let m = dst.nrows();
let n = dst.ncols();
if m <= 2 {
invert_lower_triangular_impl_small(dst, src);
return;
}
let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_at_mut(m / 2, n / 2) };
let m = src.nrows();
let n = src.ncols();
let (src_tl, _, src_bl, src_br) = { src.split_at(m / 2, n / 2) };
join_raw(
|parallelism| invert_lower_triangular_impl(dst_tl.rb_mut(), src_tl, parallelism),
|parallelism| invert_lower_triangular_impl(dst_br.rb_mut(), src_br, parallelism),
parallelism,
);
triangular::matmul(
dst_bl.rb_mut(),
BlockStructure::Rectangular,
src_bl,
BlockStructure::Rectangular,
dst_tl.rb(),
BlockStructure::TriangularLower,
None,
E::faer_one().faer_neg(),
parallelism,
);
solve::solve_lower_triangular_in_place(src_br, dst_bl, parallelism);
}
unsafe fn invert_unit_lower_triangular_impl<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
parallelism: Parallelism,
) {
let m = dst.nrows();
let n = dst.ncols();
if m <= 2 {
invert_unit_lower_triangular_impl_small(dst, src);
return;
}
let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_at_mut(m / 2, n / 2) };
let m = src.nrows();
let n = src.ncols();
let (src_tl, _, src_bl, src_br) = { src.split_at(m / 2, n / 2) };
join_raw(
|parallelism| invert_unit_lower_triangular_impl(dst_tl.rb_mut(), src_tl, parallelism),
|parallelism| invert_unit_lower_triangular_impl(dst_br.rb_mut(), src_br, parallelism),
parallelism,
);
triangular::matmul(
dst_bl.rb_mut(),
BlockStructure::Rectangular,
src_bl,
BlockStructure::Rectangular,
dst_tl.rb(),
BlockStructure::UnitTriangularLower,
None,
E::faer_one().faer_neg(),
parallelism,
);
solve::solve_unit_lower_triangular_in_place(src_br, dst_bl, parallelism);
}
#[track_caller]
pub fn invert_unit_lower_triangular<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
parallelism: Parallelism,
) {
assert!(all(
dst.nrows() == src.nrows(),
dst.ncols() == src.ncols(),
dst.nrows() == dst.ncols()
));
unsafe { invert_unit_lower_triangular_impl(dst, src, parallelism) }
}
#[track_caller]
pub fn invert_lower_triangular<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
parallelism: Parallelism,
) {
assert!(all(
dst.nrows() == src.nrows(),
dst.ncols() == src.ncols(),
dst.nrows() == dst.ncols()
));
unsafe { invert_lower_triangular_impl(dst, src, parallelism) }
}
#[track_caller]
pub fn invert_unit_upper_triangular<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
parallelism: Parallelism,
) {
invert_unit_lower_triangular(
dst.reverse_rows_and_cols_mut(),
src.reverse_rows_and_cols(),
parallelism,
)
}
#[track_caller]
pub fn invert_upper_triangular<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
parallelism: Parallelism,
) {
invert_lower_triangular(
dst.reverse_rows_and_cols_mut(),
src.reverse_rows_and_cols(),
parallelism,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{assert, Mat};
use assert_approx_eq::assert_approx_eq;
use rand::random;
#[test]
fn test_invert_lower() {
(0..32).for_each(|n| {
let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
let mut inv = Mat::zeros(n, n);
invert_lower_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
let mut prod = Mat::zeros(n, n);
triangular::matmul(
prod.as_mut(),
BlockStructure::Rectangular,
a.as_ref(),
BlockStructure::TriangularLower,
inv.as_ref(),
BlockStructure::TriangularLower,
None,
1.0,
Parallelism::Rayon(0),
);
for i in 0..n {
for j in 0..n {
let target = if i == j { 1.0 } else { 0.0 };
assert_approx_eq!(prod.read(i, j), target, 1e-4);
}
}
});
}
#[test]
fn test_invert_unit_lower() {
(0..32).for_each(|n| {
let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
let mut inv = Mat::zeros(n, n);
invert_unit_lower_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
let mut prod = Mat::zeros(n, n);
triangular::matmul(
prod.as_mut(),
BlockStructure::Rectangular,
a.as_ref(),
BlockStructure::UnitTriangularLower,
inv.as_ref(),
BlockStructure::UnitTriangularLower,
None,
1.0,
Parallelism::Rayon(0),
);
for i in 0..n {
for j in 0..n {
let target = if i == j { 1.0 } else { 0.0 };
assert_approx_eq!(prod.read(i, j), target, 1e-4);
}
}
});
}
#[test]
fn test_invert_upper() {
(0..32).for_each(|n| {
let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
let mut inv = Mat::zeros(n, n);
invert_upper_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
let mut prod = Mat::zeros(n, n);
triangular::matmul(
prod.as_mut(),
BlockStructure::Rectangular,
a.as_ref(),
BlockStructure::TriangularUpper,
inv.as_ref(),
BlockStructure::TriangularUpper,
None,
1.0,
Parallelism::Rayon(0),
);
for i in 0..n {
for j in 0..n {
let target = if i == j { 1.0 } else { 0.0 };
assert_approx_eq!(prod.read(i, j), target, 1e-4);
}
}
});
}
#[test]
fn test_invert_unit_upper() {
(0..32).for_each(|n| {
let a = Mat::from_fn(n, n, |_, _| 2.0 + random::<f64>());
let mut inv = Mat::zeros(n, n);
invert_unit_upper_triangular(inv.as_mut(), a.as_ref(), Parallelism::Rayon(0));
let mut prod = Mat::zeros(n, n);
triangular::matmul(
prod.as_mut(),
BlockStructure::Rectangular,
a.as_ref(),
BlockStructure::UnitTriangularUpper,
inv.as_ref(),
BlockStructure::UnitTriangularUpper,
None,
1.0,
Parallelism::Rayon(0),
);
for i in 0..n {
for j in 0..n {
let target = if i == j { 1.0 } else { 0.0 };
assert_approx_eq!(prod.read(i, j), target, 1e-4);
}
}
});
}
}