use crate::{
assert, linalg::triangular_solve as solve, unzipped, zipped_rw, ComplexField, Conj, Entity,
MatMut, MatRef, Parallelism,
};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use reborrow::*;
pub fn solve_in_place_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
pub fn solve_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
pub fn solve_transpose_in_place_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
pub fn solve_transpose_req<E: Entity>(
cholesky_dimension: usize,
rhs_ncols: usize,
parallelism: Parallelism,
) -> Result<StackReq, SizeOverflow> {
let _ = cholesky_dimension;
let _ = rhs_ncols;
let _ = parallelism;
Ok(StackReq::default())
}
#[track_caller]
pub fn solve_in_place_with_conj<E: ComplexField>(
cholesky_factor: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
let _ = &stack;
let n = cholesky_factor.nrows();
assert!(all(
cholesky_factor.nrows() == cholesky_factor.ncols(),
rhs.nrows() == n,
));
let mut rhs = rhs;
solve::solve_lower_triangular_in_place_with_conj(
cholesky_factor,
conj_lhs,
rhs.rb_mut(),
parallelism,
);
solve::solve_upper_triangular_in_place_with_conj(
cholesky_factor.transpose(),
conj_lhs.compose(Conj::Yes),
rhs.rb_mut(),
parallelism,
);
}
#[track_caller]
pub fn solve_with_conj<E: ComplexField>(
dst: MatMut<'_, E>,
cholesky_factor: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
let mut dst = dst;
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
}
#[track_caller]
pub fn solve_transpose_in_place_with_conj<E: ComplexField>(
cholesky_factor: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatMut<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
solve_in_place_with_conj(
cholesky_factor,
conj_lhs.compose(Conj::Yes),
rhs,
parallelism,
stack,
)
}
#[track_caller]
pub fn solve_transpose_with_conj<E: ComplexField>(
dst: MatMut<'_, E>,
cholesky_factor: MatRef<'_, E>,
conj_lhs: Conj,
rhs: MatRef<'_, E>,
parallelism: Parallelism,
stack: &mut PodStack,
) {
let mut dst = dst;
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
solve_transpose_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
}