use super::CholeskyError;
use crate::{
assert, debug_assert,
linalg::{
cholesky::ldlt_diagonal::compute::new_cholesky,
entity::{self, *},
matmul::triangular::BlockStructure,
triangular_solve,
},
utils::DivCeil,
ComplexField, Entity, MatMut, Parallelism,
};
use core::marker::PhantomData;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use pulp::Simd;
use reborrow::*;
fn cholesky_in_place_left_looking_impl<E: ComplexField>(
offset: usize,
matrix: MatMut<'_, E>,
regularization: LltRegularization<E>,
parallelism: Parallelism,
stack: &mut PodStack,
params: LltParams,
) -> Result<usize, CholeskyError> {
_ = params;
_ = parallelism;
match new_cholesky(matrix, ®ularization, stack) {
Ok(dyn_reg_count) => Ok(dyn_reg_count),
Err(mut e) => {
e.non_positive_definite_minor += offset;
Err(e)
}
}
}
#[derive(Default, Copy, Clone)]
#[non_exhaustive]
pub struct LltParams {}
#[derive(Copy, Clone, Debug)]
pub struct LltRegularization<E: ComplexField> {
pub dynamic_regularization_delta: E::Real,
pub dynamic_regularization_epsilon: E::Real,
}
impl<E: ComplexField> Default for LltRegularization<E> {
fn default() -> Self {
Self {
dynamic_regularization_delta: E::Real::faer_zero(),
dynamic_regularization_epsilon: E::Real::faer_zero(),
}
}
}
pub fn cholesky_in_place_req<E: Entity>(
dim: usize,
parallelism: Parallelism,
params: LltParams,
) -> Result<StackReq, SizeOverflow> {
let _ = parallelism;
let _ = params;
let dim = Ord::min(dim, 64);
crate::linalg::temp_mat_req::<E>(dim, dim)?.try_and(StackReq::try_new::<E>(dim)?)
}
fn cholesky_in_place_impl<E: ComplexField>(
offset: usize,
count: &mut usize,
matrix: MatMut<'_, E>,
regularization: LltRegularization<E>,
parallelism: Parallelism,
stack: &mut PodStack,
params: LltParams,
) -> Result<(), CholeskyError> {
debug_assert!(matrix.nrows() == matrix.ncols());
let mut matrix = matrix;
let mut stack = stack;
struct Lanes<E> {
__marker: PhantomData<E>,
}
impl<E: ComplexField> pulp::WithSimd for Lanes<E> {
type Output = usize;
fn with_simd<S: Simd>(self, _: S) -> Self::Output {
core::mem::size_of::<entity::SimdUnitFor<E, S>>() / core::mem::size_of::<E::Unit>()
}
}
let lanes = E::Simd::default().dispatch(Lanes {
__marker: PhantomData::<E>,
});
let stride = matrix.nrows().msrv_div_ceil(lanes);
let n = matrix.nrows();
if stride <= 4 && n <= 64 {
*count += cholesky_in_place_left_looking_impl(
offset,
matrix,
regularization,
parallelism,
stack,
params,
)?;
Ok(())
} else {
let block_size = n / 2;
let (mut l00, _, mut a10, mut a11) = matrix.rb_mut().split_at_mut(block_size, block_size);
cholesky_in_place_impl(
offset,
count,
l00.rb_mut(),
regularization,
parallelism,
stack.rb_mut(),
params,
)?;
let l00 = l00.into_const();
triangular_solve::solve_lower_triangular_in_place(
l00.conjugate(),
a10.rb_mut().transpose_mut(),
parallelism,
);
crate::linalg::matmul::triangular::matmul(
a11.rb_mut(),
BlockStructure::TriangularLower,
a10.rb(),
BlockStructure::Rectangular,
a10.rb().adjoint(),
BlockStructure::Rectangular,
Some(E::faer_one()),
E::faer_one().faer_neg(),
parallelism,
);
cholesky_in_place_impl(
offset + block_size,
count,
a11,
regularization,
parallelism,
stack,
params,
)
}
}
#[derive(Copy, Clone, Debug)]
pub struct LltInfo {
pub dynamic_regularization_count: usize,
}
#[track_caller]
#[inline]
pub fn cholesky_in_place<E: ComplexField>(
matrix: MatMut<'_, E>,
regularization: LltRegularization<E>,
parallelism: Parallelism,
stack: &mut PodStack,
params: LltParams,
) -> Result<LltInfo, CholeskyError> {
let _ = params;
assert!(matrix.ncols() == matrix.nrows());
#[cfg(feature = "perf-warn")]
if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) {
if matrix.col_stride().unsigned_abs() == 1 {
log::warn!(target: "faer_perf", "LLT prefers column-major matrix. Found row-major matrix.");
} else {
log::warn!(target: "faer_perf", "LLT prefers column-major matrix. Found matrix with generic strides.");
}
}
let mut count = 0;
cholesky_in_place_impl(
0,
&mut count,
matrix,
regularization,
parallelism,
stack,
params,
)?;
Ok(LltInfo {
dynamic_regularization_count: count,
})
}