use crate::sparse::compensated::{CompensatedField, CompensatedSum};
use crate::twosum::TwoSum;
use faer::prelude::Solve;
use faer::{Mat, MatRef};
use faer_traits::ext::ComplexFieldExt;
use num_traits::{Float, One, Zero};
pub(crate) fn dense_mul<T>(lhs: MatRef<'_, T>, rhs: MatRef<'_, T>) -> Mat<T>
where
T: CompensatedField,
T::Real: Float,
{
Mat::from_fn(lhs.nrows(), rhs.ncols(), |row, col| {
let mut acc = CompensatedSum::<T>::default();
for k in 0..lhs.ncols() {
acc.add(lhs[(row, k)] * rhs[(k, col)]);
}
acc.finish()
})
}
pub(crate) fn dense_mul_adjoint_rhs<T>(lhs: MatRef<'_, T>, rhs: MatRef<'_, T>) -> Mat<T>
where
T: CompensatedField,
T::Real: Float,
{
Mat::from_fn(lhs.nrows(), rhs.nrows(), |row, col| {
let mut acc = CompensatedSum::<T>::default();
for k in 0..lhs.ncols() {
acc.add(lhs[(row, k)] * rhs[(col, k)].conj());
}
acc.finish()
})
}
pub(crate) fn dense_mul_adjoint_lhs<T>(lhs: MatRef<'_, T>, rhs: MatRef<'_, T>) -> Mat<T>
where
T: CompensatedField,
T::Real: Float,
{
Mat::from_fn(lhs.ncols(), rhs.ncols(), |row, col| {
let mut acc = CompensatedSum::<T>::default();
for k in 0..lhs.nrows() {
acc.add(lhs[(k, row)].conj() * rhs[(k, col)]);
}
acc.finish()
})
}
pub(crate) fn dense_mul_plain<T>(lhs: MatRef<'_, T>, rhs: MatRef<'_, T>) -> Mat<T>
where
T: faer_traits::ComplexField + Copy,
{
Mat::from_fn(lhs.nrows(), rhs.ncols(), |row, col| {
let mut acc = T::zero();
for k in 0..lhs.ncols() {
acc += lhs[(row, k)] * rhs[(k, col)];
}
acc
})
}
pub(crate) fn inner_product_real<T>(lhs: MatRef<'_, T>, rhs: MatRef<'_, T>) -> T::Real
where
T: CompensatedField,
T::Real: Float,
{
let mut acc = CompensatedSum::<T>::default();
for row in 0..lhs.nrows() {
acc.add(lhs[(row, 0)].conj() * rhs[(row, 0)]);
}
acc.finish().real()
}
pub(crate) fn column_vector_norm<T>(vector: MatRef<'_, T>) -> T::Real
where
T: CompensatedField,
T::Real: Float,
{
let mut acc = <T::Real as Zero>::zero();
for row in 0..vector.nrows() {
acc += vector[(row, 0)].abs2();
}
acc.sqrt()
}
pub(crate) fn hermitian_project_in_place<T>(matrix: &mut Mat<T>)
where
T: CompensatedField,
T::Real: Float,
{
let one = <T::Real as One>::one();
let half = one / (one + one);
for col in 0..matrix.ncols() {
for row in 0..=col {
let avg = (matrix[(row, col)] + matrix[(col, row)].conj()).mul_real(half);
matrix[(row, col)] = avg;
matrix[(col, row)] = avg.conj();
}
}
}
pub(crate) fn frobenius_norm_plain<T>(matrix: MatRef<'_, T>) -> T::Real
where
T: faer_traits::ComplexField + Copy,
T::Real: Float,
{
let mut acc = <T::Real as Zero>::zero();
for col in 0..matrix.ncols() {
for row in 0..matrix.nrows() {
acc += matrix[(row, col)].abs2();
}
}
acc.sqrt()
}
pub(crate) fn frobenius_norm<T>(matrix: MatRef<'_, T>) -> T::Real
where
T: faer_traits::ComplexField + Copy,
T::Real: Float,
{
let mut acc: Option<TwoSum<T::Real>> = None;
for col in 0..matrix.ncols() {
for row in 0..matrix.nrows() {
let value = matrix[(row, col)].abs2();
match acc.as_mut() {
Some(acc) => acc.add(value),
None => acc = Some(TwoSum::new(value)),
}
}
}
match acc {
Some(acc) => {
let (sum, residual) = acc.finish();
(sum + residual).sqrt()
}
None => <T::Real as Zero>::zero(),
}
}
pub(crate) fn default_solve_tolerance<T>() -> T::Real
where
T: CompensatedField,
T::Real: Float,
{
T::Real::epsilon().sqrt()
}
pub(crate) fn solve_left_checked<T, E, F>(
lhs: MatRef<'_, T>,
rhs: MatRef<'_, T>,
tol: T::Real,
err: F,
) -> Result<Mat<T>, E>
where
T: faer_traits::ComplexField + Copy,
T::Real: Float,
F: Fn() -> E,
{
let solution = lhs.full_piv_lu().solve(rhs);
if !solution.as_ref().is_all_finite() {
return Err(err());
}
let residual = dense_mul_plain(lhs, solution.as_ref()) - rhs;
let residual_norm = frobenius_norm_plain(residual.as_ref());
let scale = frobenius_norm_plain(lhs) * frobenius_norm_plain(solution.as_ref())
+ frobenius_norm_plain(rhs);
let one = <T::Real as One>::one();
let threshold = scale.max(one) * tol * (one + one);
if !residual_norm.is_finite() || residual_norm > threshold {
return Err(err());
}
Ok(solution)
}
pub(crate) fn solve_right_checked<T, E, F>(
rhs_left: MatRef<'_, T>,
lhs_right: MatRef<'_, T>,
tol: T::Real,
err: F,
) -> Result<Mat<T>, E>
where
T: faer_traits::ComplexField + Copy,
T::Real: Float,
F: Fn() -> E + Copy,
{
let solved_t = solve_left_checked(lhs_right.transpose(), rhs_left.transpose(), tol, err)?;
Ok(solved_t.transpose().to_owned())
}