use dyn_stack::{PodStack, SizeOverflow, StackReq};
use faer_core::{
permutation::{permute_rows, Index, PermutationRef},
solve::*,
temp_mat_req, temp_mat_uninit, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism,
};
use reborrow::*;
fn solve_impl<I: Index, E: ComplexField>(
lu_factors: MatRef<'_, E>,
conj_lhs: Conj,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
dst: MatMut<'_, E>,
rhs: Option<MatRef<'_, E>>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
let n = lu_factors.ncols();
let k = dst.ncols();
let (mut temp, _) = temp_mat_uninit::<E>(n, k, stack);
let mut temp = temp.as_mut();
let src = match rhs {
Some(rhs) => rhs,
None => dst.rb(),
};
permute_rows(temp.rb_mut(), src, row_perm);
solve_unit_lower_triangular_in_place_with_conj(
lu_factors,
conj_lhs,
temp.rb_mut(),
parallelism,
);
solve_upper_triangular_in_place_with_conj(lu_factors, conj_lhs, temp.rb_mut(), parallelism);
permute_rows(dst, temp.rb(), col_perm.inverse());
}
fn solve_transpose_impl<I: Index, E: ComplexField>(
lu_factors: MatRef<'_, E>,
conj_lhs: Conj,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
dst: MatMut<'_, E>,
rhs: Option<MatRef<'_, E>>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
let n = lu_factors.ncols();
let k = dst.ncols();
let (mut temp, _) = temp_mat_uninit::<E>(n, k, stack);
let mut temp = temp.as_mut();
let src = match rhs {
Some(rhs) => rhs,
None => dst.rb(),
};
permute_rows(temp.rb_mut(), src, col_perm);
solve_lower_triangular_in_place_with_conj(
lu_factors.transpose(),
conj_lhs,
temp.rb_mut(),
parallelism,
);
solve_unit_upper_triangular_in_place_with_conj(
lu_factors.transpose(),
conj_lhs,
temp.rb_mut(),
parallelism,
);
permute_rows(dst, temp.rb(), row_perm.inverse());
}
pub fn solve_in_place_req<I: Index, E: Entity>(
lu_nrows: usize,
lu_ncols: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = lu_ncols;
let _ = parallelism;
temp_mat_req::<E>(lu_nrows, rhs_ncols)
}
pub fn solve_req<I: Index, E: Entity>(
lu_nrows: usize,
lu_ncols: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = lu_ncols;
let _ = parallelism;
temp_mat_req::<E>(lu_nrows, rhs_ncols)
}
pub fn solve_transpose_in_place_req<I: Index, E: Entity>(
lu_nrows: usize,
lu_ncols: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = lu_ncols;
let _ = parallelism;
temp_mat_req::<E>(lu_nrows, rhs_ncols)
}
pub fn solve_transpose_req<I: Index, E: Entity>(
lu_nrows: usize,
lu_ncols: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = lu_ncols;
let _ = parallelism;
temp_mat_req::<E>(lu_nrows, rhs_ncols)
}
pub fn solve<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
lu_factors: MatRef<'_, E>,
conj_lhs: Conj,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
solve_impl(
lu_factors,
conj_lhs,
row_perm,
col_perm,
dst,
Some(rhs),
parallelism,
stack,
)
}
pub fn solve_in_place<I: Index, E: ComplexField>(
lu_factors: MatRef<'_, E>,
conj_lhs: Conj,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
solve_impl(
lu_factors,
conj_lhs,
row_perm,
col_perm,
rhs,
None,
parallelism,
stack,
);
}
pub fn solve_transpose<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
lu_factors: MatRef<'_, E>,
conj_lhs: Conj,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
solve_transpose_impl(
lu_factors,
conj_lhs,
row_perm,
col_perm,
dst,
Some(rhs),
parallelism,
stack,
)
}
pub fn solve_transpose_in_place<I: Index, E: ComplexField>(
lu_factors: MatRef<'_, E>,
conj_lhs: Conj,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
solve_transpose_impl(
lu_factors,
conj_lhs,
row_perm,
col_perm,
rhs,
None,
parallelism,
stack,
);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::full_pivoting::compute::{lu_in_place, lu_in_place_req};
use faer_core::{assert, c32, c64, mul::matmul_with_conj, Mat};
use std::cell::RefCell;
macro_rules! make_stack {
($req: expr) => {
::dyn_stack::PodStack::new(&mut ::dyn_stack::GlobalPodBuffer::new($req.unwrap()))
};
}
fn test_solve<E: ComplexField>(mut gen: impl FnMut() -> E, epsilon: E::Real) {
(0..32).chain((1..8).map(|i| i * 32)).for_each(|n| {
for conj_lhs in [Conj::No, Conj::Yes] {
let a = Mat::from_fn(n, n, |_, _| gen());
let mut lu = a.clone();
let a = a.as_ref();
let mut lu = lu.as_mut();
let k = 32;
let rhs = Mat::from_fn(n, k, |_, _| gen());
let rhs = rhs.as_ref();
let mut sol = Mat::<E>::zeros(n, k);
let mut sol = sol.as_mut();
let mut row_perm = vec![0_usize; n];
let mut row_perm_inv = vec![0_usize; n];
let mut col_perm = vec![0_usize; n];
let mut col_perm_inv = vec![0_usize; n];
let parallelism = Parallelism::Rayon(0);
let (_, row_perm, col_perm) = lu_in_place(
lu.rb_mut(),
&mut row_perm,
&mut row_perm_inv,
&mut col_perm,
&mut col_perm_inv,
parallelism,
make_stack!(lu_in_place_req::<usize, E>(
n,
n,
parallelism,
Default::default()
)),
Default::default(),
);
solve(
sol.rb_mut(),
lu.rb(),
conj_lhs,
row_perm.rb(),
col_perm.rb(),
rhs,
parallelism,
make_stack!(solve_req::<usize, E>(n, n, k, parallelism)),
);
let mut rhs_reconstructed = Mat::zeros(n, k);
let mut rhs_reconstructed = rhs_reconstructed.as_mut();
matmul_with_conj(
rhs_reconstructed.rb_mut(),
a,
conj_lhs,
sol.rb(),
Conj::No,
None,
E::faer_one(),
parallelism,
);
for j in 0..k {
for i in 0..n {
assert!(
(rhs_reconstructed.read(i, j).faer_sub(rhs.read(i, j))).faer_abs()
< epsilon
)
}
}
}
});
}
fn test_solve_transpose<E: ComplexField>(mut gen: impl FnMut() -> E, epsilon: E::Real) {
(0..32).chain((1..16).map(|i| i * 32)).for_each(|n| {
for conj_lhs in [Conj::No, Conj::Yes] {
let a = Mat::from_fn(n, n, |_, _| gen());
let mut lu = a.clone();
let a = a.as_ref();
let mut lu = lu.as_mut();
let k = 32;
let rhs = Mat::from_fn(n, k, |_, _| gen());
let rhs = rhs.as_ref();
let mut sol = Mat::<E>::zeros(n, k);
let mut sol = sol.as_mut();
let mut row_perm = vec![0_usize; n];
let mut row_perm_inv = vec![0_usize; n];
let mut col_perm = vec![0_usize; n];
let mut col_perm_inv = vec![0_usize; n];
let parallelism = Parallelism::Rayon(0);
let (_, row_perm, col_perm) = lu_in_place(
lu.rb_mut(),
&mut row_perm,
&mut row_perm_inv,
&mut col_perm,
&mut col_perm_inv,
parallelism,
make_stack!(lu_in_place_req::<usize, E>(
n,
n,
parallelism,
Default::default()
)),
Default::default(),
);
solve_transpose(
sol.rb_mut(),
lu.rb(),
conj_lhs,
row_perm.rb(),
col_perm.rb(),
rhs,
parallelism,
make_stack!(solve_transpose_req::<usize, E>(n, n, k, parallelism)),
);
let mut rhs_reconstructed = Mat::zeros(n, k);
let mut rhs_reconstructed = rhs_reconstructed.as_mut();
matmul_with_conj(
rhs_reconstructed.rb_mut(),
a.transpose(),
conj_lhs,
sol.rb(),
Conj::No,
None,
E::faer_one(),
parallelism,
);
for j in 0..k {
for i in 0..n {
assert!(
(rhs_reconstructed.read(i, j).faer_sub(rhs.read(i, j))).faer_abs()
< epsilon
)
}
}
}
});
}
use rand::prelude::*;
thread_local! {
static RNG: RefCell<StdRng> = RefCell::new(StdRng::seed_from_u64(0));
}
fn random_f64() -> f64 {
RNG.with(|rng| {
let mut rng = rng.borrow_mut();
let rng = &mut *rng;
rng.gen()
})
}
fn random_f32() -> f32 {
RNG.with(|rng| {
let mut rng = rng.borrow_mut();
let rng = &mut *rng;
rng.gen()
})
}
fn random_c64() -> c64 {
c64 {
re: random_f64(),
im: random_f64(),
}
}
fn random_c32() -> c32 {
c32 {
re: random_f32(),
im: random_f32(),
}
}
#[test]
fn test_solve_f64() {
test_solve(random_f64, 1e-6_f64);
test_solve_transpose(random_f64, 1e-6_f64);
}
#[test]
fn test_solve_f32() {
test_solve(random_f32, 1e-1_f32);
test_solve_transpose(random_f32, 1e-1_f32);
}
#[test]
fn test_solve_c64() {
test_solve(random_c64, 1e-6_f64);
test_solve_transpose(random_c64, 1e-6_f64);
}
#[test]
fn test_solve_c32() {
test_solve(random_c32, 1e-1_f32);
test_solve_transpose(random_c32, 1e-1_f32);
}
}