use crate::{
assert, linalg::triangular_solve as solve, unzipped, zipped_rw, ComplexField, Conj, Entity,
MatMut, MatRef, Parallelism,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use reborrow::*;
pub fn solve_in_place_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
pub fn solve_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
pub fn solve_transpose_in_place_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
pub fn solve_transpose_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
#[track_caller]
pub fn solve_in_place_with_conj<E: ComplexField>(
cholesky_factors: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
let n = cholesky_factors.nrows();
let k = rhs.ncols();
let _ = &stack;
assert!(all(
cholesky_factors.nrows() == cholesky_factors.ncols(),
rhs.nrows() == n,
));
let mut rhs = rhs;
solve::solve_unit_lower_triangular_in_place_with_conj(
cholesky_factors,
conj_lhs,
rhs.rb_mut(),
parallelism,
);
for j in 0..k {
for i in 0..n {
let d = unsafe { cholesky_factors.read_unchecked(i, i).faer_inv() };
let rhs_elem = unsafe { rhs.read_unchecked(i, j) };
unsafe {
rhs.write_unchecked(i, j, rhs_elem.faer_mul(d));
}
}
}
solve::solve_unit_upper_triangular_in_place_with_conj(
cholesky_factors.transpose(),
conj_lhs.compose(Conj::Yes),
rhs.rb_mut(),
parallelism,
);
}
#[track_caller]
pub fn solve_transpose_in_place_with_conj<E: ComplexField>(
cholesky_factors: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
solve_in_place_with_conj(
cholesky_factors,
match conj_lhs {
Conj::No => Conj::Yes,
Conj::Yes => Conj::No,
},
rhs,
parallelism,
stack,
)
}
#[track_caller]
pub fn solve_transpose_with_conj<E: ComplexField>(
dst: MatMut<'_, E>,
cholesky_factors: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
let mut dst = dst;
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_transpose_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
}
#[track_caller]
pub fn solve_with_conj<E: ComplexField>(
dst: MatMut<'_, E>,
cholesky_factors: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
let mut dst = dst;
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
}