1use dyn_stack::{MemBuffer, MemStack};
2use faer::diag::{Diag, DiagRef};
3use faer::linalg::solvers::{self, Solve};
4pub use faer::linalg::solvers::{
5 Lblt as FaerLblt, Ldlt as FaerLdlt, Llt as FaerLlt, Solve as FaerSolve,
6};
7use faer::linalg::svd::{self, ComputeSvdVectors};
8use faer::prelude::ReborrowMut;
9use faer::{Conj, Mat, MatMut, MatRef, Par, Side, Unbind, get_global_parallelism};
10use ndarray::{Array1, Array2, ArrayBase, ArrayViewMut1, Data, Ix1, Ix2};
11use std::marker::PhantomData;
12use std::panic::{AssertUnwindSafe, catch_unwind};
13use thiserror::Error;
14
15const RRQR_RANK_ALPHA: f64 = 100.0;
16
17thread_local! {
18 static NESTED_PARALLEL_DEPTH: std::cell::Cell<usize> = const { std::cell::Cell::new(0) };
19}
20
21struct NestedParallelGuard;
22
23impl NestedParallelGuard {
24 #[inline]
25 fn enter() -> Self {
26 NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_add(1)));
27 Self
28 }
29}
30
31impl Drop for NestedParallelGuard {
32 #[inline]
33 fn drop(&mut self) {
34 NESTED_PARALLEL_DEPTH.with(|depth| depth.set(depth.get().saturating_sub(1)));
35 }
36}
37
38#[inline]
48pub fn with_nested_parallel<T>(body: impl FnOnce() -> T) -> T {
49 let guard = NestedParallelGuard::enter();
50 let out = body();
51 drop(guard);
52 out
53}
54
55#[inline]
58pub fn in_nested_parallel_region() -> bool {
59 NESTED_PARALLEL_DEPTH.with(|depth| depth.get() > 0)
60}
61
62#[inline]
70pub fn effective_global_parallelism() -> Par {
71 if in_nested_parallel_region() {
72 Par::Seq
73 } else {
74 get_global_parallelism()
75 }
76}
77
78#[derive(Debug, Error)]
79pub enum FaerLinalgError {
80 #[error("Factorization failed in {context}")]
81 FactorizationFailed { context: &'static str },
82 #[error("SVD failed to converge in {context}")]
83 SvdNoConvergence { context: &'static str },
84 #[error("Self-adjoint eigendecomposition input contains non-finite values in {context}")]
85 SelfAdjointEigenNonFiniteInput { context: &'static str },
86 #[error("Self-adjoint eigendecomposition failed: {0:?}")]
87 SelfAdjointEigen(solvers::EvdError),
88 #[error("Cholesky factorization failed: {0:?}")]
89 Cholesky(solvers::LltError),
90 #[error("LDLT factorization failed: {0:?}")]
91 Ldlt(solvers::LdltError),
92}
93
94pub enum FaerSymmetricFactor {
95 Llt(FaerLlt<f64>),
96 Ldlt(FaerLdlt<f64>),
97 Lblt(FaerLblt<f64>),
98}
99
100#[inline]
101pub fn cholesky_factor_logdet(factor: MatRef<'_, f64>) -> f64 {
102 2.0 * diagonal_log_sum(factor.diagonal())
103}
104
105#[inline]
106fn diagonal_log_sum(diagonal: DiagRef<'_, f64>) -> f64 {
107 diagonal
108 .column_vector()
109 .iter()
110 .map(|&x| x.ln())
111 .sum::<f64>()
112}
113
114impl FaerSymmetricFactor {
115 #[inline]
117 pub fn n(&self) -> usize {
118 use faer::linalg::solvers::ShapeCore;
119 match self {
120 FaerSymmetricFactor::Llt(f) => f.nrows(),
121 FaerSymmetricFactor::Ldlt(f) => f.nrows(),
122 FaerSymmetricFactor::Lblt(f) => f.nrows(),
123 }
124 }
125
126 #[inline]
127 pub fn solve(&self, rhs: MatRef<'_, f64>) -> Mat<f64> {
128 match self {
129 FaerSymmetricFactor::Llt(f) => f.solve(rhs),
130 FaerSymmetricFactor::Ldlt(f) => f.solve(rhs),
131 FaerSymmetricFactor::Lblt(f) => f.solve(rhs),
132 }
133 }
134
135 #[inline]
136 pub fn solve_in_place(&self, rhs: MatMut<'_, f64>) {
137 match self {
138 FaerSymmetricFactor::Llt(f) => f.solve_in_place(rhs),
139 FaerSymmetricFactor::Ldlt(f) => f.solve_in_place(rhs),
140 FaerSymmetricFactor::Lblt(f) => f.solve_in_place(rhs),
141 }
142 }
143}
144
145impl crate::matrix::FactorizedSystem for FaerSymmetricFactor {
146 fn solve(&self, rhs: &Array1<f64>) -> Result<Array1<f64>, String> {
147 let mut out = rhs.clone();
148 let mut out_mat = array1_to_col_matmut(&mut out);
149 self.solve_in_place(out_mat.as_mut());
150 if !out.iter().all(|v| v.is_finite()) {
151 return Err("symmetric factor solve produced non-finite values".to_string());
152 }
153 Ok(out)
154 }
155
156 fn solvemulti(&self, rhs: &Array2<f64>) -> Result<Array2<f64>, String> {
157 let mut out = Array2::<f64>::zeros(rhs.raw_dim());
158 for j in 0..rhs.ncols() {
159 for i in 0..rhs.nrows() {
160 out[[i, j]] = rhs[[i, j]];
161 }
162 }
163 let mut out_mat = array2_to_matmut(&mut out);
164 self.solve_in_place(out_mat.as_mut());
165 if !out.iter().all(|v| v.is_finite()) {
166 return Err("symmetric factor multi-solve produced non-finite values".to_string());
167 }
168 Ok(out)
169 }
170
171 fn logdet(&self) -> f64 {
172 match self {
173 FaerSymmetricFactor::Llt(f) => cholesky_factor_logdet(f.L()),
174 FaerSymmetricFactor::Ldlt(f) => diagonal_log_sum(f.D()),
175 FaerSymmetricFactor::Lblt(..) => {
176 f64::NAN
180 }
181 }
182 }
183}
184
185#[inline]
187pub fn factorize_symmetricwith_fallback(
188 matrix: MatRef<'_, f64>,
189 side: Side,
190) -> Result<FaerSymmetricFactor, FaerLinalgError> {
191 if let Ok(llt) = FaerLlt::new(matrix, side) {
192 return Ok(FaerSymmetricFactor::Llt(llt));
193 }
194 let ldlt_err = match FaerLdlt::new(matrix, side) {
195 Ok(ldlt) => return Ok(FaerSymmetricFactor::Ldlt(ldlt)),
196 Err(err) => err,
197 };
198 let lblt = catch_unwind(AssertUnwindSafe(|| FaerLblt::new(matrix, side)))
199 .map_err(|_| FaerLinalgError::Ldlt(ldlt_err))?;
200 Ok(FaerSymmetricFactor::Lblt(lblt))
201}
202
203#[inline]
204const fn should_use_faer_matmul(m: usize, n: usize, k: usize) -> bool {
205 const MIN_DIM: usize = 32;
209 const MIN_FLOP_SCALE: usize = 64 * 64;
210 (m >= MIN_DIM || n >= MIN_DIM || k >= MIN_DIM)
211 && m.saturating_mul(n).saturating_mul(k) >= MIN_FLOP_SCALE
212}
213
214#[inline]
215pub fn matmul_parallelism(m: usize, n: usize, k: usize) -> Par {
216 const PAR_MIN_FLOP_SCALE: usize = 2_000_000;
220 const PAR_MIN_LONG_DIM: usize = 256;
221 let flop_scale = m.saturating_mul(n).saturating_mul(k);
222 let long_dim = m.max(n).max(k);
223 if flop_scale >= PAR_MIN_FLOP_SCALE && long_dim >= PAR_MIN_LONG_DIM {
224 effective_global_parallelism()
228 } else {
229 Par::Seq
230 }
231}
232
233#[inline]
234pub fn array2_to_matmut(array: &mut Array2<f64>) -> MatMut<'_, f64> {
235 let (rows, cols) = array.dim();
236 let strides = array.strides();
237
238 let s0 = strides[0];
245 let s1 = strides[1];
246
247 unsafe { MatMut::from_raw_parts_mut(array.as_mut_ptr(), rows, cols, s0, s1) }
251}
252
253#[inline]
254pub fn array1_to_col_matmut(array: &mut Array1<f64>) -> MatMut<'_, f64> {
255 let len = array.len();
256 let stride = array.strides()[0];
257 unsafe {
261 MatMut::from_raw_parts_mut(
262 array.as_mut_ptr(),
263 len,
264 1,
265 stride,
266 0, )
268 }
269}
270
271#[inline]
278pub fn fast_ata<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>) -> Array2<f64> {
279 let p = a.ncols();
280 let mut out = Array2::<f64>::zeros((p, p));
281 fast_ata_into(a, &mut out);
282 out
283}
284
285#[inline]
288pub fn fast_ata_into<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>, out: &mut Array2<f64>) {
289 use faer::Accum;
290 use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
291
292 let (n, p) = a.dim();
293 assert_eq!(out.nrows(), p, "output rows must match p");
294 assert_eq!(out.ncols(), p, "output cols must match p");
295
296 if !should_use_faer_matmul(p, p, n) {
297 out.assign(&a.t().dot(a));
298 return;
299 }
300
301 let mut outview = array2_to_matmut(out);
302
303 let aview = FaerArrayView::new(a);
304 let a_ref = aview.as_ref();
305 let a_t = a_ref.transpose();
306 let par = matmul_parallelism(p, p, n);
307 tri_matmul(
308 outview.as_mut(),
309 BlockStructure::TriangularLower,
310 Accum::Replace,
311 a_t,
312 BlockStructure::Rectangular,
313 a_ref,
314 BlockStructure::Rectangular,
315 1.0,
316 par,
317 );
318 for i in 0..p {
320 for j in (i + 1)..p {
321 out[[i, j]] = out[[j, i]];
322 }
323 }
324}
325
326#[inline]
330pub fn fast_atb<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
331 a: &ArrayBase<S1, Ix2>,
332 b: &ArrayBase<S2, Ix2>,
333) -> Array2<f64> {
334 if let Some(out) =
335 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atb(a.view(), b.view()))
336 {
337 return out;
338 }
339 let (n_a, p) = a.dim();
340 let q = b.ncols();
341 fast_atb_with_parallelism(a, b, matmul_parallelism(p, q, n_a))
342}
343
344#[inline]
347pub fn fast_atb_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
348 a: &ArrayBase<S1, Ix2>,
349 b: &ArrayBase<S2, Ix2>,
350 par: Par,
351) -> Array2<f64> {
352 use faer::linalg::matmul::matmul;
353 use faer::{Accum, Mat};
354
355 let (n_a, p) = a.dim();
356 let (n_b, q) = b.dim();
357 assert_eq!(n_a, n_b, "A and B must have same number of rows");
358
359 if !should_use_faer_matmul(p, q, n_a) {
361 return a.t().dot(b);
362 }
363
364 let mut result = Mat::<f64>::zeros(p, q);
365
366 let aview = FaerArrayView::new(a);
367 let bview = FaerArrayView::new(b);
368 let a_ref = aview.as_ref();
369 let b_ref = bview.as_ref();
370
371 matmul(
373 result.as_mut(),
374 Accum::Replace,
375 a_ref.transpose(),
376 b_ref,
377 1.0,
378 par,
379 );
380
381 mat_to_array(result.as_ref())
382}
383
384#[inline]
387pub fn fast_abt<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
388 a: &ArrayBase<S1, Ix2>,
389 b: &ArrayBase<S2, Ix2>,
390) -> Array2<f64> {
391 use faer::linalg::matmul::matmul;
392 use faer::{Accum, Mat};
393
394 let (m, k_a) = a.dim();
395 let (n, k_b) = b.dim();
396 assert_eq!(
397 k_a, k_b,
398 "A and B must have same number of columns for A·Bᵀ"
399 );
400
401 if !should_use_faer_matmul(m, n, k_a) {
402 return a.dot(&b.t());
403 }
404
405 let mut result = Mat::<f64>::zeros(m, n);
406 let aview = FaerArrayView::new(a);
407 let bview = FaerArrayView::new(b);
408 let par = matmul_parallelism(m, n, k_a);
409 matmul(
410 result.as_mut(),
411 Accum::Replace,
412 aview.as_ref(),
413 bview.as_ref().transpose(),
414 1.0,
415 par,
416 );
417 mat_to_array(result.as_ref())
418}
419
420#[inline]
424pub fn fast_ab<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
425 a: &ArrayBase<S1, Ix2>,
426 b: &ArrayBase<S2, Ix2>,
427) -> Array2<f64> {
428 if let Some(out) =
429 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_ab(a.view(), b.view()))
430 {
431 return out;
432 }
433 let n = a.nrows();
434 let q = b.ncols();
435 let mut out = Array2::<f64>::zeros((n, q));
436 fast_ab_into(a, b, &mut out);
437 out
438}
439
440#[inline]
443pub fn fast_av<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
444 a: &ArrayBase<S1, Ix2>,
445 v: &ArrayBase<S2, Ix1>,
446) -> Array1<f64> {
447 if let Some(out) =
448 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_av(a.view(), v.view()))
449 {
450 return out;
451 }
452 fast_av_impl(a, v)
453}
454
455#[inline]
456fn fast_av_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
457 a: &ArrayBase<S1, Ix2>,
458 v: &ArrayBase<S2, Ix1>,
459) -> Array1<f64> {
460 use faer::linalg::matmul::matmul;
461 use faer::{Accum, Mat};
462
463 let (n, p) = a.dim();
464 assert_eq!(p, v.len(), "A cols must match v length");
465
466 if !should_use_faer_matmul(n, 1, p) {
467 return a.dot(v);
468 }
469
470 let mut result = Mat::<f64>::zeros(n, 1);
471
472 let aview = FaerArrayView::new(a);
473 let vview = FaerColView::new(v);
474 let a_ref = aview.as_ref();
475 let v_ref = vview.as_ref();
476
477 let par = matmul_parallelism(n, 1, p);
478 matmul(result.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
479
480 let mut out = Array1::<f64>::zeros(n);
481 for i in 0..n {
482 out[i] = result[(i, 0)];
483 }
484 out
485}
486
487#[inline]
490pub fn fast_av_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
491 a: &ArrayBase<S1, Ix2>,
492 v: &ArrayBase<S2, Ix1>,
493 out: &mut Array1<f64>,
494) {
495 fast_av_into_impl(a, v, out);
496}
497
498#[inline]
499fn fast_av_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
500 a: &ArrayBase<S1, Ix2>,
501 v: &ArrayBase<S2, Ix1>,
502 out: &mut Array1<f64>,
503) {
504 use faer::Accum;
505 use faer::linalg::matmul::matmul;
506
507 let (n, p) = a.dim();
508 assert_eq!(v.len(), p, "vector length must match A cols");
509 assert_eq!(out.len(), n, "output length must match A rows");
510
511 if !should_use_faer_matmul(n, 1, p) {
512 out.assign(&a.dot(v));
513 return;
514 }
515
516 let mut outview = array1_to_col_matmut(out);
517
518 let aview = FaerArrayView::new(a);
519 let vview = FaerColView::new(v);
520 let a_ref = aview.as_ref();
521 let v_ref = vview.as_ref();
522 let par = matmul_parallelism(n, 1, p);
523 matmul(outview.as_mut(), Accum::Replace, a_ref, v_ref, 1.0, par);
524}
525
526#[inline]
533pub fn fast_av_view_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
534 a: &ArrayBase<S1, Ix2>,
535 v: &ArrayBase<S2, Ix1>,
536 out: ArrayViewMut1<'_, f64>,
537) {
538 fast_av_view_into_impl(a, v, out);
539}
540
541#[inline]
542fn fast_av_view_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
543 a: &ArrayBase<S1, Ix2>,
544 v: &ArrayBase<S2, Ix1>,
545 mut out: ArrayViewMut1<'_, f64>,
546) {
547 use faer::Accum;
548 use faer::linalg::matmul::matmul;
549
550 let (n, p) = a.dim();
551 assert_eq!(v.len(), p, "vector length must match A cols");
552 assert_eq!(out.len(), n, "output length must match A rows");
553
554 if !should_use_faer_matmul(n, 1, p) {
555 let prod = a.dot(v);
556 out.assign(&prod);
557 return;
558 }
559
560 let len = out.len();
561 let stride = out.strides()[0];
562 let outview = unsafe {
566 MatMut::from_raw_parts_mut(
567 out.as_mut_ptr(),
568 len,
569 1,
570 stride,
571 0, )
573 };
574
575 let aview = FaerArrayView::new(a);
576 let vview = FaerColView::new(v);
577 let a_ref = aview.as_ref();
578 let v_ref = vview.as_ref();
579 let par = matmul_parallelism(n, 1, p);
580 matmul(outview, Accum::Replace, a_ref, v_ref, 1.0, par);
581}
582
583#[inline]
586pub fn fast_atv<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
587 a: &ArrayBase<S1, Ix2>,
588 v: &ArrayBase<S2, Ix1>,
589) -> Array1<f64> {
590 if let Some(out) =
591 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_atv(a.view(), v.view()))
592 {
593 return out;
594 }
595 fast_atv_impl(a, v)
596}
597
598#[inline]
599fn fast_atv_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
600 a: &ArrayBase<S1, Ix2>,
601 v: &ArrayBase<S2, Ix1>,
602) -> Array1<f64> {
603 use faer::Accum;
604 use faer::linalg::matmul::matmul;
605
606 let (n, p) = a.dim();
607 assert_eq!(n, v.len(), "A rows must match v length");
608
609 if !should_use_faer_matmul(p, 1, n) {
611 return a.t().dot(v);
612 }
613
614 let mut out = Array1::<f64>::zeros(p);
615 let mut outview = array1_to_col_matmut(&mut out);
616
617 let aview = FaerArrayView::new(a);
618 let vview = FaerColView::new(v);
619 let a_ref = aview.as_ref();
620 let v_ref = vview.as_ref();
621
622 let par = matmul_parallelism(p, 1, n);
624 matmul(
625 outview.as_mut(),
626 Accum::Replace,
627 a_ref.transpose(),
628 v_ref,
629 1.0,
630 par,
631 );
632
633 out
634}
635
636#[inline]
639pub fn fast_atv_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
640 a: &ArrayBase<S1, Ix2>,
641 v: &ArrayBase<S2, Ix1>,
642 out: &mut Array1<f64>,
643) {
644 fast_atv_into_impl(a, v, out);
645}
646
647#[inline]
648fn fast_atv_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
649 a: &ArrayBase<S1, Ix2>,
650 v: &ArrayBase<S2, Ix1>,
651 out: &mut Array1<f64>,
652) {
653 use faer::Accum;
654 use faer::linalg::matmul::matmul;
655
656 let (n, p) = a.dim();
657 assert_eq!(v.len(), n, "vector length must match A rows");
658 assert_eq!(out.len(), p, "output length must match A cols");
659
660 if !should_use_faer_matmul(p, 1, n) {
661 out.assign(&a.t().dot(v));
662 return;
663 }
664
665 let mut outview = array1_to_col_matmut(out);
666
667 let aview = FaerArrayView::new(a);
668 let vview = FaerColView::new(v);
669 let a_ref = aview.as_ref();
670 let v_ref = vview.as_ref();
671 let par = matmul_parallelism(p, 1, n);
672 matmul(
673 outview.as_mut(),
674 Accum::Replace,
675 a_ref.transpose(),
676 v_ref,
677 1.0,
678 par,
679 );
680}
681
682#[inline]
684pub fn fast_xt_diag_x<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
685 x: &ArrayBase<S1, Ix2>,
686 w: &ArrayBase<S2, Ix1>,
687) -> Array2<f64> {
688 assert_eq!(
689 x.nrows(),
690 w.len(),
691 "fast_xt_diag_x row/weight length mismatch"
692 );
693 if let Some(out) =
694 crate::gpu_hook::gpu_dispatch().and_then(|d| d.try_fast_xt_diag_x(x.view(), w.view()))
695 {
696 return out;
697 }
698 let p = x.ncols();
699 fast_xt_diag_x_with_parallelism(x, w, matmul_parallelism(p, p, x.nrows()))
700}
701
702#[inline]
705pub fn fast_xt_diag_x_with_parallelism<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
706 x: &ArrayBase<S1, Ix2>,
707 w: &ArrayBase<S2, Ix1>,
708 par: Par,
709) -> Array2<f64> {
710 assert_eq!(
711 x.nrows(),
712 w.len(),
713 "fast_xt_diag_x_with_parallelism row/weight length mismatch"
714 );
715 fast_xt_diag_x_with_parallelism_impl(x, w, par)
716}
717
718#[inline]
719fn fast_xt_diag_x_with_parallelism_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
720 x: &ArrayBase<S1, Ix2>,
721 w: &ArrayBase<S2, Ix1>,
722 par: Par,
723) -> Array2<f64> {
724 use ndarray::ShapeBuilder;
725
726 let p = x.ncols();
727 let mut result = Array2::<f64>::zeros((p, p).f());
730 stream_weighted_crossprod_into(
731 x,
732 w,
733 &mut result,
734 CrossprodStructure::SymmetricLower,
735 CrossprodAccum::Replace,
736 par,
737 );
738 result
739}
740
741#[derive(Clone, Copy, PartialEq, Eq, Debug)]
743pub enum CrossprodStructure {
744 Full,
746 SymmetricLower,
750}
751
752#[derive(Clone, Copy, PartialEq, Eq, Debug)]
754pub enum CrossprodAccum {
755 Replace,
757 Add,
759}
760
761pub fn stream_weighted_crossprod_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
780 x: &ArrayBase<S1, Ix2>,
781 w: &ArrayBase<S2, Ix1>,
782 out: &mut Array2<f64>,
783 structure: CrossprodStructure,
784 accum: CrossprodAccum,
785 par: Par,
786) {
787 use faer::Accum;
788 use faer::linalg::matmul::matmul;
789 use faer::linalg::matmul::triangular::{BlockStructure, matmul as tri_matmul};
790 use ndarray::s;
791
792 let (n, p) = x.dim();
793 assert_eq!(n, w.len(), "X rows must match W length");
794 assert_eq!(out.nrows(), p, "output rows must match X cols");
795 assert_eq!(out.ncols(), p, "output cols must match X cols");
796 if p == 0 {
797 return;
798 }
799 if n == 0 {
800 if accum == CrossprodAccum::Replace {
801 out.fill(0.0);
802 }
803 return;
804 }
805
806 if !should_use_faer_matmul(p, p, n) {
807 let w_x = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
809 let gram = x.t().dot(&w_x);
810 match accum {
811 CrossprodAccum::Replace => out.assign(&gram),
812 CrossprodAccum::Add => *out += &gram,
813 }
814 return;
815 }
816
817 const TARGET_BYTES: usize = 8 * 1024 * 1024;
819 const MIN_ROWS: usize = 512;
820 const MAX_ROWS: usize = 131_072;
821 let chunk_rows = (TARGET_BYTES / (p.max(1) * 8))
822 .clamp(MIN_ROWS, MAX_ROWS)
823 .min(n);
824
825 if accum == CrossprodAccum::Replace {
830 out.fill(0.0);
831 }
832
833 let mut wx_chunk = Array2::<f64>::zeros((chunk_rows, p));
839
840 let x_is_row_major = x.is_standard_layout();
841 let w_slice_opt = w.as_slice();
842
843 {
846 let mut out_view = array2_to_matmut(out);
847 for start in (0..n).step_by(chunk_rows) {
848 let rows = (n - start).min(chunk_rows);
849 {
850 let chunk_slice = wx_chunk
851 .as_slice_mut()
852 .expect("row-major chunk is contiguous");
853 if x_is_row_major && let (Some(x_all), Some(w_all)) = (x.as_slice(), w_slice_opt) {
854 for local in 0..rows {
855 let src = start + local;
856 let wi = w_all[src];
857 let src_off = src * p;
858 let dst_off = local * p;
859 let src_row = &x_all[src_off..src_off + p];
860 let dst_row = &mut chunk_slice[dst_off..dst_off + p];
861 for col in 0..p {
862 dst_row[col] = src_row[col] * wi;
863 }
864 }
865 } else {
866 let x_slice = x.slice(s![start..start + rows, ..]);
867 for local in 0..rows {
868 let wi = w[start + local];
869 let xrow = x_slice.row(local);
870 let dst_off = local * p;
871 let dst_row = &mut chunk_slice[dst_off..dst_off + p];
872 for (col, xij) in xrow.iter().enumerate() {
873 dst_row[col] = xij * wi;
874 }
875 }
876 }
877 }
878 let x_slice = x.slice(s![start..start + rows, ..]);
879 let wx_slice = wx_chunk.slice(s![0..rows, ..]);
880 let x_view = FaerArrayView::new(&x_slice);
881 let wx_view = FaerArrayView::new(&wx_slice);
882 match structure {
883 CrossprodStructure::SymmetricLower => {
884 tri_matmul(
888 out_view.as_mut(),
889 BlockStructure::TriangularLower,
890 Accum::Add,
891 x_view.as_ref().transpose(),
892 BlockStructure::Rectangular,
893 wx_view.as_ref(),
894 BlockStructure::Rectangular,
895 1.0,
896 par,
897 );
898 }
899 CrossprodStructure::Full => {
900 matmul(
901 out_view.as_mut(),
902 Accum::Add,
903 x_view.as_ref().transpose(),
904 wx_view.as_ref(),
905 1.0,
906 par,
907 );
908 }
909 }
910 }
911 }
912
913 if structure == CrossprodStructure::SymmetricLower {
914 for i in 0..p {
916 for j in (i + 1)..p {
917 out[[i, j]] = out[[j, i]];
918 }
919 }
920 }
921}
922
923#[inline]
925pub fn fast_xt_diag_y<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
926 x: &ArrayBase<S1, Ix2>,
927 w: &ArrayBase<S2, Ix1>,
928 y: &ArrayBase<S3, Ix2>,
929) -> Array2<f64> {
930 assert_eq!(x.nrows(), y.nrows(), "fast_xt_diag_y X/Y row mismatch");
931 assert_eq!(
932 y.nrows(),
933 w.len(),
934 "fast_xt_diag_y row/weight length mismatch"
935 );
936 if let Some(out) = crate::gpu_hook::gpu_dispatch()
937 .and_then(|d| d.try_fast_xt_diag_y(x.view(), w.view(), y.view()))
938 {
939 return out;
940 }
941 fast_xt_diag_y_impl(x, w, y)
942}
943
944#[inline]
945fn fast_xt_diag_y_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>, S3: Data<Elem = f64>>(
946 x: &ArrayBase<S1, Ix2>,
947 w: &ArrayBase<S2, Ix1>,
948 y: &ArrayBase<S3, Ix2>,
949) -> Array2<f64> {
950 use faer::Accum;
951 use faer::linalg::matmul::matmul;
952 use ndarray::{ShapeBuilder, s};
953
954 let (n, q) = y.dim();
955 let px = x.ncols();
956 assert_eq!(n, w.len(), "Y rows must match W length");
957 assert_eq!(n, x.nrows(), "X rows must match Y rows");
958 if n == 0 || px == 0 || q == 0 {
959 return Array2::<f64>::zeros((px, q));
960 }
961 if !should_use_faer_matmul(px, q, n) {
962 let w_y = Array2::from_shape_fn((n, q), |(i, j)| w[i] * y[[i, j]]);
963 return x.t().dot(&w_y);
964 }
965
966 const TARGET_BYTES: usize = 8 * 1024 * 1024;
968 const MIN_ROWS: usize = 512;
969 const MAX_ROWS: usize = 131_072;
970 let total_cols = px + q;
971 let chunk_rows = (TARGET_BYTES / (total_cols.max(1) * 8))
972 .clamp(MIN_ROWS, MAX_ROWS)
973 .min(n);
974
975 let mut result = Array2::<f64>::zeros((px, q).f());
976 let mut wy_chunk = Array2::<f64>::zeros((chunk_rows, q));
979
980 let y_is_row_major = y.is_standard_layout();
981 let w_slice_opt = w.as_slice();
982
983 {
984 let mut out_view = array2_to_matmut(&mut result);
985
986 for start in (0..n).step_by(chunk_rows) {
987 let rows = (n - start).min(chunk_rows);
988 {
989 let chunk_slice = wy_chunk
990 .as_slice_mut()
991 .expect("row-major chunk is contiguous");
992 if y_is_row_major && let (Some(y_all), Some(w_all)) = (y.as_slice(), w_slice_opt) {
993 for local in 0..rows {
994 let src = start + local;
995 let wi = w_all[src];
996 let src_off = src * q;
997 let dst_off = local * q;
998 let src_row = &y_all[src_off..src_off + q];
999 let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1000 for col in 0..q {
1001 dst_row[col] = src_row[col] * wi;
1002 }
1003 }
1004 } else {
1005 let y_slice = y.slice(s![start..start + rows, ..]);
1006 for local in 0..rows {
1007 let wi = w[start + local];
1008 let yrow = y_slice.row(local);
1009 let dst_off = local * q;
1010 let dst_row = &mut chunk_slice[dst_off..dst_off + q];
1011 for (col, yij) in yrow.iter().enumerate() {
1012 dst_row[col] = yij * wi;
1013 }
1014 }
1015 }
1016 }
1017 let x_slice = x.slice(s![start..start + rows, ..]);
1018 let wy_slice = wy_chunk.slice(s![0..rows, ..]);
1019 let x_view = FaerArrayView::new(&x_slice);
1020 let wy_view = FaerArrayView::new(&wy_slice);
1021 let par = matmul_parallelism(px, q, rows);
1022 matmul(
1023 out_view.as_mut(),
1024 Accum::Add,
1025 x_view.as_ref().transpose(),
1026 wy_view.as_ref(),
1027 1.0,
1028 par,
1029 );
1030 }
1031 }
1032
1033 result
1034}
1035
1036pub fn fast_joint_hessian_2x2<
1042 S1: Data<Elem = f64>,
1043 S2: Data<Elem = f64>,
1044 S3: Data<Elem = f64>,
1045 S4: Data<Elem = f64>,
1046 S5: Data<Elem = f64>,
1047>(
1048 x_a: &ArrayBase<S1, Ix2>,
1049 x_b: &ArrayBase<S2, Ix2>,
1050 w_aa: &ArrayBase<S3, Ix1>,
1051 w_ab: &ArrayBase<S4, Ix1>,
1052 w_bb: &ArrayBase<S5, Ix1>,
1053) -> Array2<f64> {
1054 if let Some(out) = crate::gpu_hook::gpu_dispatch().and_then(|d| {
1055 d.try_fast_joint_hessian_2x2(
1056 x_a.view(),
1057 x_b.view(),
1058 w_aa.view(),
1059 w_ab.view(),
1060 w_bb.view(),
1061 )
1062 }) {
1063 return out;
1064 }
1065 fast_joint_hessian_2x2_impl(x_a, x_b, w_aa, w_ab, w_bb)
1066}
1067
1068#[inline]
1069fn fast_joint_hessian_2x2_impl<
1070 S1: Data<Elem = f64>,
1071 S2: Data<Elem = f64>,
1072 S3: Data<Elem = f64>,
1073 S4: Data<Elem = f64>,
1074 S5: Data<Elem = f64>,
1075>(
1076 x_a: &ArrayBase<S1, Ix2>,
1077 x_b: &ArrayBase<S2, Ix2>,
1078 w_aa: &ArrayBase<S3, Ix1>,
1079 w_ab: &ArrayBase<S4, Ix1>,
1080 w_bb: &ArrayBase<S5, Ix1>,
1081) -> Array2<f64> {
1082 use faer::Accum;
1083 use faer::linalg::matmul::matmul;
1084 use ndarray::{ShapeBuilder, s};
1085
1086 let n = x_a.nrows();
1087 let pa = x_a.ncols();
1088 let pb = x_b.ncols();
1089 let total = pa + pb;
1090 assert_eq!(n, x_b.nrows());
1091 assert_eq!(n, w_aa.len());
1092 assert_eq!(n, w_ab.len());
1093 assert_eq!(n, w_bb.len());
1094
1095 if n == 0 || total == 0 {
1096 return Array2::<f64>::zeros((total, total));
1097 }
1098
1099 if !should_use_faer_matmul(pa.max(pb), pa.max(pb), n) {
1101 let waa_xa = Array2::from_shape_fn((n, pa), |(i, j)| w_aa[i] * x_a[[i, j]]);
1102 let wab_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_ab[i] * x_b[[i, j]]);
1103 let wbb_xb = Array2::from_shape_fn((n, pb), |(i, j)| w_bb[i] * x_b[[i, j]]);
1104 let mut out = Array2::<f64>::zeros((total, total));
1105 out.slice_mut(s![..pa, ..pa]).assign(&x_a.t().dot(&waa_xa));
1106 out.slice_mut(s![..pa, pa..]).assign(&x_a.t().dot(&wab_xb));
1107 out.slice_mut(s![pa.., pa..]).assign(&x_b.t().dot(&wbb_xb));
1108 for i in 0..total {
1110 for j in 0..i {
1111 out[[i, j]] = out[[j, i]];
1112 }
1113 }
1114 return out;
1115 }
1116
1117 const TARGET_BYTES: usize = 8 * 1024 * 1024;
1118 const MIN_ROWS: usize = 512;
1119 const MAX_ROWS: usize = 131_072;
1120 let cols_needed = pa + 2 * pb;
1122 let chunk_rows = (TARGET_BYTES / (cols_needed.max(1) * 8))
1123 .clamp(MIN_ROWS, MAX_ROWS)
1124 .min(n);
1125
1126 let mut out = Array2::<f64>::zeros((total, total).f());
1127 let mut waa_xa_buf = Array2::<f64>::zeros((chunk_rows, pa));
1132 let mut wab_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1133 let mut wbb_xb_buf = Array2::<f64>::zeros((chunk_rows, pb));
1134
1135 let xa_is_row_major = x_a.is_standard_layout();
1136 let xb_is_row_major = x_b.is_standard_layout();
1137 let waa_slice_opt = w_aa.as_slice();
1138 let wab_slice_opt = w_ab.as_slice();
1139 let wbb_slice_opt = w_bb.as_slice();
1140
1141 {
1142 let mut out_mat = array2_to_matmut(&mut out);
1143
1144 for start in (0..n).step_by(chunk_rows) {
1145 let rows = (n - start).min(chunk_rows);
1146 let xa_slice = x_a.slice(s![start..start + rows, ..]);
1147 let xb_slice = x_b.slice(s![start..start + rows, ..]);
1148
1149 {
1151 let waa_chunk = waa_xa_buf
1152 .as_slice_mut()
1153 .expect("row-major waa chunk is contiguous");
1154 let wab_chunk = wab_xb_buf
1155 .as_slice_mut()
1156 .expect("row-major wab chunk is contiguous");
1157 let wbb_chunk = wbb_xb_buf
1158 .as_slice_mut()
1159 .expect("row-major wbb chunk is contiguous");
1160
1161 if xa_is_row_major
1162 && xb_is_row_major
1163 && let (Some(xa_all), Some(xb_all)) = (x_a.as_slice(), x_b.as_slice())
1164 && let (Some(waa_all), Some(wab_all), Some(wbb_all)) =
1165 (waa_slice_opt, wab_slice_opt, wbb_slice_opt)
1166 {
1167 for local in 0..rows {
1168 let i = start + local;
1169 let waa_i = waa_all[i];
1170 let wab_i = wab_all[i];
1171 let wbb_i = wbb_all[i];
1172 let xa_off = i * pa;
1173 let xa_row = &xa_all[xa_off..xa_off + pa];
1174 let xb_off = i * pb;
1175 let xb_row = &xb_all[xb_off..xb_off + pb];
1176 let waa_off = local * pa;
1177 let wab_off = local * pb;
1178 let wbb_off = local * pb;
1179 let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1180 for col in 0..pa {
1181 waa_row[col] = xa_row[col] * waa_i;
1182 }
1183 let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1184 let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1185 for col in 0..pb {
1186 let xij = xb_row[col];
1187 wab_row[col] = xij * wab_i;
1188 wbb_row[col] = xij * wbb_i;
1189 }
1190 }
1191 } else {
1192 for local in 0..rows {
1193 let i = start + local;
1194 let waa_i = w_aa[i];
1195 let wab_i = w_ab[i];
1196 let wbb_i = w_bb[i];
1197 let waa_off = local * pa;
1198 let wab_off = local * pb;
1199 let wbb_off = local * pb;
1200 let waa_row = &mut waa_chunk[waa_off..waa_off + pa];
1201 let xa_row = xa_slice.row(local);
1202 for (col, xij) in xa_row.iter().enumerate() {
1203 waa_row[col] = xij * waa_i;
1204 }
1205 let wab_row = &mut wab_chunk[wab_off..wab_off + pb];
1206 let wbb_row = &mut wbb_chunk[wbb_off..wbb_off + pb];
1207 let xb_row = xb_slice.row(local);
1208 for (col, xij) in xb_row.iter().enumerate() {
1209 wab_row[col] = xij * wab_i;
1210 wbb_row[col] = xij * wbb_i;
1211 }
1212 }
1213 }
1214 }
1215
1216 let xa_view = FaerArrayView::new(&xa_slice);
1217 let xb_view = FaerArrayView::new(&xb_slice);
1218 let waa_xa_slice = waa_xa_buf.slice(s![0..rows, ..]);
1219 let wab_xb_slice = wab_xb_buf.slice(s![0..rows, ..]);
1220 let wbb_xb_slice = wbb_xb_buf.slice(s![0..rows, ..]);
1221 let waa_xa_view = FaerArrayView::new(&waa_xa_slice);
1222 let wab_xb_view = FaerArrayView::new(&wab_xb_slice);
1223 let wbb_xb_view = FaerArrayView::new(&wbb_xb_slice);
1224
1225 matmul(
1227 out_mat.rb_mut().submatrix_mut(0, 0, pa, pa),
1228 Accum::Add,
1229 xa_view.as_ref().transpose(),
1230 waa_xa_view.as_ref(),
1231 1.0,
1232 matmul_parallelism(pa, pa, rows),
1233 );
1234 matmul(
1236 out_mat.rb_mut().submatrix_mut(0, pa, pa, pb),
1237 Accum::Add,
1238 xa_view.as_ref().transpose(),
1239 wab_xb_view.as_ref(),
1240 1.0,
1241 matmul_parallelism(pa, pb, rows),
1242 );
1243 matmul(
1245 out_mat.rb_mut().submatrix_mut(pa, pa, pb, pb),
1246 Accum::Add,
1247 xb_view.as_ref().transpose(),
1248 wbb_xb_view.as_ref(),
1249 1.0,
1250 matmul_parallelism(pb, pb, rows),
1251 );
1252 }
1253 } for i in 0..total {
1256 for j in 0..i {
1257 out[[i, j]] = out[[j, i]];
1258 }
1259 }
1260 out
1261}
1262
1263fn mat_to_array(mat: MatRef<'_, f64>) -> Array2<f64> {
1264 let nrows = mat.nrows();
1265 let ncols = mat.ncols();
1266 let mut out = Array2::<f64>::zeros((nrows, ncols));
1267 if nrows == 0 || ncols == 0 {
1268 return out;
1269 }
1270 if let Some(out_slice) = out.as_slice_memory_order_mut() {
1273 for i in 0..nrows {
1275 let row_start = i * ncols;
1276 for j in 0..ncols {
1277 out_slice[row_start + j] = mat[(i, j)];
1278 }
1279 }
1280 } else {
1281 for j in 0..ncols {
1282 for i in 0..nrows {
1283 out[[i, j]] = mat[(i, j)];
1284 }
1285 }
1286 }
1287 out
1288}
1289
1290#[inline]
1293pub fn fast_ab_into<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1294 a: &ArrayBase<S1, Ix2>,
1295 b: &ArrayBase<S2, Ix2>,
1296 out: &mut Array2<f64>,
1297) {
1298 fast_ab_into_impl(a, b, out);
1299}
1300
1301#[inline]
1302fn fast_ab_into_impl<S1: Data<Elem = f64>, S2: Data<Elem = f64>>(
1303 a: &ArrayBase<S1, Ix2>,
1304 b: &ArrayBase<S2, Ix2>,
1305 out: &mut Array2<f64>,
1306) {
1307 use faer::Accum;
1308 use faer::linalg::matmul::matmul;
1309
1310 let (n, p) = a.dim();
1311 let (p_b, q) = b.dim();
1312 assert_eq!(p, p_b, "A and B must have compatible inner dimensions");
1313 assert_eq!(out.dim(), (n, q), "output dimensions must match A*B result");
1314
1315 if !should_use_faer_matmul(n, q, p) {
1316 out.assign(&a.dot(b));
1317 return;
1318 }
1319
1320 let aview = FaerArrayView::new(a);
1321 let bview = FaerArrayView::new(b);
1322 let a_ref = aview.as_ref();
1323 let b_ref = bview.as_ref();
1324
1325 let par = matmul_parallelism(n, q, p);
1326 let mut outview = array2_to_matmut(out);
1327 matmul(outview.as_mut(), Accum::Replace, a_ref, b_ref, 1.0, par);
1328}
1329
1330fn diag_to_array(diag: DiagRef<'_, f64>) -> Array1<f64> {
1331 let mat = diag.column_vector().as_mat();
1332 let mut out = Array1::<f64>::zeros(mat.nrows());
1333 for i in 0..mat.nrows() {
1334 out[i] = mat[(i, 0)];
1335 }
1336 out
1337}
1338
1339pub struct FaerArrayView<'a> {
1340 ptr: *const f64,
1341 rows: usize,
1342 cols: usize,
1343 row_stride: isize,
1344 col_stride: isize,
1345 owned: Option<Array2<f64>>,
1346 marker: PhantomData<&'a f64>,
1347}
1348
1349impl<'a> FaerArrayView<'a> {
1350 #[inline]
1351 pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix2>) -> Self {
1352 let (rows, cols) = array.dim();
1353 let strides = array.strides();
1354 if strides[0] <= 0 || strides[1] <= 0 {
1358 let owned = array.to_owned();
1359 let owned_strides = owned.strides();
1360 return Self {
1361 ptr: owned.as_ptr(),
1362 rows,
1363 cols,
1364 row_stride: owned_strides[0],
1365 col_stride: owned_strides[1],
1366 owned: Some(owned),
1367 marker: PhantomData,
1368 };
1369 }
1370
1371 Self {
1372 ptr: array.as_ptr(),
1373 rows,
1374 cols,
1375 row_stride: strides[0],
1376 col_stride: strides[1],
1377 owned: None,
1378 marker: PhantomData,
1379 }
1380 }
1381
1382 #[inline]
1383 pub fn as_ref(&self) -> MatRef<'_, f64> {
1384 let (ptr, rows, cols, row_stride, col_stride) = if let Some(owned) = &self.owned {
1385 let strides = owned.strides();
1386 (
1387 owned.as_ptr(),
1388 owned.nrows(),
1389 owned.ncols(),
1390 strides[0],
1391 strides[1],
1392 )
1393 } else {
1394 (
1395 self.ptr,
1396 self.rows,
1397 self.cols,
1398 self.row_stride,
1399 self.col_stride,
1400 )
1401 };
1402 unsafe { MatRef::from_raw_parts(ptr, rows, cols, row_stride, col_stride) }
1406 }
1407}
1408
1409pub struct FaerColView<'a> {
1410 ptr: *const f64,
1411 len: usize,
1412 stride: isize,
1413 owned: Option<Array1<f64>>,
1414 marker: PhantomData<&'a f64>,
1415}
1416
1417impl<'a> FaerColView<'a> {
1418 #[inline]
1419 pub fn new<S: Data<Elem = f64>>(array: &'a ArrayBase<S, Ix1>) -> Self {
1420 let len = array.len();
1421 let stride = array.strides()[0];
1422 if stride <= 0 {
1423 let owned = array.to_owned();
1424 return Self {
1425 ptr: owned.as_ptr(),
1426 len,
1427 stride: 1,
1428 owned: Some(owned),
1429 marker: PhantomData,
1430 };
1431 }
1432 Self {
1433 ptr: array.as_ptr(),
1434 len,
1435 stride,
1436 owned: None,
1437 marker: PhantomData,
1438 }
1439 }
1440
1441 #[inline]
1442 pub fn as_ref(&self) -> MatRef<'_, f64> {
1443 let (ptr, len, stride) = if let Some(owned) = &self.owned {
1444 (owned.as_ptr(), owned.len(), 1)
1445 } else {
1446 (self.ptr, self.len, self.stride)
1447 };
1448 unsafe { MatRef::from_raw_parts(ptr, len, 1, stride, 0) }
1452 }
1453}
1454
1455pub trait FaerSvd {
1456 fn svd(
1457 &self,
1458 compute_u: bool,
1459 computevt: bool,
1460 ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError>;
1461}
1462
1463impl<S: Data<Elem = f64>> FaerSvd for ArrayBase<S, Ix2> {
1464 fn svd(
1465 &self,
1466 compute_u: bool,
1467 computevt: bool,
1468 ) -> Result<(Option<Array2<f64>>, Array1<f64>, Option<Array2<f64>>), FaerLinalgError> {
1469 let faerview = FaerArrayView::new(self);
1470 let faer_mat = faerview.as_ref();
1471 if !compute_u && !computevt {
1472 let (rows, cols) = faer_mat.shape();
1473 let mut singular = Diag::<f64>::zeros(rows.min(cols));
1474 let par = get_global_parallelism();
1475 let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1476 rows,
1477 cols,
1478 ComputeSvdVectors::No,
1479 ComputeSvdVectors::No,
1480 par,
1481 Default::default(),
1482 ));
1483 let stack = MemStack::new(&mut mem);
1484 svd::svd(
1485 faer_mat,
1486 singular.as_mut(),
1487 None,
1488 None,
1489 par,
1490 stack,
1491 Default::default(),
1492 )
1493 .map_err(|_| FaerLinalgError::SvdNoConvergence {
1494 context: "faer SVD singular values only",
1495 })?;
1496 let singularvalues = diag_to_array(singular.as_ref());
1497 return Ok((None, singularvalues, None));
1498 }
1499
1500 let (rows, cols) = faer_mat.shape();
1501 let rank = rows.min(cols);
1502 let compute_u_flag = if compute_u {
1503 ComputeSvdVectors::Thin
1504 } else {
1505 ComputeSvdVectors::No
1506 };
1507 let computev_flag = if computevt {
1508 ComputeSvdVectors::Thin
1509 } else {
1510 ComputeSvdVectors::No
1511 };
1512
1513 let mut singular = Diag::<f64>::zeros(rows.min(cols));
1514 let mut u_storage = compute_u.then(|| Mat::<f64>::zeros(rows, rank));
1515 let mut v_storage = computevt.then(|| Mat::<f64>::zeros(cols, rank));
1516
1517 let par = get_global_parallelism();
1518 let mut mem = MemBuffer::new(svd::svd_scratch::<f64>(
1519 rows,
1520 cols,
1521 compute_u_flag,
1522 computev_flag,
1523 par,
1524 Default::default(),
1525 ));
1526 let stack = MemStack::new(&mut mem);
1527
1528 svd::svd(
1529 faer_mat.as_ref(),
1530 singular.as_mut(),
1531 u_storage.as_mut().map(|mat| mat.as_mut()),
1532 v_storage.as_mut().map(|mat| mat.as_mut()),
1533 par,
1534 stack,
1535 Default::default(),
1536 )
1537 .map_err(|_| FaerLinalgError::SvdNoConvergence {
1538 context: "faer SVD with vectors",
1539 })?;
1540
1541 let singularvalues = diag_to_array(singular.as_ref());
1542 let u_opt = u_storage.map(|mat| mat_to_array(mat.as_ref()));
1543 let vt_opt = v_storage.map(|mat| {
1544 let mat_ref = mat.as_ref();
1545 let mut out = Array2::<f64>::zeros((mat_ref.ncols(), mat_ref.nrows()));
1546 for j in 0..mat_ref.nrows() {
1547 for i in 0..mat_ref.ncols() {
1548 out[[i, j]] = mat_ref[(j, i)];
1549 }
1550 }
1551 out
1552 });
1553
1554 Ok((u_opt, singularvalues, vt_opt))
1555 }
1556}
1557
1558pub trait FaerEigh {
1559 fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError>;
1560}
1561
1562impl<S: Data<Elem = f64>> FaerEigh for ArrayBase<S, Ix2> {
1563 fn eigh(&self, side: Side) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1564 fn try_eigh(
1565 matrix: &Array2<f64>,
1566 side: Side,
1567 ) -> Result<(Array1<f64>, Array2<f64>), FaerLinalgError> {
1568 let faerview = FaerArrayView::new(matrix);
1569 let eigen = catch_unwind(AssertUnwindSafe(|| {
1570 faerview.as_ref().self_adjoint_eigen(side)
1571 }))
1572 .map_err(|_| FaerLinalgError::FactorizationFailed {
1573 context: "self-adjoint eigendecomposition panic boundary",
1574 })?
1575 .map_err(FaerLinalgError::SelfAdjointEigen)?;
1576 let values = diag_to_array(eigen.S());
1577 let vectors = mat_to_array(eigen.U());
1578 Ok((values, vectors))
1579 }
1580
1581 let owned = self.to_owned();
1582 if owned.nrows() != owned.ncols() {
1583 return Err(FaerLinalgError::FactorizationFailed {
1584 context: "self-adjoint eigendecomposition non-square input",
1585 });
1586 }
1587 if owned.nrows() == 0 {
1588 return Ok((Array1::zeros(0), Array2::zeros((0, 0))));
1589 }
1590 if owned.iter().any(|value| !value.is_finite()) {
1591 return Err(FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1592 context: "self-adjoint eigendecomposition input validation",
1593 });
1594 }
1595 if let Ok((evals, evecs)) = try_eigh(&owned, side)
1596 && evals.iter().all(|value| value.is_finite())
1597 && evecs.iter().all(|value| value.is_finite())
1598 {
1599 return Ok((evals, evecs));
1600 }
1601
1602 let mut repaired = owned.clone();
1603 crate::matrix::symmetrize_in_place(&mut repaired);
1604
1605 let scale = repaired
1606 .iter()
1607 .fold(0.0_f64, |acc, &value| acc.max(value.abs()))
1608 .max(1.0);
1609 let scaled = repaired.mapv(|value| value / scale);
1610 const JITTER_SCHEDULE: [f64; 6] = [0.0, 1e-12, 1e-10, 1e-8, 1e-6, 1e-4];
1616 let jitter_schedule = JITTER_SCHEDULE;
1617 let mut last_error = FaerLinalgError::FactorizationFailed {
1618 context: "self-adjoint eigendecomposition repair attempts",
1619 };
1620
1621 for &jitter in &jitter_schedule {
1622 let mut candidate = scaled.clone();
1623 if jitter > 0.0 {
1624 let n = candidate.nrows();
1625 for i in 0..n {
1626 candidate[[i, i]] += jitter;
1627 }
1628 }
1629
1630 match try_eigh(&candidate, side) {
1631 Ok((mut evals, evecs))
1632 if evals.iter().all(|value| value.is_finite())
1633 && evecs.iter().all(|value| value.is_finite()) =>
1634 {
1635 for value in &mut evals {
1636 *value = (*value - jitter) * scale;
1637 }
1638 return Ok((evals, evecs));
1639 }
1640 Ok((_, _)) => {
1641 last_error = FaerLinalgError::SelfAdjointEigenNonFiniteInput {
1642 context: "self-adjoint eigendecomposition repaired output validation",
1643 };
1644 }
1645 Err(err) => {
1646 last_error = err;
1647 }
1648 }
1649 }
1650
1651 Err(last_error)
1652 }
1653}
1654
1655pub struct FaerCholeskyFactor {
1656 factor: solvers::Llt<f64>,
1657}
1658
1659impl FaerCholeskyFactor {
1660 pub fn solvevec(&self, rhs: &Array1<f64>) -> Array1<f64> {
1661 let mut rhs = rhs.to_owned();
1662 let mut rhsview = array1_to_col_matmut(&mut rhs);
1663 self.factor.solve_in_place(rhsview.as_mut());
1664 rhs
1665 }
1666
1667 pub fn solve_mat_in_place(&self, rhs: &mut Array2<f64>) {
1668 let mut rhsview = array2_to_matmut(rhs);
1669 self.factor.solve_in_place(rhsview.as_mut());
1670 }
1671
1672 pub fn solve_mat_into<S: Data<Elem = f64>>(
1673 &self,
1674 rhs: &ArrayBase<S, Ix2>,
1675 out: &mut Array2<f64>,
1676 ) {
1677 if out.dim() != rhs.dim() {
1678 *out = Array2::<f64>::zeros(rhs.dim());
1679 }
1680 out.assign(rhs);
1681 self.solve_mat_in_place(out);
1682 }
1683
1684 pub fn solve_mat(&self, rhs: &Array2<f64>) -> Array2<f64> {
1685 let mut out = Array2::<f64>::zeros(rhs.dim());
1686 self.solve_mat_into(rhs, &mut out);
1687 out
1688 }
1689
1690 pub fn diag(&self) -> Array1<f64> {
1691 diag_to_array(self.factor.L().diagonal())
1692 }
1693
1694 pub fn lower_triangular(&self) -> Array2<f64> {
1695 mat_to_array(self.factor.L())
1696 }
1697}
1698
1699pub trait FaerCholesky {
1700 fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError>;
1701}
1702
1703impl<S: Data<Elem = f64>> FaerCholesky for ArrayBase<S, Ix2> {
1704 fn cholesky(&self, side: Side) -> Result<FaerCholeskyFactor, FaerLinalgError> {
1705 let faerview = FaerArrayView::new(self);
1706 let factor = faerview
1707 .as_ref()
1708 .llt(side)
1709 .map_err(FaerLinalgError::Cholesky)?;
1710 Ok(FaerCholeskyFactor { factor })
1711 }
1712}
1713
1714pub trait FaerQr {
1715 fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError>;
1716}
1717
1718impl<S: Data<Elem = f64>> FaerQr for ArrayBase<S, Ix2> {
1719 fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), FaerLinalgError> {
1720 let faerview = FaerArrayView::new(self);
1721 let qr = faerview.as_ref().qr();
1722 let q = qr.compute_thin_Q();
1723 let r = qr.thin_R();
1724 Ok((mat_to_array(q.as_ref()), mat_to_array(r)))
1725 }
1726}
1727
1728pub fn rrqr_nullspace_basis<S: Data<Elem = f64>>(
1747 a: &ArrayBase<S, Ix2>,
1748 rank_alpha: f64,
1749) -> Result<(Array2<f64>, usize), FaerLinalgError> {
1750 let faerview = FaerArrayView::new(a);
1751 let qr = faerview.as_ref().col_piv_qr();
1752 let r = qr.thin_R();
1753 let diag_len = r.nrows().min(r.ncols());
1754 let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
1755 let tol = rank_alpha
1756 * f64::EPSILON
1757 * (a.nrows().max(a.ncols()).max(1) as f64)
1758 * leading_diag.max(1.0);
1759 let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
1760 let z = if rank >= a.nrows() {
1761 Array2::<f64>::zeros((a.nrows(), 0))
1762 } else if rank == 0 {
1763 Array2::<f64>::eye(a.nrows())
1767 } else {
1768 let nullity = a.nrows() - rank;
1769 let mut selector = Mat::<f64>::zeros(a.nrows(), nullity);
1770 for j in 0..nullity {
1771 selector[(rank + j, j)] = 1.0;
1772 }
1773 let par = get_global_parallelism();
1774 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1775 qr.Q_basis(),
1776 qr.Q_coeff(),
1777 Conj::No,
1778 selector.as_mut(),
1779 par,
1780 MemStack::new(&mut MemBuffer::new(
1781 faer::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<f64>(
1782 a.nrows(),
1783 qr.Q_coeff().nrows(),
1784 nullity,
1785 ),
1786 )),
1787 );
1788 mat_to_array(selector.as_ref())
1789 };
1790 Ok((z, rank))
1791}
1792
1793#[inline]
1794pub const fn default_rrqr_rank_alpha() -> f64 {
1795 RRQR_RANK_ALPHA
1796}
1797
1798pub struct RrqrWithPermutation {
1809 pub rank: usize,
1810 pub column_permutation: Vec<usize>,
1811 pub leading_diag_abs: f64,
1812 pub rank_tol: f64,
1813}
1814
1815pub fn rrqr_with_permutation<S: Data<Elem = f64>>(
1824 a: &ArrayBase<S, Ix2>,
1825 rank_alpha: f64,
1826) -> Result<RrqrWithPermutation, FaerLinalgError> {
1827 if a.nrows() == 0 {
1828 return Err(FaerLinalgError::FactorizationFailed {
1829 context: "rrqr_with_permutation: input has zero rows",
1830 });
1831 }
1832 let faerview = FaerArrayView::new(a);
1833 let qr = faerview.as_ref().col_piv_qr();
1834 let r = qr.thin_R();
1835 let diag_len = r.nrows().min(r.ncols());
1836 let leading_diag = if diag_len > 0 { r[(0, 0)].abs() } else { 0.0 };
1837 let tol = rank_alpha
1838 * f64::EPSILON
1839 * (a.nrows().max(a.ncols()).max(1) as f64)
1840 * leading_diag.max(1.0);
1841 let rank = (0..diag_len).filter(|&i| r[(i, i)].abs() > tol).count();
1842 let (forward, _inverse) = qr.P().arrays();
1843 let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
1844 Ok(RrqrWithPermutation {
1845 rank,
1846 column_permutation,
1847 leading_diag_abs: leading_diag,
1848 rank_tol: tol,
1849 })
1850}
1851
1852pub struct RrqrFromGram {
1861 pub rank: usize,
1862 pub column_permutation: Vec<usize>,
1863 pub rank_tol: f64,
1864 pub leading_diag_abs: f64,
1869 pub verdict_margin: f64,
1872}
1873
1874pub fn rrqr_from_gram_with_permutation<S: Data<Elem = f64>>(
1910 gram: &ArrayBase<S, Ix2>,
1911 m_rows: usize,
1912 rank_alpha: f64,
1913) -> Result<RrqrFromGram, FaerLinalgError> {
1914 let p = gram.ncols();
1915 if p == 0 {
1916 return Ok(RrqrFromGram {
1917 rank: 0,
1918 column_permutation: Vec::new(),
1919 rank_tol: 0.0,
1920 leading_diag_abs: 0.0,
1921 verdict_margin: 0.0,
1922 });
1923 }
1924 if gram.nrows() != p {
1925 return Err(FaerLinalgError::FactorizationFailed {
1926 context: "rrqr_from_gram_with_permutation: Gram is not square",
1927 });
1928 }
1929 let (evals, evecs) = gram.eigh(Side::Lower)?;
1938 let mut f = Array2::<f64>::zeros((p, p));
1939 for k in 0..p {
1940 let scale = evals[k].max(0.0).sqrt();
1941 if scale == 0.0 {
1942 continue;
1943 }
1944 for i in 0..p {
1945 f[[k, i]] = scale * evecs[[i, k]];
1946 }
1947 }
1948 let faer_f = FaerArrayView::new(&f);
1952 let qr = faer_f.as_ref().col_piv_qr();
1953 let r = qr.thin_R();
1954 let diag_len = r.nrows().min(r.ncols());
1955 let pivots: Vec<f64> = (0..diag_len).map(|i| r[(i, i)].abs()).collect();
1956 let leading_diag = pivots.first().copied().unwrap_or(0.0);
1957 let (forward, _inverse) = qr.P().arrays();
1958 let column_permutation: Vec<usize> = forward.iter().copied().map(|idx| idx.unbound()).collect();
1959 let tol = rank_alpha * f64::EPSILON * (m_rows.max(p).max(1) as f64) * leading_diag.max(1.0);
1963 let rank = pivots.iter().filter(|&&v| v > tol).count();
1964 let min_kept = pivots[..rank].iter().copied().fold(f64::INFINITY, f64::min);
1965 let max_dropped = pivots[rank..].iter().copied().fold(0.0f64, f64::max);
1966 let kept_margin = if rank == 0 {
1970 f64::INFINITY
1971 } else {
1972 min_kept / tol
1973 };
1974 let dropped_margin = if rank == diag_len {
1975 f64::INFINITY
1976 } else {
1977 tol / max_dropped.max(f64::MIN_POSITIVE)
1978 };
1979 let gram_precision_floor = f64::EPSILON.sqrt() * leading_diag.max(1.0);
2001 let kept_floor_margin = if rank == 0 {
2002 f64::INFINITY
2003 } else {
2004 min_kept / gram_precision_floor.max(f64::MIN_POSITIVE)
2005 };
2006 let verdict_margin = kept_margin.min(dropped_margin).min(kept_floor_margin);
2007 Ok(RrqrFromGram {
2008 rank,
2009 column_permutation,
2010 rank_tol: tol,
2011 leading_diag_abs: leading_diag,
2012 verdict_margin,
2013 })
2014}
2015
2016#[cfg(test)]
2017mod tests {
2018 use super::*;
2019 use ndarray::{array, s};
2020
2021 const JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST: f64 = 1.0e3;
2025
2026 #[test]
2027 fn rrqr_nullspace_basis_is_orthonormal_and_annihilates_transpose() {
2028 let a = array![[1.0, 0.0], [1.0, 0.0], [0.0, 2.0], [0.0, 0.0],];
2029 let (z, rank) =
2030 rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2031 assert_eq!(rank, 2);
2032 assert_eq!(z.nrows(), 4);
2033 assert_eq!(z.ncols(), 2);
2034
2035 let gram = z.t().dot(&z);
2036 let ident = Array2::<f64>::eye(z.ncols());
2037 let gram_err = (&gram - &ident)
2038 .iter()
2039 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2040 assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2041
2042 let residual = a.t().dot(&z);
2043 let resid_max = residual.iter().fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2044 assert!(resid_max < 1e-10, "A^T Z residual too large: {resid_max:e}");
2045 }
2046
2047 #[test]
2048 fn rrqr_with_permutation_attributes_redundant_column() {
2049 let a = array![
2053 [1.0, 0.0, 1.0],
2054 [1.0, 0.0, 1.0],
2055 [0.0, 2.0, 0.0],
2056 [0.0, 0.0, 0.0],
2057 ];
2058 let result =
2059 rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2060 assert_eq!(result.rank, 2);
2061 assert_eq!(result.column_permutation.len(), 3);
2062 let demoted = result.column_permutation[result.rank..].to_vec();
2063 assert!(
2064 demoted.contains(&2) || demoted.contains(&0),
2065 "demoted suffix should include one of the aliased columns (0 or 2), got {demoted:?}"
2066 );
2067 let mut sorted = result.column_permutation.clone();
2068 sorted.sort();
2069 assert_eq!(
2070 sorted,
2071 vec![0, 1, 2],
2072 "permutation must be a valid bijection on 0..n"
2073 );
2074 }
2075
2076 #[test]
2077 fn rrqr_with_permutation_full_rank_returns_identity_like_order() {
2078 let a = array![[1.0, 0.0], [0.0, 2.0], [0.0, 0.0]];
2079 let result =
2080 rrqr_with_permutation(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2081 assert_eq!(result.rank, 2);
2082 let mut sorted = result.column_permutation.clone();
2083 sorted.sort();
2084 assert_eq!(sorted, vec![0, 1]);
2085 }
2086
2087 #[test]
2088 fn rrqr_with_permutation_rejects_zero_rows() {
2089 let a = Array2::<f64>::zeros((0, 3));
2090 assert!(rrqr_with_permutation(&a, default_rrqr_rank_alpha()).is_err());
2091 }
2092
2093 #[test]
2094 fn rrqr_nullspace_basis_square_zero_matrix_is_finite_identity() {
2095 let a = Array2::<f64>::zeros((3, 3));
2098 let (z, rank) =
2099 rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2100 assert_eq!(rank, 0);
2101 assert_eq!(z.dim(), (3, 3));
2102 assert!(
2103 z.iter().all(|v| v.is_finite()),
2104 "square zero matrix produced a non-finite null basis: {z:?}"
2105 );
2106 let gram = z.t().dot(&z);
2107 let ident = Array2::<f64>::eye(3);
2108 let gram_err = (&gram - &ident)
2109 .iter()
2110 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2111 assert!(gram_err < 1e-10, "Z is not orthonormal: {gram_err:e}");
2112 }
2113
2114 #[test]
2115 fn rrqr_nullspace_basis_detectszero_rank_matrix() {
2116 let a = Array2::<f64>::zeros((5, 2));
2117 let (z, rank) =
2118 rrqr_nullspace_basis(&a, default_rrqr_rank_alpha()).expect("RRQR should succeed");
2119 assert_eq!(rank, 0);
2120 assert_eq!(z.dim(), (5, 5));
2121 let ident = Array2::<f64>::eye(5);
2122 let max_err = (&z.slice(s![.., ..5]).to_owned() - &ident)
2123 .iter()
2124 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2125 assert!(max_err < 1e-10, "zero matrix should yield identity basis");
2126 }
2127
2128 #[test]
2137 fn eigh_on_nan_matrix_rejects_non_finite_input() {
2138 let mat = array![
2139 [1.0, 0.0, 0.0, 0.0],
2140 [0.0, 2.0, 0.0, 0.0],
2141 [0.0, 0.0, 3.0, f64::NAN],
2142 [0.0, 0.0, f64::NAN, 4.0]
2143 ];
2144 let err = mat
2145 .eigh(Side::Lower)
2146 .expect_err("non-finite symmetric input must be rejected");
2147 assert!(matches!(
2148 err,
2149 FaerLinalgError::SelfAdjointEigenNonFiniteInput { .. }
2150 ));
2151 }
2152
2153 #[test]
2154 fn fast_ata_matches_full_gemm_above_threshold() {
2155 let n = 200;
2158 let p = 40;
2159 let a: Array2<f64> = Array2::from_shape_fn((n, p), |(i, j)| {
2160 ((i * 7 + j * 3) as f64).sin() + 0.1 * j as f64
2161 });
2162 let expected = a.t().dot(&a);
2163 let got = fast_ata(&a);
2164 let max_err = (&got - &expected)
2165 .iter()
2166 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2167 assert!(max_err < 1e-10, "fast_ata mismatch: {max_err:e}");
2168 for i in 0..p {
2170 for j in 0..p {
2171 assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2172 }
2173 }
2174 }
2175
2176 #[test]
2177 fn fast_xt_diag_x_matches_naive_above_threshold() {
2178 let n = 400;
2179 let p = 36;
2180 let x: Array2<f64> =
2181 Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.1).cos() + j as f64 * 0.05);
2182 let w: Array1<f64> = Array1::from_shape_fn(n, |i| (i as f64 * 0.03).sin());
2183 let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2185 let expected = x.t().dot(&wx);
2186 let got = fast_xt_diag_x(&x, &w);
2187 let max_err = (&got - &expected)
2188 .iter()
2189 .fold(0.0_f64, |acc, &v| acc.max(v.abs()));
2190 assert!(max_err < 1e-9, "fast_xt_diag_x mismatch: {max_err:e}");
2191 for i in 0..p {
2192 for j in 0..p {
2193 assert!((got[[i, j]] - got[[j, i]]).abs() < 1e-12);
2194 }
2195 }
2196 }
2197
2198 #[test]
2199 fn stream_weighted_crossprod_full_and_triangular_parity_with_negative_weights() {
2200 for &(n, p) in &[(900usize, 40usize), (8usize, 3usize)] {
2209 let x: Array2<f64> =
2210 Array2::from_shape_fn((n, p), |(i, j)| (i as f64 * 0.07).cos() + j as f64 * 0.013);
2211 let w: Array1<f64> =
2214 Array1::from_shape_fn(n, |i| (i as f64 * 0.11).sin() - 0.25 * (i % 3) as f64);
2215 assert!(
2216 w.iter().any(|&v| v < 0.0),
2217 "weight vector must contain negatives to test sign preservation"
2218 );
2219
2220 let wx = Array2::from_shape_fn((n, p), |(i, j)| w[i] * x[[i, j]]);
2222 let expected = x.t().dot(&wx);
2223
2224 let par = matmul_parallelism(p, p, n);
2225
2226 let mut full = Array2::<f64>::ones((p, p));
2228 stream_weighted_crossprod_into(
2229 &x,
2230 &w,
2231 &mut full,
2232 CrossprodStructure::Full,
2233 CrossprodAccum::Replace,
2234 par,
2235 );
2236
2237 let mut tri = Array2::<f64>::from_elem((p, p), -7.0);
2241 stream_weighted_crossprod_into(
2242 &x,
2243 &w,
2244 &mut tri,
2245 CrossprodStructure::SymmetricLower,
2246 CrossprodAccum::Replace,
2247 par,
2248 );
2249
2250 let full_err = (&full - &expected)
2251 .iter()
2252 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2253 let tri_err = (&tri - &expected)
2254 .iter()
2255 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2256 assert!(
2257 full_err < 1e-9,
2258 "full kernel mismatch (n={n}, p={p}): {full_err:e}"
2259 );
2260 assert!(
2261 tri_err < 1e-9,
2262 "triangular kernel mismatch (n={n}, p={p}): {tri_err:e}"
2263 );
2264
2265 for i in 0..p {
2268 for j in 0..p {
2269 assert!(
2270 (full[[i, j]] - tri[[i, j]]).abs() < 1e-12,
2271 "full vs triangular disagree at ({i},{j})"
2272 );
2273 assert!(
2274 (tri[[i, j]] - tri[[j, i]]).abs() < 1e-12,
2275 "triangular output not symmetric at ({i},{j})"
2276 );
2277 }
2278 }
2279
2280 let base = Array2::<f64>::from_elem((p, p), 1.5);
2283 let mut add_full = base.clone();
2284 stream_weighted_crossprod_into(
2285 &x,
2286 &w,
2287 &mut add_full,
2288 CrossprodStructure::Full,
2289 CrossprodAccum::Add,
2290 par,
2291 );
2292 let mut add_tri = base.clone();
2293 stream_weighted_crossprod_into(
2294 &x,
2295 &w,
2296 &mut add_tri,
2297 CrossprodStructure::SymmetricLower,
2298 CrossprodAccum::Add,
2299 par,
2300 );
2301 let expected_add = &base + &expected;
2302 let add_full_err = (&add_full - &expected_add)
2303 .iter()
2304 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2305 let add_tri_err = (&add_tri - &expected_add)
2306 .iter()
2307 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2308 assert!(
2309 add_full_err < 1e-9,
2310 "full Add mismatch (n={n}, p={p}): {add_full_err:e}"
2311 );
2312 assert!(
2313 add_tri_err < 1e-9,
2314 "triangular Add mismatch (n={n}, p={p}): {add_tri_err:e}"
2315 );
2316
2317 let returned = fast_xt_diag_x(&x, &w);
2320 let returned_err = (&returned - &full)
2321 .iter()
2322 .fold(0.0_f64, |a, &v| a.max(v.abs()));
2323 assert!(
2324 returned_err < 1e-12,
2325 "return adapter vs stream-into adapter disagree (n={n}, p={p}): {returned_err:e}"
2326 );
2327 }
2328 }
2329
2330 #[test]
2331 fn eigh_succeeds_on_same_structure_without_nan() {
2332 let mat = array![[1.0, 0.5, 0.1], [0.5, 2.0, 0.3], [0.1, 0.3, 1.5]];
2334 let (evals, _) = mat
2335 .eigh(Side::Lower)
2336 .expect("eigh should succeed on a well-conditioned finite matrix");
2337 assert!(
2338 evals.iter().all(|&v| v.is_finite()),
2339 "all eigenvalues should be finite"
2340 );
2341 }
2342
2343 #[test]
2353 fn gram_rrqr_flags_low_margin_on_exact_collinearity_so_caller_falls_back() {
2354 let n = 48usize;
2357 let x: Vec<f64> = (0..n)
2358 .map(|i| -1.0 + 2.0 * (i as f64) / (n as f64 - 1.0))
2359 .collect();
2360 let mut a = Array2::<f64>::zeros((n, 4));
2361 for i in 0..n {
2362 a[[i, 0]] = 1.0;
2363 a[[i, 1]] = x[i];
2364 a[[i, 2]] = x[i];
2365 a[[i, 3]] = x[i] * x[i];
2366 }
2367 let alpha = default_rrqr_rank_alpha();
2368
2369 let tall = rrqr_with_permutation(&a, alpha).expect("tall RRQR should succeed");
2372 assert_eq!(tall.rank, 3, "tall RRQR must demote the exact alias");
2373
2374 let unit = Array1::<f64>::ones(n);
2379 let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2380 let gram_rrqr =
2381 rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2382 assert!(
2383 gram_rrqr.verdict_margin < JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2384 "exact-collinearity Gram verdict must report low margin to force tall \
2385 fallback; got margin={:.3e} (rank={})",
2386 gram_rrqr.verdict_margin,
2387 gram_rrqr.rank,
2388 );
2389 }
2390
2391 #[test]
2396 fn gram_rrqr_keeps_high_margin_on_full_rank_design() {
2397 let n = 200usize;
2398 let p = 5usize;
2399 let mut a = Array2::<f64>::zeros((n, p));
2400 for i in 0..n {
2402 let t = (i as f64) / (n as f64 - 1.0);
2403 a[[i, 0]] = 1.0;
2404 a[[i, 1]] = t;
2405 a[[i, 2]] = t * t;
2406 a[[i, 3]] = t * t * t;
2407 a[[i, 4]] = (t * 6.0).sin();
2408 }
2409 let alpha = default_rrqr_rank_alpha();
2410 let unit = Array1::<f64>::ones(n);
2411 let gram = fast_xt_diag_x_with_parallelism(&a, &unit, faer::get_global_parallelism());
2412 let gram_rrqr =
2413 rrqr_from_gram_with_permutation(&gram, n, alpha).expect("Gram RRQR should succeed");
2414 assert_eq!(gram_rrqr.rank, p, "full-rank design must keep all columns");
2415 assert!(
2416 gram_rrqr.verdict_margin >= JOINT_GRAM_RRQR_TRUST_MARGIN_FOR_TEST,
2417 "full-rank design must keep a high margin (fast Gram path); got {:.3e}",
2418 gram_rrqr.verdict_margin,
2419 );
2420 }
2421}