use dyn_stack::{MemBuffer, MemStack};
use faer::diag::{Diag, DiagRef};
use faer::linalg::solvers::{self, Solve};
pub use faer::linalg::solvers::{
Lblt as FaerLblt, Ldlt as FaerLdlt, Llt as FaerLlt, Solve as FaerSolve,
};
use faer::linalg::svd::{self, ComputeSvdVectors};
use faer::prelude::ReborrowMut;
use faer::{Conj, Mat, MatMut, MatRef, Par, Side, Unbind, get_global_parallelism};
use ndarray::{Array1, Array2, ArrayBase, ArrayViewMut1, Data, Ix1, Ix2};
use std::marker::PhantomData;
use std::panic::{AssertUnwindSafe, catch_unwind};
use thiserror::Error;
const RRQR_RANK_ALPHA: f64 = 100.0;
#[derive(Debug, Error)]
pub enum FaerLinalgError {
#[error("Factorization failed in {context}")]
FactorizationFailed { context: &'static str },
#[error("SVD failed to converge in {context}")]
SvdNoConvergence { context: &'static str },
#[error("Self-adjoint eigendecomposition input contains non-finite values in {context}")]
SelfAdjointEigenNonFiniteInput { context: &'static str },
#[error("Self-adjoint eigendecomposition failed: {0:?}")]
SelfAdjointEigen(solvers::EvdError),
#[error("Cholesky factorization failed: {0:?}")]
Cholesky(solvers::LltError),
#[error("LDLT factorization failed: {0:?}")]
Ldlt(solvers::LdltError),
}
pub enum FaerSymmetricFactor {
Llt(FaerLlt<f64>),
Ldlt(FaerLdlt<f64>),
Lblt(FaerLblt<f64>),
}
#[inline]
pub(crate) fn cholesky_factor_logdet(factor: MatRef<'_, f64>) -> f64 {
2.0 * diagonal_log_sum(factor.diagonal())
}
#[inline]
fn diagonal_log_sum(diagonal: DiagRef<'_, f64>) -> f64 {
diagonal
.column_vector()
.iter()
.map(|&x| x.ln())
.sum::<f64>()
}
impl FaerSymmetricFactor {
#[inline]
pub fn solve(&self, rhs: MatRef<'_, f64>) -> Mat<f64> {
match self {
FaerSymmetricFactor::Llt(f) => f.solve(rhs),
FaerSymmetricFactor::Ldlt(f) => f.solve(rhs),
FaerSymmetricFactor::Lblt(f) => f.solve(rhs),
}
}
#[inline]
pub fn solve_in_place(&self, rhs: MatMut<'_, f64>) {
match self {
FaerSymmetricFactor::Llt(f) => f.solve_in_place(rhs),
FaerSymmetricFactor::Ldlt(f) => f.solve_in_place(rhs),
FaerSymmetricFactor::Lblt(f) => f.solve_in_place(rhs),
}
}
}
impl crate::matrix::FactorizedSystem for FaerSymmetricFactor {
fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
let mut out = rhs.clone();
let mut out_mat = array1_to_col_matmut(&mut out);
self.solve_in_place(out_mat.as_mut());
if !out.iter().all(|v| v.is_finite()) {
return Err("symmetric factor solve produced non-finite values".to_string());
}
Ok(out)
}
fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
let mut out = Array2::<f64>::zeros(rhs.raw_dim());
for j in 0..rhs.ncols() {
for i in 0..rhs.nrows() {
out[[i, j]] = rhs[[i, j]];
}
}
let mut out_mat = array2_to_matmut(&mut out);
self.solve_in_place(out_mat.as_mut());
if !out.iter().all(|v| v.is_finite()) {
return Err("symmetric factor multi-solve produced non-finite values".to_string());
}
Ok(out)
}
fn logdet(&self) -> f64 {
match self {
FaerSymmetricFactor::Llt(f) => cholesky_factor_logdet(f.L()),
FaerSymmetricFactor::Ldlt(f) => diagonal_log_sum(f.D()),
FaerSymmetricFactor::Lblt(..) => {
f64::NAN
}
}
}
}
#[inline]
pub fn factorize_symmetricwith_fallback(
matrix: MatRef<'_, f64>,
side: Side,
) -> Result<FaerSymmetricFactor, FaerLinalgError> {
if let Ok(llt) = FaerLlt::new(matrix, side) {
return Ok(FaerSymmetricFactor::Llt(llt));
}
let ldlt_err = match FaerLdlt::new(matrix, side) {
Ok(ldlt) => return Ok(FaerSymmetricFactor::Ldlt(ldlt)),
Err(err) => err,
};
let lblt = catch_unwind(AssertUnwindSafe(|| FaerLblt::new(matrix, side)))
.map_err(|_| FaerLinalgError::Ldlt(ldlt_err))?;
Ok(FaerSymmetricFactor::Lblt(lblt))
}
#[inline]
const fn should_use_faer_matmul(m: usize, n: usize, k: usize) -> bool {
const MIN_DIM: usize = 32;
const MIN_FLOP_SCALE: usize = 64 * 64;
(m >= MIN_DIM || n >= MIN_DIM || k >= MIN_DIM)
&& m.saturating_mul(n).saturating_mul(k) >= MIN_FLOP_SCALE
}
#[inline]
pub(crate) fn matmul_parallelism(m: usize, n: usize, k: usize) -> Par {
const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
const PAR_MIN_LONG_DIM: usize = 256;
let flop_scale = m.saturating_mul(n).saturating_mul(k);
let long_dim = m.max(n).max(k);
if flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM {
get_global_parallelism()
} else {
Par::Seq
}
}
#[inline]
pub fn array2_to_matmut(array: &mut Array2<f64>) -> MatMut<'_, f64> {
let (rows, cols) = array.dim();
let strides = array.strides();
let s0 = strides[0];
let s1 = strides[1];
unsafe { MatMut::from_raw_parts_mut(array.as_mut_ptr(), rows, cols, s0, s1) }
}
#[inline]
pub fn array1_to_col_matmut(array: &mut Array1<f64>) -> MatMut<'_, f64> {
let len = array.len();
let stride = array.strides()[0];
unsafe {
MatMut::from_raw_parts_mut(
array.as_mut_ptr(),
len,
1,
stride,
0, )
}
}
#[inline]
pub fn fast_ata<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>) -> Array2<f64> {
let p = a.ncols();
let mut out = Array2::<f64>::zeros((p, p));
fast_ata_into(a, &mut out);
out
}
#[inline]
pub fn fast_ata_into<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>, out: &mut Array2<f64>) {
use faer::Accum;
use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
let (n, p) = a.dim();
assert_eq!(out.nrows(), p, "output rows must match p");
assert_eq!(out.ncols(), p, "output cols must match p");
if !should_use_faer_matmul(p, p, n) {
out.assign(&a.t().dot(a));
return;
}
let mut outview = array2_to_matmut(out);
let aview = FaerArrayView::new(a);
let a_ref = aview.as_ref();
let a_t = a_ref.transpose();
let par = matmul_parallelism(p, p, n);
tri_matmul(
outview.as_mut(),
BlockStructure::TriangularLower,
Accum::Replace,
a_t,
BlockStructure::Rectangular,
a_ref,
BlockStructure::Rectangular,
1.0,
par,
);
for i in 0..p {
for j in (i + 1)..p {
out[[i, j]] = out[[j, i]];
}
}
}
#[inline]
pub fn fast_atb<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix2>,
) -> Array2<f64> {
if let Some(out) = crate::gpu::linalg::try_fast_atb(a.view(), b.view()) {
return out;
}
let (n_a, p) = a.dim();
let q = b.ncols();
fast_atb_with_parallelism(a, b, matmul_parallelism(p, q, n_a))
}
#[inline]
pub fn fast_atb_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix2>,
par: Par,
) -> Array2<f64> {
use faer::linalg::matmul::matmul;
use faer::{Accum, Mat};
let (n_a, p) = a.dim();
let (n_b, q) = b.dim();
assert_eq!(n_a, n_b, "A and B must have same number of rows");
if !should_use_faer_matmul(p, q, n_a) {
return a.t().dot(b);
}
let mut result = Mat::<f64>::zeros(p, q);
let aview = FaerArrayView::new(a);
let bview = FaerArrayView::new(b);
let a_ref = aview.as_ref();
let b_ref = bview.as_ref();
matmul(
result.as_mut(),
Accum::Replace,
a_ref.transpose(),
b_ref,
1.0,
par,
);
mat_to_array(result.as_ref())
}
#[inline]
pub fn fast_abt<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix2>,
) -> Array2<f64> {
use faer::linalg::matmul::matmul;
use faer::{Accum, Mat};
let (m, k_a) = a.dim();
let (n, k_b) = b.dim();
assert_eq!(
k_a, k_b,
"A and B must have same number of columns for A·Bᵀ"
);
if !should_use_faer_matmul(m, n, k_a) {
return a.dot(&b.t());
}
let mut result = Mat::<f64>::zeros(m, n);
let aview = FaerArrayView::new(a);
let bview = FaerArrayView::new(b);
let par = matmul_parallelism(m, n, k_a);
matmul(
result.as_mut(),
Accum::Replace,
aview.as_ref(),
bview.as_ref().transpose(),
1.0,
par,
);
mat_to_array(result.as_ref())
}
#[inline]
pub fn fast_ab<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix2>,
) -> Array2<f64> {
if let Some(out) = crate::gpu::linalg::try_fast_ab(a.view(), b.view()) {
return out;
}
let n = a.nrows();
let q = b.ncols();
let mut out = Array2::<f64>::zeros((n, q));
fast_ab_into(a, b, &mut out);
out
}
#[inline]
pub fn fast_av<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
) -> Array1<f64> {
if let Some(out) = crate::gpu::linalg::try_fast_av(a.view(), v.view()) {
return out;
}
fast_av_impl(a, v)
}
#[inline]
fn fast_av_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
) -> Array1<f64> {
use faer::linalg::matmul::matmul;
use faer::{Accum, Mat};
let (n, p) = a.dim();
assert_eq!(p, v.len(), "A cols must match v length");
if !should_use_faer_matmul(n, 1, p) {
return a.dot(v);
}
let mut result = Mat::<f64>::zeros(n, 1);
let aview = FaerArrayView::new(a);
let vview = FaerColView::new(v);
let a_ref = aview.as_ref();
let v_ref = vview.as_ref();
let par = matmul_parallelism(n, 1, p);
matmul(result.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
let mut out = Array1::<f64>::zeros(n);
for i in 0..n {
out[i] = result[(i, 0)];
}
out
}
#[inline]
pub fn fast_av_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
out: &mut Array1<f64>,
) {
fast_av_into_impl(a, v, out);
}
#[inline]
fn fast_av_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
out: &mut Array1<f64>,
) {
use faer::Accum;
use faer::linalg::matmul::matmul;
let (n, p) = a.dim();
assert_eq!(v.len(), p, "vector length must match A cols");
assert_eq!(out.len(), n, "output length must match A rows");
if !should_use_faer_matmul(n, 1, p) {
out.assign(&a.dot(v));
return;
}
let mut outview = array1_to_col_matmut(out);
let aview = FaerArrayView::new(a);
let vview = FaerColView::new(v);
let a_ref = aview.as_ref();
let v_ref = vview.as_ref();
let par = matmul_parallelism(n, 1, p);
matmul(outview.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
}
#[inline]
pub fn fast_av_view_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
out: ArrayViewMut1<'_, f64>,
) {
fast_av_view_into_impl(a, v, out);
}
#[inline]
fn fast_av_view_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
mut out: ArrayViewMut1<'_, f64>,
) {
use faer::Accum;
use faer::linalg::matmul::matmul;
let (n, p) = a.dim();
assert_eq!(v.len(), p, "vector length must match A cols");
assert_eq!(out.len(), n, "output length must match A rows");
if !should_use_faer_matmul(n, 1, p) {
let prod = a.dot(v);
out.assign(&prod);
return;
}
let len = out.len();
let stride = out.strides()[0];
let outview = unsafe {
MatMut::from_raw_parts_mut(
out.as_mut_ptr(),
len,
1,
stride,
0, )
};
let aview = FaerArrayView::new(a);
let vview = FaerColView::new(v);
let a_ref = aview.as_ref();
let v_ref = vview.as_ref();
let par = matmul_parallelism(n, 1, p);
matmul(outview, Accum::Replace, a_ref, v_ref, 1.0, par);
}
#[inline]
pub fn fast_atv<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
) -> Array1<f64> {
if let Some(out) = crate::gpu::linalg::try_fast_atv(a.view(), v.view()) {
return out;
}
fast_atv_impl(a, v)
}
#[inline]
fn fast_atv_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
v: &ArrayBase<S2, Ix1>,
) -> Array1<f64> {
use faer::linalg::matmul::matmul;
use faer::{Accum, Mat};
let (n, p) = a.dim();
assert_eq!(n, v.len(), "A rows must match v length");
if !should_use_faer_matmul(p, 1, n) {
return a.t().dot(v);
}
let mut result = Mat::<f64>::zeros(p, 1);
let aview = FaerArrayView::new(a);
let vview = FaerColView::new(v);
let a_ref = aview.as_ref();
let v_ref = vview.as_ref();
let par = matmul_parallelism(p, 1, n);
matmul(
result.as_mut(),
Accum::Replace,
a_ref.transpose(),
v_ref,
1.0,
par,
);
let mut out = Array1::<f64>::zeros(p);
for i in 0..p {
out[i] = result[(i, 0)];
}
out
}
#[inline]
pub fn fast_atv_into<S: Data<Elem = f64>>(
a: &ArrayBase<S, Ix2>,
v: &Array1<f64>,
out: &mut Array1<f64>,
) {
fast_atv_into_impl(a, v, out);
}
#[inline]
fn fast_atv_into_impl<S: Data<Elem = f64>>(
a: &ArrayBase<S, Ix2>,
v: &Array1<f64>,
out: &mut Array1<f64>,
) {
use faer::Accum;
use faer::linalg::matmul::matmul;
let (n, p) = a.dim();
assert_eq!(v.len(), n, "vector length must match A rows");
assert_eq!(out.len(), p, "output length must match A cols");
if !should_use_faer_matmul(p, 1, n) {
out.assign(&a.t().dot(v));
return;
}
let mut outview = array1_to_col_matmut(out);
let aview = FaerArrayView::new(a);
let vview = FaerColView::new(v);
let a_ref = aview.as_ref();
let v_ref = vview.as_ref();
let par = matmul_parallelism(p, 1, n);
matmul(
outview.as_mut(),
Accum::Replace,
a_ref.transpose(),
v_ref,
1.0,
par,
);
}
#[inline]
pub fn fast_xt_diag_x<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
x: &ArrayBase<S1, Ix2>,
w: &ArrayBase<S2, Ix1>,
) -> Array2<f64> {
assert_eq!(
x.nrows(),
w.len(),
"fast_xt_diag_x row/weight length mismatch"
);
if let Some(out) = crate::gpu::linalg::try_fast_xt_diag_x(x.view(), w.view()) {
return out;
}
let p = x.ncols();
fast_xt_diag_x_with_parallelism(x, w, matmul_parallelism(p, p, x.nrows()))
}
#[inline]
pub fn fast_xt_diag_x_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
x: &ArrayBase<S1, Ix2>,
w: &ArrayBase<S2, Ix1>,
par: Par,
) -> Array2<f64> {
assert_eq!(
x.nrows(),
w.len(),
"fast_xt_diag_x_with_parallelism row/weight length mismatch"
);
fast_xt_diag_x_with_parallelism_impl(x, w, par)
}
#[inline]
fn fast_xt_diag_x_with_parallelism_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
x: &ArrayBase<S1, Ix2>,
w: &ArrayBase<S2, Ix1>,
par: Par,
) -> Array2<f64> {
use faer::Accum;
use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
use ndarray::{ShapeBuilder, s};
let (n, p) = x.dim();
assert_eq!(n, w.len(), "X rows must match W length");
if n == 0 || p == 0 {
return Array2::<f64>::zeros((p, p));
}
if !should_use_faer_matmul(p, p, n) {
let w_x = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
return x.t().dot(&w_x);
}
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_ROWS: usize = 512;
const MAX_ROWS: usize = 131_072;
let chunk_rows = (TARGET_BYTES / (p.max(1) * 8))
.clamp(MIN_ROWS, MAX_ROWS)
.min(n);
let mut result = Array2::<f64>::zeros((p, p).f());
let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p));
let x_is_row_major = x.is_standard_layout();
let w_slice_opt = w.as_slice();
{
let mut out_view = array2_to_matmut(&mut result);
for start in (0..n).step_by(chunk_rows) {
let rows = (n - start).min(chunk_rows);
{
let chunk_slice = wx_chunk
.as_slice_mut()
.expect("row-major chunk is contiguous");
if x_is_row_major && let (Some(x_all), Some(w_all)) = (x.as_slice(), w_slice_opt) {
for local in 0..rows {
let src = start + local;
let wi = w_all[src];
let src_off = src * p;
let dst_off = local * p;
let src_row = &x_all[src_off..src_off + p];
let dst_row = &mut chunk_slice[dst_off..dst_off + p];
for col in 0..p {
dst_row[col] = src_row[col] * wi;
}
}
} else {
let x_slice = x.slice(s![start..start + rows, ..]);
for local in 0..rows {
let wi = w[start + local];
let xrow = x_slice.row(local);
let dst_off = local * p;
let dst_row = &mut chunk_slice[dst_off..dst_off + p];
for (col, xij) in xrow.iter().enumerate() {
dst_row[col] = xij * wi;
}
}
}
}
let x_slice = x.slice(s![start..start + rows, ..]);
let wx_slice = wx_chunk.slice(s![0..rows, ..]);
let x_view = FaerArrayView::new(&x_slice);
let wx_view = FaerArrayView::new(&wx_slice);
tri_matmul(
out_view.as_mut(),
BlockStructure::TriangularLower,
Accum::Add,
x_view.as_ref().transpose(),
BlockStructure::Rectangular,
wx_view.as_ref(),
BlockStructure::Rectangular,
1.0,
par,
);
}
}
for i in 0..p {
for j in (i + 1)..p {
result[[i, j]] = result[[j, i]];
}
}
result
}
#[inline]
pub fn fast_xt_diag_y<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
x: &ArrayBase<S1, Ix2>,
w: &ArrayBase<S2, Ix1>,
y: &ArrayBase<S3, Ix2>,
) -> Array2<f64> {
assert_eq!(x.nrows(), y.nrows(), "fast_xt_diag_y X/Y row mismatch");
assert_eq!(
y.nrows(),
w.len(),
"fast_xt_diag_y row/weight length mismatch"
);
if let Some(out) = crate::gpu::linalg::try_fast_xt_diag_y(x.view(), w.view(), y.view()) {
return out;
}
fast_xt_diag_y_impl(x, w, y)
}
#[inline]
fn fast_xt_diag_y_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
x: &ArrayBase<S1, Ix2>,
w: &ArrayBase<S2, Ix1>,
y: &ArrayBase<S3, Ix2>,
) -> Array2<f64> {
use faer::Accum;
use faer::linalg::matmul::matmul;
use ndarray::{ShapeBuilder, s};
let (n, q) = y.dim();
let px = x.ncols();
assert_eq!(n, w.len(), "Y rows must match W length");
assert_eq!(n, x.nrows(), "X rows must match Y rows");
if n == 0 || px == 0 || q == 0 {
return Array2::<f64>::zeros((px, q));
}
if !should_use_faer_matmul(px, q, n) {
let w_y = Array2::from_shape_fn((n, q), |(i, j)| w[i] * y[[i, j]]);
return x.t().dot(&w_y);
}
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_ROWS: usize = 512;
const MAX_ROWS: usize = 131_072;
let total_cols = px + q;
let chunk_rows = (TARGET_BYTES / (total_cols.max(1) * 8))
.clamp(MIN_ROWS, MAX_ROWS)
.min(n);
let mut result = Array2::<f64>::zeros((px, q).f());
let mut wy_chunk = Array2::<f64>::zeros((chunk_rows, q));
let y_is_row_major = y.is_standard_layout();
let w_slice_opt = w.as_slice();
{
let mut out_view = array2_to_matmut(&mut result);
for start in (0..n).step_by(chunk_rows) {
let rows = (n - start).min(chunk_rows);
{
let chunk_slice = wy_chunk
.as_slice_mut()
.expect("row-major chunk is contiguous");
if y_is_row_major && let (Some(y_all), Some(w_all)) = (y.as_slice(), w_slice_opt) {
for local in 0..rows {
let src = start + local;
let wi = w_all[src];
let src_off = src * q;
let dst_off = local * q;
let src_row = &y_all[src_off..src_off + q];
let dst_row = &mut chunk_slice[dst_off..dst_off + q];
for col in 0..q {
dst_row[col] = src_row[col] * wi;
}
}
} else {
let y_slice = y.slice(s![start..start + rows, ..]);
for local in 0..rows {
let wi = w[start + local];
let yrow = y_slice.row(local);
let dst_off = local * q;
let dst_row = &mut chunk_slice[dst_off..dst_off + q];
for (col, yij) in yrow.iter().enumerate() {
dst_row[col] = yij * wi;
}
}
}
}
let x_slice = x.slice(s![start..start + rows, ..]);
let wy_slice = wy_chunk.slice(s![0..rows, ..]);
let x_view = FaerArrayView::new(&x_slice);
let wy_view = FaerArrayView::new(&wy_slice);
let par = matmul_parallelism(px, q, rows);
matmul(
out_view.as_mut(),
Accum::Add,
x_view.as_ref().transpose(),
wy_view.as_ref(),
1.0,
par,
);
}
}
result
}
pub fn fast_joint_hessian_2x2<
S1: Data<Elem = f64>,
S2: Data<Elem = f64>,
S3: Data<Elem = f64>,
S4: Data<Elem = f64>,
S5: Data<Elem = f64>,
>(
x_a: &ArrayBase<S1, Ix2>,
x_b: &ArrayBase<S2, Ix2>,
w_aa: &ArrayBase<S3, Ix1>,
w_ab: &ArrayBase<S4, Ix1>,
w_bb: &ArrayBase<S5, Ix1>,
) -> Array2<f64> {
if let Some(out) = crate::gpu::linalg::try_fast_joint_hessian_2x2(
x_a.view(),
x_b.view(),
w_aa.view(),
w_ab.view(),
w_bb.view(),
) {
return out;
}
fast_joint_hessian_2x2_impl(x_a, x_b, w_aa, w_ab, w_bb)
}
#[inline]
fn fast_joint_hessian_2x2_impl<
S1: Data<Elem = f64>,
S2: Data<Elem = f64>,
S3: Data<Elem = f64>,
S4: Data<Elem = f64>,
S5: Data<Elem = f64>,
>(
x_a: &ArrayBase<S1, Ix2>,
x_b: &ArrayBase<S2, Ix2>,
w_aa: &ArrayBase<S3, Ix1>,
w_ab: &ArrayBase<S4, Ix1>,
w_bb: &ArrayBase<S5, Ix1>,
) -> Array2<f64> {
use faer::Accum;
use faer::linalg::matmul::matmul;
use ndarray::{ShapeBuilder, s};
let n = x_a.nrows();
let pa = x_a.ncols();
let pb = x_b.ncols();
let total = pa + pb;
assert_eq!(n, x_b.nrows());
assert_eq!(n, w_aa.len());
assert_eq!(n, w_ab.len());
assert_eq!(n, w_bb.len());
if n == 0 || total == 0 {
return Array2::<f64>::zeros((total, total));
}
if !should_use_faer_matmul(pa.max(pb), pa.max(pb), n) {
let waa_xa = Array2::from_shape_fn((n, pa), |(i, j)| w_aa[i] * x_a[[i, j]]);
let wab_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_ab[i] * x_b[[i, j]]);
let wbb_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_bb[i] * x_b[[i, j]]);
let mut out = Array2::<f64>::zeros((total, total));
out.slice_mut(s![..pa, ..pa]).assign(&x_a.t().dot(&waa_xa));
out.slice_mut(s![..pa, pa..]).assign(&x_a.t().dot(&wab_xb));
out.slice_mut(s![pa.., pa..]).assign(&x_b.t().dot(&wbb_xb));
for i in 0..total {
for j in 0..i {
out[[i, j]] = out[[j, i]];
}
}
return out;
}
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_ROWS: usize = 512;
const MAX_ROWS: usize = 131_072;
let cols_needed = pa + 2 * pb;
let chunk_rows = (TARGET_BYTES / (cols_needed.max(1) * 8))
.clamp(MIN_ROWS, MAX_ROWS)
.min(n);
let mut out = Array2::<f64>::zeros((total, total).f());
let mut waa_xa_buf = Array2::<f64>::zeros((chunk_rows, pa));
let mut wab_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
let mut wbb_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
let xa_is_row_major = x_a.is_standard_layout();
let xb_is_row_major = x_b.is_standard_layout();
let waa_slice_opt = w_aa.as_slice();
let wab_slice_opt = w_ab.as_slice();
let wbb_slice_opt = w_bb.as_slice();
{
let mut out_mat = array2_to_matmut(&mut out);
for start in (0..n).step_by(chunk_rows) {
let rows = (n - start).min(chunk_rows);
let xa_slice = x_a.slice(s![start..start + rows, ..]);
let xb_slice = x_b.slice(s![start..start + rows, ..]);
{
let waa_chunk = waa_xa_buf
.as_slice_mut()
.expect("row-major waa chunk is contiguous");
let wab_chunk = wab_xb_buf
.as_slice_mut()
.expect("row-major wab chunk is contiguous");
let wbb_chunk = wbb_xb_buf
.as_slice_mut()
.expect("row-major wbb chunk is contiguous");
if xa_is_row_major
&& xb_is_row_major
&& let (Some(xa_all), Some(xb_all)) = (x_a.as_slice(), x_b.as_slice())
&& let (Some(waa_all), Some(wab_all), Some(wbb_all)) =
(waa_slice_opt, wab_slice_opt, wbb_slice_opt)
{
for local in 0..rows {
let i = start + local;
let waa_i = waa_all[i];
let wab_i = wab_all[i];
let wbb_i = wbb_all[i];
let xa_off = i * pa;
let xa_row = &xa_all[xa_off..xa_off + pa];
let xb_off = i * pb;
let xb_row = &xb_all[xb_off..xb_off + pb];
let waa_off = local * pa;
let wab_off = local * pb;
let wbb_off = local * pb;
let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
for col in 0..pa {
waa_row[col] = xa_row[col] * waa_i;
}
let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
for col in 0..pb {
let xij = xb_row[col];
wab_row[col] = xij * wab_i;
wbb_row[col] = xij * wbb_i;
}
}
} else {
for local in 0..rows {
let i = start + local;
let waa_i = w_aa[i];
let wab_i = w_ab[i];
let wbb_i = w_bb[i];
let waa_off = local * pa;
let wab_off = local * pb;
let wbb_off = local * pb;
let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
let xa_row = xa_slice.row(local);
for (col, xij) in xa_row.iter().enumerate() {
waa_row[col] = xij * waa_i;
}
let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
let xb_row = xb_slice.row(local);
for (col, xij) in xb_row.iter().enumerate() {
wab_row[col] = xij * wab_i;
wbb_row[col] = xij * wbb_i;
}
}
}
}
let xa_view = FaerArrayView::new(&xa_slice);
let xb_view = FaerArrayView::new(&xb_slice);
let waa_xa_slice = waa_xa_buf.slice(s![0..rows, ..]);
let wab_xb_slice = wab_xb_buf.slice(s![0..rows, ..]);
let wbb_xb_slice = wbb_xb_buf.slice(s![0..rows, ..]);
let waa_xa_view = FaerArrayView::new(&waa_xa_slice);
let wab_xb_view = FaerArrayView::new(&wab_xb_slice);
let wbb_xb_view = FaerArrayView::new(&wbb_xb_slice);
matmul(
out_mat.rb_mut().submatrix_mut(0, 0, pa, pa),
Accum::Add,
xa_view.as_ref().transpose(),
waa_xa_view.as_ref(),
1.0,
matmul_parallelism(pa, pa, rows),
);
matmul(
out_mat.rb_mut().submatrix_mut(0, pa, pa, pb),
Accum::Add,
xa_view.as_ref().transpose(),
wab_xb_view.as_ref(),
1.0,
matmul_parallelism(pa, pb, rows),
);
matmul(
out_mat.rb_mut().submatrix_mut(pa, pa, pb, pb),
Accum::Add,
xb_view.as_ref().transpose(),
wbb_xb_view.as_ref(),
1.0,
matmul_parallelism(pb, pb, rows),
);
}
} for i in 0..total {
for j in 0..i {
out[[i, j]] = out[[j, i]];
}
}
out
}
fn mat_to_array(mat: MatRef<'_, f64>) -> Array2<f64> {
let nrows = mat.nrows();
let ncols = mat.ncols();
let mut out = Array2::<f64>::zeros((nrows, ncols));
if nrows == 0 || ncols == 0 {
return out;
}
if let Some(out_slice) = out.as_slice_memory_order_mut() {
for i in 0..nrows {
let row_start = i * ncols;
for j in 0..ncols {
out_slice[row_start + j] = mat[(i, j)];
}
}
} else {
for j in 0..ncols {
for i in 0..nrows {
out[[i, j]] = mat[(i, j)];
}
}
}
out
}
#[inline]
pub fn fast_ab_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix2>,
out: &mut Array2<f64>,
) {
fast_ab_into_impl(a, b, out);
}
#[inline]
fn fast_ab_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix2>,
out: &mut Array2<f64>,
) {
use faer::Accum;
use faer::linalg::matmul::matmul;
let (n, p) = a.dim();
let (p_b, q) = b.dim();
assert_eq!(p, p_b, "A and B must have compatible inner dimensions");
assert_eq!(out.dim(), (n, q), "output dimensions must match A*B result");
if !should_use_faer_matmul(n, q, p) {
out.assign(&a.dot(b));
return;
}
let aview = FaerArrayView::new(a);
let bview = FaerArrayView::new(b);
let a_ref = aview.as_ref();
let b_ref = bview.as_ref();
let par = matmul_parallelism(n, q, p);
let mut outview = array2_to_matmut(out);
matmul(outview.as_mut(), Accum::Replace, a_ref, b_ref, 1.0, par);
}
fn diag_to_array(diag: DiagRef<'_, f64>) -> Array1<f64> {
let mat = diag.column_vector().as_mat();
let mut out = Array1::<f64>::zeros(mat.nrows());
for i in 0..mat.nrows() {
out[i] = mat[(i, 0)];
}
out
}
pub struct FaerArrayView<'a> {
ptr: *const f64,
rows: usize,
cols: usize,
row_stride: isize,
col_stride: isize,
owned: Option<Array2<f64>>,
marker: PhantomData<&'a f64>,
}
impl<'a> FaerArrayView<'a> {
#[inline]
pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix2>) -> Self {
let (rows, cols) = array.dim();
let strides = array.strides();
if strides[0] <= 0 || strides[1] <= 0 {
let owned = array.to_owned();
let owned_strides = owned.strides();
return Self {
ptr: owned.as_ptr(),
rows,
cols,
row_stride: owned_strides[0],
col_stride: owned_strides[1],
owned: Some(owned),
marker: PhantomData,
};
}
Self {
ptr: array.as_ptr(),
rows,
cols,
row_stride: strides[0],
col_stride: strides[1],
owned: None,
marker: PhantomData,
}
}
#[inline]
pub fn as_ref(&self) -> MatRef<'_, f64> {
let (ptr, rows, cols, row_stride, col_stride) = if let Some(owned) = &self.owned {
let strides = owned.strides();
(
owned.as_ptr(),
owned.nrows(),
owned.ncols(),
strides[0],
strides[1],
)
} else {
(
self.ptr,
self.rows,
self.cols,
self.row_stride,
self.col_stride,
)
};
unsafe { MatRef::from_raw_parts(ptr, rows, cols, row_stride, col_stride) }
}
}
pub struct FaerColView<'a> {
ptr: *const f64,
len: usize,
stride: isize,
owned: Option<Array1<f64>>,
marker: PhantomData<&'a f64>,
}
impl<'a> FaerColView<'a> {
#[inline]
pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix1>) -> Self {
let len = array.len();
let stride = array.strides()[0];
if stride <= 0 {
let owned = array.to_owned();
return Self {
ptr: owned.as_ptr(),
len,
stride: 1,
owned: Some(owned),
marker: PhantomData,
};
}
Self {
ptr: array.as_ptr(),
len,
stride,
owned: None,
marker: PhantomData,
}
}
#[inline]
pub fn as_ref(&self) -> MatRef<'_, f64> {
let (ptr, len, stride) = if let Some(owned) = &self.owned {
(owned.as_ptr(), owned.len(), 1)
} else {
(self.ptr, self.len, self.stride)
};
unsafe { MatRef::from_raw_parts(ptr, len, 1, stride, 0) }
}
}
pub trait FaerSvd {
fn svd(
&self,
compute_u: bool,
computevt: bool,
) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError>;
}
impl<S: Data<Elem = f64>> FaerSvd for ArrayBase<S, Ix2> {
fn svd(
&self,
compute_u: bool,
computevt: bool,
) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError> {
let faerview = FaerArrayView::new(self);
let faer_mat = faerview.as_ref();
if !compute_u && !computevt {
let (rows, cols) = faer_mat.shape();
let mut singular = Diag::<f64>::zeros(rows.min(cols));
let par = get_global_parallelism();
let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
rows,
cols,
ComputeSvdVectors::No,
ComputeSvdVectors::No,
par,
Default::default(),
));
let stack = MemStack::new(&mut mem);
svd::svd(
faer_mat,
singular.as_mut(),
None,
None,
par,
stack,
Default::default(),
)
.map_err(|_| FaerLinalgError::SvdNoConvergence {
context: "faer SVD singular values only",
})?;
let singularvalues = diag_to_array(singular.as_ref());
return Ok((None, singularvalues, None));
}
let (rows, cols) = faer_mat.shape();
let rank = rows.min(cols);
let compute_u_flag = if compute_u {
ComputeSvdVectors::Thin
} else {
ComputeSvdVectors::No
};
let computev_flag = if computevt {
ComputeSvdVectors::Thin
} else {
ComputeSvdVectors::No
};
let mut singular = Diag::<f64>::zeros(rows.min(cols));
let mut u_storage = compute_u.then(|| Mat::<f64>::zeros(rows, rank));
let mut v_storage = computevt.then(|| Mat::<f64>::zeros(cols, rank));
let par = get_global_parallelism();
let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
rows,
cols,
compute_u_flag,
computev_flag,
par,
Default::default(),
));
let stack = MemStack::new(&mut mem);
svd::svd(
faer_mat.as_ref(),
singular.as_mut(),
u_storage.as_mut().map(|mat| mat.as_mut()),
v_storage.as_mut().map(|mat| mat.as_mut()),
par,
stack,
Default::default(),
)
.map_err(|_| FaerLinalgError::SvdNoConvergence {
context: "faer SVD with vectors",
})?;
let singularvalues = diag_to_array(singular.as_ref());
let u_opt = u_storage.map(|mat| mat_to_array(mat.as_ref()));
let vt_opt = v_storage.map(|mat| {
let mat_ref = mat.as_ref();
let mut out = Array2::<f64>::zeros((mat_ref.ncols(), mat_ref.nrows()));
for j in 0..mat_ref.nrows() {
for i in 0..mat_ref.ncols() {
out[[i, j]] = mat_ref[(j, i)];
}
}
out
});
Ok((u_opt, singularvalues, vt_opt))
}
}
pub trait FaerEigh {
fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError>;
}
impl<S: Data<Elem = f64>> FaerEigh for ArrayBase<S, Ix2> {
fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
fn try_eigh(
matrix: &Array2<f64>,
side: Side,
) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
let faerview = FaerArrayView::new(matrix);
let eigen = catch_unwind(AssertUnwindSafe(|| {
faerview.as_ref().self_adjoint_eigen(side)
}))
.map_err(|_| FaerLinalgError::FactorizationFailed {
context: "self-adjoint eigendecomposition panic boundary",
})?
.map_err(FaerLinalgError::SelfAdjointEigen)?;
let values = diag_to_array(eigen.S());
let vectors = mat_to_array(eigen.U());
Ok((values, vectors))
}
let owned = self.to_owned();
if owned.nrows() != owned.ncols() {
return Err(FaerLinalgError::FactorizationFailed {
context: "self-adjoint eigendecomposition non-square input",
});
}
if owned.nrows() == 0 {
return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
}
if owned.iter().any(|value| !value.is_finite()) {
return Err(FaerLinalgError::SelfAdjointEigenNonFiniteInput {
context: "self-adjoint eigendecomposition input validation",
});
}
if let Ok((evals, evecs)) = try_eigh(&owned, side)
&& evals.iter().all(|value| value.is_finite())
&& evecs.iter().all(|value| value.is_finite())
{
return Ok((evals, evecs));
}
let mut repaired = owned.clone();
let n = repaired.nrows();
for i in 0..n {
for j in (i + 1)..n {
let avg = 0.5 * (repaired[[i, j]] + repaired[[j, i]]);
repaired[[i, j]] = avg;
repaired[[j, i]] = avg;
}
}
let scale = repaired
.iter()
.fold(0.0_f64, |acc, &value| acc.max(value.abs()))
.max(1.0);
let scaled = repaired.mapv(|value| value / scale);
let jitter_schedule = [0.0_f64, 1e-12, 1e-10, 1e-8, 1e-6, 1e-4];
let mut last_error = FaerLinalgError::FactorizationFailed {
context: "self-adjoint eigendecomposition repair attempts",
};
for &jitter in &jitter_schedule {
let mut candidate = scaled.clone();
if jitter > 0.0 {
for i in 0..n {
candidate[[i, i]] += jitter;
}
}
match try_eigh(&candidate, side) {
Ok((mut evals, evecs))
if evals.iter().all(|value| value.is_finite())
&& evecs.iter().all(|value| value.is_finite()) =>
{
for value in &mut evals {
*value = (*value - jitter) * scale;
}
return Ok((evals, evecs));
}
Ok((_, _)) => {
last_error = FaerLinalgError::SelfAdjointEigenNonFiniteInput {
context: "self-adjoint eigendecomposition repaired output validation",
};
}
Err(err) => {
last_error = err;
}
}
}
Err(last_error)
}
}
pub struct FaerCholeskyFactor {
factor: solvers::Llt<f64>,
}
impl FaerCholeskyFactor {
pub fn solvevec(&self, rhs: &Array1<f64>) -> Array1<f64> {
let mut rhs = rhs.to_owned();
let mut rhsview = array1_to_col_matmut(&mut rhs);
self.factor.solve_in_place(rhsview.as_mut());
rhs
}
pub fn solve_mat_in_place(&self, rhs: &mut Array2<f64>) {
let mut rhsview = array2_to_matmut(rhs);
self.factor.solve_in_place(rhsview.as_mut());
}
pub fn solve_mat_into<S: Data<Elem = f64>>(
&self,
rhs: &ArrayBase<S, Ix2>,
out: &mut Array2<f64>,
) {
if out.dim() != rhs.dim() {
*out = Array2::<f64>::zeros(rhs.dim());
}
out.assign(rhs);
self.solve_mat_in_place(out);
}
pub fn solve_mat(&self, rhs: &Array2<f64>) -> Array2<f64> {
let mut out = Array2::<f64>::zeros(rhs.dim());
self.solve_mat_into(rhs, &mut out);
out
}
pub fn diag(&self) -> Array1<f64> {
diag_to_array(self.factor.L().diagonal())
}
pub fn lower_triangular(&self) -> Array2<f64> {
mat_to_array(self.factor.L())
}
}
pub trait FaerCholesky {
fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError>;
}
impl<S: Data<Elem = f64>> FaerCholesky for ArrayBase<S, Ix2> {
fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError> {
let faerview = FaerArrayView::new(self);
let factor = faerview
.as_ref()
.llt(side)
.map_err(FaerLinalgError::Cholesky)?;
Ok(FaerCholeskyFactor { factor })
}
}
pub trait FaerQr {
fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError>;
}
impl<S: Data<Elem = f64>> FaerQr for ArrayBase<S, Ix2> {
fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError> {
let faerview = FaerArrayView::new(self);
let qr = faerview.as_ref().qr();
let q = qr.compute_thin_Q();
let r = qr.thin_R();
Ok((mat_to_array(q.as_ref()), mat_to_array(r)))
}
}
pub fn rrqr_nullspace_basis<S: Data<Elem = f64>>(
a: &ArrayBase<S, Ix2>,
rank_alpha: f64,
) -> Result<(Array2<f64>, usize), FaerLinalgError> {
let faerview = FaerArrayView::new(a);
let qr = faerview.as_ref().col_piv_qr();
let r = qr.thin_R();
let diag_len = r.nrows().min(r.ncols());
let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
let tol = rank_alpha
* f64::EPSILON
* (a.nrows().max(a.ncols()).max(1) as f64)
* leading_diag.max(1.0);
let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
let z = if rank >= a.nrows() {
Array2::<f64>::zeros((a.nrows(), 0))
} else {
let nullity = a.nrows() - rank;
let mut selector = Mat::<f64>::zeros(a.nrows(), nullity);
for j in 0..nullity {
selector[(rank + j, j)] = 1.0;
}
let par = get_global_parallelism();
faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
qr.Q_basis(),
qr.Q_coeff(),
Conj::No,
selector.as_mut(),
par,
MemStack::new(&mut MemBuffer::new(
faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<f64>(
a.nrows(),
qr.Q_coeff().nrows(),
nullity,
),
)),
);
mat_to_array(selector.as_ref())
};
Ok((z, rank))
}
#[inline]
pub const fn default_rrqr_rank_alpha() -> f64 {
RRQR_RANK_ALPHA
}
pub struct RrqrWithPermutation {
pub rank: usize,
pub column_permutation: Vec<usize>,
pub leading_diag_abs: f64,
pub rank_tol: f64,
}
pub fn rrqr_with_permutation<S: Data<Elem = f64>>(
a: &ArrayBase<S, Ix2>,
rank_alpha: f64,
) -> Result<RrqrWithPermutation, FaerLinalgError> {
if a.nrows() == 0 {
return Err(FaerLinalgError::FactorizationFailed {
context: "rrqr_with_permutation: input has zero rows",
});
}
let faerview = FaerArrayView::new(a);
let qr = faerview.as_ref().col_piv_qr();
let r = qr.thin_R();
let diag_len = r.nrows().min(r.ncols());
let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
let tol = rank_alpha
* f64::EPSILON
* (a.nrows().max(a.ncols()).max(1) as f64)
* leading_diag.max(1.0);
let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
let (forward, _inverse) = qr.P().arrays();
let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
Ok(RrqrWithPermutation {
rank,
column_permutation,
leading_diag_abs: leading_diag,
rank_tol: tol,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{array, s};
#[test]
fn rrqr_nullspace_basis_is_orthonormal_and_annihilates_transpose() {
let a = array![[1.0, 0.0], [1.0, 0.0], [0.0, 2.0], [0.0, 0.0],];
let (z, rank) =
rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
assert_eq!(rank, 2);
assert_eq!(z.nrows(), 4);
assert_eq!(z.ncols(), 2);
let gram = z.t().dot(&z);
let ident = Array2::<f64>::eye(z.ncols());
let gram_err = (&gram - &ident)
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
let residual = a.t().dot(&z);
let resid_max = residual.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(resid_max < 1e-10, "A^T Z residual too large: {resid_max:e}");
}
#[test]
fn rrqr_with_permutation_attributes_redundant_column() {
let a = array![
[1.0, 0.0, 1.0],
[1.0, 0.0, 1.0],
[0.0, 2.0, 0.0],
[0.0, 0.0, 0.0],
];
let result =
rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
assert_eq!(result.rank, 2);
assert_eq!(result.column_permutation.len(), 3);
let demoted = result.column_permutation[result.rank..].to_vec();
assert!(
demoted.contains(&2) || demoted.contains(&0),
"demoted suffix should include one of the aliased columns (0 or 2), got {demoted:?}"
);
let mut sorted = result.column_permutation.clone();
sorted.sort();
assert_eq!(
sorted,
vec![0, 1, 2],
"permutation must be a valid bijection on 0..n"
);
}
#[test]
fn rrqr_with_permutation_full_rank_returns_identity_like_order() {
let a = array![[1.0, 0.0], [0.0, 2.0], [0.0, 0.0]];
let result =
rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
assert_eq!(result.rank, 2);
let mut sorted = result.column_permutation.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1]);
}
#[test]
fn rrqr_with_permutation_rejects_zero_rows() {
let a = Array2::<f64>::zeros((0, 3));
assert!(rrqr_with_permutation(&a, default_rrqr_rank_alpha()).is_err());
}
#[test]
fn rrqr_nullspace_basis_detectszero_rank_matrix() {
let a = Array2::<f64>::zeros((5, 2));
let (z, rank) =
rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
assert_eq!(rank, 0);
assert_eq!(z.dim(), (5, 5));
let ident = Array2::<f64>::eye(5);
let max_err = (&z.slice(s![.., ..5]).to_owned() - &ident)
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(max_err < 1e-10, "zero matrix should yield identity basis");
}
#[test]
fn eigh_on_nan_matrix_rejects_non_finite_input() {
let mat = array![
[1.0, 0.0, 0.0, 0.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 3.0, f64::NAN],
[0.0, 0.0, f64::NAN, 4.0]
];
let err = mat
.eigh(Side::Lower)
.expect_err("non-finite symmetric input must be rejected");
assert!(matches!(
err,
FaerLinalgError::SelfAdjointEigenNonFiniteInput { .. }
));
}
#[test]
fn fast_ata_matches_full_gemm_above_threshold() {
let n = 200;
let p = 40;
let a: Array2<f64> = Array2::from_shape_fn((n, p), |(i, j)| {
((i * 7 + j * 3) as f64).sin() + 0.1 * j as f64
});
let expected = a.t().dot(&a);
let got = fast_ata(&a);
let max_err = (&got - &expected)
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(max_err < 1e-10, "fast_ata mismatch: {max_err:e}");
for i in 0..p {
for j in 0..p {
assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
}
}
}
#[test]
fn fast_xt_diag_x_matches_naive_above_threshold() {
let n = 400;
let p = 36;
let x: Array2<f64> =
Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.1).cos() + j as f64 * 0.05);
let w: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64 * 0.03).sin());
let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
let expected = x.t().dot(&wx);
let got = fast_xt_diag_x(&x, &w);
let max_err = (&got - &expected)
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()));
assert!(max_err < 1e-9, "fast_xt_diag_x mismatch: {max_err:e}");
for i in 0..p {
for j in 0..p {
assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
}
}
}
#[test]
fn eigh_succeeds_on_same_structure_without_nan() {
let mat = array![[1.0, 0.5, 0.1], [0.5, 2.0, 0.3], [0.1, 0.3, 1.5]];
let (evals, _) = mat
.eigh(Side::Lower)
.expect("eigh should succeed on a well-conditioned finite matrix");
assert!(
evals.iter().all(|&v| v.is_finite()),
"all eigenvalues should be finite"
);
}
}