oxiblas_ndarray/
lapack.rs

1//! LAPACK decompositions and operations on ndarray types.
2//!
3//! This module provides LAPACK decompositions (LU, QR, SVD, EVD, Cholesky)
4//! directly on ndarray types, using OxiBLAS-LAPACK as the backend.
5
6use crate::conversions::{array2_to_mat, mat_ref_to_array2, mat_to_array2};
7use ndarray::{Array1, Array2};
8use oxiblas_core::scalar::Field;
9use oxiblas_lapack::{cholesky, evd, lu, qr, solve, svd};
10use oxiblas_matrix::Mat;
11
12// Re-export useful types
13pub use evd::Eigenvalue;
14
15// =============================================================================
16// Error Types
17// =============================================================================
18
19/// Error type for LAPACK operations on ndarray.
20#[derive(Debug, Clone)]
21pub enum LapackError {
22    /// Matrix is singular or nearly singular
23    Singular(String),
24    /// Matrix is not positive definite
25    NotPositiveDefinite(String),
26    /// Dimension mismatch
27    DimensionMismatch(String),
28    /// Decomposition did not converge
29    NotConverged(String),
30    /// Other error
31    Other(String),
32}
33
34impl std::fmt::Display for LapackError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Self::Singular(msg) => write!(f, "Singular matrix: {msg}"),
38            Self::NotPositiveDefinite(msg) => write!(f, "Not positive definite: {msg}"),
39            Self::DimensionMismatch(msg) => write!(f, "Dimension mismatch: {msg}"),
40            Self::NotConverged(msg) => write!(f, "Did not converge: {msg}"),
41            Self::Other(msg) => write!(f, "LAPACK error: {msg}"),
42        }
43    }
44}
45
46impl std::error::Error for LapackError {}
47
48/// Result type for LAPACK operations.
49pub type LapackResult<T> = Result<T, LapackError>;
50
51// =============================================================================
52// LU Decomposition
53// =============================================================================
54
55/// Result of LU decomposition.
56#[derive(Debug, Clone)]
57pub struct LuResult<T> {
58    /// L factor (lower triangular with unit diagonal)
59    pub l: Array2<T>,
60    /// U factor (upper triangular)
61    pub u: Array2<T>,
62    /// Permutation vector
63    pub perm: Vec<usize>,
64}
65
66impl<T: Field + Clone> LuResult<T>
67where
68    T: bytemuck::Zeroable,
69{
70    /// Solves Ax = b using the LU decomposition.
71    pub fn solve(&self, b: &Array1<T>) -> Array1<T> {
72        let n = self.l.dim().0;
73        assert_eq!(b.len(), n, "b length must match matrix dimension");
74
75        // Apply permutation to b
76        let mut pb: Vec<T> = vec![T::zero(); n];
77        for i in 0..n {
78            pb[i] = b[self.perm[i]];
79        }
80
81        // Forward substitution: L * y = pb
82        let mut y: Vec<T> = vec![T::zero(); n];
83        for i in 0..n {
84            let mut sum = pb[i];
85            for j in 0..i {
86                sum -= self.l[[i, j]] * y[j];
87            }
88            y[i] = sum;
89        }
90
91        // Back substitution: U * x = y
92        let mut x: Vec<T> = vec![T::zero(); n];
93        for i in (0..n).rev() {
94            let mut sum = y[i];
95            for j in (i + 1)..n {
96                sum -= self.u[[i, j]] * x[j];
97            }
98            x[i] = sum / self.u[[i, i]];
99        }
100
101        Array1::from_vec(x)
102    }
103
104    /// Computes the determinant.
105    pub fn det(&self) -> T {
106        let n = self.l.dim().0;
107        let mut det = T::one();
108
109        // Product of U diagonal elements
110        for i in 0..n {
111            det *= self.u[[i, i]];
112        }
113
114        // Account for permutation sign
115        let mut sign_changes = 0;
116        let mut visited = vec![false; n];
117        for i in 0..n {
118            if visited[i] {
119                continue;
120            }
121            let mut j = i;
122            let mut cycle_len = 0;
123            while !visited[j] {
124                visited[j] = true;
125                j = self.perm[j];
126                cycle_len += 1;
127            }
128            if cycle_len > 1 {
129                sign_changes += cycle_len - 1;
130            }
131        }
132
133        if sign_changes % 2 == 1 {
134            det = T::zero() - det;
135        }
136
137        det
138    }
139}
140
141/// Computes the LU decomposition of a matrix.
142///
143/// A = P * L * U
144///
145/// # Arguments
146/// * `a` - The input matrix (m×n)
147///
148/// # Returns
149/// LU decomposition with L, U, and permutation
150pub fn lu_ndarray<T: Field + Clone>(a: &Array2<T>) -> LapackResult<LuResult<T>>
151where
152    T: bytemuck::Zeroable,
153{
154    let mat = array2_to_mat(a);
155
156    match lu::Lu::compute(mat.as_ref()) {
157        Ok(lu_decomp) => {
158            // Extract L and U factors
159            let l = mat_to_array2(&lu_decomp.l_factor());
160            let u = mat_to_array2(&lu_decomp.u_factor());
161
162            // Get permutation
163            let perm = lu_decomp.pivot().to_vec();
164
165            Ok(LuResult { l, u, perm })
166        }
167        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
168    }
169}
170
171// =============================================================================
172// QR Decomposition
173// =============================================================================
174
175/// Result of QR decomposition.
176#[derive(Debug, Clone)]
177pub struct QrResult<T> {
178    /// Q factor (orthogonal/unitary)
179    pub q: Array2<T>,
180    /// R factor (upper triangular)
181    pub r: Array2<T>,
182}
183
184impl<T: Field + Clone> QrResult<T> {
185    /// Solves the least squares problem min ||Ax - b||.
186    pub fn solve_least_squares(&self, b: &Array1<T>) -> Array1<T> {
187        let (m, n) = (self.q.dim().0, self.r.dim().1);
188        assert_eq!(b.len(), m, "b length must match matrix rows");
189
190        // Compute Q^T * b (or Q^H for complex)
191        let mut qtb: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
192        for j in 0..n {
193            let mut sum = T::zero();
194            for i in 0..m {
195                sum += self.q[[i, j]].conj() * b[i];
196            }
197            qtb[j] = sum;
198        }
199
200        // Back substitution: R * x = Q^T * b
201        let mut x: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
202        for i in (0..n).rev() {
203            let mut sum = qtb[i];
204            for j in (i + 1)..n {
205                sum -= self.r[[i, j]] * x[j];
206            }
207            x[i] = sum / self.r[[i, i]];
208        }
209
210        x
211    }
212}
213
214/// Computes the QR decomposition of a matrix.
215///
216/// A = Q * R
217///
218/// # Arguments
219/// * `a` - The input matrix (m×n)
220///
221/// # Returns
222/// QR decomposition with Q and R
223pub fn qr_ndarray<T: Field + Clone>(a: &Array2<T>) -> LapackResult<QrResult<T>>
224where
225    T: bytemuck::Zeroable + oxiblas_core::scalar::Real,
226{
227    let mat = array2_to_mat(a);
228
229    match qr::Qr::compute(mat.as_ref()) {
230        Ok(qr_decomp) => {
231            let q = mat_to_array2(&qr_decomp.q());
232            let r = mat_to_array2(&qr_decomp.r());
233
234            Ok(QrResult { q, r })
235        }
236        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
237    }
238}
239
240// =============================================================================
241// Singular Value Decomposition
242// =============================================================================
243
244/// Result of SVD decomposition.
245#[derive(Debug, Clone)]
246pub struct SvdResult<T> {
247    /// Left singular vectors U (m×k where k = min(m,n))
248    pub u: Array2<T>,
249    /// Singular values σ (sorted in descending order)
250    pub s: Array1<T>,
251    /// Right singular vectors V^T (k×n)
252    pub vt: Array2<T>,
253}
254
255impl<T: Field + Clone> SvdResult<T> {
256    /// Returns the rank based on a tolerance.
257    pub fn rank(&self, tol: T) -> usize {
258        self.s.iter().filter(|&s| s.abs() > tol.abs()).count()
259    }
260}
261
262/// Computes the SVD of a matrix.
263///
264/// A = U * Σ * V^T
265///
266/// # Arguments
267/// * `a` - The input matrix (m×n)
268///
269/// # Returns
270/// SVD with U, S (singular values), and V^T
271pub fn svd_ndarray<T>(a: &Array2<T>) -> LapackResult<SvdResult<T>>
272where
273    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
274{
275    let mat = array2_to_mat(a);
276
277    match svd::Svd::compute(mat.as_ref()) {
278        Ok(svd_decomp) => {
279            let u = mat_ref_to_array2(svd_decomp.u());
280            let s = Array1::from_vec(svd_decomp.singular_values().to_vec());
281            let vt = mat_ref_to_array2(svd_decomp.vt());
282
283            Ok(SvdResult { u, s, vt })
284        }
285        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
286    }
287}
288
289/// Computes the truncated SVD of a matrix.
290///
291/// Returns only the top k singular values and vectors.
292pub fn svd_truncated<T>(a: &Array2<T>, k: usize) -> LapackResult<SvdResult<T>>
293where
294    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
295{
296    let svd_result = svd_ndarray(a)?;
297
298    let actual_k = k.min(svd_result.s.len());
299
300    // Truncate to k components
301    let u = svd_result.u.slice(ndarray::s![.., ..actual_k]).to_owned();
302    let s = svd_result.s.slice(ndarray::s![..actual_k]).to_owned();
303    let vt = svd_result.vt.slice(ndarray::s![..actual_k, ..]).to_owned();
304
305    Ok(SvdResult { u, s, vt })
306}
307
308// =============================================================================
309// Eigenvalue Decomposition (Symmetric)
310// =============================================================================
311
312/// Result of symmetric eigenvalue decomposition.
313#[derive(Debug, Clone)]
314pub struct SymEvdResult<T> {
315    /// Eigenvalues (sorted in ascending order)
316    pub eigenvalues: Array1<T>,
317    /// Eigenvectors (columns are eigenvectors)
318    pub eigenvectors: Array2<T>,
319}
320
321/// Computes the eigenvalue decomposition of a symmetric matrix.
322///
323/// A * V = V * Λ where Λ = diag(eigenvalues)
324///
325/// # Arguments
326/// * `a` - The input symmetric matrix (n×n)
327///
328/// # Returns
329/// Eigenvalues and eigenvectors
330pub fn eig_symmetric<T>(a: &Array2<T>) -> LapackResult<SymEvdResult<T>>
331where
332    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
333{
334    let (m, n) = a.dim();
335    if m != n {
336        return Err(LapackError::DimensionMismatch(
337            "Matrix must be square".to_string(),
338        ));
339    }
340
341    let mat = array2_to_mat(a);
342
343    match evd::SymmetricEvd::compute(mat.as_ref()) {
344        Ok(evd_result) => {
345            let eigenvalues = Array1::from_vec(evd_result.eigenvalues().to_vec());
346            // Convert MatRef to Array2
347            let evec_ref = evd_result.eigenvectors();
348            let (rows, cols) = (evec_ref.nrows(), evec_ref.ncols());
349            let eigenvectors = Array2::from_shape_fn((rows, cols), |(i, j)| evec_ref[(i, j)]);
350
351            Ok(SymEvdResult {
352                eigenvalues,
353                eigenvectors,
354            })
355        }
356        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
357    }
358}
359
360/// Computes only the eigenvalues of a symmetric matrix.
361pub fn eigvals_symmetric<T>(a: &Array2<T>) -> LapackResult<Array1<T>>
362where
363    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
364{
365    eig_symmetric(a).map(|result| result.eigenvalues)
366}
367
368// =============================================================================
369// Complex-Specific Functions (for Complex<f64>, Complex<f32>)
370// =============================================================================
371
372/// Result of complex SVD decomposition.
373#[derive(Debug, Clone)]
374pub struct ComplexSvdResult<T>
375where
376    T: oxiblas_core::Scalar,
377{
378    /// Left singular vectors U (m×k, complex unitary)
379    pub u: Array2<T>,
380    /// Singular values σ (real, sorted in descending order)
381    pub s: Array1<T::Real>,
382    /// Right singular vectors V^H (k×n, complex unitary)
383    pub vt: Array2<T>,
384}
385
386/// Computes the SVD of a complex matrix using ComplexSvd algorithm.
387///
388/// For complex matrices, this function uses the one-sided Jacobi algorithm
389/// specifically designed for complex numbers.
390///
391/// # Arguments
392/// * `a` - The input complex matrix (m×n)
393///
394/// # Returns
395/// U, singular values (real), and V^H
396pub fn svd_complex_ndarray<T>(a: &Array2<T>) -> LapackResult<ComplexSvdResult<T>>
397where
398    T: Field + oxiblas_core::scalar::ComplexScalar + Clone + bytemuck::Zeroable,
399    T::Real: oxiblas_core::scalar::Real,
400{
401    let mat = array2_to_mat(a);
402
403    match svd::ComplexSvd::compute(mat.as_ref()) {
404        Ok(svd_decomp) => {
405            let u = mat_ref_to_array2(svd_decomp.u().as_ref());
406            let s = Array1::from_vec(svd_decomp.singular_values().to_vec());
407            let vh = mat_ref_to_array2(svd_decomp.vh().as_ref());
408
409            Ok(ComplexSvdResult { u, s, vt: vh })
410        }
411        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
412    }
413}
414
415/// Computes the QR decomposition of a complex matrix.
416///
417/// # Arguments
418/// * `a` - The input complex matrix (m×n)
419///
420/// # Returns
421/// Q (unitary) and R (upper triangular)
422pub fn qr_complex_ndarray<T>(a: &Array2<T>) -> LapackResult<QrResult<T>>
423where
424    T: Field + oxiblas_core::scalar::ComplexScalar + Clone + bytemuck::Zeroable,
425    T::Real: oxiblas_core::scalar::Real,
426{
427    let mat = array2_to_mat(a);
428
429    match qr::UnitaryQr::compute(mat.as_ref()) {
430        Ok(qr_decomp) => {
431            let q_mat = qr_decomp.q();
432            let r_mat = qr_decomp.r();
433            let q = mat_to_array2(&q_mat);
434            let r = mat_to_array2(&r_mat);
435            Ok(QrResult { q, r })
436        }
437        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
438    }
439}
440
441/// Computes the Cholesky decomposition of a Hermitian positive definite matrix.
442///
443/// # Arguments
444/// * `a` - The input Hermitian positive definite matrix (n×n)
445///
446/// # Returns
447/// Lower triangular factor L such that A = LL^H
448pub fn cholesky_hermitian_ndarray<T>(a: &Array2<T>) -> LapackResult<CholeskyResult<T>>
449where
450    T: Field + oxiblas_core::scalar::ComplexScalar + Clone + bytemuck::Zeroable,
451    T::Real: oxiblas_core::scalar::Real,
452{
453    let (m, n) = a.dim();
454    if m != n {
455        return Err(LapackError::DimensionMismatch(
456            "Matrix must be square".to_string(),
457        ));
458    }
459
460    let mat = array2_to_mat(a);
461
462    match cholesky::HermitianCholesky::compute(mat.as_ref()) {
463        Ok(chol) => {
464            let l_mat = chol.l_factor();
465            let l = mat_to_array2(&l_mat);
466            Ok(CholeskyResult { l })
467        }
468        Err(e) => Err(LapackError::NotPositiveDefinite(format!("{e:?}"))),
469    }
470}
471
472/// Result of Hermitian eigenvalue decomposition for complex matrices.
473#[derive(Debug, Clone)]
474pub struct HermitianEvdResult<T>
475where
476    T: oxiblas_core::Scalar,
477{
478    /// Eigenvalues (real, sorted in ascending order)
479    pub eigenvalues: Array1<T>,
480    /// Eigenvectors (complex columns)
481    pub eigenvectors: Array2<T>,
482}
483
484/// Computes the eigenvalue decomposition of a Hermitian matrix.
485///
486/// For Hermitian matrices (A = A^H), all eigenvalues are real but eigenvectors are complex.
487///
488/// # Arguments
489/// * `a` - The input Hermitian matrix (n×n, only upper triangle is used)
490///
491/// # Returns
492/// Eigenvalues (real, sorted in ascending order) and eigenvectors (complex columns)
493pub fn eig_hermitian_ndarray<T>(a: &Array2<T>) -> LapackResult<(Array1<T::Real>, Array2<T>)>
494where
495    T: Field + oxiblas_core::scalar::ComplexScalar + Clone + bytemuck::Zeroable,
496    T::Real: oxiblas_core::scalar::Real + Clone + bytemuck::Zeroable,
497{
498    let (m, n) = a.dim();
499    if m != n {
500        return Err(LapackError::DimensionMismatch(
501            "Matrix must be square".to_string(),
502        ));
503    }
504
505    let mat = array2_to_mat(a);
506
507    match evd::HermitianEvd::compute(mat.as_ref()) {
508        Ok(evd_result) => {
509            let eigenvalues = Array1::from_vec(evd_result.eigenvalues().to_vec());
510
511            // Convert eigenvectors MatRef<T> to Array2<T>
512            let evec_ref = evd_result.eigenvectors();
513            let eigenvectors = mat_ref_to_array2(evec_ref);
514
515            Ok((eigenvalues, eigenvectors))
516        }
517        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
518    }
519}
520
521// =============================================================================
522// Cholesky Decomposition
523// =============================================================================
524
525/// Result of Cholesky decomposition.
526#[derive(Debug, Clone)]
527pub struct CholeskyResult<T> {
528    /// Lower triangular factor L such that A = L * L^T
529    pub l: Array2<T>,
530}
531
532impl<T: Field + Clone> CholeskyResult<T> {
533    /// Solves Ax = b using the Cholesky decomposition.
534    pub fn solve(&self, b: &Array1<T>) -> Array1<T> {
535        let n = self.l.dim().0;
536        assert_eq!(b.len(), n, "b length must match matrix dimension");
537
538        // Forward substitution: L * y = b
539        let mut y: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
540        for i in 0..n {
541            let mut sum = b[i];
542            for j in 0..i {
543                sum -= self.l[[i, j]] * y[j];
544            }
545            y[i] = sum / self.l[[i, i]];
546        }
547
548        // Back substitution: L^T * x = y
549        let mut x: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
550        for i in (0..n).rev() {
551            let mut sum = y[i];
552            for j in (i + 1)..n {
553                sum -= self.l[[j, i]].conj() * x[j];
554            }
555            x[i] = sum / self.l[[i, i]].conj();
556        }
557
558        x
559    }
560
561    /// Computes the determinant.
562    pub fn det(&self) -> T {
563        let n = self.l.dim().0;
564        let mut det = T::one();
565        for i in 0..n {
566            let diag = self.l[[i, i]];
567            det = det * diag * diag;
568        }
569        det
570    }
571}
572
573/// Computes the Cholesky decomposition of a positive definite matrix.
574///
575/// A = L * L^T
576///
577/// # Arguments
578/// * `a` - The input symmetric positive definite matrix (n×n)
579///
580/// # Returns
581/// Cholesky decomposition with lower triangular factor L
582pub fn cholesky_ndarray<T>(a: &Array2<T>) -> LapackResult<CholeskyResult<T>>
583where
584    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
585{
586    let (m, n) = a.dim();
587    if m != n {
588        return Err(LapackError::DimensionMismatch(
589            "Matrix must be square".to_string(),
590        ));
591    }
592
593    let mat = array2_to_mat(a);
594
595    match cholesky::Cholesky::compute(mat.as_ref()) {
596        Ok(chol) => {
597            let l = mat_to_array2(&chol.l_factor());
598            Ok(CholeskyResult { l })
599        }
600        Err(e) => Err(LapackError::NotPositiveDefinite(format!("{e:?}"))),
601    }
602}
603
604// =============================================================================
605// Linear Solve
606// =============================================================================
607
608/// Solves the linear system Ax = b.
609///
610/// # Arguments
611/// * `a` - The coefficient matrix (n×n)
612/// * `b` - The right-hand side vector (n)
613///
614/// # Returns
615/// The solution vector x
616pub fn solve_ndarray<T>(a: &Array2<T>, b: &Array1<T>) -> LapackResult<Array1<T>>
617where
618    T: Field + Clone + bytemuck::Zeroable,
619{
620    let (m, n) = a.dim();
621    if m != n {
622        return Err(LapackError::DimensionMismatch(
623            "Matrix must be square".to_string(),
624        ));
625    }
626    if b.len() != n {
627        return Err(LapackError::DimensionMismatch(
628            "b length must match matrix dimension".to_string(),
629        ));
630    }
631
632    let a_mat = array2_to_mat(a);
633    // Convert b to a column vector matrix
634    let mut b_mat: Mat<T> = Mat::zeros(n, 1);
635    for i in 0..n {
636        b_mat[(i, 0)] = b[i];
637    }
638
639    match solve::solve(a_mat.as_ref(), b_mat.as_ref()) {
640        Ok(x_mat) => {
641            // Extract column vector from result matrix
642            let x: Vec<T> = (0..n).map(|i| x_mat[(i, 0)]).collect();
643            Ok(Array1::from_vec(x))
644        }
645        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
646    }
647}
648
649/// Solves multiple linear systems AX = B.
650///
651/// # Arguments
652/// * `a` - The coefficient matrix (n×n)
653/// * `b` - The right-hand side matrix (n×k)
654///
655/// # Returns
656/// The solution matrix X (n×k)
657pub fn solve_multiple_ndarray<T>(a: &Array2<T>, b: &Array2<T>) -> LapackResult<Array2<T>>
658where
659    T: Field + Clone + bytemuck::Zeroable,
660{
661    let (m, n) = a.dim();
662    let (b_rows, _b_cols) = b.dim();
663
664    if m != n {
665        return Err(LapackError::DimensionMismatch(
666            "Matrix must be square".to_string(),
667        ));
668    }
669    if b_rows != n {
670        return Err(LapackError::DimensionMismatch(
671            "b rows must match matrix dimension".to_string(),
672        ));
673    }
674
675    let a_mat = array2_to_mat(a);
676    let b_mat = array2_to_mat(b);
677
678    match solve::solve_multiple(a_mat.as_ref(), b_mat.as_ref()) {
679        Ok(x_mat) => Ok(mat_to_array2(&x_mat)),
680        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
681    }
682}
683
684/// Solves the least squares problem min ||Ax - b||.
685pub fn lstsq_ndarray<T>(a: &Array2<T>, b: &Array1<T>) -> LapackResult<Array1<T>>
686where
687    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
688{
689    let m = a.dim().0;
690    let a_mat = array2_to_mat(a);
691    // Convert b to a column vector matrix
692    let mut b_mat: Mat<T> = Mat::zeros(m, 1);
693    for i in 0..m {
694        b_mat[(i, 0)] = b[i];
695    }
696
697    match solve::lstsq(a_mat.as_ref(), b_mat.as_ref()) {
698        Ok(result) => {
699            // Extract solution column vector
700            let n = result.solution.nrows();
701            let x: Vec<T> = (0..n).map(|i| result.solution[(i, 0)]).collect();
702            Ok(Array1::from_vec(x))
703        }
704        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
705    }
706}
707
708// =============================================================================
709// Matrix Inverse
710// =============================================================================
711
712/// Computes the inverse of a matrix.
713pub fn inv_ndarray<T>(a: &Array2<T>) -> LapackResult<Array2<T>>
714where
715    T: Field + Clone + bytemuck::Zeroable,
716{
717    let (m, n) = a.dim();
718    if m != n {
719        return Err(LapackError::DimensionMismatch(
720            "Matrix must be square".to_string(),
721        ));
722    }
723
724    let a_mat = array2_to_mat(a);
725
726    match oxiblas_lapack::utils::inv(a_mat.as_ref()) {
727        Ok(inv_mat) => Ok(mat_to_array2(&inv_mat)),
728        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
729    }
730}
731
732/// Computes the Moore-Penrose pseudo-inverse.
733pub fn pinv_ndarray<T>(a: &Array2<T>) -> LapackResult<Array2<T>>
734where
735    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
736{
737    let a_mat = array2_to_mat(a);
738
739    match oxiblas_lapack::utils::pinv_default(a_mat.as_ref()) {
740        Ok(result) => Ok(mat_to_array2(&result.pinv)),
741        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
742    }
743}
744
745// =============================================================================
746// Determinant
747// =============================================================================
748
749/// Computes the determinant of a matrix.
750pub fn det_ndarray<T>(a: &Array2<T>) -> LapackResult<T>
751where
752    T: Field + Clone + bytemuck::Zeroable,
753{
754    let (m, n) = a.dim();
755    if m != n {
756        return Err(LapackError::DimensionMismatch(
757            "Matrix must be square".to_string(),
758        ));
759    }
760
761    let a_mat = array2_to_mat(a);
762
763    match oxiblas_lapack::utils::det(a_mat.as_ref()) {
764        Ok(d) => Ok(d),
765        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
766    }
767}
768
769// =============================================================================
770// Condition Number
771// =============================================================================
772
773/// Computes the condition number of a matrix (using 2-norm).
774pub fn cond_ndarray<T>(a: &Array2<T>) -> LapackResult<T>
775where
776    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
777{
778    let svd_result = svd_ndarray(a)?;
779    let n = svd_result.s.len();
780
781    if n == 0 {
782        return Ok(T::one());
783    }
784
785    let sigma_max = svd_result.s[0];
786    let sigma_min = svd_result.s[n - 1];
787
788    // Check if sigma_min is very small
789    if sigma_min == T::zero() {
790        // Return a large number to indicate ill-conditioning
791        Ok(T::from_f64(1e15).unwrap_or(T::one()))
792    } else {
793        Ok(sigma_max / sigma_min)
794    }
795}
796
797// =============================================================================
798// Rank
799// =============================================================================
800
801/// Computes the numerical rank of a matrix.
802pub fn rank_ndarray<T>(a: &Array2<T>) -> LapackResult<usize>
803where
804    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
805{
806    let (m, n) = a.dim();
807    let svd_result = svd_ndarray(a)?;
808
809    if svd_result.s.is_empty() {
810        return Ok(0);
811    }
812
813    // Default tolerance: max(m,n) * eps * sigma_max
814    let sigma_max = svd_result.s[0];
815    let eps = T::from_f64(1e-14).unwrap_or(T::zero());
816    let dim_scale = T::from_f64(m.max(n) as f64).unwrap_or(T::one());
817    let tol = dim_scale * eps * sigma_max;
818
819    Ok(svd_result.rank(tol))
820}
821
822// =============================================================================
823// Randomized SVD
824// =============================================================================
825
826/// Result of randomized SVD.
827#[derive(Debug, Clone)]
828pub struct RandomizedSvdResult<T> {
829    /// Left singular vectors U (m × k)
830    pub u: Array2<T>,
831    /// Singular values σ (k elements, sorted descending)
832    pub s: Array1<T>,
833    /// Right singular vectors V (n × k), NOT V^T
834    pub v: Array2<T>,
835}
836
837/// Computes a randomized SVD approximation of a matrix.
838///
839/// Uses randomized projections to compute a rank-k approximation efficiently,
840/// particularly useful for large matrices where only the top singular values
841/// are needed.
842///
843/// # Arguments
844/// * `a` - The input matrix (m×n)
845/// * `k` - Target rank (number of singular values to compute)
846///
847/// # Returns
848/// Truncated SVD with k singular values and vectors
849///
850/// # Algorithm
851/// Uses the Halko-Martinsson-Tropp randomized algorithm:
852/// 1. Generate random test matrix Ω
853/// 2. Compute Y = A × Ω to sample column space
854/// 3. QR factorize Y to get orthonormal basis Q
855/// 4. Project B = Q^T × A
856/// 5. Compute full SVD of B
857/// 6. Recover U = Q × Ũ
858pub fn rsvd_ndarray<T>(a: &Array2<T>, k: usize) -> LapackResult<RandomizedSvdResult<T>>
859where
860    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
861{
862    let mat = array2_to_mat(a);
863
864    match svd::RandomizedSvd::compute(mat.as_ref(), k) {
865        Ok(rsvd) => {
866            let u = mat_ref_to_array2(rsvd.u());
867            let s = Array1::from_vec(rsvd.singular_values().to_vec());
868            let v = mat_ref_to_array2(rsvd.v());
869
870            Ok(RandomizedSvdResult { u, s, v })
871        }
872        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
873    }
874}
875
876/// Computes randomized SVD with power iteration for improved accuracy.
877///
878/// Power iteration emphasizes dominant singular values and improves accuracy
879/// for matrices with slowly decaying singular values.
880///
881/// # Arguments
882/// * `a` - The input matrix (m×n)
883/// * `k` - Target rank
884/// * `power_iterations` - Number of power iterations (typically 1-3)
885pub fn rsvd_power_ndarray<T>(
886    a: &Array2<T>,
887    k: usize,
888    power_iterations: usize,
889) -> LapackResult<RandomizedSvdResult<T>>
890where
891    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
892{
893    let mat = array2_to_mat(a);
894
895    let config = svd::RandomizedSvdConfig::new(k).with_power_iterations(power_iterations);
896
897    match svd::RandomizedSvd::compute_with_config(mat.as_ref(), config) {
898        Ok(rsvd) => {
899            let u = mat_ref_to_array2(rsvd.u());
900            let s = Array1::from_vec(rsvd.singular_values().to_vec());
901            let v = mat_ref_to_array2(rsvd.v());
902
903            Ok(RandomizedSvdResult { u, s, v })
904        }
905        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
906    }
907}
908
909// =============================================================================
910// Schur Decomposition
911// =============================================================================
912
913/// Result of Schur decomposition.
914#[derive(Debug, Clone)]
915pub struct SchurResult<T> {
916    /// Orthogonal matrix Q (Schur vectors)
917    pub q: Array2<T>,
918    /// Quasi-upper triangular matrix T (Schur form)
919    pub t: Array2<T>,
920    /// Eigenvalues (real and complex pairs)
921    pub eigenvalues: Vec<Eigenvalue<T>>,
922}
923
924/// Computes the real Schur decomposition of a square matrix.
925///
926/// A = Q T Q^T where:
927/// - Q is orthogonal (Q^T Q = I)
928/// - T is quasi-upper triangular (upper triangular with possible 2×2 blocks
929///   on the diagonal for complex eigenvalue pairs)
930///
931/// # Arguments
932/// * `a` - The input square matrix (n×n)
933///
934/// # Returns
935/// Schur decomposition with Q, T, and eigenvalues
936pub fn schur_ndarray<T>(a: &Array2<T>) -> LapackResult<SchurResult<T>>
937where
938    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
939{
940    let (m, n) = a.dim();
941    if m != n {
942        return Err(LapackError::DimensionMismatch(
943            "Matrix must be square".to_string(),
944        ));
945    }
946
947    let mat = array2_to_mat(a);
948
949    match evd::Schur::compute(mat.as_ref()) {
950        Ok(schur) => {
951            let q = mat_ref_to_array2(schur.q());
952            let t = mat_ref_to_array2(schur.t());
953            let eigenvalues = schur.eigenvalues().to_vec();
954
955            Ok(SchurResult { q, t, eigenvalues })
956        }
957        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
958    }
959}
960
961// =============================================================================
962// General Eigenvalue Decomposition
963// =============================================================================
964
965/// Result of general eigenvalue decomposition.
966#[derive(Debug, Clone)]
967pub struct GeneralEvdResult<T> {
968    /// Eigenvalues (real and imaginary parts)
969    pub eigenvalues: Vec<Eigenvalue<T>>,
970    /// Right eigenvectors (real parts), if computed
971    pub eigenvectors_real: Option<Array2<T>>,
972    /// Right eigenvectors (imaginary parts), if computed
973    pub eigenvectors_imag: Option<Array2<T>>,
974    /// Left eigenvectors (real parts), if computed
975    pub left_eigenvectors_real: Option<Array2<T>>,
976    /// Left eigenvectors (imaginary parts), if computed
977    pub left_eigenvectors_imag: Option<Array2<T>>,
978}
979
980/// Computes eigenvalues of a general (non-symmetric) matrix.
981///
982/// For a real matrix, eigenvalues may be complex. They are returned as
983/// real/imaginary pairs. Eigenvectors are also split into real and
984/// imaginary parts.
985///
986/// # Arguments
987/// * `a` - The input square matrix (n×n)
988///
989/// # Returns
990/// Eigenvalues and eigenvectors (split into real/imaginary parts)
991pub fn eig_ndarray<T>(a: &Array2<T>) -> LapackResult<GeneralEvdResult<T>>
992where
993    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
994{
995    let (m, n) = a.dim();
996    if m != n {
997        return Err(LapackError::DimensionMismatch(
998            "Matrix must be square".to_string(),
999        ));
1000    }
1001
1002    let mat = array2_to_mat(a);
1003
1004    match evd::GeneralEvd::compute(mat.as_ref()) {
1005        Ok(evd_result) => {
1006            let eigenvalues = evd_result.eigenvalues().to_vec();
1007
1008            // Get right eigenvectors (real and imaginary parts)
1009            let eigenvectors_real = evd_result
1010                .eigenvectors_real()
1011                .map(|vr| mat_ref_to_array2(vr));
1012
1013            let eigenvectors_imag = evd_result
1014                .eigenvectors_imag()
1015                .map(|vi| mat_ref_to_array2(vi));
1016
1017            // Get left eigenvectors (real and imaginary parts)
1018            let left_eigenvectors_real = evd_result
1019                .left_eigenvectors_real()
1020                .map(|vl| mat_ref_to_array2(vl));
1021
1022            let left_eigenvectors_imag = evd_result
1023                .left_eigenvectors_imag()
1024                .map(|vl| mat_ref_to_array2(vl));
1025
1026            Ok(GeneralEvdResult {
1027                eigenvalues,
1028                eigenvectors_real,
1029                eigenvectors_imag,
1030                left_eigenvectors_real,
1031                left_eigenvectors_imag,
1032            })
1033        }
1034        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
1035    }
1036}
1037
1038/// Computes only the eigenvalues of a general matrix.
1039pub fn eigvals_ndarray<T>(a: &Array2<T>) -> LapackResult<Vec<Eigenvalue<T>>>
1040where
1041    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
1042{
1043    let (m, n) = a.dim();
1044    if m != n {
1045        return Err(LapackError::DimensionMismatch(
1046            "Matrix must be square".to_string(),
1047        ));
1048    }
1049
1050    let mat = array2_to_mat(a);
1051
1052    match evd::GeneralEvd::eigenvalues_only(mat.as_ref()) {
1053        Ok(evd_result) => Ok(evd_result.eigenvalues().to_vec()),
1054        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
1055    }
1056}
1057
1058// =============================================================================
1059// Tridiagonal Solvers
1060// =============================================================================
1061
1062/// Solves a tridiagonal system of equations.
1063///
1064/// Solves T x = b where T is a tridiagonal matrix.
1065///
1066/// # Arguments
1067/// * `dl` - Lower diagonal (n-1 elements)
1068/// * `d` - Main diagonal (n elements)
1069/// * `du` - Upper diagonal (n-1 elements)
1070/// * `b` - Right-hand side vector (n elements)
1071///
1072/// # Returns
1073/// The solution vector x
1074pub fn tridiag_solve_ndarray<T>(
1075    dl: &Array1<T>,
1076    d: &Array1<T>,
1077    du: &Array1<T>,
1078    b: &Array1<T>,
1079) -> LapackResult<Array1<T>>
1080where
1081    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
1082{
1083    let n = d.len();
1084
1085    if dl.len() != n - 1 || du.len() != n - 1 || b.len() != n {
1086        return Err(LapackError::DimensionMismatch(
1087            "Tridiagonal dimensions must be consistent".to_string(),
1088        ));
1089    }
1090
1091    let dl_vec: Vec<T> = dl.iter().cloned().collect();
1092    let d_vec: Vec<T> = d.iter().cloned().collect();
1093    let du_vec: Vec<T> = du.iter().cloned().collect();
1094    let b_vec: Vec<T> = b.iter().cloned().collect();
1095
1096    match solve::tridiag_solve(&dl_vec, &d_vec, &du_vec, &b_vec) {
1097        Ok(x) => Ok(Array1::from_vec(x)),
1098        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
1099    }
1100}
1101
1102/// Solves a symmetric positive definite tridiagonal system.
1103///
1104/// Solves T x = b where T is symmetric positive definite and tridiagonal.
1105/// Uses specialized algorithm that's more efficient for SPD matrices.
1106///
1107/// # Arguments
1108/// * `d` - Main diagonal (n elements, positive)
1109/// * `e` - Off-diagonal (n-1 elements)
1110/// * `b` - Right-hand side vector (n elements)
1111///
1112/// # Returns
1113/// The solution vector x
1114pub fn tridiag_solve_spd_ndarray<T>(
1115    d: &Array1<T>,
1116    e: &Array1<T>,
1117    b: &Array1<T>,
1118) -> LapackResult<Array1<T>>
1119where
1120    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
1121{
1122    let n = d.len();
1123
1124    if e.len() != n - 1 || b.len() != n {
1125        return Err(LapackError::DimensionMismatch(
1126            "Tridiagonal dimensions must be consistent".to_string(),
1127        ));
1128    }
1129
1130    let d_vec: Vec<T> = d.iter().cloned().collect();
1131    let e_vec: Vec<T> = e.iter().cloned().collect();
1132    let b_vec: Vec<T> = b.iter().cloned().collect();
1133
1134    match solve::tridiag_solve_spd(&d_vec, &e_vec, &b_vec) {
1135        Ok(x) => Ok(Array1::from_vec(x)),
1136        Err(e) => Err(LapackError::NotPositiveDefinite(format!("{e:?}"))),
1137    }
1138}
1139
1140/// Solves multiple tridiagonal systems with the same matrix.
1141///
1142/// # Arguments
1143/// * `dl` - Lower diagonal (n-1 elements)
1144/// * `d` - Main diagonal (n elements)
1145/// * `du` - Upper diagonal (n-1 elements)
1146/// * `b` - Right-hand side matrix (n × nrhs)
1147///
1148/// # Returns
1149/// The solution matrix X (n × nrhs)
1150pub fn tridiag_solve_multiple_ndarray<T>(
1151    dl: &Array1<T>,
1152    d: &Array1<T>,
1153    du: &Array1<T>,
1154    b: &Array2<T>,
1155) -> LapackResult<Array2<T>>
1156where
1157    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
1158{
1159    let n = d.len();
1160    let (b_rows, _b_cols) = b.dim();
1161
1162    if dl.len() != n - 1 || du.len() != n - 1 || b_rows != n {
1163        return Err(LapackError::DimensionMismatch(
1164            "Tridiagonal dimensions must be consistent".to_string(),
1165        ));
1166    }
1167
1168    let dl_vec: Vec<T> = dl.iter().cloned().collect();
1169    let d_vec: Vec<T> = d.iter().cloned().collect();
1170    let du_vec: Vec<T> = du.iter().cloned().collect();
1171    let b_mat = array2_to_mat(b);
1172
1173    match solve::tridiag_solve_multiple(&dl_vec, &d_vec, &du_vec, b_mat.as_ref()) {
1174        Ok(x_mat) => Ok(mat_to_array2(&x_mat)),
1175        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
1176    }
1177}
1178
1179// =============================================================================
1180// Low-Rank Approximation
1181// =============================================================================
1182
1183/// Computes a low-rank approximation of a matrix.
1184///
1185/// Returns A_k = U_k Σ_k V_k^T, the best rank-k approximation in Frobenius norm.
1186///
1187/// # Arguments
1188/// * `a` - The input matrix (m×n)
1189/// * `k` - Target rank
1190///
1191/// # Returns
1192/// The rank-k approximation as a matrix
1193pub fn low_rank_approx_ndarray<T>(a: &Array2<T>, k: usize) -> LapackResult<Array2<T>>
1194where
1195    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
1196{
1197    let mat = array2_to_mat(a);
1198
1199    match svd::low_rank_approximation(mat.as_ref(), k) {
1200        Ok(approx) => Ok(mat_to_array2(&approx)),
1201        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
1202    }
1203}
1204
1205#[cfg(test)]
1206mod tests {
1207    use super::*;
1208    use ndarray::array;
1209
1210    #[test]
1211    fn test_lu_decomposition() {
1212        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1213        let lu = lu_ndarray(&a).unwrap();
1214
1215        // Verify L * U ≈ P * A
1216        let n = a.dim().0;
1217        for i in 0..n {
1218            for j in 0..n {
1219                let mut sum = 0.0f64;
1220                for k in 0..n {
1221                    sum += lu.l[[i, k]] * lu.u[[k, j]];
1222                }
1223                let perm_i = lu.perm.iter().position(|&p| p == i).unwrap();
1224                assert!((sum - a[[perm_i, j]]).abs() < 1e-10);
1225            }
1226        }
1227    }
1228
1229    #[test]
1230    fn test_lu_determinant() {
1231        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1232        let lu = lu_ndarray(&a).unwrap();
1233        let det = lu.det();
1234        // det = 2*3 - 1*1 = 5
1235        assert!((det - 5.0).abs() < 1e-10);
1236    }
1237
1238    #[test]
1239    fn test_lu_solve() {
1240        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1241        let b = array![5.0f64, 7.0];
1242        let lu = lu_ndarray(&a).unwrap();
1243        let x = lu.solve(&b);
1244
1245        // Verify A * x ≈ b
1246        let ax0 = a[[0, 0]] * x[0] + a[[0, 1]] * x[1];
1247        let ax1 = a[[1, 0]] * x[0] + a[[1, 1]] * x[1];
1248        assert!((ax0 - b[0]).abs() < 1e-10);
1249        assert!((ax1 - b[1]).abs() < 1e-10);
1250    }
1251
1252    #[test]
1253    fn test_qr_decomposition() {
1254        let a = array![[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
1255        let qr = qr_ndarray(&a).unwrap();
1256
1257        // Q should be orthogonal: Q^T * Q = I
1258        let qt = qr.q.t();
1259        let qtq = crate::blas::matmul(&qt.to_owned(), &qr.q);
1260        for i in 0..qtq.dim().0 {
1261            for j in 0..qtq.dim().1 {
1262                let expected = if i == j { 1.0 } else { 0.0 };
1263                assert!(
1264                    (qtq[[i, j]] - expected).abs() < 1e-10,
1265                    "Q^T Q[{},{}] = {}, expected {}",
1266                    i,
1267                    j,
1268                    qtq[[i, j]],
1269                    expected
1270                );
1271            }
1272        }
1273
1274        // Q * R should equal A
1275        let qr_product = crate::blas::matmul(&qr.q, &qr.r);
1276        for i in 0..a.dim().0 {
1277            for j in 0..a.dim().1 {
1278                assert!(
1279                    (qr_product[[i, j]] - a[[i, j]]).abs() < 1e-10,
1280                    "QR[{},{}] = {}, A = {}",
1281                    i,
1282                    j,
1283                    qr_product[[i, j]],
1284                    a[[i, j]]
1285                );
1286            }
1287        }
1288    }
1289
1290    #[test]
1291    fn test_svd() {
1292        let a = array![[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
1293        let svd = svd_ndarray(&a).unwrap();
1294
1295        // Reconstruct A from SVD: U * S * V^T
1296        let (m, n) = a.dim();
1297        let k = svd.s.len();
1298
1299        for i in 0..m {
1300            for j in 0..n {
1301                let mut sum = 0.0f64;
1302                for l in 0..k {
1303                    sum += svd.u[[i, l]] * svd.s[l] * svd.vt[[l, j]];
1304                }
1305                assert!(
1306                    (sum - a[[i, j]]).abs() < 1e-10,
1307                    "Reconstructed[{},{}] = {}, A = {}",
1308                    i,
1309                    j,
1310                    sum,
1311                    a[[i, j]]
1312                );
1313            }
1314        }
1315    }
1316
1317    #[test]
1318    fn test_symmetric_evd() {
1319        // Symmetric matrix
1320        let a = array![[4.0f64, 1.0], [1.0, 3.0]];
1321        let evd = eig_symmetric(&a).unwrap();
1322
1323        // Eigenvalues should be real and positive for this matrix
1324        assert!(evd.eigenvalues.len() == 2);
1325
1326        // Verify A * V = V * Λ for each eigenvalue/eigenvector pair
1327        for (idx, &lambda) in evd.eigenvalues.iter().enumerate() {
1328            let v = evd.eigenvectors.column(idx);
1329            let av = crate::blas::matvec(&a, &v.to_owned());
1330            let lambda_v: Array1<f64> = v.iter().map(|&x| lambda * x).collect();
1331
1332            for i in 0..2 {
1333                assert!(
1334                    (av[i] - lambda_v[i]).abs() < 1e-10,
1335                    "Av[{}] = {}, λv[{}] = {}",
1336                    i,
1337                    av[i],
1338                    i,
1339                    lambda_v[i]
1340                );
1341            }
1342        }
1343    }
1344
1345    #[test]
1346    fn test_cholesky() {
1347        // Positive definite matrix
1348        let a = array![[4.0f64, 2.0], [2.0, 5.0]];
1349        let chol = cholesky_ndarray(&a).unwrap();
1350
1351        // Verify L * L^T = A
1352        let lt = chol.l.t();
1353        let llt = crate::blas::matmul(&chol.l, &lt.to_owned());
1354
1355        for i in 0..2 {
1356            for j in 0..2 {
1357                assert!(
1358                    (llt[[i, j]] - a[[i, j]]).abs() < 1e-10,
1359                    "LLT[{},{}] = {}, A = {}",
1360                    i,
1361                    j,
1362                    llt[[i, j]],
1363                    a[[i, j]]
1364                );
1365            }
1366        }
1367    }
1368
1369    #[test]
1370    fn test_solve() {
1371        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1372        let b = array![5.0f64, 7.0];
1373        let x = solve_ndarray(&a, &b).unwrap();
1374
1375        // Verify A * x = b
1376        let ax = crate::blas::matvec(&a, &x);
1377        assert!((ax[0] - b[0]).abs() < 1e-10);
1378        assert!((ax[1] - b[1]).abs() < 1e-10);
1379    }
1380
1381    #[test]
1382    fn test_inverse() {
1383        let a = array![[4.0f64, 7.0], [2.0, 6.0]];
1384        let a_inv = inv_ndarray(&a).unwrap();
1385
1386        // A * A^-1 = I
1387        let product = crate::blas::matmul(&a, &a_inv);
1388        for i in 0..2 {
1389            for j in 0..2 {
1390                let expected = if i == j { 1.0 } else { 0.0 };
1391                assert!(
1392                    (product[[i, j]] - expected).abs() < 1e-10,
1393                    "A*A^-1[{},{}] = {}, expected {}",
1394                    i,
1395                    j,
1396                    product[[i, j]],
1397                    expected
1398                );
1399            }
1400        }
1401    }
1402
1403    #[test]
1404    fn test_determinant() {
1405        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1406        let det = det_ndarray(&a).unwrap();
1407        // det = 2*3 - 1*1 = 5
1408        assert!((det - 5.0).abs() < 1e-10);
1409    }
1410
1411    #[test]
1412    fn test_condition_number() {
1413        let a = array![[1.0f64, 0.0], [0.0, 1.0]];
1414        let cond = cond_ndarray(&a).unwrap();
1415        // Identity matrix has condition number 1
1416        assert!((cond - 1.0).abs() < 1e-10);
1417    }
1418
1419    #[test]
1420    fn test_rank() {
1421        // Full rank matrix
1422        let a = array![[1.0f64, 2.0], [3.0, 4.0]];
1423        let r = rank_ndarray(&a).unwrap();
1424        assert_eq!(r, 2);
1425
1426        // Rank deficient matrix
1427        let b = array![[1.0f64, 2.0], [2.0, 4.0]];
1428        let r2 = rank_ndarray(&b).unwrap();
1429        assert_eq!(r2, 1);
1430    }
1431
1432    // =========================================================================
1433    // Randomized SVD Tests
1434    // =========================================================================
1435
1436    #[test]
1437    fn test_rsvd_basic() {
1438        // Create a low-rank matrix
1439        let a = array![
1440            [1.0f64, 2.0, 3.0, 4.0],
1441            [5.0, 6.0, 7.0, 8.0],
1442            [9.0, 10.0, 11.0, 12.0]
1443        ];
1444
1445        let rsvd = rsvd_ndarray(&a, 2).unwrap();
1446
1447        // Should have 2 singular values
1448        assert_eq!(rsvd.s.len(), 2);
1449
1450        // Singular values should be positive and in descending order
1451        assert!(rsvd.s[0] > rsvd.s[1]);
1452        assert!(rsvd.s[1] >= 0.0);
1453
1454        // U should be m×k
1455        assert_eq!(rsvd.u.dim(), (3, 2));
1456
1457        // V should be n×k
1458        assert_eq!(rsvd.v.dim(), (4, 2));
1459    }
1460
1461    #[test]
1462    fn test_rsvd_approximation_quality() {
1463        // Create a matrix with clear rank structure
1464        let a = Array2::from_shape_fn((10, 8), |(i, j)| (i as f64) * 0.1 + (j as f64) * 0.2);
1465
1466        let rsvd = rsvd_ndarray(&a, 2).unwrap();
1467
1468        // Reconstruct: A ≈ U * S * V^T
1469        let (m, n) = a.dim();
1470        let k = rsvd.s.len();
1471
1472        let mut approx: Array2<f64> = Array2::zeros((m, n));
1473        for i in 0..m {
1474            for j in 0..n {
1475                for l in 0..k {
1476                    approx[[i, j]] += rsvd.u[[i, l]] * rsvd.s[l] * rsvd.v[[j, l]];
1477                }
1478            }
1479        }
1480
1481        // The approximation should capture most of the matrix (rank-1 for this matrix)
1482        let mut diff_norm = 0.0f64;
1483        for i in 0..m {
1484            for j in 0..n {
1485                let diff = a[[i, j]] - approx[[i, j]];
1486                diff_norm += diff.powi(2);
1487            }
1488        }
1489        diff_norm = diff_norm.sqrt();
1490
1491        // Should be reasonably small
1492        assert!(diff_norm < 1e-10, "Reconstruction error = {}", diff_norm);
1493    }
1494
1495    #[test]
1496    fn test_rsvd_power_iteration() {
1497        let a = Array2::from_shape_fn((20, 15), |(i, j)| ((i * j) as f64).sin() + 0.1 * (i as f64));
1498
1499        let rsvd = rsvd_power_ndarray(&a, 3, 2).unwrap();
1500
1501        assert_eq!(rsvd.s.len(), 3);
1502        assert!(rsvd.s[0] >= rsvd.s[1]);
1503        assert!(rsvd.s[1] >= rsvd.s[2]);
1504    }
1505
1506    // =========================================================================
1507    // Schur Decomposition Tests
1508    // =========================================================================
1509
1510    #[test]
1511    fn test_schur_triangular() {
1512        // Already upper triangular matrix
1513        let a = array![[1.0f64, 2.0], [0.0, 3.0]];
1514
1515        let schur = schur_ndarray(&a).unwrap();
1516
1517        // Eigenvalues should be 1 and 3
1518        assert_eq!(schur.eigenvalues.len(), 2);
1519
1520        let evs: Vec<f64> = schur.eigenvalues.iter().map(|e| e.real).collect();
1521        assert!(evs.contains(&1.0) || evs.iter().any(|&x| (x - 1.0).abs() < 1e-10));
1522        assert!(evs.contains(&3.0) || evs.iter().any(|&x| (x - 3.0).abs() < 1e-10));
1523    }
1524
1525    #[test]
1526    fn test_schur_reconstruction() {
1527        let a = array![[4.0f64, 1.0], [2.0, 3.0]];
1528
1529        let schur = schur_ndarray(&a).unwrap();
1530
1531        // Verify A = Q * T * Q^T
1532        let qt = schur.q.t();
1533        let qt_owned = qt.to_owned();
1534        let qr_temp = crate::blas::matmul(&schur.q, &schur.t);
1535        let reconstructed = crate::blas::matmul(&qr_temp, &qt_owned);
1536
1537        for i in 0..2 {
1538            for j in 0..2 {
1539                assert!(
1540                    (reconstructed[[i, j]] - a[[i, j]]).abs() < 1e-10,
1541                    "Reconstruction failed at [{},{}]: {} vs {}",
1542                    i,
1543                    j,
1544                    reconstructed[[i, j]],
1545                    a[[i, j]]
1546                );
1547            }
1548        }
1549    }
1550
1551    #[test]
1552    fn test_schur_orthogonality() {
1553        let a = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]];
1554
1555        let schur = schur_ndarray(&a).unwrap();
1556
1557        // Q should be orthogonal: Q^T * Q = I
1558        let qt = schur.q.t();
1559        let qtq = crate::blas::matmul(&qt.to_owned(), &schur.q);
1560
1561        for i in 0..3 {
1562            for j in 0..3 {
1563                let expected = if i == j { 1.0 } else { 0.0 };
1564                assert!(
1565                    (qtq[[i, j]] - expected).abs() < 1e-10,
1566                    "Q^T Q[{},{}] = {}, expected {}",
1567                    i,
1568                    j,
1569                    qtq[[i, j]],
1570                    expected
1571                );
1572            }
1573        }
1574    }
1575
1576    // =========================================================================
1577    // General Eigenvalue Decomposition Tests
1578    // =========================================================================
1579
1580    #[test]
1581    fn test_eig_real_eigenvalues() {
1582        // Symmetric matrix has real eigenvalues
1583        let a = array![[4.0f64, 1.0], [1.0, 3.0]];
1584
1585        let evd = eig_ndarray(&a).unwrap();
1586
1587        assert_eq!(evd.eigenvalues.len(), 2);
1588
1589        // All eigenvalues should be real (imaginary part ≈ 0)
1590        for ev in &evd.eigenvalues {
1591            assert!(
1592                ev.imag.abs() < 1e-10,
1593                "Expected real eigenvalue, got imag = {}",
1594                ev.imag
1595            );
1596        }
1597    }
1598
1599    #[test]
1600    fn test_eig_complex_eigenvalues() {
1601        // Rotation matrix has complex eigenvalues (±i)
1602        let a = array![[0.0f64, -1.0], [1.0, 0.0]];
1603
1604        let evd = eig_ndarray(&a).unwrap();
1605
1606        assert_eq!(evd.eigenvalues.len(), 2);
1607
1608        // Should have eigenvalues with nonzero imaginary parts
1609        let has_complex = evd.eigenvalues.iter().any(|e| e.imag.abs() > 0.5);
1610        assert!(has_complex, "Expected complex eigenvalues");
1611
1612        // Real parts should be close to 0
1613        for ev in &evd.eigenvalues {
1614            assert!(ev.real.abs() < 1e-10, "Expected real part ≈ 0");
1615        }
1616    }
1617
1618    #[test]
1619    fn test_eigvals_only() {
1620        let a = array![[1.0f64, 2.0], [0.0, 3.0]];
1621
1622        let evs = eigvals_ndarray(&a).unwrap();
1623
1624        assert_eq!(evs.len(), 2);
1625
1626        // Eigenvalues of upper triangular matrix are diagonal elements
1627        let reals: Vec<f64> = evs.iter().map(|e| e.real).collect();
1628        assert!(reals.iter().any(|&x| (x - 1.0).abs() < 1e-10));
1629        assert!(reals.iter().any(|&x| (x - 3.0).abs() < 1e-10));
1630    }
1631
1632    // =========================================================================
1633    // Tridiagonal Solver Tests
1634    // =========================================================================
1635
1636    #[test]
1637    fn test_tridiag_solve() {
1638        // Tridiagonal matrix:
1639        // [2  -1  0 ]   [x0]   [1]
1640        // [-1  2 -1 ] * [x1] = [0]
1641        // [0  -1  2 ]   [x2]   [1]
1642        let dl = array![-1.0f64, -1.0];
1643        let d = array![2.0f64, 2.0, 2.0];
1644        let du = array![-1.0f64, -1.0];
1645        let b = array![1.0f64, 0.0, 1.0];
1646
1647        let x = tridiag_solve_ndarray(&dl, &d, &du, &b).unwrap();
1648
1649        assert_eq!(x.len(), 3);
1650
1651        // Verify solution: T * x ≈ b
1652        let tx0 = d[0] * x[0] + du[0] * x[1];
1653        let tx1 = dl[0] * x[0] + d[1] * x[1] + du[1] * x[2];
1654        let tx2 = dl[1] * x[1] + d[2] * x[2];
1655
1656        assert!((tx0 - b[0]).abs() < 1e-10);
1657        assert!((tx1 - b[1]).abs() < 1e-10);
1658        assert!((tx2 - b[2]).abs() < 1e-10);
1659    }
1660
1661    #[test]
1662    fn test_tridiag_solve_spd() {
1663        // SPD tridiagonal matrix:
1664        // [4 1 0]
1665        // [1 4 1]
1666        // [0 1 4]
1667        // This is diagonally dominant -> SPD
1668        let d = array![4.0f64, 4.0, 4.0];
1669        let e = array![1.0f64, 1.0]; // Off-diagonal elements
1670        let b = array![5.0f64, 6.0, 5.0];
1671
1672        let x = tridiag_solve_spd_ndarray(&d, &e, &b).unwrap();
1673
1674        assert_eq!(x.len(), 3);
1675
1676        // Verify solution: T * x = b where T is symmetric with d on diagonal, e on off-diagonals
1677        let tx0 = d[0] * x[0] + e[0] * x[1];
1678        let tx1 = e[0] * x[0] + d[1] * x[1] + e[1] * x[2];
1679        let tx2 = e[1] * x[1] + d[2] * x[2];
1680
1681        assert!((tx0 - b[0]).abs() < 1e-10, "tx0 = {}, b[0] = {}", tx0, b[0]);
1682        assert!((tx1 - b[1]).abs() < 1e-10, "tx1 = {}, b[1] = {}", tx1, b[1]);
1683        assert!((tx2 - b[2]).abs() < 1e-10, "tx2 = {}, b[2] = {}", tx2, b[2]);
1684    }
1685
1686    #[test]
1687    fn test_tridiag_solve_multiple() {
1688        let dl = array![-1.0f64, -1.0];
1689        let d = array![2.0f64, 2.0, 2.0];
1690        let du = array![-1.0f64, -1.0];
1691        let b = array![[1.0f64, 0.0], [0.0, 1.0], [1.0, 0.0]];
1692
1693        let x = tridiag_solve_multiple_ndarray(&dl, &d, &du, &b).unwrap();
1694
1695        assert_eq!(x.dim(), (3, 2));
1696
1697        // Each column should be the solution to T * x_j = b_j
1698        for j in 0..2 {
1699            let tx0 = d[0] * x[[0, j]] + du[0] * x[[1, j]];
1700            let tx1 = dl[0] * x[[0, j]] + d[1] * x[[1, j]] + du[1] * x[[2, j]];
1701            let tx2 = dl[1] * x[[1, j]] + d[2] * x[[2, j]];
1702
1703            assert!((tx0 - b[[0, j]]).abs() < 1e-10);
1704            assert!((tx1 - b[[1, j]]).abs() < 1e-10);
1705            assert!((tx2 - b[[2, j]]).abs() < 1e-10);
1706        }
1707    }
1708
1709    // =========================================================================
1710    // Low-Rank Approximation Tests
1711    // =========================================================================
1712
1713    #[test]
1714    fn test_low_rank_approx() {
1715        // Create a rank-1 matrix: outer product of two vectors
1716        let u = array![1.0f64, 2.0, 3.0];
1717        let v = array![4.0, 5.0, 6.0, 7.0];
1718
1719        let mut a = Array2::zeros((3, 4));
1720        for i in 0..3 {
1721            for j in 0..4 {
1722                a[[i, j]] = u[i] * v[j];
1723            }
1724        }
1725
1726        // Rank-1 approximation should be exact
1727        let approx = low_rank_approx_ndarray(&a, 1).unwrap();
1728
1729        assert_eq!(approx.dim(), a.dim());
1730
1731        for i in 0..3 {
1732            for j in 0..4 {
1733                assert!(
1734                    (approx[[i, j]] - a[[i, j]]).abs() < 1e-10,
1735                    "Approximation failed at [{},{}]",
1736                    i,
1737                    j
1738                );
1739            }
1740        }
1741    }
1742
1743    #[test]
1744    fn test_low_rank_approx_truncation() {
1745        let a = array![
1746            [1.0f64, 2.0, 3.0],
1747            [4.0, 5.0, 6.0],
1748            [7.0, 8.0, 9.0],
1749            [10.0, 11.0, 12.0]
1750        ];
1751
1752        let approx = low_rank_approx_ndarray(&a, 2).unwrap();
1753
1754        assert_eq!(approx.dim(), (4, 3));
1755
1756        // The approximation should not equal the original (rank 2 < rank A)
1757        // but should be close
1758        let mut diff_norm = 0.0f64;
1759        let mut orig_norm = 0.0f64;
1760        for i in 0..4 {
1761            for j in 0..3 {
1762                diff_norm += (a[[i, j]] - approx[[i, j]]).powi(2);
1763                orig_norm += a[[i, j]].powi(2);
1764            }
1765        }
1766
1767        // Relative error should be small (this matrix has rank 2)
1768        let rel_error = diff_norm.sqrt() / orig_norm.sqrt();
1769        assert!(rel_error < 0.1, "Relative error = {}", rel_error);
1770    }
1771}