use dyn_stack::{PodStack, SizeOverflow, StackReq};
use faer_core::{
assert,
mul::triangular,
permutation::{Index, PermutationRef, SignedIndex},
temp_mat_req, temp_mat_uninit, ComplexField, Entity, MatMut, MatRef, Parallelism,
};
use reborrow::*;
use triangular::BlockStructure;
#[track_caller]
fn reconstruct_impl<I: Index, E: ComplexField>(
mut dst: MatMut<'_, E>,
lu_factors: Option<MatRef<'_, E>>,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
let lu_factors = match lu_factors {
Some(lu_factors) => lu_factors,
None => dst.rb(),
};
let m = lu_factors.nrows();
let n = lu_factors.ncols();
let size = Ord::min(m, n);
let (mut lu, _) = temp_mat_uninit::<E>(m, n, stack);
let mut lu = lu.as_mut();
let (l_top, _, l_bot, _) = lu_factors.split_at(size, size);
let (u_left, u_right, _, _) = lu_factors.split_at(size, size);
let (lu_topleft, lu_topright, lu_botleft, _) = lu.rb_mut().split_at_mut(size, size);
triangular::matmul(
lu_topleft,
BlockStructure::Rectangular,
l_top,
BlockStructure::UnitTriangularLower,
u_left,
BlockStructure::TriangularUpper,
None,
E::faer_one(),
parallelism,
);
triangular::matmul(
lu_topright,
BlockStructure::Rectangular,
l_top,
BlockStructure::UnitTriangularLower,
u_right,
BlockStructure::Rectangular,
None,
E::faer_one(),
parallelism,
);
triangular::matmul(
lu_botleft,
BlockStructure::Rectangular,
l_bot,
BlockStructure::Rectangular,
u_left,
BlockStructure::TriangularUpper,
None,
E::faer_one(),
parallelism,
);
let row_inv = row_perm.into_arrays().1;
let col_inv = col_perm.into_arrays().1;
assert!(row_inv.len() == m);
assert!(col_inv.len() == n);
unsafe {
if dst.row_stride().unsigned_abs() <= dst.col_stride().unsigned_abs() {
for j in 0..n {
let jj = col_inv.get_unchecked(j).to_signed().zx();
for i in 0..m {
let ii = row_inv.get_unchecked(i).to_signed().zx();
dst.write_unchecked(i, j, lu.read_unchecked(ii, jj));
}
}
} else {
for i in 0..m {
let ii = row_inv.get_unchecked(i).to_signed().zx();
for j in 0..n {
let jj = col_inv.get_unchecked(j).to_signed().zx();
dst.write_unchecked(i, j, lu.read_unchecked(ii, jj));
}
}
}
}
}
#[track_caller]
pub fn reconstruct<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
lu_factors: MatRef<'_, E>,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
assert!((dst.nrows(), dst.ncols()) == (lu_factors.nrows(), lu_factors.ncols()));
assert!((row_perm.len(), col_perm.len()) == (lu_factors.nrows(), lu_factors.ncols()));
reconstruct_impl(
dst,
Some(lu_factors),
row_perm,
col_perm,
parallelism,
stack,
)
}
#[track_caller]
pub fn reconstruct_in_place<I: Index, E: ComplexField>(
lu_factors: MatMut<'_, E>,
row_perm: PermutationRef<'_, I, E>,
col_perm: PermutationRef<'_, I, E>,
parallelism: Parallelism,
stack: PodStack<'_>,
) {
assert!((row_perm.len(), col_perm.len()) == (lu_factors.nrows(), lu_factors.ncols()));
reconstruct_impl(lu_factors, None, row_perm, col_perm, parallelism, stack)
}
pub fn reconstruct_in_place_req<I: Index, E: Entity>(
nrows: usize,
ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = parallelism;
temp_mat_req::<E>(nrows, ncols)
}
pub fn reconstruct_req<I: Index, E: Entity>(
nrows: usize,
ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
reconstruct_in_place_req::<I, E>(nrows, ncols, parallelism)
}