1use dyn_stack::{MemBuffer, MemStack};
2use faer::diag::{Diag, DiagRef};
3use faer::linalg::solvers::{self, Solve};
4pub use faer::linalg::solvers::{
5 Lblt as FaerLblt, Ldlt as FaerLdlt, Llt as FaerLlt, Solve as FaerSolve,
6};
7use faer::linalg::svd::{self, ComputeSvdVectors};
8use faer::prelude::ReborrowMut;
9use faer::{Conj, Mat, MatMut, MatRef, Par, Side, Unbind, get_global_parallelism};
10use ndarray::{Array1, Array2, ArrayBase, ArrayViewMut1, Data, Ix1, Ix2};
11use std::marker::PhantomData;
12use std::panic::{AssertUnwindSafe, catch_unwind};
13use thiserror::Error;
14
15const RRQR_RANK_ALPHA: f64 = 100.0;
16
17thread_local! {
18 static NESTED_PARALLEL_DEPTH: std::cell::Cell<usize> = const { std::cell::Cell::new(0) };
19}
20
21struct NestedParallelGuard;
22
23impl NestedParallelGuard {
24 #[inline]
25 fn enter() -> Self {
26 NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_add(1)));
27 Self
28 }
29}
30
31impl Drop for NestedParallelGuard {
32 #[inline]
33 fn drop(&mut self) {
34 NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_sub(1)));
35 }
36}
37
38#[inline]
48pub fn with_nested_parallel<T>(body: impl FnOnce() -> T) -> T {
49 let guard = NestedParallelGuard::enter();
50 let out = body();
51 drop(guard);
52 out
53}
54
55#[inline]
58pub fn in_nested_parallel_region() -> bool {
59 NESTED_PARALLEL_DEPTH.with(|depth| depth.get() > 0)
60}
61
62#[inline]
70pub fn effective_global_parallelism() -> Par {
71 if in_nested_parallel_region() {
72 Par::Seq
73 } else {
74 get_global_parallelism()
75 }
76}
77
78#[derive(Debug, Error)]
79pub enum FaerLinalgError {
80 #[error("Factorization failed in {context}")]
81 FactorizationFailed { context: &'static str },
82 #[error("SVD failed to converge in {context}")]
83 SvdNoConvergence { context: &'static str },
84 #[error("Self-adjoint eigendecomposition input contains non-finite values in {context}")]
85 SelfAdjointEigenNonFiniteInput { context: &'static str },
86 #[error("Self-adjoint eigendecomposition failed: {0:?}")]
87 SelfAdjointEigen(solvers::EvdError),
88 #[error("Cholesky factorization failed: {0:?}")]
89 Cholesky(solvers::LltError),
90 #[error("LDLT factorization failed: {0:?}")]
91 Ldlt(solvers::LdltError),
92}
93
94pub enum FaerSymmetricFactor {
95 Llt(FaerLlt<f64>),
96 Ldlt(FaerLdlt<f64>),
97 Lblt(FaerLblt<f64>),
98}
99
100#[inline]
101pub fn cholesky_factor_logdet(factor: MatRef<'_, f64>) -> f64 {
102 2.0 * diagonal_log_sum(factor.diagonal())
103}
104
105#[inline]
106fn diagonal_log_sum(diagonal: DiagRef<'_, f64>) -> f64 {
107 diagonal
108 .column_vector()
109 .iter()
110 .map(|&x| x.ln())
111 .sum::<f64>()
112}
113
114impl FaerSymmetricFactor {
115 #[inline]
117 pub fn n(&self) -> usize {
118 use faer::linalg::solvers::ShapeCore;
119 match self {
120 FaerSymmetricFactor::Llt(f) => f.nrows(),
121 FaerSymmetricFactor::Ldlt(f) => f.nrows(),
122 FaerSymmetricFactor::Lblt(f) => f.nrows(),
123 }
124 }
125
126 #[inline]
127 pub fn solve(&self, rhs: MatRef<'_, f64>) -> Mat<f64> {
128 match self {
129 FaerSymmetricFactor::Llt(f) => f.solve(rhs),
130 FaerSymmetricFactor::Ldlt(f) => f.solve(rhs),
131 FaerSymmetricFactor::Lblt(f) => f.solve(rhs),
132 }
133 }
134
135 #[inline]
136 pub fn solve_in_place(&self, rhs: MatMut<'_, f64>) {
137 match self {
138 FaerSymmetricFactor::Llt(f) => f.solve_in_place(rhs),
139 FaerSymmetricFactor::Ldlt(f) => f.solve_in_place(rhs),
140 FaerSymmetricFactor::Lblt(f) => f.solve_in_place(rhs),
141 }
142 }
143}
144
145impl crate::matrix::FactorizedSystem for FaerSymmetricFactor {
146 fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
147 let mut out = rhs.clone();
148 let mut out_mat = array1_to_col_matmut(&mut out);
149 self.solve_in_place(out_mat.as_mut());
150 if !out.iter().all(|v| v.is_finite()) {
151 return Err("symmetric factor solve produced non-finite values".to_string());
152 }
153 Ok(out)
154 }
155
156 fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
157 let mut out = Array2::<f64>::zeros(rhs.raw_dim());
158 for j in 0..rhs.ncols() {
159 for i in 0..rhs.nrows() {
160 out[[i, j]] = rhs[[i, j]];
161 }
162 }
163 let mut out_mat = array2_to_matmut(&mut out);
164 self.solve_in_place(out_mat.as_mut());
165 if !out.iter().all(|v| v.is_finite()) {
166 return Err("symmetric factor multi-solve produced non-finite values".to_string());
167 }
168 Ok(out)
169 }
170
171 fn logdet(&self) -> f64 {
172 match self {
173 FaerSymmetricFactor::Llt(f) => cholesky_factor_logdet(f.L()),
174 FaerSymmetricFactor::Ldlt(f) => diagonal_log_sum(f.D()),
175 FaerSymmetricFactor::Lblt(..) => {
176 f64::NAN
180 }
181 }
182 }
183}
184
185#[inline]
187pub fn factorize_symmetricwith_fallback(
188 matrix: MatRef<'_, f64>,
189 side: Side,
190) -> Result<FaerSymmetricFactor, FaerLinalgError> {
191 if let Ok(llt) = FaerLlt::new(matrix, side) {
192 return Ok(FaerSymmetricFactor::Llt(llt));
193 }
194 let ldlt_err = match FaerLdlt::new(matrix, side) {
195 Ok(ldlt) => return Ok(FaerSymmetricFactor::Ldlt(ldlt)),
196 Err(err) => err,
197 };
198 let lblt = catch_unwind(AssertUnwindSafe(|| FaerLblt::new(matrix, side)))
199 .map_err(|_| FaerLinalgError::Ldlt(ldlt_err))?;
200 Ok(FaerSymmetricFactor::Lblt(lblt))
201}
202
203#[inline]
204const fn should_use_faer_matmul(m: usize, n: usize, k: usize) -> bool {
205 const MIN_DIM: usize = 32;
209 const MIN_FLOP_SCALE: usize = 64 * 64;
210 (m >= MIN_DIM || n >= MIN_DIM || k >= MIN_DIM)
211 && m.saturating_mul(n).saturating_mul(k) >= MIN_FLOP_SCALE
212}
213
214#[inline]
215pub fn matmul_parallelism(m: usize, n: usize, k: usize) -> Par {
216 const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
220 const PAR_MIN_LONG_DIM: usize = 256;
221 let flop_scale = m.saturating_mul(n).saturating_mul(k);
222 let long_dim = m.max(n).max(k);
223 if flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM {
224 effective_global_parallelism()
228 } else {
229 Par::Seq
230 }
231}
232
233#[inline]
234pub fn array2_to_matmut(array: &mut Array2<f64>) -> MatMut<'_, f64> {
235 let (rows, cols) = array.dim();
236 let strides = array.strides();
237
238 let s0 = strides[0];
245 let s1 = strides[1];
246
247 unsafe { MatMut::from_raw_parts_mut(array.as_mut_ptr(), rows, cols, s0, s1) }
251}
252
253#[inline]
254pub fn array1_to_col_matmut(array: &mut Array1<f64>) -> MatMut<'_, f64> {
255 let len = array.len();
256 let stride = array.strides()[0];
257 unsafe {
261 MatMut::from_raw_parts_mut(
262 array.as_mut_ptr(),
263 len,
264 1,
265 stride,
266 0, )
268 }
269}
270
271#[inline]
278pub fn fast_ata<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>) -> Array2<f64> {
279 let p = a.ncols();
280 let mut out = Array2::<f64>::zeros((p, p));
281 fast_ata_into(a, &mut out);
282 out
283}
284
285#[inline]
288pub fn fast_ata_into<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>, out: &mut Array2<f64>) {
289 use faer::Accum;
290 use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
291
292 let (n, p) = a.dim();
293 assert_eq!(out.nrows(), p, "output rows must match p");
294 assert_eq!(out.ncols(), p, "output cols must match p");
295
296 if !should_use_faer_matmul(p, p, n) {
297 out.assign(&a.t().dot(a));
298 return;
299 }
300
301 let mut outview = array2_to_matmut(out);
302
303 let aview = FaerArrayView::new(a);
304 let a_ref = aview.as_ref();
305 let a_t = a_ref.transpose();
306 let par = matmul_parallelism(p, p, n);
307 tri_matmul(
308 outview.as_mut(),
309 BlockStructure::TriangularLower,
310 Accum::Replace,
311 a_t,
312 BlockStructure::Rectangular,
313 a_ref,
314 BlockStructure::Rectangular,
315 1.0,
316 par,
317 );
318 for i in 0..p {
320 for j in (i + 1)..p {
321 out[[i, j]] = out[[j, i]];
322 }
323 }
324}
325
326#[inline]
330pub fn fast_atb<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
331 a: &ArrayBase<S1, Ix2>,
332 b: &ArrayBase<S2, Ix2>,
333) -> Array2<f64> {
334 if let Some(out) =
335 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atb(a.view(), b.view()))
336 {
337 return out;
338 }
339 let (n_a, p) = a.dim();
340 let q = b.ncols();
341 fast_atb_with_parallelism(a, b, matmul_parallelism(p, q, n_a))
342}
343
344#[inline]
347pub fn fast_atb_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
348 a: &ArrayBase<S1, Ix2>,
349 b: &ArrayBase<S2, Ix2>,
350 par: Par,
351) -> Array2<f64> {
352 use faer::linalg::matmul::matmul;
353 use faer::{Accum, Mat};
354
355 let (n_a, p) = a.dim();
356 let (n_b, q) = b.dim();
357 assert_eq!(n_a, n_b, "A and B must have same number of rows");
358
359 if !should_use_faer_matmul(p, q, n_a) {
361 return a.t().dot(b);
362 }
363
364 let mut result = Mat::<f64>::zeros(p, q);
365
366 let aview = FaerArrayView::new(a);
367 let bview = FaerArrayView::new(b);
368 let a_ref = aview.as_ref();
369 let b_ref = bview.as_ref();
370
371 matmul(
373 result.as_mut(),
374 Accum::Replace,
375 a_ref.transpose(),
376 b_ref,
377 1.0,
378 par,
379 );
380
381 mat_to_array(result.as_ref())
382}
383
384#[inline]
387pub fn fast_abt<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
388 a: &ArrayBase<S1, Ix2>,
389 b: &ArrayBase<S2, Ix2>,
390) -> Array2<f64> {
391 use faer::linalg::matmul::matmul;
392 use faer::{Accum, Mat};
393
394 let (m, k_a) = a.dim();
395 let (n, k_b) = b.dim();
396 assert_eq!(
397 k_a, k_b,
398 "A and B must have same number of columns for A·Bᵀ"
399 );
400
401 if !should_use_faer_matmul(m, n, k_a) {
402 return a.dot(&b.t());
403 }
404
405 let mut result = Mat::<f64>::zeros(m, n);
406 let aview = FaerArrayView::new(a);
407 let bview = FaerArrayView::new(b);
408 let par = matmul_parallelism(m, n, k_a);
409 matmul(
410 result.as_mut(),
411 Accum::Replace,
412 aview.as_ref(),
413 bview.as_ref().transpose(),
414 1.0,
415 par,
416 );
417 mat_to_array(result.as_ref())
418}
419
420#[inline]
424pub fn fast_ab<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
425 a: &ArrayBase<S1, Ix2>,
426 b: &ArrayBase<S2, Ix2>,
427) -> Array2<f64> {
428 if let Some(out) =
429 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_ab(a.view(), b.view()))
430 {
431 return out;
432 }
433 let n = a.nrows();
434 let q = b.ncols();
435 let mut out = Array2::<f64>::zeros((n, q));
436 fast_ab_into(a, b, &mut out);
437 out
438}
439
440const FMA_LANES: usize = 8;
465
466const KERNEL_PAR_MIN_FLOP: usize = 1 << 18; const AV_PAR_CHUNK_ROWS: usize = 1024;
474
475const ATV_BLOCK_ROWS: usize = 512;
480
481#[inline]
482fn kernel_should_parallelize(n: usize, p: usize) -> bool {
483 !in_nested_parallel_region()
484 && n.saturating_mul(p) >= KERNEL_PAR_MIN_FLOP
485 && rayon::current_num_threads() > 1
486}
487
488#[inline(always)]
503fn fma_dot(a: &[f64], b: &[f64]) -> f64 {
504 assert_eq!(a.len(), b.len());
505 let mut sum = [0.0f64; FMA_LANES];
506 let mut comp = [0.0f64; FMA_LANES];
507 let mut ca = a.chunks_exact(FMA_LANES);
508 let mut cb = b.chunks_exact(FMA_LANES);
509 for (xa, xb) in ca.by_ref().zip(cb.by_ref()) {
510 for l in 0..FMA_LANES {
511 let x = xa[l];
512 let y = xb[l];
513 let p = x * y;
515 let ep = x.mul_add(y, -p);
516 let s = sum[l] + p;
518 let bb = s - sum[l];
519 let es = (sum[l] - (s - bb)) + (p - bb);
520 sum[l] = s;
521 comp[l] += ep + es;
522 }
523 }
524 let mut sr = 0.0f64;
526 let mut cr = 0.0f64;
527 for (&x, &y) in ca.remainder().iter().zip(cb.remainder().iter()) {
528 let p = x * y;
529 let ep = x.mul_add(y, -p);
530 let s = sr + p;
531 let bb = s - sr;
532 let es = (sr - (s - bb)) + (p - bb);
533 sr = s;
534 cr += ep + es;
535 }
536 let mut total = sr + cr;
538 for l in 0..FMA_LANES {
539 total += sum[l] + comp[l];
540 }
541 total
542}
543
544fn fast_av_rowmajor_into(x_all: &[f64], v: &[f64], n: usize, p: usize, out: &mut [f64]) {
548 assert_eq!(x_all.len(), n * p);
549 assert_eq!(v.len(), p);
550 assert_eq!(out.len(), n);
551 if kernel_should_parallelize(n, p) {
552 use rayon::prelude::*;
553 out.par_chunks_mut(AV_PAR_CHUNK_ROWS)
554 .enumerate()
555 .for_each(|(c, chunk)| {
556 let base = c * AV_PAR_CHUNK_ROWS;
557 for (k, o) in chunk.iter_mut().enumerate() {
558 let i = base + k;
559 *o = fma_dot(&x_all[i * p..i * p + p], v);
560 }
561 });
562 } else {
563 for (i, o) in out.iter_mut().enumerate() {
564 *o = fma_dot(&x_all[i * p..i * p + p], v);
565 }
566 }
567}
568
569fn pairwise_sum_into(parts: &[Vec<f64>], out: &mut [f64]) {
571 match parts.len() {
572 0 => out.fill(0.0),
573 1 => out.copy_from_slice(&parts[0]),
574 _ => {
575 let mid = parts.len() / 2;
576 let p = out.len();
577 let mut left = vec![0.0f64; p];
578 let mut right = vec![0.0f64; p];
579 pairwise_sum_into(&parts[..mid], &mut left);
580 pairwise_sum_into(&parts[mid..], &mut right);
581 for ((o, &l), &r) in out.iter_mut().zip(left.iter()).zip(right.iter()) {
582 *o = l + r;
583 }
584 }
585 }
586}
587
588fn fast_atv_rowmajor_into(x_all: &[f64], v: &[f64], n: usize, p: usize, out: &mut [f64]) {
596 assert_eq!(x_all.len(), n * p);
597 assert_eq!(v.len(), n);
598 assert_eq!(out.len(), p);
599 let nblocks = n.div_ceil(ATV_BLOCK_ROWS);
600
601 let block_partial = |b: usize| -> Vec<f64> {
602 let start = b * ATV_BLOCK_ROWS;
603 let end = (start + ATV_BLOCK_ROWS).min(n);
604 let mut acc = vec![0.0f64; p];
605 for i in start..end {
606 let vi = v[i];
607 let row = &x_all[i * p..i * p + p];
608 for (a, &xij) in acc.iter_mut().zip(row.iter()) {
609 *a = xij.mul_add(vi, *a);
610 }
611 }
612 acc
613 };
614
615 let partials: Vec<Vec<f64>> = if kernel_should_parallelize(n, p) {
616 use rayon::prelude::*;
617 (0..nblocks).into_par_iter().map(block_partial).collect()
618 } else {
619 (0..nblocks).map(block_partial).collect()
620 };
621
622 pairwise_sum_into(&partials, out);
623}
624
625#[inline]
628pub fn fast_av<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
629 a: &ArrayBase<S1, Ix2>,
630 v: &ArrayBase<S2, Ix1>,
631) -> Array1<f64> {
632 if let Some(out) =
633 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_av(a.view(), v.view()))
634 {
635 return out;
636 }
637 fast_av_impl(a, v)
638}
639
640#[inline]
641fn fast_av_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
642 a: &ArrayBase<S1, Ix2>,
643 v: &ArrayBase<S2, Ix1>,
644) -> Array1<f64> {
645 use faer::linalg::matmul::matmul;
646 use faer::{Accum, Mat};
647
648 let (n, p) = a.dim();
649 assert_eq!(p, v.len(), "A cols must match v length");
650
651 if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
655 && n != 0
656 && p != 0
657 {
658 let mut out = Array1::<f64>::zeros(n);
659 fast_av_rowmajor_into(
660 x_all,
661 vs,
662 n,
663 p,
664 out.as_slice_mut().expect("fresh Array1 is contiguous"),
665 );
666 return out;
667 }
668
669 if !should_use_faer_matmul(n, 1, p) {
670 return a.dot(v);
671 }
672
673 let mut result = Mat::<f64>::zeros(n, 1);
674
675 let aview = FaerArrayView::new(a);
676 let vview = FaerColView::new(v);
677 let a_ref = aview.as_ref();
678 let v_ref = vview.as_ref();
679
680 let par = matmul_parallelism(n, 1, p);
681 matmul(result.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
682
683 let mut out = Array1::<f64>::zeros(n);
684 for i in 0..n {
685 out[i] = result[(i, 0)];
686 }
687 out
688}
689
690#[inline]
693pub fn fast_av_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
694 a: &ArrayBase<S1, Ix2>,
695 v: &ArrayBase<S2, Ix1>,
696 out: &mut Array1<f64>,
697) {
698 fast_av_into_impl(a, v, out);
699}
700
701#[inline]
702fn fast_av_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
703 a: &ArrayBase<S1, Ix2>,
704 v: &ArrayBase<S2, Ix1>,
705 out: &mut Array1<f64>,
706) {
707 use faer::Accum;
708 use faer::linalg::matmul::matmul;
709
710 let (n, p) = a.dim();
711 assert_eq!(v.len(), p, "vector length must match A cols");
712 assert_eq!(out.len(), n, "output length must match A rows");
713
714 if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
715 && n != 0
716 && p != 0
717 && let Some(out_s) = out.as_slice_mut()
718 {
719 fast_av_rowmajor_into(x_all, vs, n, p, out_s);
720 return;
721 }
722
723 if !should_use_faer_matmul(n, 1, p) {
724 out.assign(&a.dot(v));
725 return;
726 }
727
728 let mut outview = array1_to_col_matmut(out);
729
730 let aview = FaerArrayView::new(a);
731 let vview = FaerColView::new(v);
732 let a_ref = aview.as_ref();
733 let v_ref = vview.as_ref();
734 let par = matmul_parallelism(n, 1, p);
735 matmul(outview.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
736}
737
738#[inline]
745pub fn fast_av_view_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
746 a: &ArrayBase<S1, Ix2>,
747 v: &ArrayBase<S2, Ix1>,
748 out: ArrayViewMut1<'_, f64>,
749) {
750 fast_av_view_into_impl(a, v, out);
751}
752
753#[inline]
754fn fast_av_view_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
755 a: &ArrayBase<S1, Ix2>,
756 v: &ArrayBase<S2, Ix1>,
757 mut out: ArrayViewMut1<'_, f64>,
758) {
759 use faer::Accum;
760 use faer::linalg::matmul::matmul;
761
762 let (n, p) = a.dim();
763 assert_eq!(v.len(), p, "vector length must match A cols");
764 assert_eq!(out.len(), n, "output length must match A rows");
765
766 if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
767 && n != 0
768 && p != 0
769 && let Some(out_s) = out.as_slice_mut()
770 {
771 fast_av_rowmajor_into(x_all, vs, n, p, out_s);
772 return;
773 }
774
775 if !should_use_faer_matmul(n, 1, p) {
776 let prod = a.dot(v);
777 out.assign(&prod);
778 return;
779 }
780
781 let len = out.len();
782 let stride = out.strides()[0];
783 let outview = unsafe {
787 MatMut::from_raw_parts_mut(
788 out.as_mut_ptr(),
789 len,
790 1,
791 stride,
792 0, )
794 };
795
796 let aview = FaerArrayView::new(a);
797 let vview = FaerColView::new(v);
798 let a_ref = aview.as_ref();
799 let v_ref = vview.as_ref();
800 let par = matmul_parallelism(n, 1, p);
801 matmul(outview, Accum::Replace, a_ref, v_ref, 1.0, par);
802}
803
804#[inline]
807pub fn fast_atv<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
808 a: &ArrayBase<S1, Ix2>,
809 v: &ArrayBase<S2, Ix1>,
810) -> Array1<f64> {
811 if let Some(out) =
812 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atv(a.view(), v.view()))
813 {
814 return out;
815 }
816 fast_atv_impl(a, v)
817}
818
819#[inline]
820fn fast_atv_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
821 a: &ArrayBase<S1, Ix2>,
822 v: &ArrayBase<S2, Ix1>,
823) -> Array1<f64> {
824 use faer::Accum;
825 use faer::linalg::matmul::matmul;
826
827 let (n, p) = a.dim();
828 assert_eq!(n, v.len(), "A rows must match v length");
829
830 if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
834 && n != 0
835 && p != 0
836 {
837 let mut out = Array1::<f64>::zeros(p);
838 fast_atv_rowmajor_into(
839 x_all,
840 vs,
841 n,
842 p,
843 out.as_slice_mut().expect("fresh Array1 is contiguous"),
844 );
845 return out;
846 }
847
848 if !should_use_faer_matmul(p, 1, n) {
850 return a.t().dot(v);
851 }
852
853 let mut out = Array1::<f64>::zeros(p);
854 let mut outview = array1_to_col_matmut(&mut out);
855
856 let aview = FaerArrayView::new(a);
857 let vview = FaerColView::new(v);
858 let a_ref = aview.as_ref();
859 let v_ref = vview.as_ref();
860
861 let par = matmul_parallelism(p, 1, n);
863 matmul(
864 outview.as_mut(),
865 Accum::Replace,
866 a_ref.transpose(),
867 v_ref,
868 1.0,
869 par,
870 );
871
872 out
873}
874
875#[inline]
878pub fn fast_atv_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
879 a: &ArrayBase<S1, Ix2>,
880 v: &ArrayBase<S2, Ix1>,
881 out: &mut Array1<f64>,
882) {
883 fast_atv_into_impl(a, v, out);
884}
885
886#[inline]
887fn fast_atv_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
888 a: &ArrayBase<S1, Ix2>,
889 v: &ArrayBase<S2, Ix1>,
890 out: &mut Array1<f64>,
891) {
892 use faer::Accum;
893 use faer::linalg::matmul::matmul;
894
895 let (n, p) = a.dim();
896 assert_eq!(v.len(), n, "vector length must match A rows");
897 assert_eq!(out.len(), p, "output length must match A cols");
898
899 if let (Some(x_all), Some(vs)) = (a.as_slice(), v.as_slice())
900 && n != 0
901 && p != 0
902 && let Some(out_s) = out.as_slice_mut()
903 {
904 fast_atv_rowmajor_into(x_all, vs, n, p, out_s);
905 return;
906 }
907
908 if !should_use_faer_matmul(p, 1, n) {
909 out.assign(&a.t().dot(v));
910 return;
911 }
912
913 let mut outview = array1_to_col_matmut(out);
914
915 let aview = FaerArrayView::new(a);
916 let vview = FaerColView::new(v);
917 let a_ref = aview.as_ref();
918 let v_ref = vview.as_ref();
919 let par = matmul_parallelism(p, 1, n);
920 matmul(
921 outview.as_mut(),
922 Accum::Replace,
923 a_ref.transpose(),
924 v_ref,
925 1.0,
926 par,
927 );
928}
929
930#[inline]
932pub fn fast_xt_diag_x<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
933 x: &ArrayBase<S1, Ix2>,
934 w: &ArrayBase<S2, Ix1>,
935) -> Array2<f64> {
936 assert_eq!(
937 x.nrows(),
938 w.len(),
939 "fast_xt_diag_x row/weight length mismatch"
940 );
941 if let Some(out) =
942 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_xt_diag_x(x.view(), w.view()))
943 {
944 return out;
945 }
946 let p = x.ncols();
947 fast_xt_diag_x_with_parallelism(x, w, matmul_parallelism(p, p, x.nrows()))
948}
949
950#[inline]
953pub fn fast_xt_diag_x_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
954 x: &ArrayBase<S1, Ix2>,
955 w: &ArrayBase<S2, Ix1>,
956 par: Par,
957) -> Array2<f64> {
958 assert_eq!(
959 x.nrows(),
960 w.len(),
961 "fast_xt_diag_x_with_parallelism row/weight length mismatch"
962 );
963 fast_xt_diag_x_with_parallelism_impl(x, w, par)
964}
965
966#[inline]
967fn fast_xt_diag_x_with_parallelism_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
968 x: &ArrayBase<S1, Ix2>,
969 w: &ArrayBase<S2, Ix1>,
970 par: Par,
971) -> Array2<f64> {
972 use ndarray::ShapeBuilder;
973
974 let p = x.ncols();
975 let mut result = Array2::<f64>::zeros((p, p).f());
978 stream_weighted_crossprod_into(
979 x,
980 w,
981 &mut result,
982 CrossprodStructure::SymmetricLower,
983 CrossprodAccum::Replace,
984 par,
985 );
986 result
987}
988
989#[derive(Clone, Copy, PartialEq, Eq, Debug)]
991pub enum CrossprodStructure {
992 Full,
994 SymmetricLower,
998}
999
1000#[derive(Clone, Copy, PartialEq, Eq, Debug)]
1002pub enum CrossprodAccum {
1003 Replace,
1005 Add,
1007}
1008
1009pub fn stream_weighted_crossprod_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1028 x: &ArrayBase<S1, Ix2>,
1029 w: &ArrayBase<S2, Ix1>,
1030 out: &mut Array2<f64>,
1031 structure: CrossprodStructure,
1032 accum: CrossprodAccum,
1033 par: Par,
1034) {
1035 use faer::Accum;
1036 use faer::linalg::matmul::matmul;
1037 use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
1038 use ndarray::s;
1039
1040 let (n, p) = x.dim();
1041 assert_eq!(n, w.len(), "X rows must match W length");
1042 assert_eq!(out.nrows(), p, "output rows must match X cols");
1043 assert_eq!(out.ncols(), p, "output cols must match X cols");
1044 if p == 0 {
1045 return;
1046 }
1047 if n == 0 {
1048 if accum == CrossprodAccum::Replace {
1049 out.fill(0.0);
1050 }
1051 return;
1052 }
1053
1054 if !should_use_faer_matmul(p, p, n) {
1055 let w_x = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
1057 let gram = x.t().dot(&w_x);
1058 match accum {
1059 CrossprodAccum::Replace => out.assign(&gram),
1060 CrossprodAccum::Add => *out += &gram,
1061 }
1062 return;
1063 }
1064
1065 const TARGET_BYTES: usize = 8 * 1024 * 1024;
1067 const MIN_ROWS: usize = 512;
1068 const MAX_ROWS: usize = 131_072;
1069 let chunk_rows = (TARGET_BYTES / (p.max(1) * 8))
1070 .clamp(MIN_ROWS, MAX_ROWS)
1071 .min(n);
1072
1073 if accum == CrossprodAccum::Replace {
1078 out.fill(0.0);
1079 }
1080
1081 let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p));
1087
1088 let x_is_row_major = x.is_standard_layout();
1089 let w_slice_opt = w.as_slice();
1090
1091 {
1094 let mut out_view = array2_to_matmut(out);
1095 for start in (0..n).step_by(chunk_rows) {
1096 let rows = (n - start).min(chunk_rows);
1097 {
1098 let chunk_slice = wx_chunk
1099 .as_slice_mut()
1100 .expect("row-major chunk is contiguous");
1101 if x_is_row_major && let (Some(x_all), Some(w_all)) = (x.as_slice(), w_slice_opt) {
1102 for local in 0..rows {
1103 let src = start + local;
1104 let wi = w_all[src];
1105 let src_off = src * p;
1106 let dst_off = local * p;
1107 let src_row = &x_all[src_off..src_off + p];
1108 let dst_row = &mut chunk_slice[dst_off..dst_off + p];
1109 for col in 0..p {
1110 dst_row[col] = src_row[col] * wi;
1111 }
1112 }
1113 } else {
1114 let x_slice = x.slice(s![start..start + rows, ..]);
1115 for local in 0..rows {
1116 let wi = w[start + local];
1117 let xrow = x_slice.row(local);
1118 let dst_off = local * p;
1119 let dst_row = &mut chunk_slice[dst_off..dst_off + p];
1120 for (col, xij) in xrow.iter().enumerate() {
1121 dst_row[col] = xij * wi;
1122 }
1123 }
1124 }
1125 }
1126 let x_slice = x.slice(s![start..start + rows, ..]);
1127 let wx_slice = wx_chunk.slice(s![0..rows, ..]);
1128 let x_view = FaerArrayView::new(&x_slice);
1129 let wx_view = FaerArrayView::new(&wx_slice);
1130 match structure {
1131 CrossprodStructure::SymmetricLower => {
1132 tri_matmul(
1136 out_view.as_mut(),
1137 BlockStructure::TriangularLower,
1138 Accum::Add,
1139 x_view.as_ref().transpose(),
1140 BlockStructure::Rectangular,
1141 wx_view.as_ref(),
1142 BlockStructure::Rectangular,
1143 1.0,
1144 par,
1145 );
1146 }
1147 CrossprodStructure::Full => {
1148 matmul(
1149 out_view.as_mut(),
1150 Accum::Add,
1151 x_view.as_ref().transpose(),
1152 wx_view.as_ref(),
1153 1.0,
1154 par,
1155 );
1156 }
1157 }
1158 }
1159 }
1160
1161 if structure == CrossprodStructure::SymmetricLower {
1162 for i in 0..p {
1164 for j in (i + 1)..p {
1165 out[[i, j]] = out[[j, i]];
1166 }
1167 }
1168 }
1169}
1170
1171#[inline]
1173pub fn fast_xt_diag_y<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
1174 x: &ArrayBase<S1, Ix2>,
1175 w: &ArrayBase<S2, Ix1>,
1176 y: &ArrayBase<S3, Ix2>,
1177) -> Array2<f64> {
1178 assert_eq!(x.nrows(), y.nrows(), "fast_xt_diag_y X/Y row mismatch");
1179 assert_eq!(
1180 y.nrows(),
1181 w.len(),
1182 "fast_xt_diag_y row/weight length mismatch"
1183 );
1184 if let Some(out) = crate::gpu_hook::gpu_dispatch()
1185 .and_then(|d| d.try_fast_xt_diag_y(x.view(), w.view(), y.view()))
1186 {
1187 return out;
1188 }
1189 fast_xt_diag_y_impl(x, w, y)
1190}
1191
1192#[inline]
1193fn fast_xt_diag_y_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
1194 x: &ArrayBase<S1, Ix2>,
1195 w: &ArrayBase<S2, Ix1>,
1196 y: &ArrayBase<S3, Ix2>,
1197) -> Array2<f64> {
1198 use faer::Accum;
1199 use faer::linalg::matmul::matmul;
1200 use ndarray::{ShapeBuilder, s};
1201
1202 let (n, q) = y.dim();
1203 let px = x.ncols();
1204 assert_eq!(n, w.len(), "Y rows must match W length");
1205 assert_eq!(n, x.nrows(), "X rows must match Y rows");
1206 if n == 0 || px == 0 || q == 0 {
1207 return Array2::<f64>::zeros((px, q));
1208 }
1209 if !should_use_faer_matmul(px, q, n) {
1210 let w_y = Array2::from_shape_fn((n, q), |(i, j)| w[i] * y[[i, j]]);
1211 return x.t().dot(&w_y);
1212 }
1213
1214 const TARGET_BYTES: usize = 8 * 1024 * 1024;
1216 const MIN_ROWS: usize = 512;
1217 const MAX_ROWS: usize = 131_072;
1218 let total_cols = px + q;
1219 let chunk_rows = (TARGET_BYTES / (total_cols.max(1) * 8))
1220 .clamp(MIN_ROWS, MAX_ROWS)
1221 .min(n);
1222
1223 let mut result = Array2::<f64>::zeros((px, q).f());
1224 let mut wy_chunk = Array2::<f64>::zeros((chunk_rows, q));
1227
1228 let y_is_row_major = y.is_standard_layout();
1229 let w_slice_opt = w.as_slice();
1230
1231 {
1232 let mut out_view = array2_to_matmut(&mut result);
1233
1234 for start in (0..n).step_by(chunk_rows) {
1235 let rows = (n - start).min(chunk_rows);
1236 {
1237 let chunk_slice = wy_chunk
1238 .as_slice_mut()
1239 .expect("row-major chunk is contiguous");
1240 if y_is_row_major && let (Some(y_all), Some(w_all)) = (y.as_slice(), w_slice_opt) {
1241 for local in 0..rows {
1242 let src = start + local;
1243 let wi = w_all[src];
1244 let src_off = src * q;
1245 let dst_off = local * q;
1246 let src_row = &y_all[src_off..src_off + q];
1247 let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1248 for col in 0..q {
1249 dst_row[col] = src_row[col] * wi;
1250 }
1251 }
1252 } else {
1253 let y_slice = y.slice(s![start..start + rows, ..]);
1254 for local in 0..rows {
1255 let wi = w[start + local];
1256 let yrow = y_slice.row(local);
1257 let dst_off = local * q;
1258 let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1259 for (col, yij) in yrow.iter().enumerate() {
1260 dst_row[col] = yij * wi;
1261 }
1262 }
1263 }
1264 }
1265 let x_slice = x.slice(s![start..start + rows, ..]);
1266 let wy_slice = wy_chunk.slice(s![0..rows, ..]);
1267 let x_view = FaerArrayView::new(&x_slice);
1268 let wy_view = FaerArrayView::new(&wy_slice);
1269 let par = matmul_parallelism(px, q, rows);
1270 matmul(
1271 out_view.as_mut(),
1272 Accum::Add,
1273 x_view.as_ref().transpose(),
1274 wy_view.as_ref(),
1275 1.0,
1276 par,
1277 );
1278 }
1279 }
1280
1281 result
1282}
1283
1284pub fn fast_joint_hessian_2x2<
1290 S1: Data<Elem = f64>,
1291 S2: Data<Elem = f64>,
1292 S3: Data<Elem = f64>,
1293 S4: Data<Elem = f64>,
1294 S5: Data<Elem = f64>,
1295>(
1296 x_a: &ArrayBase<S1, Ix2>,
1297 x_b: &ArrayBase<S2, Ix2>,
1298 w_aa: &ArrayBase<S3, Ix1>,
1299 w_ab: &ArrayBase<S4, Ix1>,
1300 w_bb: &ArrayBase<S5, Ix1>,
1301) -> Array2<f64> {
1302 if let Some(out) = crate::gpu_hook::gpu_dispatch().and_then(|d| {
1303 d.try_fast_joint_hessian_2x2(
1304 x_a.view(),
1305 x_b.view(),
1306 w_aa.view(),
1307 w_ab.view(),
1308 w_bb.view(),
1309 )
1310 }) {
1311 return out;
1312 }
1313 fast_joint_hessian_2x2_impl(x_a, x_b, w_aa, w_ab, w_bb)
1314}
1315
1316#[inline]
1317fn fast_joint_hessian_2x2_impl<
1318 S1: Data<Elem = f64>,
1319 S2: Data<Elem = f64>,
1320 S3: Data<Elem = f64>,
1321 S4: Data<Elem = f64>,
1322 S5: Data<Elem = f64>,
1323>(
1324 x_a: &ArrayBase<S1, Ix2>,
1325 x_b: &ArrayBase<S2, Ix2>,
1326 w_aa: &ArrayBase<S3, Ix1>,
1327 w_ab: &ArrayBase<S4, Ix1>,
1328 w_bb: &ArrayBase<S5, Ix1>,
1329) -> Array2<f64> {
1330 use faer::Accum;
1331 use faer::linalg::matmul::matmul;
1332 use ndarray::{ShapeBuilder, s};
1333
1334 let n = x_a.nrows();
1335 let pa = x_a.ncols();
1336 let pb = x_b.ncols();
1337 let total = pa + pb;
1338 assert_eq!(n, x_b.nrows());
1339 assert_eq!(n, w_aa.len());
1340 assert_eq!(n, w_ab.len());
1341 assert_eq!(n, w_bb.len());
1342
1343 if n == 0 || total == 0 {
1344 return Array2::<f64>::zeros((total, total));
1345 }
1346
1347 if !should_use_faer_matmul(pa.max(pb), pa.max(pb), n) {
1349 let waa_xa = Array2::from_shape_fn((n, pa), |(i, j)| w_aa[i] * x_a[[i, j]]);
1350 let wab_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_ab[i] * x_b[[i, j]]);
1351 let wbb_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_bb[i] * x_b[[i, j]]);
1352 let mut out = Array2::<f64>::zeros((total, total));
1353 out.slice_mut(s![..pa, ..pa]).assign(&x_a.t().dot(&waa_xa));
1354 out.slice_mut(s![..pa, pa..]).assign(&x_a.t().dot(&wab_xb));
1355 out.slice_mut(s![pa.., pa..]).assign(&x_b.t().dot(&wbb_xb));
1356 for i in 0..total {
1358 for j in 0..i {
1359 out[[i, j]] = out[[j, i]];
1360 }
1361 }
1362 return out;
1363 }
1364
1365 const TARGET_BYTES: usize = 8 * 1024 * 1024;
1366 const MIN_ROWS: usize = 512;
1367 const MAX_ROWS: usize = 131_072;
1368 let cols_needed = pa + 2 * pb;
1370 let chunk_rows = (TARGET_BYTES / (cols_needed.max(1) * 8))
1371 .clamp(MIN_ROWS, MAX_ROWS)
1372 .min(n);
1373
1374 let mut out = Array2::<f64>::zeros((total, total).f());
1375 let mut waa_xa_buf = Array2::<f64>::zeros((chunk_rows, pa));
1380 let mut wab_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1381 let mut wbb_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1382
1383 let xa_is_row_major = x_a.is_standard_layout();
1384 let xb_is_row_major = x_b.is_standard_layout();
1385 let waa_slice_opt = w_aa.as_slice();
1386 let wab_slice_opt = w_ab.as_slice();
1387 let wbb_slice_opt = w_bb.as_slice();
1388
1389 {
1390 let mut out_mat = array2_to_matmut(&mut out);
1391
1392 for start in (0..n).step_by(chunk_rows) {
1393 let rows = (n - start).min(chunk_rows);
1394 let xa_slice = x_a.slice(s![start..start + rows, ..]);
1395 let xb_slice = x_b.slice(s![start..start + rows, ..]);
1396
1397 {
1399 let waa_chunk = waa_xa_buf
1400 .as_slice_mut()
1401 .expect("row-major waa chunk is contiguous");
1402 let wab_chunk = wab_xb_buf
1403 .as_slice_mut()
1404 .expect("row-major wab chunk is contiguous");
1405 let wbb_chunk = wbb_xb_buf
1406 .as_slice_mut()
1407 .expect("row-major wbb chunk is contiguous");
1408
1409 if xa_is_row_major
1410 && xb_is_row_major
1411 && let (Some(xa_all), Some(xb_all)) = (x_a.as_slice(), x_b.as_slice())
1412 && let (Some(waa_all), Some(wab_all), Some(wbb_all)) =
1413 (waa_slice_opt, wab_slice_opt, wbb_slice_opt)
1414 {
1415 for local in 0..rows {
1416 let i = start + local;
1417 let waa_i = waa_all[i];
1418 let wab_i = wab_all[i];
1419 let wbb_i = wbb_all[i];
1420 let xa_off = i * pa;
1421 let xa_row = &xa_all[xa_off..xa_off + pa];
1422 let xb_off = i * pb;
1423 let xb_row = &xb_all[xb_off..xb_off + pb];
1424 let waa_off = local * pa;
1425 let wab_off = local * pb;
1426 let wbb_off = local * pb;
1427 let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1428 for col in 0..pa {
1429 waa_row[col] = xa_row[col] * waa_i;
1430 }
1431 let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1432 let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1433 for col in 0..pb {
1434 let xij = xb_row[col];
1435 wab_row[col] = xij * wab_i;
1436 wbb_row[col] = xij * wbb_i;
1437 }
1438 }
1439 } else {
1440 for local in 0..rows {
1441 let i = start + local;
1442 let waa_i = w_aa[i];
1443 let wab_i = w_ab[i];
1444 let wbb_i = w_bb[i];
1445 let waa_off = local * pa;
1446 let wab_off = local * pb;
1447 let wbb_off = local * pb;
1448 let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1449 let xa_row = xa_slice.row(local);
1450 for (col, xij) in xa_row.iter().enumerate() {
1451 waa_row[col] = xij * waa_i;
1452 }
1453 let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1454 let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1455 let xb_row = xb_slice.row(local);
1456 for (col, xij) in xb_row.iter().enumerate() {
1457 wab_row[col] = xij * wab_i;
1458 wbb_row[col] = xij * wbb_i;
1459 }
1460 }
1461 }
1462 }
1463
1464 let xa_view = FaerArrayView::new(&xa_slice);
1465 let xb_view = FaerArrayView::new(&xb_slice);
1466 let waa_xa_slice = waa_xa_buf.slice(s![0..rows, ..]);
1467 let wab_xb_slice = wab_xb_buf.slice(s![0..rows, ..]);
1468 let wbb_xb_slice = wbb_xb_buf.slice(s![0..rows, ..]);
1469 let waa_xa_view = FaerArrayView::new(&waa_xa_slice);
1470 let wab_xb_view = FaerArrayView::new(&wab_xb_slice);
1471 let wbb_xb_view = FaerArrayView::new(&wbb_xb_slice);
1472
1473 matmul(
1475 out_mat.rb_mut().submatrix_mut(0, 0, pa, pa),
1476 Accum::Add,
1477 xa_view.as_ref().transpose(),
1478 waa_xa_view.as_ref(),
1479 1.0,
1480 matmul_parallelism(pa, pa, rows),
1481 );
1482 matmul(
1484 out_mat.rb_mut().submatrix_mut(0, pa, pa, pb),
1485 Accum::Add,
1486 xa_view.as_ref().transpose(),
1487 wab_xb_view.as_ref(),
1488 1.0,
1489 matmul_parallelism(pa, pb, rows),
1490 );
1491 matmul(
1493 out_mat.rb_mut().submatrix_mut(pa, pa, pb, pb),
1494 Accum::Add,
1495 xb_view.as_ref().transpose(),
1496 wbb_xb_view.as_ref(),
1497 1.0,
1498 matmul_parallelism(pb, pb, rows),
1499 );
1500 }
1501 } for i in 0..total {
1504 for j in 0..i {
1505 out[[i, j]] = out[[j, i]];
1506 }
1507 }
1508 out
1509}
1510
1511fn mat_to_array(mat: MatRef<'_, f64>) -> Array2<f64> {
1512 let nrows = mat.nrows();
1513 let ncols = mat.ncols();
1514 let mut out = Array2::<f64>::zeros((nrows, ncols));
1515 if nrows == 0 || ncols == 0 {
1516 return out;
1517 }
1518 if let Some(out_slice) = out.as_slice_memory_order_mut() {
1521 for i in 0..nrows {
1523 let row_start = i * ncols;
1524 for j in 0..ncols {
1525 out_slice[row_start + j] = mat[(i, j)];
1526 }
1527 }
1528 } else {
1529 for j in 0..ncols {
1530 for i in 0..nrows {
1531 out[[i, j]] = mat[(i, j)];
1532 }
1533 }
1534 }
1535 out
1536}
1537
1538#[inline]
1541pub fn fast_ab_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1542 a: &ArrayBase<S1, Ix2>,
1543 b: &ArrayBase<S2, Ix2>,
1544 out: &mut Array2<f64>,
1545) {
1546 fast_ab_into_impl(a, b, out);
1547}
1548
1549#[inline]
1550fn fast_ab_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1551 a: &ArrayBase<S1, Ix2>,
1552 b: &ArrayBase<S2, Ix2>,
1553 out: &mut Array2<f64>,
1554) {
1555 use faer::Accum;
1556 use faer::linalg::matmul::matmul;
1557
1558 let (n, p) = a.dim();
1559 let (p_b, q) = b.dim();
1560 assert_eq!(p, p_b, "A and B must have compatible inner dimensions");
1561 assert_eq!(out.dim(), (n, q), "output dimensions must match A*B result");
1562
1563 if !should_use_faer_matmul(n, q, p) {
1564 out.assign(&a.dot(b));
1565 return;
1566 }
1567
1568 let aview = FaerArrayView::new(a);
1569 let bview = FaerArrayView::new(b);
1570 let a_ref = aview.as_ref();
1571 let b_ref = bview.as_ref();
1572
1573 let par = matmul_parallelism(n, q, p);
1574 let mut outview = array2_to_matmut(out);
1575 matmul(outview.as_mut(), Accum::Replace, a_ref, b_ref, 1.0, par);
1576}
1577
1578fn diag_to_array(diag: DiagRef<'_, f64>) -> Array1<f64> {
1579 let mat = diag.column_vector().as_mat();
1580 let mut out = Array1::<f64>::zeros(mat.nrows());
1581 for i in 0..mat.nrows() {
1582 out[i] = mat[(i, 0)];
1583 }
1584 out
1585}
1586
1587pub struct FaerArrayView<'a> {
1588 ptr: *const f64,
1589 rows: usize,
1590 cols: usize,
1591 row_stride: isize,
1592 col_stride: isize,
1593 owned: Option<Array2<f64>>,
1594 marker: PhantomData<&'a f64>,
1595}
1596
1597impl<'a> FaerArrayView<'a> {
1598 #[inline]
1599 pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix2>) -> Self {
1600 let (rows, cols) = array.dim();
1601 let strides = array.strides();
1602 if strides[0] <= 0 || strides[1] <= 0 {
1606 let owned = array.to_owned();
1607 let owned_strides = owned.strides();
1608 return Self {
1609 ptr: owned.as_ptr(),
1610 rows,
1611 cols,
1612 row_stride: owned_strides[0],
1613 col_stride: owned_strides[1],
1614 owned: Some(owned),
1615 marker: PhantomData,
1616 };
1617 }
1618
1619 Self {
1620 ptr: array.as_ptr(),
1621 rows,
1622 cols,
1623 row_stride: strides[0],
1624 col_stride: strides[1],
1625 owned: None,
1626 marker: PhantomData,
1627 }
1628 }
1629
1630 #[inline]
1631 pub fn as_ref(&self) -> MatRef<'_, f64> {
1632 let (ptr, rows, cols, row_stride, col_stride) = if let Some(owned) = &self.owned {
1633 let strides = owned.strides();
1634 (
1635 owned.as_ptr(),
1636 owned.nrows(),
1637 owned.ncols(),
1638 strides[0],
1639 strides[1],
1640 )
1641 } else {
1642 (
1643 self.ptr,
1644 self.rows,
1645 self.cols,
1646 self.row_stride,
1647 self.col_stride,
1648 )
1649 };
1650 unsafe { MatRef::from_raw_parts(ptr, rows, cols, row_stride, col_stride) }
1654 }
1655}
1656
1657pub struct FaerColView<'a> {
1658 ptr: *const f64,
1659 len: usize,
1660 stride: isize,
1661 owned: Option<Array1<f64>>,
1662 marker: PhantomData<&'a f64>,
1663}
1664
1665impl<'a> FaerColView<'a> {
1666 #[inline]
1667 pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix1>) -> Self {
1668 let len = array.len();
1669 let stride = array.strides()[0];
1670 if stride <= 0 {
1671 let owned = array.to_owned();
1672 return Self {
1673 ptr: owned.as_ptr(),
1674 len,
1675 stride: 1,
1676 owned: Some(owned),
1677 marker: PhantomData,
1678 };
1679 }
1680 Self {
1681 ptr: array.as_ptr(),
1682 len,
1683 stride,
1684 owned: None,
1685 marker: PhantomData,
1686 }
1687 }
1688
1689 #[inline]
1690 pub fn as_ref(&self) -> MatRef<'_, f64> {
1691 let (ptr, len, stride) = if let Some(owned) = &self.owned {
1692 (owned.as_ptr(), owned.len(), 1)
1693 } else {
1694 (self.ptr, self.len, self.stride)
1695 };
1696 unsafe { MatRef::from_raw_parts(ptr, len, 1, stride, 0) }
1700 }
1701}
1702
1703pub trait FaerSvd {
1704 fn svd(
1705 &self,
1706 compute_u: bool,
1707 computevt: bool,
1708 ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError>;
1709}
1710
1711impl<S: Data<Elem = f64>> FaerSvd for ArrayBase<S, Ix2> {
1712 fn svd(
1713 &self,
1714 compute_u: bool,
1715 computevt: bool,
1716 ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError> {
1717 let faerview = FaerArrayView::new(self);
1718 let faer_mat = faerview.as_ref();
1719 if !compute_u && !computevt {
1720 let (rows, cols) = faer_mat.shape();
1721 let mut singular = Diag::<f64>::zeros(rows.min(cols));
1722 let par = get_global_parallelism();
1723 let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1724 rows,
1725 cols,
1726 ComputeSvdVectors::No,
1727 ComputeSvdVectors::No,
1728 par,
1729 Default::default(),
1730 ));
1731 let stack = MemStack::new(&mut mem);
1732 svd::svd(
1733 faer_mat,
1734 singular.as_mut(),
1735 None,
1736 None,
1737 par,
1738 stack,
1739 Default::default(),
1740 )
1741 .map_err(|_| FaerLinalgError::SvdNoConvergence {
1742 context: "faer SVD singular values only",
1743 })?;
1744 let singularvalues = diag_to_array(singular.as_ref());
1745 return Ok((None, singularvalues, None));
1746 }
1747
1748 let (rows, cols) = faer_mat.shape();
1749 let rank = rows.min(cols);
1750 let compute_u_flag = if compute_u {
1751 ComputeSvdVectors::Thin
1752 } else {
1753 ComputeSvdVectors::No
1754 };
1755 let computev_flag = if computevt {
1756 ComputeSvdVectors::Thin
1757 } else {
1758 ComputeSvdVectors::No
1759 };
1760
1761 let mut singular = Diag::<f64>::zeros(rows.min(cols));
1762 let mut u_storage = compute_u.then(|| Mat::<f64>::zeros(rows, rank));
1763 let mut v_storage = computevt.then(|| Mat::<f64>::zeros(cols, rank));
1764
1765 let par = get_global_parallelism();
1766 let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1767 rows,
1768 cols,
1769 compute_u_flag,
1770 computev_flag,
1771 par,
1772 Default::default(),
1773 ));
1774 let stack = MemStack::new(&mut mem);
1775
1776 svd::svd(
1777 faer_mat.as_ref(),
1778 singular.as_mut(),
1779 u_storage.as_mut().map(|mat| mat.as_mut()),
1780 v_storage.as_mut().map(|mat| mat.as_mut()),
1781 par,
1782 stack,
1783 Default::default(),
1784 )
1785 .map_err(|_| FaerLinalgError::SvdNoConvergence {
1786 context: "faer SVD with vectors",
1787 })?;
1788
1789 let singularvalues = diag_to_array(singular.as_ref());
1790 let u_opt = u_storage.map(|mat| mat_to_array(mat.as_ref()));
1791 let vt_opt = v_storage.map(|mat| {
1792 let mat_ref = mat.as_ref();
1793 let mut out = Array2::<f64>::zeros((mat_ref.ncols(), mat_ref.nrows()));
1794 for j in 0..mat_ref.nrows() {
1795 for i in 0..mat_ref.ncols() {
1796 out[[i, j]] = mat_ref[(j, i)];
1797 }
1798 }
1799 out
1800 });
1801
1802 Ok((u_opt, singularvalues, vt_opt))
1803 }
1804}
1805
1806pub trait FaerEigh {
1807 fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError>;
1808}
1809
1810impl<S: Data<Elem = f64>> FaerEigh for ArrayBase<S, Ix2> {
1811 fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1812 fn try_eigh(
1813 matrix: &Array2<f64>,
1814 side: Side,
1815 ) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1816 let faerview = FaerArrayView::new(matrix);
1817 let eigen = catch_unwind(AssertUnwindSafe(|| {
1818 faerview.as_ref().self_adjoint_eigen(side)
1819 }))
1820 .map_err(|_| FaerLinalgError::FactorizationFailed {
1821 context: "self-adjoint eigendecomposition panic boundary",
1822 })?
1823 .map_err(FaerLinalgError::SelfAdjointEigen)?;
1824 let values = diag_to_array(eigen.S());
1825 let vectors = mat_to_array(eigen.U());
1826 Ok((values, vectors))
1827 }
1828
1829 let owned = self.to_owned();
1830 if owned.nrows() != owned.ncols() {
1831 return Err(FaerLinalgError::FactorizationFailed {
1832 context: "self-adjoint eigendecomposition non-square input",
1833 });
1834 }
1835 if owned.nrows() == 0 {
1836 return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
1837 }
1838 if owned.iter().any(|value| !value.is_finite()) {
1839 return Err(FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1840 context: "self-adjoint eigendecomposition input validation",
1841 });
1842 }
1843 if let Ok((evals, evecs)) = try_eigh(&owned, side)
1844 && evals.iter().all(|value| value.is_finite())
1845 && evecs.iter().all(|value| value.is_finite())
1846 {
1847 return Ok((evals, evecs));
1848 }
1849
1850 let mut repaired = owned.clone();
1851 crate::matrix::symmetrize_in_place(&mut repaired);
1852
1853 let scale = repaired
1854 .iter()
1855 .fold(0.0_f64, |acc, &value| acc.max(value.abs()))
1856 .max(1.0);
1857 let scaled = repaired.mapv(|value| value / scale);
1858 const JITTER_SCHEDULE: [f64; 6] = [0.0, 1e-12, 1e-10, 1e-8, 1e-6, 1e-4];
1864 let jitter_schedule = JITTER_SCHEDULE;
1865 let mut last_error = FaerLinalgError::FactorizationFailed {
1866 context: "self-adjoint eigendecomposition repair attempts",
1867 };
1868
1869 for &jitter in &jitter_schedule {
1870 let mut candidate = scaled.clone();
1871 if jitter > 0.0 {
1872 let n = candidate.nrows();
1873 for i in 0..n {
1874 candidate[[i, i]] += jitter;
1875 }
1876 }
1877
1878 match try_eigh(&candidate, side) {
1879 Ok((mut evals, evecs))
1880 if evals.iter().all(|value| value.is_finite())
1881 && evecs.iter().all(|value| value.is_finite()) =>
1882 {
1883 for value in &mut evals {
1884 *value = (*value - jitter) * scale;
1885 }
1886 return Ok((evals, evecs));
1887 }
1888 Ok((_, _)) => {
1889 last_error = FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1890 context: "self-adjoint eigendecomposition repaired output validation",
1891 };
1892 }
1893 Err(err) => {
1894 last_error = err;
1895 }
1896 }
1897 }
1898
1899 Err(last_error)
1900 }
1901}
1902
1903pub struct FaerCholeskyFactor {
1904 factor: solvers::Llt<f64>,
1905}
1906
1907impl FaerCholeskyFactor {
1908 pub fn solvevec(&self, rhs: &Array1<f64>) -> Array1<f64> {
1909 let mut rhs = rhs.to_owned();
1910 let mut rhsview = array1_to_col_matmut(&mut rhs);
1911 self.factor.solve_in_place(rhsview.as_mut());
1912 rhs
1913 }
1914
1915 pub fn solve_mat_in_place(&self, rhs: &mut Array2<f64>) {
1916 let mut rhsview = array2_to_matmut(rhs);
1917 self.factor.solve_in_place(rhsview.as_mut());
1918 }
1919
1920 pub fn solve_mat_into<S: Data<Elem = f64>>(
1921 &self,
1922 rhs: &ArrayBase<S, Ix2>,
1923 out: &mut Array2<f64>,
1924 ) {
1925 if out.dim() != rhs.dim() {
1926 *out = Array2::<f64>::zeros(rhs.dim());
1927 }
1928 out.assign(rhs);
1929 self.solve_mat_in_place(out);
1930 }
1931
1932 pub fn solve_mat(&self, rhs: &Array2<f64>) -> Array2<f64> {
1933 let mut out = Array2::<f64>::zeros(rhs.dim());
1934 self.solve_mat_into(rhs, &mut out);
1935 out
1936 }
1937
1938 pub fn diag(&self) -> Array1<f64> {
1939 diag_to_array(self.factor.L().diagonal())
1940 }
1941
1942 pub fn lower_triangular(&self) -> Array2<f64> {
1943 mat_to_array(self.factor.L())
1944 }
1945}
1946
1947pub trait FaerCholesky {
1948 fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError>;
1949}
1950
1951impl<S: Data<Elem = f64>> FaerCholesky for ArrayBase<S, Ix2> {
1952 fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError> {
1953 let faerview = FaerArrayView::new(self);
1954 let factor = faerview
1955 .as_ref()
1956 .llt(side)
1957 .map_err(FaerLinalgError::Cholesky)?;
1958 Ok(FaerCholeskyFactor { factor })
1959 }
1960}
1961
1962pub trait FaerQr {
1963 fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError>;
1964}
1965
1966impl<S: Data<Elem = f64>> FaerQr for ArrayBase<S, Ix2> {
1967 fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError> {
1968 let faerview = FaerArrayView::new(self);
1969 let qr = faerview.as_ref().qr();
1970 let q = qr.compute_thin_Q();
1971 let r = qr.thin_R();
1972 Ok((mat_to_array(q.as_ref()), mat_to_array(r)))
1973 }
1974}
1975
1976pub fn rrqr_nullspace_basis<S: Data<Elem = f64>>(
1995 a: &ArrayBase<S, Ix2>,
1996 rank_alpha: f64,
1997) -> Result<(Array2<f64>, usize), FaerLinalgError> {
1998 let faerview = FaerArrayView::new(a);
1999 let qr = faerview.as_ref().col_piv_qr();
2000 let r = qr.thin_R();
2001 let diag_len = r.nrows().min(r.ncols());
2002 let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
2003 let tol = rank_alpha
2004 * f64::EPSILON
2005 * (a.nrows().max(a.ncols()).max(1) as f64)
2006 * leading_diag.max(1.0);
2007 let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
2008 let z = if rank >= a.nrows() {
2009 Array2::<f64>::zeros((a.nrows(), 0))
2010 } else if rank == 0 {
2011 Array2::<f64>::eye(a.nrows())
2015 } else {
2016 let nullity = a.nrows() - rank;
2017 let mut selector = Mat::<f64>::zeros(a.nrows(), nullity);
2018 for j in 0..nullity {
2019 selector[(rank + j, j)] = 1.0;
2020 }
2021 let par = get_global_parallelism();
2022 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
2023 qr.Q_basis(),
2024 qr.Q_coeff(),
2025 Conj::No,
2026 selector.as_mut(),
2027 par,
2028 MemStack::new(&mut MemBuffer::new(
2029 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<f64>(
2030 a.nrows(),
2031 qr.Q_coeff().nrows(),
2032 nullity,
2033 ),
2034 )),
2035 );
2036 mat_to_array(selector.as_ref())
2037 };
2038 Ok((z, rank))
2039}
2040
2041#[inline]
2042pub const fn default_rrqr_rank_alpha() -> f64 {
2043 RRQR_RANK_ALPHA
2044}
2045
2046pub struct RrqrWithPermutation {
2057 pub rank: usize,
2058 pub column_permutation: Vec<usize>,
2059 pub leading_diag_abs: f64,
2060 pub rank_tol: f64,
2061}
2062
2063pub fn rrqr_with_permutation<S: Data<Elem = f64>>(
2072 a: &ArrayBase<S, Ix2>,
2073 rank_alpha: f64,
2074) -> Result<RrqrWithPermutation, FaerLinalgError> {
2075 if a.nrows() == 0 {
2076 return Err(FaerLinalgError::FactorizationFailed {
2077 context: "rrqr_with_permutation: input has zero rows",
2078 });
2079 }
2080 let faerview = FaerArrayView::new(a);
2081 let qr = faerview.as_ref().col_piv_qr();
2082 let r = qr.thin_R();
2083 let diag_len = r.nrows().min(r.ncols());
2084 let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
2085 let tol = rank_alpha
2086 * f64::EPSILON
2087 * (a.nrows().max(a.ncols()).max(1) as f64)
2088 * leading_diag.max(1.0);
2089 let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
2090 let (forward, _inverse) = qr.P().arrays();
2091 let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
2092 Ok(RrqrWithPermutation {
2093 rank,
2094 column_permutation,
2095 leading_diag_abs: leading_diag,
2096 rank_tol: tol,
2097 })
2098}
2099
2100pub struct RrqrFromGram {
2109 pub rank: usize,
2110 pub column_permutation: Vec<usize>,
2111 pub rank_tol: f64,
2112 pub leading_diag_abs: f64,
2117 pub verdict_margin: f64,
2120}
2121
2122pub fn rrqr_from_gram_with_permutation<S: Data<Elem = f64>>(
2158 gram: &ArrayBase<S, Ix2>,
2159 m_rows: usize,
2160 rank_alpha: f64,
2161) -> Result<RrqrFromGram, FaerLinalgError> {
2162 let p = gram.ncols();
2163 if p == 0 {
2164 return Ok(RrqrFromGram {
2165 rank: 0,
2166 column_permutation: Vec::new(),
2167 rank_tol: 0.0,
2168 leading_diag_abs: 0.0,
2169 verdict_margin: 0.0,
2170 });
2171 }
2172 if gram.nrows() != p {
2173 return Err(FaerLinalgError::FactorizationFailed {
2174 context: "rrqr_from_gram_with_permutation: Gram is not square",
2175 });
2176 }
2177 let (evals, evecs) = gram.eigh(Side::Lower)?;
2186 let mut f = Array2::<f64>::zeros((p, p));
2187 for k in 0..p {
2188 let scale = evals[k].max(0.0).sqrt();
2189 if scale == 0.0 {
2190 continue;
2191 }
2192 for i in 0..p {
2193 f[[k, i]] = scale * evecs[[i, k]];
2194 }
2195 }
2196 let faer_f = FaerArrayView::new(&f);
2200 let qr = faer_f.as_ref().col_piv_qr();
2201 let r = qr.thin_R();
2202 let diag_len = r.nrows().min(r.ncols());
2203 let pivots: Vec<f64> = (0..diag_len).map(|i| r[(i, i)].abs()).collect();
2204 let leading_diag = pivots.first().copied().unwrap_or(0.0);
2205 let (forward, _inverse) = qr.P().arrays();
2206 let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
2207 let tol = rank_alpha * f64::EPSILON * (m_rows.max(p).max(1) as f64) * leading_diag.max(1.0);
2211 let rank = pivots.iter().filter(|&&v| v > tol).count();
2212 let min_kept = pivots[..rank].iter().copied().fold(f64::INFINITY, f64::min);
2213 let max_dropped = pivots[rank..].iter().copied().fold(0.0f64, f64::max);
2214 let kept_margin = if rank == 0 {
2218 f64::INFINITY
2219 } else {
2220 min_kept / tol
2221 };
2222 let dropped_margin = if rank == diag_len {
2223 f64::INFINITY
2224 } else {
2225 tol / max_dropped.max(f64::MIN_POSITIVE)
2226 };
2227 let gram_precision_floor = f64::EPSILON.sqrt() * leading_diag.max(1.0);
2249 let kept_floor_margin = if rank == 0 {
2250 f64::INFINITY
2251 } else {
2252 min_kept / gram_precision_floor.max(f64::MIN_POSITIVE)
2253 };
2254 let verdict_margin = kept_margin.min(dropped_margin).min(kept_floor_margin);
2255 Ok(RrqrFromGram {
2256 rank,
2257 column_permutation,
2258 rank_tol: tol,
2259 leading_diag_abs: leading_diag,
2260 verdict_margin,
2261 })
2262}
2263
2264#[cfg(test)]
2265mod tests {
2266 use super::*;
2267 use ndarray::{array, s};
2268
2269 const JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST: f64 = 1.0e3;
2273
2274 #[test]
2275 fn rrqr_nullspace_basis_is_orthonormal_and_annihilates_transpose() {
2276 let a = array![[1.0, 0.0], [1.0, 0.0], [0.0, 2.0], [0.0, 0.0],];
2277 let (z, rank) =
2278 rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2279 assert_eq!(rank, 2);
2280 assert_eq!(z.nrows(), 4);
2281 assert_eq!(z.ncols(), 2);
2282
2283 let gram = z.t().dot(&z);
2284 let ident = Array2::<f64>::eye(z.ncols());
2285 let gram_err = (&gram - &ident)
2286 .iter()
2287 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2288 assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2289
2290 let residual = a.t().dot(&z);
2291 let resid_max = residual.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2292 assert!(resid_max < 1e-10, "A^T Z residual too large: {resid_max:e}");
2293 }
2294
2295 #[test]
2296 fn rrqr_with_permutation_attributes_redundant_column() {
2297 let a = array![
2301 [1.0, 0.0, 1.0],
2302 [1.0, 0.0, 1.0],
2303 [0.0, 2.0, 0.0],
2304 [0.0, 0.0, 0.0],
2305 ];
2306 let result =
2307 rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2308 assert_eq!(result.rank, 2);
2309 assert_eq!(result.column_permutation.len(), 3);
2310 let demoted = result.column_permutation[result.rank..].to_vec();
2311 assert!(
2312 demoted.contains(&2) || demoted.contains(&0),
2313 "demoted suffix should include one of the aliased columns (0 or 2), got {demoted:?}"
2314 );
2315 let mut sorted = result.column_permutation.clone();
2316 sorted.sort();
2317 assert_eq!(
2318 sorted,
2319 vec![0, 1, 2],
2320 "permutation must be a valid bijection on 0..n"
2321 );
2322 }
2323
2324 #[test]
2325 fn rrqr_with_permutation_full_rank_returns_identity_like_order() {
2326 let a = array![[1.0, 0.0], [0.0, 2.0], [0.0, 0.0]];
2327 let result =
2328 rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2329 assert_eq!(result.rank, 2);
2330 let mut sorted = result.column_permutation.clone();
2331 sorted.sort();
2332 assert_eq!(sorted, vec![0, 1]);
2333 }
2334
2335 #[test]
2336 fn rrqr_with_permutation_rejects_zero_rows() {
2337 let a = Array2::<f64>::zeros((0, 3));
2338 assert!(rrqr_with_permutation(&a, default_rrqr_rank_alpha()).is_err());
2339 }
2340
2341 #[test]
2342 fn rrqr_nullspace_basis_square_zero_matrix_is_finite_identity() {
2343 let a = Array2::<f64>::zeros((3, 3));
2346 let (z, rank) =
2347 rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2348 assert_eq!(rank, 0);
2349 assert_eq!(z.dim(), (3, 3));
2350 assert!(
2351 z.iter().all(|v| v.is_finite()),
2352 "square zero matrix produced a non-finite null basis: {z:?}"
2353 );
2354 let gram = z.t().dot(&z);
2355 let ident = Array2::<f64>::eye(3);
2356 let gram_err = (&gram - &ident)
2357 .iter()
2358 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2359 assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2360 }
2361
2362 #[test]
2363 fn rrqr_nullspace_basis_detectszero_rank_matrix() {
2364 let a = Array2::<f64>::zeros((5, 2));
2365 let (z, rank) =
2366 rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2367 assert_eq!(rank, 0);
2368 assert_eq!(z.dim(), (5, 5));
2369 let ident = Array2::<f64>::eye(5);
2370 let max_err = (&z.slice(s![.., ..5]).to_owned() - &ident)
2371 .iter()
2372 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2373 assert!(max_err < 1e-10, "zero matrix should yield identity basis");
2374 }
2375
2376 #[test]
2385 fn eigh_on_nan_matrix_rejects_non_finite_input() {
2386 let mat = array![
2387 [1.0, 0.0, 0.0, 0.0],
2388 [0.0, 2.0, 0.0, 0.0],
2389 [0.0, 0.0, 3.0, f64::NAN],
2390 [0.0, 0.0, f64::NAN, 4.0]
2391 ];
2392 let err = mat
2393 .eigh(Side::Lower)
2394 .expect_err("non-finite symmetric input must be rejected");
2395 assert!(matches!(
2396 err,
2397 FaerLinalgError::SelfAdjointEigenNonFiniteInput { .. }
2398 ));
2399 }
2400
2401 #[test]
2402 fn fast_ata_matches_full_gemm_above_threshold() {
2403 let n = 200;
2406 let p = 40;
2407 let a: Array2<f64> = Array2::from_shape_fn((n, p), |(i, j)| {
2408 ((i * 7 + j * 3) as f64).sin() + 0.1 * j as f64
2409 });
2410 let expected = a.t().dot(&a);
2411 let got = fast_ata(&a);
2412 let max_err = (&got - &expected)
2413 .iter()
2414 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2415 assert!(max_err < 1e-10, "fast_ata mismatch: {max_err:e}");
2416 for i in 0..p {
2418 for j in 0..p {
2419 assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2420 }
2421 }
2422 }
2423
2424 #[test]
2425 fn fast_xt_diag_x_matches_naive_above_threshold() {
2426 let n = 400;
2427 let p = 36;
2428 let x: Array2<f64> =
2429 Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.1).cos() + j as f64 * 0.05);
2430 let w: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64 * 0.03).sin());
2431 let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2433 let expected = x.t().dot(&wx);
2434 let got = fast_xt_diag_x(&x, &w);
2435 let max_err = (&got - &expected)
2436 .iter()
2437 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2438 assert!(max_err < 1e-9, "fast_xt_diag_x mismatch: {max_err:e}");
2439 for i in 0..p {
2440 for j in 0..p {
2441 assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2442 }
2443 }
2444 }
2445
2446 #[test]
2447 fn stream_weighted_crossprod_full_and_triangular_parity_with_negative_weights() {
2448 for &(n, p) in &[(900usize, 40usize), (8usize, 3usize)] {
2457 let x: Array2<f64> =
2458 Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.07).cos() + j as f64 * 0.013);
2459 let w: Array1<f64> =
2462 Array1::from_shape_fn(n, |i| (i as f64 * 0.11).sin() - 0.25 * (i % 3) as f64);
2463 assert!(
2464 w.iter().any(|&v| v < 0.0),
2465 "weight vector must contain negatives to test sign preservation"
2466 );
2467
2468 let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2470 let expected = x.t().dot(&wx);
2471
2472 let par = matmul_parallelism(p, p, n);
2473
2474 let mut full = Array2::<f64>::ones((p, p));
2476 stream_weighted_crossprod_into(
2477 &x,
2478 &w,
2479 &mut full,
2480 CrossprodStructure::Full,
2481 CrossprodAccum::Replace,
2482 par,
2483 );
2484
2485 let mut tri = Array2::<f64>::from_elem((p, p), -7.0);
2489 stream_weighted_crossprod_into(
2490 &x,
2491 &w,
2492 &mut tri,
2493 CrossprodStructure::SymmetricLower,
2494 CrossprodAccum::Replace,
2495 par,
2496 );
2497
2498 let full_err = (&full - &expected)
2499 .iter()
2500 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2501 let tri_err = (&tri - &expected)
2502 .iter()
2503 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2504 assert!(
2505 full_err < 1e-9,
2506 "full kernel mismatch (n={n}, p={p}): {full_err:e}"
2507 );
2508 assert!(
2509 tri_err < 1e-9,
2510 "triangular kernel mismatch (n={n}, p={p}): {tri_err:e}"
2511 );
2512
2513 for i in 0..p {
2516 for j in 0..p {
2517 assert!(
2518 (full[[i, j]] - tri[[i, j]]).abs() < 1e-12,
2519 "full vs triangular disagree at ({i},{j})"
2520 );
2521 assert!(
2522 (tri[[i, j]] - tri[[j, i]]).abs() < 1e-12,
2523 "triangular output not symmetric at ({i},{j})"
2524 );
2525 }
2526 }
2527
2528 let base = Array2::<f64>::from_elem((p, p), 1.5);
2531 let mut add_full = base.clone();
2532 stream_weighted_crossprod_into(
2533 &x,
2534 &w,
2535 &mut add_full,
2536 CrossprodStructure::Full,
2537 CrossprodAccum::Add,
2538 par,
2539 );
2540 let mut add_tri = base.clone();
2541 stream_weighted_crossprod_into(
2542 &x,
2543 &w,
2544 &mut add_tri,
2545 CrossprodStructure::SymmetricLower,
2546 CrossprodAccum::Add,
2547 par,
2548 );
2549 let expected_add = &base + &expected;
2550 let add_full_err = (&add_full - &expected_add)
2551 .iter()
2552 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2553 let add_tri_err = (&add_tri - &expected_add)
2554 .iter()
2555 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2556 assert!(
2557 add_full_err < 1e-9,
2558 "full Add mismatch (n={n}, p={p}): {add_full_err:e}"
2559 );
2560 assert!(
2561 add_tri_err < 1e-9,
2562 "triangular Add mismatch (n={n}, p={p}): {add_tri_err:e}"
2563 );
2564
2565 let returned = fast_xt_diag_x(&x, &w);
2568 let returned_err = (&returned - &full)
2569 .iter()
2570 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2571 assert!(
2572 returned_err < 1e-12,
2573 "return adapter vs stream-into adapter disagree (n={n}, p={p}): {returned_err:e}"
2574 );
2575 }
2576 }
2577
2578 #[test]
2579 fn eigh_succeeds_on_same_structure_without_nan() {
2580 let mat = array![[1.0, 0.5, 0.1], [0.5, 2.0, 0.3], [0.1, 0.3, 1.5]];
2582 let (evals, _) = mat
2583 .eigh(Side::Lower)
2584 .expect("eigh should succeed on a well-conditioned finite matrix");
2585 assert!(
2586 evals.iter().all(|&v| v.is_finite()),
2587 "all eigenvalues should be finite"
2588 );
2589 }
2590
2591 #[test]
2600 fn gram_rrqr_flags_low_margin_on_exact_collinearity_so_caller_falls_back() {
2601 let n = 48usize;
2604 let x: Vec<f64> = (0..n)
2605 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
2606 .collect();
2607 let mut a = Array2::<f64>::zeros((n, 4));
2608 for i in 0..n {
2609 a[[i, 0]] = 1.0;
2610 a[[i, 1]] = x[i];
2611 a[[i, 2]] = x[i];
2612 a[[i, 3]] = x[i] * x[i];
2613 }
2614 let alpha = default_rrqr_rank_alpha();
2615
2616 let tall = rrqr_with_permutation(&a, alpha).expect("tall RRQR should succeed");
2619 assert_eq!(tall.rank, 3, "tall RRQR must demote the exact alias");
2620
2621 let unit = Array1::<f64>::ones(n);
2632 let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2633 let gram_rrqr =
2634 rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2635 let ok = gram_rrqr.rank == 3
2636 || gram_rrqr.verdict_margin < JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST;
2637 assert!(
2638 ok,
2639 "gam#933: Gram RRQR must either find correct rank=3 OR signal low margin \
2640 (< {:.0e}) to force the tall fallback; got rank={} margin={:.3e}",
2641 JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2642 gram_rrqr.rank,
2643 gram_rrqr.verdict_margin,
2644 );
2645 }
2646
2647 #[test]
2652 fn gram_rrqr_keeps_high_margin_on_full_rank_design() {
2653 let n = 200usize;
2654 let p = 5usize;
2655 let mut a = Array2::<f64>::zeros((n, p));
2656 for i in 0..n {
2658 let t = (i as f64) / (n as f64 - 1.0);
2659 a[[i, 0]] = 1.0;
2660 a[[i, 1]] = t;
2661 a[[i, 2]] = t * t;
2662 a[[i, 3]] = t * t * t;
2663 a[[i, 4]] = (t * 6.0).sin();
2664 }
2665 let alpha = default_rrqr_rank_alpha();
2666 let unit = Array1::<f64>::ones(n);
2667 let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2668 let gram_rrqr =
2669 rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2670 assert_eq!(gram_rrqr.rank, p, "full-rank design must keep all columns");
2671 assert!(
2672 gram_rrqr.verdict_margin >= JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2673 "full-rank design must keep a high margin (fast Gram path); got {:.3e}",
2674 gram_rrqr.verdict_margin,
2675 );
2676 }
2677
2678 fn max_abs_diff(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
2681 assert_eq!(a.dim(), b.dim(), "shape mismatch in max_abs_diff");
2682 a.iter().zip(b.iter()).fold(0.0_f64, |acc, (&x, &y)| acc.max((x - y).abs()))
2683 }
2684
2685 fn max_abs_diff_1d(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
2686 assert_eq!(a.len(), b.len(), "len mismatch in max_abs_diff_1d");
2687 a.iter().zip(b.iter()).fold(0.0_f64, |acc, (&x, &y)| acc.max((x - y).abs()))
2688 }
2689
2690 #[test]
2692 fn fast_ab_small_matches_ndarray_dot() {
2693 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
2694 let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
2695 let got = fast_ab(&a, &b);
2696 let want = a.dot(&b);
2697 assert!(max_abs_diff(&got, &want) < 1e-12, "fast_ab small mismatch");
2698 assert_eq!(got.dim(), (2, 2));
2699 }
2700
2701 #[test]
2703 fn fast_ab_large_matches_ndarray_dot() {
2704 let n = 50usize;
2705 let p = 40usize;
2706 let q = 35usize;
2707 let mut a = Array2::<f64>::zeros((n, p));
2708 let mut b = Array2::<f64>::zeros((p, q));
2709 let mut state = 0xDEAD_BEEF_1234_5678u64;
2710 let next = |s: &mut u64| -> f64 {
2711 *s ^= *s << 13;
2712 *s ^= *s >> 7;
2713 *s ^= *s << 17;
2714 ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2715 };
2716 for v in a.iter_mut() { *v = next(&mut state); }
2717 for v in b.iter_mut() { *v = next(&mut state); }
2718 let got = fast_ab(&a, &b);
2719 let want = a.dot(&b);
2720 assert!(max_abs_diff(&got, &want) < 1e-9, "fast_ab large mismatch");
2721 }
2722
2723 #[test]
2725 fn fast_atb_small_matches_ndarray_dot() {
2726 let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2727 let b = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
2728 let got = fast_atb(&a, &b);
2729 let want = a.t().dot(&b);
2730 assert!(max_abs_diff(&got, &want) < 1e-12, "fast_atb small mismatch");
2731 assert_eq!(got.dim(), (2, 3));
2732 }
2733
2734 #[test]
2736 fn fast_atb_large_matches_ndarray_dot() {
2737 let n = 50usize;
2738 let p = 40usize;
2739 let q = 35usize;
2740 let mut a = Array2::<f64>::zeros((n, p));
2741 let mut b = Array2::<f64>::zeros((n, q));
2742 let mut state = 0xCAFE_BABE_9876_5432u64;
2743 let next = |s: &mut u64| -> f64 {
2744 *s ^= *s << 13;
2745 *s ^= *s >> 7;
2746 *s ^= *s << 17;
2747 ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2748 };
2749 for v in a.iter_mut() { *v = next(&mut state); }
2750 for v in b.iter_mut() { *v = next(&mut state); }
2751 let got = fast_atb(&a, &b);
2752 let want = a.t().dot(&b);
2753 assert!(max_abs_diff(&got, &want) < 1e-9, "fast_atb large mismatch");
2754 }
2755
2756 #[test]
2758 fn fast_abt_small_matches_ndarray_dot() {
2759 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
2760 let b = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
2761 let got = fast_abt(&a, &b);
2762 let want = a.dot(&b.t());
2763 assert!(max_abs_diff(&got, &want) < 1e-12, "fast_abt small mismatch");
2764 assert_eq!(got.dim(), (2, 2));
2765 }
2766
2767 #[test]
2769 fn fast_av_small_matches_ndarray_dot() {
2770 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
2771 let v = array![1.0, -1.0, 2.0];
2772 let got = fast_av(&a, &v);
2773 let want = a.dot(&v);
2774 assert!(max_abs_diff_1d(&got, &want) < 1e-12, "fast_av small mismatch");
2775 assert!((got[0] - 5.0).abs() < 1e-12, "fast_av[0] should be 5");
2777 assert!((got[1] - 11.0).abs() < 1e-12, "fast_av[1] should be 11");
2779 }
2780
2781 #[test]
2783 fn fast_av_large_matches_ndarray_dot() {
2784 let n = 50usize;
2785 let p = 40usize;
2786 let mut a = Array2::<f64>::zeros((n, p));
2787 let mut v = Array1::<f64>::zeros(p);
2788 let mut state = 0xFEED_FACE_ABCD_EF01u64;
2789 let next = |s: &mut u64| -> f64 {
2790 *s ^= *s << 13;
2791 *s ^= *s >> 7;
2792 *s ^= *s << 17;
2793 ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2794 };
2795 for v in a.iter_mut() { *v = next(&mut state); }
2796 for x in v.iter_mut() { *x = next(&mut state); }
2797 let got = fast_av(&a, &v);
2798 let want = a.dot(&v);
2799 assert!(max_abs_diff_1d(&got, &want) < 1e-9, "fast_av large mismatch");
2800 }
2801
2802 #[test]
2804 fn fast_atv_small_matches_ndarray_dot() {
2805 let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2806 let v = array![1.0, 0.0, -1.0];
2807 let got = fast_atv(&a, &v);
2808 let want = a.t().dot(&v);
2809 assert!(max_abs_diff_1d(&got, &want) < 1e-12, "fast_atv small mismatch");
2811 assert!((got[0] - (-4.0)).abs() < 1e-12, "fast_atv[0]");
2812 assert!((got[1] - (-4.0)).abs() < 1e-12, "fast_atv[1]");
2813 }
2814
2815 #[test]
2817 fn fast_atv_large_matches_ndarray_dot() {
2818 let n = 50usize;
2819 let p = 40usize;
2820 let mut a = Array2::<f64>::zeros((n, p));
2821 let mut v = Array1::<f64>::zeros(n);
2822 let mut state = 0x1234_ABCD_5678_EF90u64;
2823 let next = |s: &mut u64| -> f64 {
2824 *s ^= *s << 13;
2825 *s ^= *s >> 7;
2826 *s ^= *s << 17;
2827 ((*s >> 11) as f64 / ((1u64 << 53) as f64)) - 0.5
2828 };
2829 for x in a.iter_mut() { *x = next(&mut state); }
2830 for x in v.iter_mut() { *x = next(&mut state); }
2831 let got = fast_atv(&a, &v);
2832 let want = a.t().dot(&v);
2833 assert!(max_abs_diff_1d(&got, &want) < 1e-9, "fast_atv large mismatch");
2834 }
2835
2836 #[test]
2839 fn fast_xt_diag_y_small_matches_manual() {
2840 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
2841 let d = array![2.0, 0.5, 1.0];
2842 let y = array![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
2843 let got = fast_xt_diag_y(&x, &d, &y);
2844 let diag_y = {
2846 let mut dy = Array2::<f64>::zeros(y.dim());
2847 for i in 0..3 {
2848 for j in 0..3 {
2849 dy[[i, j]] = d[i] * y[[i, j]];
2850 }
2851 }
2852 dy
2853 };
2854 let want = x.t().dot(&diag_y);
2855 assert!(max_abs_diff(&got, &want) < 1e-12, "fast_xt_diag_y small mismatch");
2856 assert_eq!(got.dim(), (2, 3));
2857 }
2858
2859 #[inline]
2866 fn two_prod(a: f64, b: f64) -> (f64, f64) {
2867 let p = a * b;
2868 let e = a.mul_add(b, -p);
2869 (p, e)
2870 }
2871
2872 #[inline]
2873 fn two_sum(a: f64, b: f64) -> (f64, f64) {
2874 let s = a + b;
2875 let bb = s - a;
2876 let e = (a - (s - bb)) + (b - bb);
2877 (s, e)
2878 }
2879
2880 fn grow_expansion(e: &mut Vec<f64>, mut q: f64) {
2882 for h in e.iter_mut() {
2883 let (s, err) = two_sum(*h, q);
2884 *h = err;
2885 q = s;
2886 }
2887 if q != 0.0 {
2888 e.push(q);
2889 }
2890 }
2891
2892 fn exact_dot(a: &[f64], b: &[f64]) -> f64 {
2897 let mut e: Vec<f64> = Vec::new();
2898 for (&x, &y) in a.iter().zip(b.iter()) {
2899 let (p, ep) = two_prod(x, y);
2900 grow_expansion(&mut e, p);
2901 grow_expansion(&mut e, ep);
2902 }
2903 e.iter().fold(0.0f64, |acc, &c| acc + c)
2906 }
2907
2908 fn dd_dot(a: &[f64], b: &[f64]) -> f64 {
2912 let (mut s, mut c) = (0.0f64, 0.0f64);
2913 for (&x, &y) in a.iter().zip(b.iter()) {
2914 let (p, ep) = two_prod(x, y);
2915 let (s2, es) = two_sum(s, p);
2916 s = s2;
2917 c += ep + es;
2918 }
2919 s + c
2920 }
2921
2922 fn naive_dot(a: &[f64], b: &[f64]) -> f64 {
2923 let mut acc = 0.0f64;
2924 for (&x, &y) in a.iter().zip(b.iter()) {
2925 acc += x * y;
2926 }
2927 acc
2928 }
2929
2930 fn ill_conditioned_pair(len: usize, seed: u64) -> (Vec<f64>, Vec<f64>) {
2933 let mut s = seed | 1;
2934 let mut next = || {
2935 s ^= s << 13;
2936 s ^= s >> 7;
2937 s ^= s << 17;
2938 (s >> 11) as f64 / ((1u64 << 53) as f64) - 0.5
2939 };
2940 let mut a = Vec::with_capacity(len);
2941 let mut b = Vec::with_capacity(len);
2942 for i in 0..len {
2943 let scale = 10f64.powi((i % 17) as i32 - 8);
2945 let sign = if i % 2 == 0 { 1.0 } else { -1.0 };
2946 a.push(sign * next() * scale);
2947 b.push(next() * scale);
2948 }
2949 (a, b)
2950 }
2951
2952 #[test]
2955 fn fma_dot_beats_naive_accuracy() {
2956 let mut fma_total = 0.0f64;
2957 let mut naive_total = 0.0f64;
2958 let mut strict_wins = 0;
2959 for seed in 0..64u64 {
2960 let len = 200 + (seed as usize % 57);
2961 let (a, b) = ill_conditioned_pair(len, 0x9E37_79B9 ^ seed.wrapping_mul(2654435761));
2962 let truth = exact_dot(&a, &b);
2963 let fe = (super::fma_dot(&a, &b) - truth).abs();
2964 let ne = (naive_dot(&a, &b) - truth).abs();
2965 let floor = 8.0 * f64::EPSILON * truth.abs();
2969 assert!(
2970 fe <= ne * (1.0 + 1e-6) + floor,
2971 "fma_dot worse than naive: seed={seed} fma_err={fe:.3e} naive_err={ne:.3e}",
2972 );
2973 if fe < ne {
2974 strict_wins += 1;
2975 }
2976 fma_total += fe;
2977 naive_total += ne;
2978 }
2979 assert!(
2980 fma_total < naive_total,
2981 "fma_dot aggregate error {fma_total:.3e} not below naive {naive_total:.3e}",
2982 );
2983 assert!(
2984 strict_wins >= 40,
2985 "expected fma_dot to strictly win the majority; only {strict_wins}/64",
2986 );
2987 }
2988
2989 #[test]
2992 fn fast_atv_blocked_beats_naive_accuracy() {
2993 let n = 200_003usize;
2994 let p = 3usize;
2995 let mut s = 0xD1B5_4A32u64;
2996 let mut next = || {
2997 s ^= s << 13;
2998 s ^= s >> 7;
2999 s ^= s << 17;
3000 (s >> 11) as f64 / ((1u64 << 53) as f64) - 0.5
3001 };
3002 let mut x = Array2::<f64>::zeros((n, p));
3003 let mut v = Array1::<f64>::zeros(n);
3004 for i in 0..n {
3005 let scale = 10f64.powi((i % 17) as i32 - 8);
3006 v[i] = if i % 2 == 0 { scale } else { -scale } * next();
3007 for j in 0..p {
3008 x[[i, j]] = next() * scale;
3009 }
3010 }
3011 let got = fast_atv(&x, &v);
3012 for j in 0..p {
3014 let col: Vec<f64> = (0..n).map(|i| x[[i, j]]).collect();
3015 let vv: Vec<f64> = v.to_vec();
3016 let truth = dd_dot(&col, &vv);
3017 let naive = naive_dot(&col, &vv);
3018 let ge = (got[j] - truth).abs();
3019 let ne = (naive - truth).abs();
3020 assert!(
3021 ge <= ne + f64::MIN_POSITIVE,
3022 "col {j}: blocked err {ge:.3e} exceeds naive {ne:.3e}",
3023 );
3024 }
3025 }
3026
3027 #[test]
3030 fn fast_av_strided_input_matches_ndarray() {
3031 let mut base = Array2::<f64>::zeros((40, 60));
3032 let mut s = 0x0BAD_F00Du64;
3033 let mut next = || {
3034 s ^= s << 13;
3035 s ^= s >> 7;
3036 s ^= s << 17;
3037 (s >> 11) as f64 / ((1u64 << 53) as f64) - 0.5
3038 };
3039 for x in base.iter_mut() {
3040 *x = next();
3041 }
3042 let a = base.t();
3044 let mut v = Array1::<f64>::zeros(40);
3045 for x in v.iter_mut() {
3046 *x = next();
3047 }
3048 let got = fast_av(&a, &v);
3049 let want = a.dot(&v);
3050 assert!(
3051 max_abs_diff_1d(&got, &want) < 1e-11,
3052 "strided fast_av mismatch (fallback path)",
3053 );
3054 }
3055}