oxiblas_ndarray/
lapack.rs

1//! LAPACK decompositions and operations on ndarray types.
2//!
3//! This module provides LAPACK decompositions (LU, QR, SVD, EVD, Cholesky)
4//! directly on ndarray types, using OxiBLAS-LAPACK as the backend.
5
6use crate::conversions::{array2_to_mat, mat_ref_to_array2, mat_to_array2};
7use ndarray::{Array1, Array2};
8use oxiblas_core::scalar::Field;
9use oxiblas_lapack::{cholesky, evd, lu, qr, solve, svd};
10use oxiblas_matrix::Mat;
11
12// Re-export useful types
13pub use evd::Eigenvalue;
14
15// =============================================================================
16// Error Types
17// =============================================================================
18
19/// Error type for LAPACK operations on ndarray.
20#[derive(Debug, Clone)]
21pub enum LapackError {
22    /// Matrix is singular or nearly singular
23    Singular(String),
24    /// Matrix is not positive definite
25    NotPositiveDefinite(String),
26    /// Dimension mismatch
27    DimensionMismatch(String),
28    /// Decomposition did not converge
29    NotConverged(String),
30    /// Other error
31    Other(String),
32}
33
34impl std::fmt::Display for LapackError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Self::Singular(msg) => write!(f, "Singular matrix: {msg}"),
38            Self::NotPositiveDefinite(msg) => write!(f, "Not positive definite: {msg}"),
39            Self::DimensionMismatch(msg) => write!(f, "Dimension mismatch: {msg}"),
40            Self::NotConverged(msg) => write!(f, "Did not converge: {msg}"),
41            Self::Other(msg) => write!(f, "LAPACK error: {msg}"),
42        }
43    }
44}
45
46impl std::error::Error for LapackError {}
47
48/// Result type for LAPACK operations.
49pub type LapackResult<T> = Result<T, LapackError>;
50
51// =============================================================================
52// LU Decomposition
53// =============================================================================
54
55/// Result of LU decomposition.
56#[derive(Debug, Clone)]
57pub struct LuResult<T> {
58    /// L factor (lower triangular with unit diagonal)
59    pub l: Array2<T>,
60    /// U factor (upper triangular)
61    pub u: Array2<T>,
62    /// Permutation vector
63    pub perm: Vec<usize>,
64}
65
66impl<T: Field + Clone> LuResult<T>
67where
68    T: bytemuck::Zeroable,
69{
70    /// Solves Ax = b using the LU decomposition.
71    pub fn solve(&self, b: &Array1<T>) -> Array1<T> {
72        let n = self.l.dim().0;
73        assert_eq!(b.len(), n, "b length must match matrix dimension");
74
75        // Apply permutation to b
76        let mut pb: Vec<T> = vec![T::zero(); n];
77        for i in 0..n {
78            pb[i] = b[self.perm[i]];
79        }
80
81        // Forward substitution: L * y = pb
82        let mut y: Vec<T> = vec![T::zero(); n];
83        for i in 0..n {
84            let mut sum = pb[i];
85            for j in 0..i {
86                sum -= self.l[[i, j]] * y[j];
87            }
88            y[i] = sum;
89        }
90
91        // Back substitution: U * x = y
92        let mut x: Vec<T> = vec![T::zero(); n];
93        for i in (0..n).rev() {
94            let mut sum = y[i];
95            for j in (i + 1)..n {
96                sum -= self.u[[i, j]] * x[j];
97            }
98            x[i] = sum / self.u[[i, i]];
99        }
100
101        Array1::from_vec(x)
102    }
103
104    /// Computes the determinant.
105    pub fn det(&self) -> T {
106        let n = self.l.dim().0;
107        let mut det = T::one();
108
109        // Product of U diagonal elements
110        for i in 0..n {
111            det *= self.u[[i, i]];
112        }
113
114        // Account for permutation sign
115        let mut sign_changes = 0;
116        let mut visited = vec![false; n];
117        for i in 0..n {
118            if visited[i] {
119                continue;
120            }
121            let mut j = i;
122            let mut cycle_len = 0;
123            while !visited[j] {
124                visited[j] = true;
125                j = self.perm[j];
126                cycle_len += 1;
127            }
128            if cycle_len > 1 {
129                sign_changes += cycle_len - 1;
130            }
131        }
132
133        if sign_changes % 2 == 1 {
134            det = T::zero() - det;
135        }
136
137        det
138    }
139}
140
141/// Computes the LU decomposition of a matrix.
142///
143/// A = P * L * U
144///
145/// # Arguments
146/// * `a` - The input matrix (m×n)
147///
148/// # Returns
149/// LU decomposition with L, U, and permutation
150pub fn lu_ndarray<T: Field + Clone>(a: &Array2<T>) -> LapackResult<LuResult<T>>
151where
152    T: bytemuck::Zeroable,
153{
154    let mat = array2_to_mat(a);
155
156    match lu::Lu::compute(mat.as_ref()) {
157        Ok(lu_decomp) => {
158            // Extract L and U factors
159            let l = mat_to_array2(&lu_decomp.l_factor());
160            let u = mat_to_array2(&lu_decomp.u_factor());
161
162            // Get permutation
163            let perm = lu_decomp.pivot().to_vec();
164
165            Ok(LuResult { l, u, perm })
166        }
167        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
168    }
169}
170
171// =============================================================================
172// QR Decomposition
173// =============================================================================
174
175/// Result of QR decomposition.
176#[derive(Debug, Clone)]
177pub struct QrResult<T> {
178    /// Q factor (orthogonal/unitary)
179    pub q: Array2<T>,
180    /// R factor (upper triangular)
181    pub r: Array2<T>,
182}
183
184impl<T: Field + Clone> QrResult<T> {
185    /// Solves the least squares problem min ||Ax - b||.
186    pub fn solve_least_squares(&self, b: &Array1<T>) -> Array1<T> {
187        let (m, n) = (self.q.dim().0, self.r.dim().1);
188        assert_eq!(b.len(), m, "b length must match matrix rows");
189
190        // Compute Q^T * b (or Q^H for complex)
191        let mut qtb: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
192        for j in 0..n {
193            let mut sum = T::zero();
194            for i in 0..m {
195                sum += self.q[[i, j]].conj() * b[i];
196            }
197            qtb[j] = sum;
198        }
199
200        // Back substitution: R * x = Q^T * b
201        let mut x: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
202        for i in (0..n).rev() {
203            let mut sum = qtb[i];
204            for j in (i + 1)..n {
205                sum -= self.r[[i, j]] * x[j];
206            }
207            x[i] = sum / self.r[[i, i]];
208        }
209
210        x
211    }
212}
213
214/// Computes the QR decomposition of a matrix.
215///
216/// A = Q * R
217///
218/// # Arguments
219/// * `a` - The input matrix (m×n)
220///
221/// # Returns
222/// QR decomposition with Q and R
223pub fn qr_ndarray<T: Field + Clone>(a: &Array2<T>) -> LapackResult<QrResult<T>>
224where
225    T: bytemuck::Zeroable + oxiblas_core::scalar::Real,
226{
227    let mat = array2_to_mat(a);
228
229    match qr::Qr::compute(mat.as_ref()) {
230        Ok(qr_decomp) => {
231            let q = mat_to_array2(&qr_decomp.q());
232            let r = mat_to_array2(&qr_decomp.r());
233
234            Ok(QrResult { q, r })
235        }
236        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
237    }
238}
239
240// =============================================================================
241// Singular Value Decomposition
242// =============================================================================
243
244/// Result of SVD decomposition.
245#[derive(Debug, Clone)]
246pub struct SvdResult<T> {
247    /// Left singular vectors U (m×k where k = min(m,n))
248    pub u: Array2<T>,
249    /// Singular values σ (sorted in descending order)
250    pub s: Array1<T>,
251    /// Right singular vectors V^T (k×n)
252    pub vt: Array2<T>,
253}
254
255impl<T: Field + Clone> SvdResult<T> {
256    /// Returns the rank based on a tolerance.
257    pub fn rank(&self, tol: T) -> usize {
258        self.s.iter().filter(|&s| s.abs() > tol.abs()).count()
259    }
260}
261
262/// Computes the SVD of a matrix.
263///
264/// A = U * Σ * V^T
265///
266/// # Arguments
267/// * `a` - The input matrix (m×n)
268///
269/// # Returns
270/// SVD with U, S (singular values), and V^T
271pub fn svd_ndarray<T>(a: &Array2<T>) -> LapackResult<SvdResult<T>>
272where
273    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
274{
275    let mat = array2_to_mat(a);
276
277    match svd::Svd::compute(mat.as_ref()) {
278        Ok(svd_decomp) => {
279            let u = mat_ref_to_array2(svd_decomp.u());
280            let s = Array1::from_vec(svd_decomp.singular_values().to_vec());
281            let vt = mat_ref_to_array2(svd_decomp.vt());
282
283            Ok(SvdResult { u, s, vt })
284        }
285        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
286    }
287}
288
289/// Computes the truncated SVD of a matrix.
290///
291/// Returns only the top k singular values and vectors.
292pub fn svd_truncated<T>(a: &Array2<T>, k: usize) -> LapackResult<SvdResult<T>>
293where
294    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
295{
296    let svd_result = svd_ndarray(a)?;
297
298    let actual_k = k.min(svd_result.s.len());
299
300    // Truncate to k components
301    let u = svd_result.u.slice(ndarray::s![.., ..actual_k]).to_owned();
302    let s = svd_result.s.slice(ndarray::s![..actual_k]).to_owned();
303    let vt = svd_result.vt.slice(ndarray::s![..actual_k, ..]).to_owned();
304
305    Ok(SvdResult { u, s, vt })
306}
307
308// =============================================================================
309// Eigenvalue Decomposition (Symmetric)
310// =============================================================================
311
312/// Result of symmetric eigenvalue decomposition.
313#[derive(Debug, Clone)]
314pub struct SymEvdResult<T> {
315    /// Eigenvalues (sorted in ascending order)
316    pub eigenvalues: Array1<T>,
317    /// Eigenvectors (columns are eigenvectors)
318    pub eigenvectors: Array2<T>,
319}
320
321/// Computes the eigenvalue decomposition of a symmetric matrix.
322///
323/// A * V = V * Λ where Λ = diag(eigenvalues)
324///
325/// # Arguments
326/// * `a` - The input symmetric matrix (n×n)
327///
328/// # Returns
329/// Eigenvalues and eigenvectors
330pub fn eig_symmetric<T>(a: &Array2<T>) -> LapackResult<SymEvdResult<T>>
331where
332    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
333{
334    let (m, n) = a.dim();
335    if m != n {
336        return Err(LapackError::DimensionMismatch(
337            "Matrix must be square".to_string(),
338        ));
339    }
340
341    let mat = array2_to_mat(a);
342
343    match evd::SymmetricEvd::compute(mat.as_ref()) {
344        Ok(evd_result) => {
345            let eigenvalues = Array1::from_vec(evd_result.eigenvalues().to_vec());
346            // Convert MatRef to Array2
347            let evec_ref = evd_result.eigenvectors();
348            let (rows, cols) = (evec_ref.nrows(), evec_ref.ncols());
349            let eigenvectors = Array2::from_shape_fn((rows, cols), |(i, j)| evec_ref[(i, j)]);
350
351            Ok(SymEvdResult {
352                eigenvalues,
353                eigenvectors,
354            })
355        }
356        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
357    }
358}
359
360/// Computes only the eigenvalues of a symmetric matrix.
361pub fn eigvals_symmetric<T>(a: &Array2<T>) -> LapackResult<Array1<T>>
362where
363    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
364{
365    eig_symmetric(a).map(|result| result.eigenvalues)
366}
367
368// =============================================================================
369// Cholesky Decomposition
370// =============================================================================
371
372/// Result of Cholesky decomposition.
373#[derive(Debug, Clone)]
374pub struct CholeskyResult<T> {
375    /// Lower triangular factor L such that A = L * L^T
376    pub l: Array2<T>,
377}
378
379impl<T: Field + Clone> CholeskyResult<T> {
380    /// Solves Ax = b using the Cholesky decomposition.
381    pub fn solve(&self, b: &Array1<T>) -> Array1<T> {
382        let n = self.l.dim().0;
383        assert_eq!(b.len(), n, "b length must match matrix dimension");
384
385        // Forward substitution: L * y = b
386        let mut y: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
387        for i in 0..n {
388            let mut sum = b[i];
389            for j in 0..i {
390                sum -= self.l[[i, j]] * y[j];
391            }
392            y[i] = sum / self.l[[i, i]];
393        }
394
395        // Back substitution: L^T * x = y
396        let mut x: Array1<T> = Array1::from_vec(vec![T::zero(); n]);
397        for i in (0..n).rev() {
398            let mut sum = y[i];
399            for j in (i + 1)..n {
400                sum -= self.l[[j, i]].conj() * x[j];
401            }
402            x[i] = sum / self.l[[i, i]].conj();
403        }
404
405        x
406    }
407
408    /// Computes the determinant.
409    pub fn det(&self) -> T {
410        let n = self.l.dim().0;
411        let mut det = T::one();
412        for i in 0..n {
413            let diag = self.l[[i, i]];
414            det = det * diag * diag;
415        }
416        det
417    }
418}
419
420/// Computes the Cholesky decomposition of a positive definite matrix.
421///
422/// A = L * L^T
423///
424/// # Arguments
425/// * `a` - The input symmetric positive definite matrix (n×n)
426///
427/// # Returns
428/// Cholesky decomposition with lower triangular factor L
429pub fn cholesky_ndarray<T>(a: &Array2<T>) -> LapackResult<CholeskyResult<T>>
430where
431    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
432{
433    let (m, n) = a.dim();
434    if m != n {
435        return Err(LapackError::DimensionMismatch(
436            "Matrix must be square".to_string(),
437        ));
438    }
439
440    let mat = array2_to_mat(a);
441
442    match cholesky::Cholesky::compute(mat.as_ref()) {
443        Ok(chol) => {
444            let l = mat_to_array2(&chol.l_factor());
445            Ok(CholeskyResult { l })
446        }
447        Err(e) => Err(LapackError::NotPositiveDefinite(format!("{e:?}"))),
448    }
449}
450
451// =============================================================================
452// Linear Solve
453// =============================================================================
454
455/// Solves the linear system Ax = b.
456///
457/// # Arguments
458/// * `a` - The coefficient matrix (n×n)
459/// * `b` - The right-hand side vector (n)
460///
461/// # Returns
462/// The solution vector x
463pub fn solve_ndarray<T>(a: &Array2<T>, b: &Array1<T>) -> LapackResult<Array1<T>>
464where
465    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
466{
467    let (m, n) = a.dim();
468    if m != n {
469        return Err(LapackError::DimensionMismatch(
470            "Matrix must be square".to_string(),
471        ));
472    }
473    if b.len() != n {
474        return Err(LapackError::DimensionMismatch(
475            "b length must match matrix dimension".to_string(),
476        ));
477    }
478
479    let a_mat = array2_to_mat(a);
480    // Convert b to a column vector matrix
481    let mut b_mat: Mat<T> = Mat::zeros(n, 1);
482    for i in 0..n {
483        b_mat[(i, 0)] = b[i];
484    }
485
486    match solve::solve(a_mat.as_ref(), b_mat.as_ref()) {
487        Ok(x_mat) => {
488            // Extract column vector from result matrix
489            let x: Vec<T> = (0..n).map(|i| x_mat[(i, 0)]).collect();
490            Ok(Array1::from_vec(x))
491        }
492        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
493    }
494}
495
496/// Solves multiple linear systems AX = B.
497///
498/// # Arguments
499/// * `a` - The coefficient matrix (n×n)
500/// * `b` - The right-hand side matrix (n×k)
501///
502/// # Returns
503/// The solution matrix X (n×k)
504pub fn solve_multiple_ndarray<T>(a: &Array2<T>, b: &Array2<T>) -> LapackResult<Array2<T>>
505where
506    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
507{
508    let (m, n) = a.dim();
509    let (b_rows, _b_cols) = b.dim();
510
511    if m != n {
512        return Err(LapackError::DimensionMismatch(
513            "Matrix must be square".to_string(),
514        ));
515    }
516    if b_rows != n {
517        return Err(LapackError::DimensionMismatch(
518            "b rows must match matrix dimension".to_string(),
519        ));
520    }
521
522    let a_mat = array2_to_mat(a);
523    let b_mat = array2_to_mat(b);
524
525    match solve::solve_multiple(a_mat.as_ref(), b_mat.as_ref()) {
526        Ok(x_mat) => Ok(mat_to_array2(&x_mat)),
527        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
528    }
529}
530
531/// Solves the least squares problem min ||Ax - b||.
532pub fn lstsq_ndarray<T>(a: &Array2<T>, b: &Array1<T>) -> LapackResult<Array1<T>>
533where
534    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
535{
536    let m = a.dim().0;
537    let a_mat = array2_to_mat(a);
538    // Convert b to a column vector matrix
539    let mut b_mat: Mat<T> = Mat::zeros(m, 1);
540    for i in 0..m {
541        b_mat[(i, 0)] = b[i];
542    }
543
544    match solve::lstsq(a_mat.as_ref(), b_mat.as_ref()) {
545        Ok(result) => {
546            // Extract solution column vector
547            let n = result.solution.nrows();
548            let x: Vec<T> = (0..n).map(|i| result.solution[(i, 0)]).collect();
549            Ok(Array1::from_vec(x))
550        }
551        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
552    }
553}
554
555// =============================================================================
556// Matrix Inverse
557// =============================================================================
558
559/// Computes the inverse of a matrix.
560pub fn inv_ndarray<T>(a: &Array2<T>) -> LapackResult<Array2<T>>
561where
562    T: Field + Clone + bytemuck::Zeroable,
563{
564    let (m, n) = a.dim();
565    if m != n {
566        return Err(LapackError::DimensionMismatch(
567            "Matrix must be square".to_string(),
568        ));
569    }
570
571    let a_mat = array2_to_mat(a);
572
573    match oxiblas_lapack::utils::inv(a_mat.as_ref()) {
574        Ok(inv_mat) => Ok(mat_to_array2(&inv_mat)),
575        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
576    }
577}
578
579/// Computes the Moore-Penrose pseudo-inverse.
580pub fn pinv_ndarray<T>(a: &Array2<T>) -> LapackResult<Array2<T>>
581where
582    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
583{
584    let a_mat = array2_to_mat(a);
585
586    match oxiblas_lapack::utils::pinv_default(a_mat.as_ref()) {
587        Ok(result) => Ok(mat_to_array2(&result.pinv)),
588        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
589    }
590}
591
592// =============================================================================
593// Determinant
594// =============================================================================
595
596/// Computes the determinant of a matrix.
597pub fn det_ndarray<T>(a: &Array2<T>) -> LapackResult<T>
598where
599    T: Field + Clone + bytemuck::Zeroable,
600{
601    let (m, n) = a.dim();
602    if m != n {
603        return Err(LapackError::DimensionMismatch(
604            "Matrix must be square".to_string(),
605        ));
606    }
607
608    let a_mat = array2_to_mat(a);
609
610    match oxiblas_lapack::utils::det(a_mat.as_ref()) {
611        Ok(d) => Ok(d),
612        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
613    }
614}
615
616// =============================================================================
617// Condition Number
618// =============================================================================
619
620/// Computes the condition number of a matrix (using 2-norm).
621pub fn cond_ndarray<T>(a: &Array2<T>) -> LapackResult<T>
622where
623    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
624{
625    let svd_result = svd_ndarray(a)?;
626    let n = svd_result.s.len();
627
628    if n == 0 {
629        return Ok(T::one());
630    }
631
632    let sigma_max = svd_result.s[0];
633    let sigma_min = svd_result.s[n - 1];
634
635    // Check if sigma_min is very small
636    if sigma_min == T::zero() {
637        // Return a large number to indicate ill-conditioning
638        Ok(T::from_f64(1e15).unwrap_or(T::one()))
639    } else {
640        Ok(sigma_max / sigma_min)
641    }
642}
643
644// =============================================================================
645// Rank
646// =============================================================================
647
648/// Computes the numerical rank of a matrix.
649pub fn rank_ndarray<T>(a: &Array2<T>) -> LapackResult<usize>
650where
651    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
652{
653    let (m, n) = a.dim();
654    let svd_result = svd_ndarray(a)?;
655
656    if svd_result.s.is_empty() {
657        return Ok(0);
658    }
659
660    // Default tolerance: max(m,n) * eps * sigma_max
661    let sigma_max = svd_result.s[0];
662    let eps = T::from_f64(1e-14).unwrap_or(T::zero());
663    let dim_scale = T::from_f64(m.max(n) as f64).unwrap_or(T::one());
664    let tol = dim_scale * eps * sigma_max;
665
666    Ok(svd_result.rank(tol))
667}
668
669// =============================================================================
670// Randomized SVD
671// =============================================================================
672
673/// Result of randomized SVD.
674#[derive(Debug, Clone)]
675pub struct RandomizedSvdResult<T> {
676    /// Left singular vectors U (m × k)
677    pub u: Array2<T>,
678    /// Singular values σ (k elements, sorted descending)
679    pub s: Array1<T>,
680    /// Right singular vectors V (n × k), NOT V^T
681    pub v: Array2<T>,
682}
683
684/// Computes a randomized SVD approximation of a matrix.
685///
686/// Uses randomized projections to compute a rank-k approximation efficiently,
687/// particularly useful for large matrices where only the top singular values
688/// are needed.
689///
690/// # Arguments
691/// * `a` - The input matrix (m×n)
692/// * `k` - Target rank (number of singular values to compute)
693///
694/// # Returns
695/// Truncated SVD with k singular values and vectors
696///
697/// # Algorithm
698/// Uses the Halko-Martinsson-Tropp randomized algorithm:
699/// 1. Generate random test matrix Ω
700/// 2. Compute Y = A × Ω to sample column space
701/// 3. QR factorize Y to get orthonormal basis Q
702/// 4. Project B = Q^T × A
703/// 5. Compute full SVD of B
704/// 6. Recover U = Q × Ũ
705pub fn rsvd_ndarray<T>(a: &Array2<T>, k: usize) -> LapackResult<RandomizedSvdResult<T>>
706where
707    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
708{
709    let mat = array2_to_mat(a);
710
711    match svd::RandomizedSvd::compute(mat.as_ref(), k) {
712        Ok(rsvd) => {
713            let u = mat_ref_to_array2(rsvd.u());
714            let s = Array1::from_vec(rsvd.singular_values().to_vec());
715            let v = mat_ref_to_array2(rsvd.v());
716
717            Ok(RandomizedSvdResult { u, s, v })
718        }
719        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
720    }
721}
722
723/// Computes randomized SVD with power iteration for improved accuracy.
724///
725/// Power iteration emphasizes dominant singular values and improves accuracy
726/// for matrices with slowly decaying singular values.
727///
728/// # Arguments
729/// * `a` - The input matrix (m×n)
730/// * `k` - Target rank
731/// * `power_iterations` - Number of power iterations (typically 1-3)
732pub fn rsvd_power_ndarray<T>(
733    a: &Array2<T>,
734    k: usize,
735    power_iterations: usize,
736) -> LapackResult<RandomizedSvdResult<T>>
737where
738    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
739{
740    let mat = array2_to_mat(a);
741
742    let config = svd::RandomizedSvdConfig::new(k).with_power_iterations(power_iterations);
743
744    match svd::RandomizedSvd::compute_with_config(mat.as_ref(), config) {
745        Ok(rsvd) => {
746            let u = mat_ref_to_array2(rsvd.u());
747            let s = Array1::from_vec(rsvd.singular_values().to_vec());
748            let v = mat_ref_to_array2(rsvd.v());
749
750            Ok(RandomizedSvdResult { u, s, v })
751        }
752        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
753    }
754}
755
756// =============================================================================
757// Schur Decomposition
758// =============================================================================
759
760/// Result of Schur decomposition.
761#[derive(Debug, Clone)]
762pub struct SchurResult<T> {
763    /// Orthogonal matrix Q (Schur vectors)
764    pub q: Array2<T>,
765    /// Quasi-upper triangular matrix T (Schur form)
766    pub t: Array2<T>,
767    /// Eigenvalues (real and complex pairs)
768    pub eigenvalues: Vec<Eigenvalue<T>>,
769}
770
771/// Computes the real Schur decomposition of a square matrix.
772///
773/// A = Q T Q^T where:
774/// - Q is orthogonal (Q^T Q = I)
775/// - T is quasi-upper triangular (upper triangular with possible 2×2 blocks
776///   on the diagonal for complex eigenvalue pairs)
777///
778/// # Arguments
779/// * `a` - The input square matrix (n×n)
780///
781/// # Returns
782/// Schur decomposition with Q, T, and eigenvalues
783pub fn schur_ndarray<T>(a: &Array2<T>) -> LapackResult<SchurResult<T>>
784where
785    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
786{
787    let (m, n) = a.dim();
788    if m != n {
789        return Err(LapackError::DimensionMismatch(
790            "Matrix must be square".to_string(),
791        ));
792    }
793
794    let mat = array2_to_mat(a);
795
796    match evd::Schur::compute(mat.as_ref()) {
797        Ok(schur) => {
798            let q = mat_ref_to_array2(schur.q());
799            let t = mat_ref_to_array2(schur.t());
800            let eigenvalues = schur.eigenvalues().to_vec();
801
802            Ok(SchurResult { q, t, eigenvalues })
803        }
804        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
805    }
806}
807
808// =============================================================================
809// General Eigenvalue Decomposition
810// =============================================================================
811
812/// Result of general eigenvalue decomposition.
813#[derive(Debug, Clone)]
814pub struct GeneralEvdResult<T> {
815    /// Eigenvalues (real and imaginary parts)
816    pub eigenvalues: Vec<Eigenvalue<T>>,
817    /// Right eigenvectors (real parts), if computed
818    pub eigenvectors_real: Option<Array2<T>>,
819    /// Right eigenvectors (imaginary parts), if computed
820    pub eigenvectors_imag: Option<Array2<T>>,
821    /// Left eigenvectors (real parts), if computed
822    pub left_eigenvectors_real: Option<Array2<T>>,
823    /// Left eigenvectors (imaginary parts), if computed
824    pub left_eigenvectors_imag: Option<Array2<T>>,
825}
826
827/// Computes eigenvalues of a general (non-symmetric) matrix.
828///
829/// For a real matrix, eigenvalues may be complex. They are returned as
830/// real/imaginary pairs. Eigenvectors are also split into real and
831/// imaginary parts.
832///
833/// # Arguments
834/// * `a` - The input square matrix (n×n)
835///
836/// # Returns
837/// Eigenvalues and eigenvectors (split into real/imaginary parts)
838pub fn eig_ndarray<T>(a: &Array2<T>) -> LapackResult<GeneralEvdResult<T>>
839where
840    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
841{
842    let (m, n) = a.dim();
843    if m != n {
844        return Err(LapackError::DimensionMismatch(
845            "Matrix must be square".to_string(),
846        ));
847    }
848
849    let mat = array2_to_mat(a);
850
851    match evd::GeneralEvd::compute(mat.as_ref()) {
852        Ok(evd_result) => {
853            let eigenvalues = evd_result.eigenvalues().to_vec();
854
855            // Get right eigenvectors (real and imaginary parts)
856            let eigenvectors_real = evd_result
857                .eigenvectors_real()
858                .map(|vr| mat_ref_to_array2(vr));
859
860            let eigenvectors_imag = evd_result
861                .eigenvectors_imag()
862                .map(|vi| mat_ref_to_array2(vi));
863
864            // Get left eigenvectors (real and imaginary parts)
865            let left_eigenvectors_real = evd_result
866                .left_eigenvectors_real()
867                .map(|vl| mat_ref_to_array2(vl));
868
869            let left_eigenvectors_imag = evd_result
870                .left_eigenvectors_imag()
871                .map(|vl| mat_ref_to_array2(vl));
872
873            Ok(GeneralEvdResult {
874                eigenvalues,
875                eigenvectors_real,
876                eigenvectors_imag,
877                left_eigenvectors_real,
878                left_eigenvectors_imag,
879            })
880        }
881        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
882    }
883}
884
885/// Computes only the eigenvalues of a general matrix.
886pub fn eigvals_ndarray<T>(a: &Array2<T>) -> LapackResult<Vec<Eigenvalue<T>>>
887where
888    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
889{
890    let (m, n) = a.dim();
891    if m != n {
892        return Err(LapackError::DimensionMismatch(
893            "Matrix must be square".to_string(),
894        ));
895    }
896
897    let mat = array2_to_mat(a);
898
899    match evd::GeneralEvd::eigenvalues_only(mat.as_ref()) {
900        Ok(evd_result) => Ok(evd_result.eigenvalues().to_vec()),
901        Err(e) => Err(LapackError::NotConverged(format!("{e:?}"))),
902    }
903}
904
905// =============================================================================
906// Tridiagonal Solvers
907// =============================================================================
908
909/// Solves a tridiagonal system of equations.
910///
911/// Solves T x = b where T is a tridiagonal matrix.
912///
913/// # Arguments
914/// * `dl` - Lower diagonal (n-1 elements)
915/// * `d` - Main diagonal (n elements)
916/// * `du` - Upper diagonal (n-1 elements)
917/// * `b` - Right-hand side vector (n elements)
918///
919/// # Returns
920/// The solution vector x
921pub fn tridiag_solve_ndarray<T>(
922    dl: &Array1<T>,
923    d: &Array1<T>,
924    du: &Array1<T>,
925    b: &Array1<T>,
926) -> LapackResult<Array1<T>>
927where
928    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
929{
930    let n = d.len();
931
932    if dl.len() != n - 1 || du.len() != n - 1 || b.len() != n {
933        return Err(LapackError::DimensionMismatch(
934            "Tridiagonal dimensions must be consistent".to_string(),
935        ));
936    }
937
938    let dl_vec: Vec<T> = dl.iter().cloned().collect();
939    let d_vec: Vec<T> = d.iter().cloned().collect();
940    let du_vec: Vec<T> = du.iter().cloned().collect();
941    let b_vec: Vec<T> = b.iter().cloned().collect();
942
943    match solve::tridiag_solve(&dl_vec, &d_vec, &du_vec, &b_vec) {
944        Ok(x) => Ok(Array1::from_vec(x)),
945        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
946    }
947}
948
949/// Solves a symmetric positive definite tridiagonal system.
950///
951/// Solves T x = b where T is symmetric positive definite and tridiagonal.
952/// Uses specialized algorithm that's more efficient for SPD matrices.
953///
954/// # Arguments
955/// * `d` - Main diagonal (n elements, positive)
956/// * `e` - Off-diagonal (n-1 elements)
957/// * `b` - Right-hand side vector (n elements)
958///
959/// # Returns
960/// The solution vector x
961pub fn tridiag_solve_spd_ndarray<T>(
962    d: &Array1<T>,
963    e: &Array1<T>,
964    b: &Array1<T>,
965) -> LapackResult<Array1<T>>
966where
967    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
968{
969    let n = d.len();
970
971    if e.len() != n - 1 || b.len() != n {
972        return Err(LapackError::DimensionMismatch(
973            "Tridiagonal dimensions must be consistent".to_string(),
974        ));
975    }
976
977    let d_vec: Vec<T> = d.iter().cloned().collect();
978    let e_vec: Vec<T> = e.iter().cloned().collect();
979    let b_vec: Vec<T> = b.iter().cloned().collect();
980
981    match solve::tridiag_solve_spd(&d_vec, &e_vec, &b_vec) {
982        Ok(x) => Ok(Array1::from_vec(x)),
983        Err(e) => Err(LapackError::NotPositiveDefinite(format!("{e:?}"))),
984    }
985}
986
987/// Solves multiple tridiagonal systems with the same matrix.
988///
989/// # Arguments
990/// * `dl` - Lower diagonal (n-1 elements)
991/// * `d` - Main diagonal (n elements)
992/// * `du` - Upper diagonal (n-1 elements)
993/// * `b` - Right-hand side matrix (n × nrhs)
994///
995/// # Returns
996/// The solution matrix X (n × nrhs)
997pub fn tridiag_solve_multiple_ndarray<T>(
998    dl: &Array1<T>,
999    d: &Array1<T>,
1000    du: &Array1<T>,
1001    b: &Array2<T>,
1002) -> LapackResult<Array2<T>>
1003where
1004    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
1005{
1006    let n = d.len();
1007    let (b_rows, _b_cols) = b.dim();
1008
1009    if dl.len() != n - 1 || du.len() != n - 1 || b_rows != n {
1010        return Err(LapackError::DimensionMismatch(
1011            "Tridiagonal dimensions must be consistent".to_string(),
1012        ));
1013    }
1014
1015    let dl_vec: Vec<T> = dl.iter().cloned().collect();
1016    let d_vec: Vec<T> = d.iter().cloned().collect();
1017    let du_vec: Vec<T> = du.iter().cloned().collect();
1018    let b_mat = array2_to_mat(b);
1019
1020    match solve::tridiag_solve_multiple(&dl_vec, &d_vec, &du_vec, b_mat.as_ref()) {
1021        Ok(x_mat) => Ok(mat_to_array2(&x_mat)),
1022        Err(e) => Err(LapackError::Singular(format!("{e:?}"))),
1023    }
1024}
1025
1026// =============================================================================
1027// Low-Rank Approximation
1028// =============================================================================
1029
1030/// Computes a low-rank approximation of a matrix.
1031///
1032/// Returns A_k = U_k Σ_k V_k^T, the best rank-k approximation in Frobenius norm.
1033///
1034/// # Arguments
1035/// * `a` - The input matrix (m×n)
1036/// * `k` - Target rank
1037///
1038/// # Returns
1039/// The rank-k approximation as a matrix
1040pub fn low_rank_approx_ndarray<T>(a: &Array2<T>, k: usize) -> LapackResult<Array2<T>>
1041where
1042    T: Field + Clone + bytemuck::Zeroable + oxiblas_core::scalar::Real,
1043{
1044    let mat = array2_to_mat(a);
1045
1046    match svd::low_rank_approximation(mat.as_ref(), k) {
1047        Ok(approx) => Ok(mat_to_array2(&approx)),
1048        Err(e) => Err(LapackError::Other(format!("{e:?}"))),
1049    }
1050}
1051
1052#[cfg(test)]
1053mod tests {
1054    use super::*;
1055    use ndarray::array;
1056
1057    #[test]
1058    fn test_lu_decomposition() {
1059        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1060        let lu = lu_ndarray(&a).unwrap();
1061
1062        // Verify L * U ≈ P * A
1063        let n = a.dim().0;
1064        for i in 0..n {
1065            for j in 0..n {
1066                let mut sum = 0.0f64;
1067                for k in 0..n {
1068                    sum += lu.l[[i, k]] * lu.u[[k, j]];
1069                }
1070                let perm_i = lu.perm.iter().position(|&p| p == i).unwrap();
1071                assert!((sum - a[[perm_i, j]]).abs() < 1e-10);
1072            }
1073        }
1074    }
1075
1076    #[test]
1077    fn test_lu_determinant() {
1078        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1079        let lu = lu_ndarray(&a).unwrap();
1080        let det = lu.det();
1081        // det = 2*3 - 1*1 = 5
1082        assert!((det - 5.0).abs() < 1e-10);
1083    }
1084
1085    #[test]
1086    fn test_lu_solve() {
1087        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1088        let b = array![5.0f64, 7.0];
1089        let lu = lu_ndarray(&a).unwrap();
1090        let x = lu.solve(&b);
1091
1092        // Verify A * x ≈ b
1093        let ax0 = a[[0, 0]] * x[0] + a[[0, 1]] * x[1];
1094        let ax1 = a[[1, 0]] * x[0] + a[[1, 1]] * x[1];
1095        assert!((ax0 - b[0]).abs() < 1e-10);
1096        assert!((ax1 - b[1]).abs() < 1e-10);
1097    }
1098
1099    #[test]
1100    fn test_qr_decomposition() {
1101        let a = array![[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
1102        let qr = qr_ndarray(&a).unwrap();
1103
1104        // Q should be orthogonal: Q^T * Q = I
1105        let qt = qr.q.t();
1106        let qtq = crate::blas::matmul(&qt.to_owned(), &qr.q);
1107        for i in 0..qtq.dim().0 {
1108            for j in 0..qtq.dim().1 {
1109                let expected = if i == j { 1.0 } else { 0.0 };
1110                assert!(
1111                    (qtq[[i, j]] - expected).abs() < 1e-10,
1112                    "Q^T Q[{},{}] = {}, expected {}",
1113                    i,
1114                    j,
1115                    qtq[[i, j]],
1116                    expected
1117                );
1118            }
1119        }
1120
1121        // Q * R should equal A
1122        let qr_product = crate::blas::matmul(&qr.q, &qr.r);
1123        for i in 0..a.dim().0 {
1124            for j in 0..a.dim().1 {
1125                assert!(
1126                    (qr_product[[i, j]] - a[[i, j]]).abs() < 1e-10,
1127                    "QR[{},{}] = {}, A = {}",
1128                    i,
1129                    j,
1130                    qr_product[[i, j]],
1131                    a[[i, j]]
1132                );
1133            }
1134        }
1135    }
1136
1137    #[test]
1138    fn test_svd() {
1139        let a = array![[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
1140        let svd = svd_ndarray(&a).unwrap();
1141
1142        // Reconstruct A from SVD: U * S * V^T
1143        let (m, n) = a.dim();
1144        let k = svd.s.len();
1145
1146        for i in 0..m {
1147            for j in 0..n {
1148                let mut sum = 0.0f64;
1149                for l in 0..k {
1150                    sum += svd.u[[i, l]] * svd.s[l] * svd.vt[[l, j]];
1151                }
1152                assert!(
1153                    (sum - a[[i, j]]).abs() < 1e-10,
1154                    "Reconstructed[{},{}] = {}, A = {}",
1155                    i,
1156                    j,
1157                    sum,
1158                    a[[i, j]]
1159                );
1160            }
1161        }
1162    }
1163
1164    #[test]
1165    fn test_symmetric_evd() {
1166        // Symmetric matrix
1167        let a = array![[4.0f64, 1.0], [1.0, 3.0]];
1168        let evd = eig_symmetric(&a).unwrap();
1169
1170        // Eigenvalues should be real and positive for this matrix
1171        assert!(evd.eigenvalues.len() == 2);
1172
1173        // Verify A * V = V * Λ for each eigenvalue/eigenvector pair
1174        for (idx, &lambda) in evd.eigenvalues.iter().enumerate() {
1175            let v = evd.eigenvectors.column(idx);
1176            let av = crate::blas::matvec(&a, &v.to_owned());
1177            let lambda_v: Array1<f64> = v.iter().map(|&x| lambda * x).collect();
1178
1179            for i in 0..2 {
1180                assert!(
1181                    (av[i] - lambda_v[i]).abs() < 1e-10,
1182                    "Av[{}] = {}, λv[{}] = {}",
1183                    i,
1184                    av[i],
1185                    i,
1186                    lambda_v[i]
1187                );
1188            }
1189        }
1190    }
1191
1192    #[test]
1193    fn test_cholesky() {
1194        // Positive definite matrix
1195        let a = array![[4.0f64, 2.0], [2.0, 5.0]];
1196        let chol = cholesky_ndarray(&a).unwrap();
1197
1198        // Verify L * L^T = A
1199        let lt = chol.l.t();
1200        let llt = crate::blas::matmul(&chol.l, &lt.to_owned());
1201
1202        for i in 0..2 {
1203            for j in 0..2 {
1204                assert!(
1205                    (llt[[i, j]] - a[[i, j]]).abs() < 1e-10,
1206                    "LLT[{},{}] = {}, A = {}",
1207                    i,
1208                    j,
1209                    llt[[i, j]],
1210                    a[[i, j]]
1211                );
1212            }
1213        }
1214    }
1215
1216    #[test]
1217    fn test_solve() {
1218        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1219        let b = array![5.0f64, 7.0];
1220        let x = solve_ndarray(&a, &b).unwrap();
1221
1222        // Verify A * x = b
1223        let ax = crate::blas::matvec(&a, &x);
1224        assert!((ax[0] - b[0]).abs() < 1e-10);
1225        assert!((ax[1] - b[1]).abs() < 1e-10);
1226    }
1227
1228    #[test]
1229    fn test_inverse() {
1230        let a = array![[4.0f64, 7.0], [2.0, 6.0]];
1231        let a_inv = inv_ndarray(&a).unwrap();
1232
1233        // A * A^-1 = I
1234        let product = crate::blas::matmul(&a, &a_inv);
1235        for i in 0..2 {
1236            for j in 0..2 {
1237                let expected = if i == j { 1.0 } else { 0.0 };
1238                assert!(
1239                    (product[[i, j]] - expected).abs() < 1e-10,
1240                    "A*A^-1[{},{}] = {}, expected {}",
1241                    i,
1242                    j,
1243                    product[[i, j]],
1244                    expected
1245                );
1246            }
1247        }
1248    }
1249
1250    #[test]
1251    fn test_determinant() {
1252        let a = array![[2.0f64, 1.0], [1.0, 3.0]];
1253        let det = det_ndarray(&a).unwrap();
1254        // det = 2*3 - 1*1 = 5
1255        assert!((det - 5.0).abs() < 1e-10);
1256    }
1257
1258    #[test]
1259    fn test_condition_number() {
1260        let a = array![[1.0f64, 0.0], [0.0, 1.0]];
1261        let cond = cond_ndarray(&a).unwrap();
1262        // Identity matrix has condition number 1
1263        assert!((cond - 1.0).abs() < 1e-10);
1264    }
1265
1266    #[test]
1267    fn test_rank() {
1268        // Full rank matrix
1269        let a = array![[1.0f64, 2.0], [3.0, 4.0]];
1270        let r = rank_ndarray(&a).unwrap();
1271        assert_eq!(r, 2);
1272
1273        // Rank deficient matrix
1274        let b = array![[1.0f64, 2.0], [2.0, 4.0]];
1275        let r2 = rank_ndarray(&b).unwrap();
1276        assert_eq!(r2, 1);
1277    }
1278
1279    // =========================================================================
1280    // Randomized SVD Tests
1281    // =========================================================================
1282
1283    #[test]
1284    fn test_rsvd_basic() {
1285        // Create a low-rank matrix
1286        let a = array![
1287            [1.0f64, 2.0, 3.0, 4.0],
1288            [5.0, 6.0, 7.0, 8.0],
1289            [9.0, 10.0, 11.0, 12.0]
1290        ];
1291
1292        let rsvd = rsvd_ndarray(&a, 2).unwrap();
1293
1294        // Should have 2 singular values
1295        assert_eq!(rsvd.s.len(), 2);
1296
1297        // Singular values should be positive and in descending order
1298        assert!(rsvd.s[0] > rsvd.s[1]);
1299        assert!(rsvd.s[1] >= 0.0);
1300
1301        // U should be m×k
1302        assert_eq!(rsvd.u.dim(), (3, 2));
1303
1304        // V should be n×k
1305        assert_eq!(rsvd.v.dim(), (4, 2));
1306    }
1307
1308    #[test]
1309    fn test_rsvd_approximation_quality() {
1310        // Create a matrix with clear rank structure
1311        let a = Array2::from_shape_fn((10, 8), |(i, j)| (i as f64) * 0.1 + (j as f64) * 0.2);
1312
1313        let rsvd = rsvd_ndarray(&a, 2).unwrap();
1314
1315        // Reconstruct: A ≈ U * S * V^T
1316        let (m, n) = a.dim();
1317        let k = rsvd.s.len();
1318
1319        let mut approx: Array2<f64> = Array2::zeros((m, n));
1320        for i in 0..m {
1321            for j in 0..n {
1322                for l in 0..k {
1323                    approx[[i, j]] += rsvd.u[[i, l]] * rsvd.s[l] * rsvd.v[[j, l]];
1324                }
1325            }
1326        }
1327
1328        // The approximation should capture most of the matrix (rank-1 for this matrix)
1329        let mut diff_norm = 0.0f64;
1330        for i in 0..m {
1331            for j in 0..n {
1332                let diff = a[[i, j]] - approx[[i, j]];
1333                diff_norm += diff.powi(2);
1334            }
1335        }
1336        diff_norm = diff_norm.sqrt();
1337
1338        // Should be reasonably small
1339        assert!(diff_norm < 1e-10, "Reconstruction error = {}", diff_norm);
1340    }
1341
1342    #[test]
1343    fn test_rsvd_power_iteration() {
1344        let a = Array2::from_shape_fn((20, 15), |(i, j)| ((i * j) as f64).sin() + 0.1 * (i as f64));
1345
1346        let rsvd = rsvd_power_ndarray(&a, 3, 2).unwrap();
1347
1348        assert_eq!(rsvd.s.len(), 3);
1349        assert!(rsvd.s[0] >= rsvd.s[1]);
1350        assert!(rsvd.s[1] >= rsvd.s[2]);
1351    }
1352
1353    // =========================================================================
1354    // Schur Decomposition Tests
1355    // =========================================================================
1356
1357    #[test]
1358    fn test_schur_triangular() {
1359        // Already upper triangular matrix
1360        let a = array![[1.0f64, 2.0], [0.0, 3.0]];
1361
1362        let schur = schur_ndarray(&a).unwrap();
1363
1364        // Eigenvalues should be 1 and 3
1365        assert_eq!(schur.eigenvalues.len(), 2);
1366
1367        let evs: Vec<f64> = schur.eigenvalues.iter().map(|e| e.real).collect();
1368        assert!(evs.contains(&1.0) || evs.iter().any(|&x| (x - 1.0).abs() < 1e-10));
1369        assert!(evs.contains(&3.0) || evs.iter().any(|&x| (x - 3.0).abs() < 1e-10));
1370    }
1371
1372    #[test]
1373    fn test_schur_reconstruction() {
1374        let a = array![[4.0f64, 1.0], [2.0, 3.0]];
1375
1376        let schur = schur_ndarray(&a).unwrap();
1377
1378        // Verify A = Q * T * Q^T
1379        let qt = schur.q.t();
1380        let qt_owned = qt.to_owned();
1381        let qr_temp = crate::blas::matmul(&schur.q, &schur.t);
1382        let reconstructed = crate::blas::matmul(&qr_temp, &qt_owned);
1383
1384        for i in 0..2 {
1385            for j in 0..2 {
1386                assert!(
1387                    (reconstructed[[i, j]] - a[[i, j]]).abs() < 1e-10,
1388                    "Reconstruction failed at [{},{}]: {} vs {}",
1389                    i,
1390                    j,
1391                    reconstructed[[i, j]],
1392                    a[[i, j]]
1393                );
1394            }
1395        }
1396    }
1397
1398    #[test]
1399    fn test_schur_orthogonality() {
1400        let a = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 10.0]];
1401
1402        let schur = schur_ndarray(&a).unwrap();
1403
1404        // Q should be orthogonal: Q^T * Q = I
1405        let qt = schur.q.t();
1406        let qtq = crate::blas::matmul(&qt.to_owned(), &schur.q);
1407
1408        for i in 0..3 {
1409            for j in 0..3 {
1410                let expected = if i == j { 1.0 } else { 0.0 };
1411                assert!(
1412                    (qtq[[i, j]] - expected).abs() < 1e-10,
1413                    "Q^T Q[{},{}] = {}, expected {}",
1414                    i,
1415                    j,
1416                    qtq[[i, j]],
1417                    expected
1418                );
1419            }
1420        }
1421    }
1422
1423    // =========================================================================
1424    // General Eigenvalue Decomposition Tests
1425    // =========================================================================
1426
1427    #[test]
1428    fn test_eig_real_eigenvalues() {
1429        // Symmetric matrix has real eigenvalues
1430        let a = array![[4.0f64, 1.0], [1.0, 3.0]];
1431
1432        let evd = eig_ndarray(&a).unwrap();
1433
1434        assert_eq!(evd.eigenvalues.len(), 2);
1435
1436        // All eigenvalues should be real (imaginary part ≈ 0)
1437        for ev in &evd.eigenvalues {
1438            assert!(
1439                ev.imag.abs() < 1e-10,
1440                "Expected real eigenvalue, got imag = {}",
1441                ev.imag
1442            );
1443        }
1444    }
1445
1446    #[test]
1447    fn test_eig_complex_eigenvalues() {
1448        // Rotation matrix has complex eigenvalues (±i)
1449        let a = array![[0.0f64, -1.0], [1.0, 0.0]];
1450
1451        let evd = eig_ndarray(&a).unwrap();
1452
1453        assert_eq!(evd.eigenvalues.len(), 2);
1454
1455        // Should have eigenvalues with nonzero imaginary parts
1456        let has_complex = evd.eigenvalues.iter().any(|e| e.imag.abs() > 0.5);
1457        assert!(has_complex, "Expected complex eigenvalues");
1458
1459        // Real parts should be close to 0
1460        for ev in &evd.eigenvalues {
1461            assert!(ev.real.abs() < 1e-10, "Expected real part ≈ 0");
1462        }
1463    }
1464
1465    #[test]
1466    fn test_eigvals_only() {
1467        let a = array![[1.0f64, 2.0], [0.0, 3.0]];
1468
1469        let evs = eigvals_ndarray(&a).unwrap();
1470
1471        assert_eq!(evs.len(), 2);
1472
1473        // Eigenvalues of upper triangular matrix are diagonal elements
1474        let reals: Vec<f64> = evs.iter().map(|e| e.real).collect();
1475        assert!(reals.iter().any(|&x| (x - 1.0).abs() < 1e-10));
1476        assert!(reals.iter().any(|&x| (x - 3.0).abs() < 1e-10));
1477    }
1478
1479    // =========================================================================
1480    // Tridiagonal Solver Tests
1481    // =========================================================================
1482
1483    #[test]
1484    fn test_tridiag_solve() {
1485        // Tridiagonal matrix:
1486        // [2  -1  0 ]   [x0]   [1]
1487        // [-1  2 -1 ] * [x1] = [0]
1488        // [0  -1  2 ]   [x2]   [1]
1489        let dl = array![-1.0f64, -1.0];
1490        let d = array![2.0f64, 2.0, 2.0];
1491        let du = array![-1.0f64, -1.0];
1492        let b = array![1.0f64, 0.0, 1.0];
1493
1494        let x = tridiag_solve_ndarray(&dl, &d, &du, &b).unwrap();
1495
1496        assert_eq!(x.len(), 3);
1497
1498        // Verify solution: T * x ≈ b
1499        let tx0 = d[0] * x[0] + du[0] * x[1];
1500        let tx1 = dl[0] * x[0] + d[1] * x[1] + du[1] * x[2];
1501        let tx2 = dl[1] * x[1] + d[2] * x[2];
1502
1503        assert!((tx0 - b[0]).abs() < 1e-10);
1504        assert!((tx1 - b[1]).abs() < 1e-10);
1505        assert!((tx2 - b[2]).abs() < 1e-10);
1506    }
1507
1508    #[test]
1509    fn test_tridiag_solve_spd() {
1510        // SPD tridiagonal matrix:
1511        // [4 1 0]
1512        // [1 4 1]
1513        // [0 1 4]
1514        // This is diagonally dominant -> SPD
1515        let d = array![4.0f64, 4.0, 4.0];
1516        let e = array![1.0f64, 1.0]; // Off-diagonal elements
1517        let b = array![5.0f64, 6.0, 5.0];
1518
1519        let x = tridiag_solve_spd_ndarray(&d, &e, &b).unwrap();
1520
1521        assert_eq!(x.len(), 3);
1522
1523        // Verify solution: T * x = b where T is symmetric with d on diagonal, e on off-diagonals
1524        let tx0 = d[0] * x[0] + e[0] * x[1];
1525        let tx1 = e[0] * x[0] + d[1] * x[1] + e[1] * x[2];
1526        let tx2 = e[1] * x[1] + d[2] * x[2];
1527
1528        assert!((tx0 - b[0]).abs() < 1e-10, "tx0 = {}, b[0] = {}", tx0, b[0]);
1529        assert!((tx1 - b[1]).abs() < 1e-10, "tx1 = {}, b[1] = {}", tx1, b[1]);
1530        assert!((tx2 - b[2]).abs() < 1e-10, "tx2 = {}, b[2] = {}", tx2, b[2]);
1531    }
1532
1533    #[test]
1534    fn test_tridiag_solve_multiple() {
1535        let dl = array![-1.0f64, -1.0];
1536        let d = array![2.0f64, 2.0, 2.0];
1537        let du = array![-1.0f64, -1.0];
1538        let b = array![[1.0f64, 0.0], [0.0, 1.0], [1.0, 0.0]];
1539
1540        let x = tridiag_solve_multiple_ndarray(&dl, &d, &du, &b).unwrap();
1541
1542        assert_eq!(x.dim(), (3, 2));
1543
1544        // Each column should be the solution to T * x_j = b_j
1545        for j in 0..2 {
1546            let tx0 = d[0] * x[[0, j]] + du[0] * x[[1, j]];
1547            let tx1 = dl[0] * x[[0, j]] + d[1] * x[[1, j]] + du[1] * x[[2, j]];
1548            let tx2 = dl[1] * x[[1, j]] + d[2] * x[[2, j]];
1549
1550            assert!((tx0 - b[[0, j]]).abs() < 1e-10);
1551            assert!((tx1 - b[[1, j]]).abs() < 1e-10);
1552            assert!((tx2 - b[[2, j]]).abs() < 1e-10);
1553        }
1554    }
1555
1556    // =========================================================================
1557    // Low-Rank Approximation Tests
1558    // =========================================================================
1559
1560    #[test]
1561    fn test_low_rank_approx() {
1562        // Create a rank-1 matrix: outer product of two vectors
1563        let u = array![1.0f64, 2.0, 3.0];
1564        let v = array![4.0, 5.0, 6.0, 7.0];
1565
1566        let mut a = Array2::zeros((3, 4));
1567        for i in 0..3 {
1568            for j in 0..4 {
1569                a[[i, j]] = u[i] * v[j];
1570            }
1571        }
1572
1573        // Rank-1 approximation should be exact
1574        let approx = low_rank_approx_ndarray(&a, 1).unwrap();
1575
1576        assert_eq!(approx.dim(), a.dim());
1577
1578        for i in 0..3 {
1579            for j in 0..4 {
1580                assert!(
1581                    (approx[[i, j]] - a[[i, j]]).abs() < 1e-10,
1582                    "Approximation failed at [{},{}]",
1583                    i,
1584                    j
1585                );
1586            }
1587        }
1588    }
1589
1590    #[test]
1591    fn test_low_rank_approx_truncation() {
1592        let a = array![
1593            [1.0f64, 2.0, 3.0],
1594            [4.0, 5.0, 6.0],
1595            [7.0, 8.0, 9.0],
1596            [10.0, 11.0, 12.0]
1597        ];
1598
1599        let approx = low_rank_approx_ndarray(&a, 2).unwrap();
1600
1601        assert_eq!(approx.dim(), (4, 3));
1602
1603        // The approximation should not equal the original (rank 2 < rank A)
1604        // but should be close
1605        let mut diff_norm = 0.0f64;
1606        let mut orig_norm = 0.0f64;
1607        for i in 0..4 {
1608            for j in 0..3 {
1609                diff_norm += (a[[i, j]] - approx[[i, j]]).powi(2);
1610                orig_norm += a[[i, j]].powi(2);
1611            }
1612        }
1613
1614        // Relative error should be small (this matrix has rank 2)
1615        let rel_error = diff_norm.sqrt() / orig_norm.sqrt();
1616        assert!(rel_error < 0.1, "Relative error = {}", rel_error);
1617    }
1618}