1use crate::backend::Backend;
17use crate::error::{FerroError, FerroResult};
18use ndarray::{Array1, Array2};
19
20fn ndarray_to_faer(a: &Array2<f64>) -> faer::Mat<f64> {
22 let (nrows, ncols) = a.dim();
23 faer::Mat::from_fn(nrows, ncols, |i, j| a[[i, j]])
24}
25
26fn faer_to_ndarray(m: &faer::Mat<f64>) -> Array2<f64> {
28 let (nrows, ncols) = m.shape();
29 Array2::from_shape_fn((nrows, ncols), |(i, j)| m[(i, j)])
30}
31
32fn faer_ref_to_ndarray(m: faer::MatRef<'_, f64>) -> Array2<f64> {
34 let (nrows, ncols) = m.shape();
35 Array2::from_shape_fn((nrows, ncols), |(i, j)| m[(i, j)])
36}
37
38fn faer_diag_to_ndarray(d: faer::diag::DiagRef<'_, f64>) -> Array1<f64> {
40 let vals: Vec<f64> = d.column_vector().iter().copied().collect();
41 Array1::from_vec(vals)
42}
43
44pub struct NdarrayFaerBackend;
59
60impl Backend for NdarrayFaerBackend {
61 fn gemm(a: &Array2<f64>, b: &Array2<f64>) -> FerroResult<Array2<f64>> {
62 if a.ncols() != b.nrows() {
63 return Err(FerroError::ShapeMismatch {
64 expected: vec![a.nrows(), a.ncols()],
65 actual: vec![b.nrows(), b.ncols()],
66 context: format!(
67 "gemm: A is {}x{} but B is {}x{} (inner dimensions {} != {})",
68 a.nrows(),
69 a.ncols(),
70 b.nrows(),
71 b.ncols(),
72 a.ncols(),
73 b.nrows()
74 ),
75 });
76 }
77 Ok(a.dot(b))
78 }
79
80 fn svd(a: &Array2<f64>) -> FerroResult<(Array2<f64>, Array1<f64>, Array2<f64>)> {
81 let mat = ndarray_to_faer(a);
82 let decomp = mat.svd().map_err(|e| FerroError::NumericalInstability {
83 message: format!("SVD failed to converge: {e:?}"),
84 })?;
85
86 let u = faer_ref_to_ndarray(decomp.U());
87 let s = faer_diag_to_ndarray(decomp.S());
88 let vt = faer_ref_to_ndarray(decomp.V().transpose());
90
91 Ok((u, s, vt))
92 }
93
94 fn qr(a: &Array2<f64>) -> FerroResult<(Array2<f64>, Array2<f64>)> {
95 let (m, n) = a.dim();
96 let mat = ndarray_to_faer(a);
97 let decomp = mat.qr();
98
99 let q = faer_to_ndarray(&decomp.compute_Q());
100
101 let r_compact = decomp.R();
104 let r_rows = r_compact.nrows();
105 let mut r = Array2::<f64>::zeros((m, n));
106 for i in 0..r_rows {
107 for j in 0..n {
108 r[[i, j]] = r_compact[(i, j)];
109 }
110 }
111
112 Ok((q, r))
113 }
114
115 fn cholesky(a: &Array2<f64>) -> FerroResult<Array2<f64>> {
116 let (nrows, ncols) = a.dim();
117 if nrows != ncols {
118 return Err(FerroError::ShapeMismatch {
119 expected: vec![nrows, nrows],
120 actual: vec![nrows, ncols],
121 context: "cholesky: matrix must be square".into(),
122 });
123 }
124
125 let mat = ndarray_to_faer(a);
126 let decomp = mat
127 .llt(faer::Side::Lower)
128 .map_err(|e| FerroError::NumericalInstability {
129 message: format!(
130 "Cholesky decomposition failed (matrix not positive definite): {e:?}"
131 ),
132 })?;
133
134 Ok(faer_ref_to_ndarray(decomp.L()))
135 }
136
137 fn solve(a: &Array2<f64>, b: &Array1<f64>) -> FerroResult<Array1<f64>> {
138 let (nrows, ncols) = a.dim();
139 if nrows != ncols {
140 return Err(FerroError::ShapeMismatch {
141 expected: vec![nrows, nrows],
142 actual: vec![nrows, ncols],
143 context: "solve: coefficient matrix must be square".into(),
144 });
145 }
146 if b.len() != nrows {
147 return Err(FerroError::ShapeMismatch {
148 expected: vec![nrows],
149 actual: vec![b.len()],
150 context: format!("solve: b has length {} but A has {} rows", b.len(), nrows),
151 });
152 }
153
154 use faer::linalg::solvers::Solve;
155
156 let mat = ndarray_to_faer(a);
157 let rhs = faer::Mat::from_fn(nrows, 1, |i, _| b[i]);
158 let lu = mat.partial_piv_lu();
159 let result = lu.solve(rhs.as_ref());
160
161 Ok(Array1::from_shape_fn(nrows, |i| result[(i, 0)]))
162 }
163
164 fn eigh(a: &Array2<f64>) -> FerroResult<(Array1<f64>, Array2<f64>)> {
165 let (nrows, ncols) = a.dim();
166 if nrows != ncols {
167 return Err(FerroError::ShapeMismatch {
168 expected: vec![nrows, nrows],
169 actual: vec![nrows, ncols],
170 context: "eigh: matrix must be square".into(),
171 });
172 }
173
174 let mat = ndarray_to_faer(a);
175 let decomp = mat.self_adjoint_eigen(faer::Side::Lower).map_err(|e| {
176 FerroError::NumericalInstability {
177 message: format!("Symmetric eigendecomposition failed to converge: {e:?}"),
178 }
179 })?;
180
181 let eigenvalues = faer_diag_to_ndarray(decomp.S());
182 let eigenvectors = faer_ref_to_ndarray(decomp.U());
183
184 Ok((eigenvalues, eigenvectors))
185 }
186
187 fn det(a: &Array2<f64>) -> FerroResult<f64> {
188 let (nrows, ncols) = a.dim();
189 if nrows != ncols {
190 return Err(FerroError::ShapeMismatch {
191 expected: vec![nrows, nrows],
192 actual: vec![nrows, ncols],
193 context: "det: matrix must be square".into(),
194 });
195 }
196
197 let mat = ndarray_to_faer(a);
198 Ok(mat.as_ref().determinant())
199 }
200
201 fn inv(a: &Array2<f64>) -> FerroResult<Array2<f64>> {
202 let (nrows, ncols) = a.dim();
203 if nrows != ncols {
204 return Err(FerroError::ShapeMismatch {
205 expected: vec![nrows, nrows],
206 actual: vec![nrows, ncols],
207 context: "inv: matrix must be square".into(),
208 });
209 }
210
211 use faer::linalg::solvers::DenseSolveCore;
212
213 let mat = ndarray_to_faer(a);
214 let lu = mat.partial_piv_lu();
215 let inv_mat = lu.inverse();
216
217 Ok(faer_to_ndarray(&inv_mat))
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use approx::assert_relative_eq;
225 use ndarray::array;
226
227 type B = NdarrayFaerBackend;
231
232 fn assert_mat_eq(actual: &Array2<f64>, expected: &Array2<f64>, eps: f64) {
234 assert_eq!(actual.dim(), expected.dim(), "shape mismatch");
235 for ((i, j), &val) in actual.indexed_iter() {
236 assert_relative_eq!(val, expected[[i, j]], epsilon = eps,);
237 }
238 }
239
240 fn assert_vec_eq(actual: &Array1<f64>, expected: &Array1<f64>, eps: f64) {
242 assert_eq!(actual.len(), expected.len(), "length mismatch");
243 for (i, &val) in actual.iter().enumerate() {
244 assert_relative_eq!(val, expected[i], epsilon = eps);
245 }
246 }
247
248 #[test]
253 fn test_gemm_identity() {
254 let a = array![[1.0, 2.0], [3.0, 4.0]];
255 let eye = array![[1.0, 0.0], [0.0, 1.0]];
256 let c = B::gemm(&a, &eye).unwrap();
257 assert_mat_eq(&c, &a, 1e-12);
258 }
259
260 #[test]
261 fn test_gemm_known_result() {
262 let a = array![[1.0, 2.0], [3.0, 4.0]];
264 let b = array![[5.0, 6.0], [7.0, 8.0]];
265 let c = B::gemm(&a, &b).unwrap();
266 let expected = array![[19.0, 22.0], [43.0, 50.0]];
267 assert_mat_eq(&c, &expected, 1e-12);
268 }
269
270 #[test]
271 fn test_gemm_rectangular() {
272 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
274 let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
275 let c = B::gemm(&a, &b).unwrap();
276 let expected = array![[58.0, 64.0], [139.0, 154.0]];
277 assert_mat_eq(&c, &expected, 1e-12);
278 }
279
280 #[test]
281 fn test_gemm_shape_mismatch() {
282 let a = array![[1.0, 2.0], [3.0, 4.0]];
283 let b = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
284 let result = B::gemm(&a, &b);
285 assert!(result.is_err());
286 }
287
288 #[test]
293 fn test_svd_identity() {
294 let eye = array![[1.0, 0.0], [0.0, 1.0]];
295 let (u, s, vt) = B::svd(&eye).unwrap();
296 for &val in s.iter() {
298 assert_relative_eq!(val, 1.0, epsilon = 1e-12);
299 }
300 let reconstructed = reconstruct_svd(&u, &s, &vt);
302 assert_mat_eq(&reconstructed, &eye, 1e-12);
303 }
304
305 #[test]
306 fn test_svd_reconstruction() {
307 let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
308 let (u, s, vt) = B::svd(&a).unwrap();
309 let reconstructed = reconstruct_svd(&u, &s, &vt);
310 assert_mat_eq(&reconstructed, &a, 1e-10);
311 }
312
313 #[test]
314 fn test_svd_singular_values_descending() {
315 let a = array![[3.0, 1.0], [1.0, 3.0]];
316 let (_, s, _) = B::svd(&a).unwrap();
317 assert!(s[0] >= s[1], "singular values should be non-increasing");
318 }
319
320 #[test]
325 fn test_qr_reconstruction() {
326 let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
327 let (q, r) = B::qr(&a).unwrap();
328 let reconstructed = q.dot(&r);
329 assert_mat_eq(&reconstructed, &a, 1e-10);
330 }
331
332 #[test]
333 fn test_qr_orthogonality() {
334 let a = array![[1.0, 2.0], [3.0, 4.0]];
335 let (q, _) = B::qr(&a).unwrap();
336 let qtq = q.t().dot(&q);
338 let eye = array![[1.0, 0.0], [0.0, 1.0]];
339 assert_mat_eq(&qtq, &eye, 1e-10);
340 }
341
342 #[test]
343 fn test_qr_identity() {
344 let eye = array![[1.0, 0.0], [0.0, 1.0]];
345 let (q, r) = B::qr(&eye).unwrap();
346 let reconstructed = q.dot(&r);
347 assert_mat_eq(&reconstructed, &eye, 1e-12);
348 }
349
350 #[test]
355 fn test_cholesky_known() {
356 let a = array![[4.0, 2.0], [2.0, 3.0]];
358 let l = B::cholesky(&a).unwrap();
359 let reconstructed = l.dot(&l.t());
360 assert_mat_eq(&reconstructed, &a, 1e-10);
361 }
362
363 #[test]
364 fn test_cholesky_identity() {
365 let eye = array![[1.0, 0.0], [0.0, 1.0]];
366 let l = B::cholesky(&eye).unwrap();
367 assert_mat_eq(&l, &eye, 1e-12);
368 }
369
370 #[test]
371 fn test_cholesky_not_positive_definite() {
372 let a = array![[-1.0, 0.0], [0.0, -1.0]];
374 let result = B::cholesky(&a);
375 assert!(result.is_err());
376 }
377
378 #[test]
383 fn test_solve_simple() {
384 let a = array![[2.0, 1.0], [1.0, 3.0]];
386 let b = array![5.0, 7.0];
387 let x = B::solve(&a, &b).unwrap();
388 assert_relative_eq!(x[0], 1.6, epsilon = 1e-10);
389 assert_relative_eq!(x[1], 1.8, epsilon = 1e-10);
390 }
391
392 #[test]
393 fn test_solve_identity() {
394 let eye = array![[1.0, 0.0], [0.0, 1.0]];
395 let b = array![3.0, 7.0];
396 let x = B::solve(&eye, &b).unwrap();
397 assert_vec_eq(&x, &b, 1e-12);
398 }
399
400 #[test]
401 fn test_solve_3x3() {
402 let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
404 let b = array![1.0, 2.0, 3.0];
405 let x = B::solve(&a, &b).unwrap();
406 let ax = a.dot(&x);
408 assert_vec_eq(&ax, &b, 1e-10);
409 }
410
411 #[test]
412 fn test_solve_shape_mismatch() {
413 let a = array![[1.0, 2.0], [3.0, 4.0]];
414 let b = array![1.0, 2.0, 3.0]; let result = B::solve(&a, &b);
416 assert!(result.is_err());
417 }
418
419 #[test]
424 fn test_eigh_identity() {
425 let eye = array![[1.0, 0.0], [0.0, 1.0]];
426 let (eigenvalues, eigenvectors) = B::eigh(&eye).unwrap();
427 for &val in eigenvalues.iter() {
429 assert_relative_eq!(val, 1.0, epsilon = 1e-12);
430 }
431 let vvt = eigenvectors.dot(&eigenvectors.t());
433 assert_mat_eq(&vvt, &eye, 1e-12);
434 }
435
436 #[test]
437 fn test_eigh_symmetric() {
438 let a = array![[2.0, 1.0], [1.0, 2.0]];
440 let (eigenvalues, eigenvectors) = B::eigh(&a).unwrap();
441 assert_relative_eq!(eigenvalues[0], 1.0, epsilon = 1e-10);
443 assert_relative_eq!(eigenvalues[1], 3.0, epsilon = 1e-10);
444
445 let reconstructed = reconstruct_eigh(&eigenvalues, &eigenvectors);
447 assert_mat_eq(&reconstructed, &a, 1e-10);
448 }
449
450 #[test]
451 fn test_eigh_eigenvalues_sorted() {
452 let a = array![[5.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]];
453 let (eigenvalues, _) = B::eigh(&a).unwrap();
454 for i in 1..eigenvalues.len() {
455 assert!(
456 eigenvalues[i] >= eigenvalues[i - 1],
457 "eigenvalues should be non-decreasing"
458 );
459 }
460 }
461
462 #[test]
463 fn test_eigh_not_square() {
464 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
465 let result = B::eigh(&a);
466 assert!(result.is_err());
467 }
468
469 #[test]
474 fn test_det_identity() {
475 let eye = array![[1.0, 0.0], [0.0, 1.0]];
476 let d = B::det(&eye).unwrap();
477 assert_relative_eq!(d, 1.0, epsilon = 1e-12);
478 }
479
480 #[test]
481 fn test_det_known() {
482 let a = array![[1.0, 2.0], [3.0, 4.0]];
484 let d = B::det(&a).unwrap();
485 assert_relative_eq!(d, -2.0, epsilon = 1e-10);
486 }
487
488 #[test]
489 fn test_det_singular() {
490 let a = array![[1.0, 2.0], [2.0, 4.0]];
492 let d = B::det(&a).unwrap();
493 assert_relative_eq!(d, 0.0, epsilon = 1e-10);
494 }
495
496 #[test]
497 fn test_det_3x3() {
498 let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
500 let d = B::det(&a).unwrap();
501 assert_relative_eq!(d, 1.0, epsilon = 1e-10);
502 }
503
504 #[test]
505 fn test_det_not_square() {
506 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
507 let result = B::det(&a);
508 assert!(result.is_err());
509 }
510
511 #[test]
516 fn test_inv_identity() {
517 let eye = array![[1.0, 0.0], [0.0, 1.0]];
518 let inv = B::inv(&eye).unwrap();
519 assert_mat_eq(&inv, &eye, 1e-12);
520 }
521
522 #[test]
523 fn test_inv_known() {
524 let a = array![[1.0, 2.0], [3.0, 4.0]];
526 let inv = B::inv(&a).unwrap();
527 let expected = array![[-2.0, 1.0], [1.5, -0.5]];
528 assert_mat_eq(&inv, &expected, 1e-10);
529 }
530
531 #[test]
532 fn test_inv_roundtrip() {
533 let a = array![[4.0, 7.0], [2.0, 6.0]];
534 let inv = B::inv(&a).unwrap();
535 let product = a.dot(&inv);
536 let eye = array![[1.0, 0.0], [0.0, 1.0]];
537 assert_mat_eq(&product, &eye, 1e-10);
538 }
539
540 #[test]
541 fn test_inv_3x3() {
542 let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
543 let inv = B::inv(&a).unwrap();
544 let product = a.dot(&inv);
545 let eye = Array2::from_shape_fn((3, 3), |(i, j)| if i == j { 1.0 } else { 0.0 });
546 assert_mat_eq(&product, &eye, 1e-10);
547 }
548
549 #[test]
550 fn test_inv_not_square() {
551 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
552 let result = B::inv(&a);
553 assert!(result.is_err());
554 }
555
556 #[test]
561 fn test_solve_via_inv() {
562 let a = array![[2.0, 1.0], [1.0, 3.0]];
564 let b = array![5.0, 7.0];
565 let x_solve = B::solve(&a, &b).unwrap();
566 let a_inv = B::inv(&a).unwrap();
567 let x_inv = a_inv.dot(&b);
568 assert_vec_eq(&x_solve, &x_inv, 1e-10);
569 }
570
571 #[test]
572 fn test_det_via_eigh() {
573 let a = array![[4.0, 2.0], [2.0, 3.0]];
575 let det_direct = B::det(&a).unwrap();
576 let (eigenvalues, _) = B::eigh(&a).unwrap();
577 let det_from_eig: f64 = eigenvalues.iter().product();
578 assert_relative_eq!(det_direct, det_from_eig, epsilon = 1e-10);
579 }
580
581 #[test]
586 fn test_gemm_single_element() {
587 let a = array![[3.0]];
588 let b = array![[4.0]];
589 let c = B::gemm(&a, &b).unwrap();
590 assert_relative_eq!(c[[0, 0]], 12.0, epsilon = 1e-12);
591 }
592
593 #[test]
594 fn test_svd_diagonal() {
595 let a = array![[3.0, 0.0], [0.0, 5.0]];
596 let (_, s, _) = B::svd(&a).unwrap();
597 assert_relative_eq!(s[0], 5.0, epsilon = 1e-10);
599 assert_relative_eq!(s[1], 3.0, epsilon = 1e-10);
600 }
601
602 #[test]
603 fn test_cholesky_3x3() {
604 let x = array![[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]];
606 let a = x.t().dot(&x);
607 let l = B::cholesky(&a).unwrap();
608 let reconstructed = l.dot(&l.t());
609 assert_mat_eq(&reconstructed, &a, 1e-10);
610 }
611
612 #[test]
613 fn test_eigh_reconstruction_3x3() {
614 let a = array![[5.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 2.0]];
615 let (eigenvalues, eigenvectors) = B::eigh(&a).unwrap();
616 let reconstructed = reconstruct_eigh(&eigenvalues, &eigenvectors);
617 assert_mat_eq(&reconstructed, &a, 1e-10);
618 }
619
620 #[test]
621 fn test_backend_is_send_sync() {
622 fn assert_send_sync<T: Send + Sync + 'static>() {}
623 assert_send_sync::<NdarrayFaerBackend>();
624 }
625
626 #[test]
627 fn test_cholesky_non_square_error() {
628 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
629 let result = B::cholesky(&a);
630 assert!(result.is_err());
631 }
632
633 #[test]
634 fn test_solve_non_square_error() {
635 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
636 let b = array![1.0, 2.0];
637 let result = B::solve(&a, &b);
638 assert!(result.is_err());
639 }
640
641 fn reconstruct_svd(u: &Array2<f64>, s: &Array1<f64>, vt: &Array2<f64>) -> Array2<f64> {
647 let m = u.nrows();
648 let n = vt.ncols();
649 let k = s.len();
650 let mut result = Array2::zeros((m, n));
651 for i in 0..m {
652 for j in 0..n {
653 let mut sum = 0.0;
654 for l in 0..k {
655 sum += u[[i, l]] * s[l] * vt[[l, j]];
656 }
657 result[[i, j]] = sum;
658 }
659 }
660 result
661 }
662
663 fn reconstruct_eigh(eigenvalues: &Array1<f64>, v: &Array2<f64>) -> Array2<f64> {
665 let n = eigenvalues.len();
666 let mut result = Array2::zeros((n, n));
667 for i in 0..n {
668 for j in 0..n {
669 let mut sum = 0.0;
670 for k in 0..n {
671 sum += v[[i, k]] * eigenvalues[k] * v[[j, k]];
672 }
673 result[[i, j]] = sum;
674 }
675 }
676 result
677 }
678}