Skip to main content

ferrolearn_core/
backend_faer.rs

1//! Default backend implementation using `faer` for linear algebra.
2//!
3//! [`NdarrayFaerBackend`] implements the [`Backend`](crate::backend::Backend)
4//! trait by converting between `ndarray::Array2<f64>` and `faer::Mat<f64>`,
5//! then delegating to `faer`'s high-performance decomposition routines.
6//!
7//! - **gemm**: uses `ndarray`'s `dot` (which may use optimized BLAS internally).
8//! - **svd**: delegates to `faer::linalg::solvers::Svd`.
9//! - **qr**: delegates to `faer::linalg::solvers::Qr`.
10//! - **cholesky**: delegates to `faer::linalg::solvers::Llt`.
11//! - **solve**: uses LU decomposition via `faer::linalg::solvers::PartialPivLu`.
12//! - **eigh**: delegates to `faer::linalg::solvers::SelfAdjointEigen`.
13//! - **det**: computed via `faer::MatRef::determinant`.
14//! - **inv**: computed via LU decomposition with `DenseSolveCore::inverse`.
15
16use crate::backend::Backend;
17use crate::error::{FerroError, FerroResult};
18use ndarray::{Array1, Array2};
19
20/// Convert an `ndarray::Array2<f64>` to a `faer::Mat<f64>`.
21fn ndarray_to_faer(a: &Array2<f64>) -> faer::Mat<f64> {
22    let (nrows, ncols) = a.dim();
23    faer::Mat::from_fn(nrows, ncols, |i, j| a[[i, j]])
24}
25
26/// Convert a `faer::Mat<f64>` to an `ndarray::Array2<f64>`.
27fn faer_to_ndarray(m: &faer::Mat<f64>) -> Array2<f64> {
28    let (nrows, ncols) = m.shape();
29    Array2::from_shape_fn((nrows, ncols), |(i, j)| m[(i, j)])
30}
31
32/// Convert a `faer::MatRef<'_, f64>` to an `ndarray::Array2<f64>`.
33fn faer_ref_to_ndarray(m: faer::MatRef<'_, f64>) -> Array2<f64> {
34    let (nrows, ncols) = m.shape();
35    Array2::from_shape_fn((nrows, ncols), |(i, j)| m[(i, j)])
36}
37
38/// Convert a `faer::DiagRef<'_, f64>` to an `ndarray::Array1<f64>`.
39fn faer_diag_to_ndarray(d: faer::diag::DiagRef<'_, f64>) -> Array1<f64> {
40    let vals: Vec<f64> = d.column_vector().iter().copied().collect();
41    Array1::from_vec(vals)
42}
43
44/// The default backend using the `faer` crate for linear algebra operations.
45///
46/// This is a zero-sized type intended for use as a type parameter on
47/// algorithms that are generic over [`Backend`]:
48///
49/// ```ignore
50/// fn my_algorithm<B: Backend>(data: &Array2<f64>) -> FerroResult<Array1<f64>> {
51///     let (u, s, vt) = B::svd(data)?;
52///     // ...
53/// }
54///
55/// // Use the default backend:
56/// my_algorithm::<NdarrayFaerBackend>(&data)?;
57/// ```
58pub struct NdarrayFaerBackend;
59
60impl Backend for NdarrayFaerBackend {
61    fn gemm(a: &Array2<f64>, b: &Array2<f64>) -> FerroResult<Array2<f64>> {
62        if a.ncols() != b.nrows() {
63            return Err(FerroError::ShapeMismatch {
64                expected: vec![a.nrows(), a.ncols()],
65                actual: vec![b.nrows(), b.ncols()],
66                context: format!(
67                    "gemm: A is {}x{} but B is {}x{} (inner dimensions {} != {})",
68                    a.nrows(),
69                    a.ncols(),
70                    b.nrows(),
71                    b.ncols(),
72                    a.ncols(),
73                    b.nrows()
74                ),
75            });
76        }
77        Ok(a.dot(b))
78    }
79
80    fn svd(a: &Array2<f64>) -> FerroResult<(Array2<f64>, Array1<f64>, Array2<f64>)> {
81        let mat = ndarray_to_faer(a);
82        let decomp = mat.svd().map_err(|e| FerroError::NumericalInstability {
83            message: format!("SVD failed to converge: {e:?}"),
84        })?;
85
86        let u = faer_ref_to_ndarray(decomp.U());
87        let s = faer_diag_to_ndarray(decomp.S());
88        // faer returns V (not V^T), so we transpose it
89        let vt = faer_ref_to_ndarray(decomp.V().transpose());
90
91        Ok((u, s, vt))
92    }
93
94    fn qr(a: &Array2<f64>) -> FerroResult<(Array2<f64>, Array2<f64>)> {
95        let (m, n) = a.dim();
96        let mat = ndarray_to_faer(a);
97        let decomp = mat.qr();
98
99        let q = faer_to_ndarray(&decomp.compute_Q());
100
101        // faer stores R as min(m,n) x n. We need the full (m x n) upper
102        // trapezoidal R with zero rows below.
103        let r_compact = decomp.R();
104        let r_rows = r_compact.nrows();
105        let mut r = Array2::<f64>::zeros((m, n));
106        for i in 0..r_rows {
107            for j in 0..n {
108                r[[i, j]] = r_compact[(i, j)];
109            }
110        }
111
112        Ok((q, r))
113    }
114
115    fn cholesky(a: &Array2<f64>) -> FerroResult<Array2<f64>> {
116        let (nrows, ncols) = a.dim();
117        if nrows != ncols {
118            return Err(FerroError::ShapeMismatch {
119                expected: vec![nrows, nrows],
120                actual: vec![nrows, ncols],
121                context: "cholesky: matrix must be square".into(),
122            });
123        }
124
125        let mat = ndarray_to_faer(a);
126        let decomp = mat
127            .llt(faer::Side::Lower)
128            .map_err(|e| FerroError::NumericalInstability {
129                message: format!(
130                    "Cholesky decomposition failed (matrix not positive definite): {e:?}"
131                ),
132            })?;
133
134        Ok(faer_ref_to_ndarray(decomp.L()))
135    }
136
137    fn solve(a: &Array2<f64>, b: &Array1<f64>) -> FerroResult<Array1<f64>> {
138        let (nrows, ncols) = a.dim();
139        if nrows != ncols {
140            return Err(FerroError::ShapeMismatch {
141                expected: vec![nrows, nrows],
142                actual: vec![nrows, ncols],
143                context: "solve: coefficient matrix must be square".into(),
144            });
145        }
146        if b.len() != nrows {
147            return Err(FerroError::ShapeMismatch {
148                expected: vec![nrows],
149                actual: vec![b.len()],
150                context: format!("solve: b has length {} but A has {} rows", b.len(), nrows),
151            });
152        }
153
154        use faer::linalg::solvers::Solve;
155
156        let mat = ndarray_to_faer(a);
157        let rhs = faer::Mat::from_fn(nrows, 1, |i, _| b[i]);
158        let lu = mat.partial_piv_lu();
159        let result = lu.solve(rhs.as_ref());
160
161        Ok(Array1::from_shape_fn(nrows, |i| result[(i, 0)]))
162    }
163
164    fn eigh(a: &Array2<f64>) -> FerroResult<(Array1<f64>, Array2<f64>)> {
165        let (nrows, ncols) = a.dim();
166        if nrows != ncols {
167            return Err(FerroError::ShapeMismatch {
168                expected: vec![nrows, nrows],
169                actual: vec![nrows, ncols],
170                context: "eigh: matrix must be square".into(),
171            });
172        }
173
174        let mat = ndarray_to_faer(a);
175        let decomp = mat.self_adjoint_eigen(faer::Side::Lower).map_err(|e| {
176            FerroError::NumericalInstability {
177                message: format!("Symmetric eigendecomposition failed to converge: {e:?}"),
178            }
179        })?;
180
181        let eigenvalues = faer_diag_to_ndarray(decomp.S());
182        let eigenvectors = faer_ref_to_ndarray(decomp.U());
183
184        Ok((eigenvalues, eigenvectors))
185    }
186
187    fn det(a: &Array2<f64>) -> FerroResult<f64> {
188        let (nrows, ncols) = a.dim();
189        if nrows != ncols {
190            return Err(FerroError::ShapeMismatch {
191                expected: vec![nrows, nrows],
192                actual: vec![nrows, ncols],
193                context: "det: matrix must be square".into(),
194            });
195        }
196
197        let mat = ndarray_to_faer(a);
198        Ok(mat.as_ref().determinant())
199    }
200
201    fn inv(a: &Array2<f64>) -> FerroResult<Array2<f64>> {
202        let (nrows, ncols) = a.dim();
203        if nrows != ncols {
204            return Err(FerroError::ShapeMismatch {
205                expected: vec![nrows, nrows],
206                actual: vec![nrows, ncols],
207                context: "inv: matrix must be square".into(),
208            });
209        }
210
211        use faer::linalg::solvers::DenseSolveCore;
212
213        let mat = ndarray_to_faer(a);
214        let lu = mat.partial_piv_lu();
215        let inv_mat = lu.inverse();
216
217        Ok(faer_to_ndarray(&inv_mat))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use approx::assert_relative_eq;
225    use ndarray::array;
226
227    // -----------------------------------------------------------------------
228    // Helper: shorthand alias for the backend
229    // -----------------------------------------------------------------------
230    type B = NdarrayFaerBackend;
231
232    /// Assert that two 2D arrays are element-wise approximately equal.
233    fn assert_mat_eq(actual: &Array2<f64>, expected: &Array2<f64>, eps: f64) {
234        assert_eq!(actual.dim(), expected.dim(), "shape mismatch");
235        for ((i, j), &val) in actual.indexed_iter() {
236            assert_relative_eq!(val, expected[[i, j]], epsilon = eps,);
237        }
238    }
239
240    /// Assert that two 1D arrays are element-wise approximately equal.
241    fn assert_vec_eq(actual: &Array1<f64>, expected: &Array1<f64>, eps: f64) {
242        assert_eq!(actual.len(), expected.len(), "length mismatch");
243        for (i, &val) in actual.iter().enumerate() {
244            assert_relative_eq!(val, expected[i], epsilon = eps);
245        }
246    }
247
248    // -----------------------------------------------------------------------
249    // gemm tests
250    // -----------------------------------------------------------------------
251
252    #[test]
253    fn test_gemm_identity() {
254        let a = array![[1.0, 2.0], [3.0, 4.0]];
255        let eye = array![[1.0, 0.0], [0.0, 1.0]];
256        let c = B::gemm(&a, &eye).unwrap();
257        assert_mat_eq(&c, &a, 1e-12);
258    }
259
260    #[test]
261    fn test_gemm_known_result() {
262        // [[1,2],[3,4]] * [[5,6],[7,8]] = [[19,22],[43,50]]
263        let a = array![[1.0, 2.0], [3.0, 4.0]];
264        let b = array![[5.0, 6.0], [7.0, 8.0]];
265        let c = B::gemm(&a, &b).unwrap();
266        let expected = array![[19.0, 22.0], [43.0, 50.0]];
267        assert_mat_eq(&c, &expected, 1e-12);
268    }
269
270    #[test]
271    fn test_gemm_rectangular() {
272        // (2x3) * (3x2) = (2x2)
273        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
274        let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
275        let c = B::gemm(&a, &b).unwrap();
276        let expected = array![[58.0, 64.0], [139.0, 154.0]];
277        assert_mat_eq(&c, &expected, 1e-12);
278    }
279
280    #[test]
281    fn test_gemm_shape_mismatch() {
282        let a = array![[1.0, 2.0], [3.0, 4.0]];
283        let b = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
284        let result = B::gemm(&a, &b);
285        assert!(result.is_err());
286    }
287
288    // -----------------------------------------------------------------------
289    // svd tests
290    // -----------------------------------------------------------------------
291
292    #[test]
293    fn test_svd_identity() {
294        let eye = array![[1.0, 0.0], [0.0, 1.0]];
295        let (u, s, vt) = B::svd(&eye).unwrap();
296        // Singular values of identity are [1, 1]
297        for &val in s.iter() {
298            assert_relative_eq!(val, 1.0, epsilon = 1e-12);
299        }
300        // U * diag(S) * Vt should reconstruct the original
301        let reconstructed = reconstruct_svd(&u, &s, &vt);
302        assert_mat_eq(&reconstructed, &eye, 1e-12);
303    }
304
305    #[test]
306    fn test_svd_reconstruction() {
307        let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
308        let (u, s, vt) = B::svd(&a).unwrap();
309        let reconstructed = reconstruct_svd(&u, &s, &vt);
310        assert_mat_eq(&reconstructed, &a, 1e-10);
311    }
312
313    #[test]
314    fn test_svd_singular_values_descending() {
315        let a = array![[3.0, 1.0], [1.0, 3.0]];
316        let (_, s, _) = B::svd(&a).unwrap();
317        assert!(s[0] >= s[1], "singular values should be non-increasing");
318    }
319
320    // -----------------------------------------------------------------------
321    // qr tests
322    // -----------------------------------------------------------------------
323
324    #[test]
325    fn test_qr_reconstruction() {
326        let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
327        let (q, r) = B::qr(&a).unwrap();
328        let reconstructed = q.dot(&r);
329        assert_mat_eq(&reconstructed, &a, 1e-10);
330    }
331
332    #[test]
333    fn test_qr_orthogonality() {
334        let a = array![[1.0, 2.0], [3.0, 4.0]];
335        let (q, _) = B::qr(&a).unwrap();
336        // Q^T * Q should be identity
337        let qtq = q.t().dot(&q);
338        let eye = array![[1.0, 0.0], [0.0, 1.0]];
339        assert_mat_eq(&qtq, &eye, 1e-10);
340    }
341
342    #[test]
343    fn test_qr_identity() {
344        let eye = array![[1.0, 0.0], [0.0, 1.0]];
345        let (q, r) = B::qr(&eye).unwrap();
346        let reconstructed = q.dot(&r);
347        assert_mat_eq(&reconstructed, &eye, 1e-12);
348    }
349
350    // -----------------------------------------------------------------------
351    // cholesky tests
352    // -----------------------------------------------------------------------
353
354    #[test]
355    fn test_cholesky_known() {
356        // A = [[4, 2], [2, 3]], L = [[2, 0], [1, sqrt(2)]]
357        let a = array![[4.0, 2.0], [2.0, 3.0]];
358        let l = B::cholesky(&a).unwrap();
359        let reconstructed = l.dot(&l.t());
360        assert_mat_eq(&reconstructed, &a, 1e-10);
361    }
362
363    #[test]
364    fn test_cholesky_identity() {
365        let eye = array![[1.0, 0.0], [0.0, 1.0]];
366        let l = B::cholesky(&eye).unwrap();
367        assert_mat_eq(&l, &eye, 1e-12);
368    }
369
370    #[test]
371    fn test_cholesky_not_positive_definite() {
372        // Not positive definite
373        let a = array![[-1.0, 0.0], [0.0, -1.0]];
374        let result = B::cholesky(&a);
375        assert!(result.is_err());
376    }
377
378    // -----------------------------------------------------------------------
379    // solve tests
380    // -----------------------------------------------------------------------
381
382    #[test]
383    fn test_solve_simple() {
384        // [[2, 1], [1, 3]] * x = [5, 7] => x = [8/5, 9/5] = [1.6, 1.8]
385        let a = array![[2.0, 1.0], [1.0, 3.0]];
386        let b = array![5.0, 7.0];
387        let x = B::solve(&a, &b).unwrap();
388        assert_relative_eq!(x[0], 1.6, epsilon = 1e-10);
389        assert_relative_eq!(x[1], 1.8, epsilon = 1e-10);
390    }
391
392    #[test]
393    fn test_solve_identity() {
394        let eye = array![[1.0, 0.0], [0.0, 1.0]];
395        let b = array![3.0, 7.0];
396        let x = B::solve(&eye, &b).unwrap();
397        assert_vec_eq(&x, &b, 1e-12);
398    }
399
400    #[test]
401    fn test_solve_3x3() {
402        // A = [[1,2,3],[0,1,4],[5,6,0]], b = [1,2,3]
403        let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
404        let b = array![1.0, 2.0, 3.0];
405        let x = B::solve(&a, &b).unwrap();
406        // Verify: A * x = b
407        let ax = a.dot(&x);
408        assert_vec_eq(&ax, &b, 1e-10);
409    }
410
411    #[test]
412    fn test_solve_shape_mismatch() {
413        let a = array![[1.0, 2.0], [3.0, 4.0]];
414        let b = array![1.0, 2.0, 3.0]; // Wrong size
415        let result = B::solve(&a, &b);
416        assert!(result.is_err());
417    }
418
419    // -----------------------------------------------------------------------
420    // eigh tests
421    // -----------------------------------------------------------------------
422
423    #[test]
424    fn test_eigh_identity() {
425        let eye = array![[1.0, 0.0], [0.0, 1.0]];
426        let (eigenvalues, eigenvectors) = B::eigh(&eye).unwrap();
427        // All eigenvalues should be 1
428        for &val in eigenvalues.iter() {
429            assert_relative_eq!(val, 1.0, epsilon = 1e-12);
430        }
431        // V * V^T should be identity
432        let vvt = eigenvectors.dot(&eigenvectors.t());
433        assert_mat_eq(&vvt, &eye, 1e-12);
434    }
435
436    #[test]
437    fn test_eigh_symmetric() {
438        // Symmetric: [[2, 1], [1, 2]], eigenvalues = {1, 3}
439        let a = array![[2.0, 1.0], [1.0, 2.0]];
440        let (eigenvalues, eigenvectors) = B::eigh(&a).unwrap();
441        // faer returns eigenvalues sorted in non-decreasing order
442        assert_relative_eq!(eigenvalues[0], 1.0, epsilon = 1e-10);
443        assert_relative_eq!(eigenvalues[1], 3.0, epsilon = 1e-10);
444
445        // Reconstruct: A = V * diag(eigenvalues) * V^T
446        let reconstructed = reconstruct_eigh(&eigenvalues, &eigenvectors);
447        assert_mat_eq(&reconstructed, &a, 1e-10);
448    }
449
450    #[test]
451    fn test_eigh_eigenvalues_sorted() {
452        let a = array![[5.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]];
453        let (eigenvalues, _) = B::eigh(&a).unwrap();
454        for i in 1..eigenvalues.len() {
455            assert!(
456                eigenvalues[i] >= eigenvalues[i - 1],
457                "eigenvalues should be non-decreasing"
458            );
459        }
460    }
461
462    #[test]
463    fn test_eigh_not_square() {
464        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
465        let result = B::eigh(&a);
466        assert!(result.is_err());
467    }
468
469    // -----------------------------------------------------------------------
470    // det tests
471    // -----------------------------------------------------------------------
472
473    #[test]
474    fn test_det_identity() {
475        let eye = array![[1.0, 0.0], [0.0, 1.0]];
476        let d = B::det(&eye).unwrap();
477        assert_relative_eq!(d, 1.0, epsilon = 1e-12);
478    }
479
480    #[test]
481    fn test_det_known() {
482        // det([[1,2],[3,4]]) = 1*4 - 2*3 = -2
483        let a = array![[1.0, 2.0], [3.0, 4.0]];
484        let d = B::det(&a).unwrap();
485        assert_relative_eq!(d, -2.0, epsilon = 1e-10);
486    }
487
488    #[test]
489    fn test_det_singular() {
490        // Singular matrix: det = 0
491        let a = array![[1.0, 2.0], [2.0, 4.0]];
492        let d = B::det(&a).unwrap();
493        assert_relative_eq!(d, 0.0, epsilon = 1e-10);
494    }
495
496    #[test]
497    fn test_det_3x3() {
498        // det([[1,2,3],[0,1,4],[5,6,0]]) = 1*(0-24) - 2*(0-20) + 3*(0-5) = -24+40-15 = 1
499        let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
500        let d = B::det(&a).unwrap();
501        assert_relative_eq!(d, 1.0, epsilon = 1e-10);
502    }
503
504    #[test]
505    fn test_det_not_square() {
506        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
507        let result = B::det(&a);
508        assert!(result.is_err());
509    }
510
511    // -----------------------------------------------------------------------
512    // inv tests
513    // -----------------------------------------------------------------------
514
515    #[test]
516    fn test_inv_identity() {
517        let eye = array![[1.0, 0.0], [0.0, 1.0]];
518        let inv = B::inv(&eye).unwrap();
519        assert_mat_eq(&inv, &eye, 1e-12);
520    }
521
522    #[test]
523    fn test_inv_known() {
524        // inv([[1,2],[3,4]]) = 1/(-2) * [[4,-2],[-3,1]] = [[-2,1],[1.5,-0.5]]
525        let a = array![[1.0, 2.0], [3.0, 4.0]];
526        let inv = B::inv(&a).unwrap();
527        let expected = array![[-2.0, 1.0], [1.5, -0.5]];
528        assert_mat_eq(&inv, &expected, 1e-10);
529    }
530
531    #[test]
532    fn test_inv_roundtrip() {
533        let a = array![[4.0, 7.0], [2.0, 6.0]];
534        let inv = B::inv(&a).unwrap();
535        let product = a.dot(&inv);
536        let eye = array![[1.0, 0.0], [0.0, 1.0]];
537        assert_mat_eq(&product, &eye, 1e-10);
538    }
539
540    #[test]
541    fn test_inv_3x3() {
542        let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
543        let inv = B::inv(&a).unwrap();
544        let product = a.dot(&inv);
545        let eye = Array2::from_shape_fn((3, 3), |(i, j)| if i == j { 1.0 } else { 0.0 });
546        assert_mat_eq(&product, &eye, 1e-10);
547    }
548
549    #[test]
550    fn test_inv_not_square() {
551        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
552        let result = B::inv(&a);
553        assert!(result.is_err());
554    }
555
556    // -----------------------------------------------------------------------
557    // Cross-operation consistency tests
558    // -----------------------------------------------------------------------
559
560    #[test]
561    fn test_solve_via_inv() {
562        // solve(A, b) should give same result as inv(A) * b
563        let a = array![[2.0, 1.0], [1.0, 3.0]];
564        let b = array![5.0, 7.0];
565        let x_solve = B::solve(&a, &b).unwrap();
566        let a_inv = B::inv(&a).unwrap();
567        let x_inv = a_inv.dot(&b);
568        assert_vec_eq(&x_solve, &x_inv, 1e-10);
569    }
570
571    #[test]
572    fn test_det_via_eigh() {
573        // For a symmetric matrix, det = product of eigenvalues
574        let a = array![[4.0, 2.0], [2.0, 3.0]];
575        let det_direct = B::det(&a).unwrap();
576        let (eigenvalues, _) = B::eigh(&a).unwrap();
577        let det_from_eig: f64 = eigenvalues.iter().product();
578        assert_relative_eq!(det_direct, det_from_eig, epsilon = 1e-10);
579    }
580
581    // -----------------------------------------------------------------------
582    // Additional tests for completeness (20+ total)
583    // -----------------------------------------------------------------------
584
585    #[test]
586    fn test_gemm_single_element() {
587        let a = array![[3.0]];
588        let b = array![[4.0]];
589        let c = B::gemm(&a, &b).unwrap();
590        assert_relative_eq!(c[[0, 0]], 12.0, epsilon = 1e-12);
591    }
592
593    #[test]
594    fn test_svd_diagonal() {
595        let a = array![[3.0, 0.0], [0.0, 5.0]];
596        let (_, s, _) = B::svd(&a).unwrap();
597        // Singular values should be 5, 3 (descending)
598        assert_relative_eq!(s[0], 5.0, epsilon = 1e-10);
599        assert_relative_eq!(s[1], 3.0, epsilon = 1e-10);
600    }
601
602    #[test]
603    fn test_cholesky_3x3() {
604        // A = X^T * X guarantees positive definite
605        let x = array![[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]];
606        let a = x.t().dot(&x);
607        let l = B::cholesky(&a).unwrap();
608        let reconstructed = l.dot(&l.t());
609        assert_mat_eq(&reconstructed, &a, 1e-10);
610    }
611
612    #[test]
613    fn test_eigh_reconstruction_3x3() {
614        let a = array![[5.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]];
615        let (eigenvalues, eigenvectors) = B::eigh(&a).unwrap();
616        let reconstructed = reconstruct_eigh(&eigenvalues, &eigenvectors);
617        assert_mat_eq(&reconstructed, &a, 1e-10);
618    }
619
620    #[test]
621    fn test_backend_is_send_sync() {
622        fn assert_send_sync<T: Send + Sync + 'static>() {}
623        assert_send_sync::<NdarrayFaerBackend>();
624    }
625
626    #[test]
627    fn test_cholesky_non_square_error() {
628        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
629        let result = B::cholesky(&a);
630        assert!(result.is_err());
631    }
632
633    #[test]
634    fn test_solve_non_square_error() {
635        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
636        let b = array![1.0, 2.0];
637        let result = B::solve(&a, &b);
638        assert!(result.is_err());
639    }
640
641    // -----------------------------------------------------------------------
642    // Helpers for reconstruction
643    // -----------------------------------------------------------------------
644
645    /// Reconstruct a matrix from its SVD: A = U * diag(S) * Vt.
646    fn reconstruct_svd(u: &Array2<f64>, s: &Array1<f64>, vt: &Array2<f64>) -> Array2<f64> {
647        let m = u.nrows();
648        let n = vt.ncols();
649        let k = s.len();
650        let mut result = Array2::zeros((m, n));
651        for i in 0..m {
652            for j in 0..n {
653                let mut sum = 0.0;
654                for l in 0..k {
655                    sum += u[[i, l]] * s[l] * vt[[l, j]];
656                }
657                result[[i, j]] = sum;
658            }
659        }
660        result
661    }
662
663    /// Reconstruct a matrix from symmetric eigendecomposition: A = V * diag(eigenvalues) * V^T.
664    fn reconstruct_eigh(eigenvalues: &Array1<f64>, v: &Array2<f64>) -> Array2<f64> {
665        let n = eigenvalues.len();
666        let mut result = Array2::zeros((n, n));
667        for i in 0..n {
668            for j in 0..n {
669                let mut sum = 0.0;
670                for k in 0..n {
671                    sum += v[[i, k]] * eigenvalues[k] * v[[j, k]];
672                }
673                result[[i, j]] = sum;
674            }
675        }
676        result
677    }
678}