scirs2_sparse/linalg/
decomposition.rs

1//! Matrix decomposition algorithms for sparse matrices
2//!
3//! This module provides various matrix decomposition algorithms optimized
4//! for sparse matrices, including LU, QR, Cholesky, and incomplete variants.
5
6use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::numeric::Float;
11use scirs2_core::SparseElement;
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::ops::{Add, Div, Mul, Sub};
15
16/// LU decomposition result
17#[derive(Debug, Clone)]
18pub struct LUResult<T>
19where
20    T: Float + SparseElement + Debug + Copy + 'static,
21{
22    /// Lower triangular factor
23    pub l: CsrArray<T>,
24    /// Upper triangular factor
25    pub u: CsrArray<T>,
26    /// Permutation matrix (as permutation vector)
27    pub p: Array1<usize>,
28    /// Whether decomposition was successful
29    pub success: bool,
30}
31
32/// QR decomposition result
33#[derive(Debug, Clone)]
34pub struct QRResult<T>
35where
36    T: Float + SparseElement + Debug + Copy + 'static,
37{
38    /// Orthogonal factor Q
39    pub q: CsrArray<T>,
40    /// Upper triangular factor R
41    pub r: CsrArray<T>,
42    /// Whether decomposition was successful
43    pub success: bool,
44}
45
46/// Cholesky decomposition result
47#[derive(Debug, Clone)]
48pub struct CholeskyResult<T>
49where
50    T: Float + SparseElement + Debug + Copy + 'static,
51{
52    /// Lower triangular Cholesky factor
53    pub l: CsrArray<T>,
54    /// Whether decomposition was successful
55    pub success: bool,
56}
57
58/// Pivoted Cholesky decomposition result
59#[derive(Debug, Clone)]
60pub struct PivotedCholeskyResult<T>
61where
62    T: Float + SparseElement + Debug + Copy + 'static,
63{
64    /// Lower triangular Cholesky factor
65    pub l: CsrArray<T>,
66    /// Permutation matrix (as permutation vector)
67    pub p: Array1<usize>,
68    /// Rank of the decomposition (number of positive eigenvalues)
69    pub rank: usize,
70    /// Whether decomposition was successful
71    pub success: bool,
72}
73
74/// Pivoting strategy for LU decomposition
75#[derive(Debug, Clone, Default)]
76pub enum PivotingStrategy {
77    /// No pivoting (fastest but potentially unstable)
78    None,
79    /// Partial pivoting - choose largest element in column (default)
80    #[default]
81    Partial,
82    /// Threshold pivoting - partial pivoting with threshold
83    Threshold(f64),
84    /// Scaled partial pivoting - account for row scaling
85    ScaledPartial,
86    /// Complete pivoting - choose largest element in submatrix (most stable but expensive)
87    Complete,
88    /// Rook pivoting - hybrid approach balancing stability and cost
89    Rook,
90}
91
92/// Options for LU decomposition
93#[derive(Debug, Clone)]
94pub struct LUOptions {
95    /// Pivoting strategy to use
96    pub pivoting: PivotingStrategy,
97    /// Threshold for numerical zero (default: 1e-14)
98    pub zero_threshold: f64,
99    /// Whether to check for singularity (default: true)
100    pub check_singular: bool,
101}
102
103impl Default for LUOptions {
104    fn default() -> Self {
105        Self {
106            pivoting: PivotingStrategy::default(),
107            zero_threshold: 1e-14,
108            check_singular: true,
109        }
110    }
111}
112
113/// Options for incomplete LU decomposition
114#[derive(Debug, Clone)]
115pub struct ILUOptions {
116    /// Drop tolerance for numerical stability
117    pub drop_tol: f64,
118    /// Fill factor (maximum fill-in ratio)
119    pub fill_factor: f64,
120    /// Maximum number of fill-in entries per row
121    pub max_fill_per_row: usize,
122    /// Pivoting strategy to use
123    pub pivoting: PivotingStrategy,
124}
125
126impl Default for ILUOptions {
127    fn default() -> Self {
128        Self {
129            drop_tol: 1e-4,
130            fill_factor: 2.0,
131            max_fill_per_row: 20,
132            pivoting: PivotingStrategy::default(),
133        }
134    }
135}
136
137/// Options for incomplete Cholesky decomposition
138#[derive(Debug, Clone)]
139pub struct ICOptions {
140    /// Drop tolerance for numerical stability
141    pub drop_tol: f64,
142    /// Fill factor (maximum fill-in ratio)
143    pub fill_factor: f64,
144    /// Maximum number of fill-in entries per row
145    pub max_fill_per_row: usize,
146}
147
148impl Default for ICOptions {
149    fn default() -> Self {
150        Self {
151            drop_tol: 1e-4,
152            fill_factor: 2.0,
153            max_fill_per_row: 20,
154        }
155    }
156}
157
158/// Compute sparse LU decomposition with partial pivoting (backward compatibility)
159///
160/// Computes the LU decomposition of a sparse matrix A such that P*A = L*U,
161/// where P is a permutation matrix, L is lower triangular, and U is upper triangular.
162///
163/// # Arguments
164///
165/// * `matrix` - The sparse matrix to decompose
166/// * `pivot_threshold` - Pivoting threshold for numerical stability (0.0 to 1.0)
167///
168/// # Returns
169///
170/// LU decomposition result
171///
172/// # Examples
173///
174/// ```
175/// use scirs2_sparse::linalg::lu_decomposition;
176/// use scirs2_sparse::csr_array::CsrArray;
177///
178/// // Create a sparse matrix
179/// let rows = vec![0, 0, 1, 2];
180/// let cols = vec![0, 1, 1, 2];
181/// let data = vec![2.0, 1.0, 3.0, 4.0];
182/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
183///
184/// let lu_result = lu_decomposition(&matrix, 0.1).unwrap();
185/// ```
186#[allow(dead_code)]
187pub fn lu_decomposition<T, S>(_matrix: &S, pivotthreshold: f64) -> SparseResult<LUResult<T>>
188where
189    T: Float
190        + SparseElement
191        + Debug
192        + Copy
193        + Add<Output = T>
194        + Sub<Output = T>
195        + Mul<Output = T>
196        + Div<Output = T>,
197    S: SparseArray<T>,
198{
199    // Use _threshold pivoting for backward compatibility
200    let options = LUOptions {
201        pivoting: PivotingStrategy::Threshold(pivotthreshold),
202        zero_threshold: 1e-14,
203        check_singular: true,
204    };
205
206    lu_decomposition_with_options(_matrix, Some(options))
207}
208
209/// Compute sparse LU decomposition with enhanced pivoting strategies
210///
211/// Computes the LU decomposition of a sparse matrix A such that P*A = L*U,
212/// where P is a permutation matrix, L is lower triangular, and U is upper triangular.
213/// This version supports multiple pivoting strategies for enhanced numerical stability.
214///
215/// # Arguments
216///
217/// * `matrix` - The sparse matrix to decompose
218/// * `options` - LU decomposition options (pivoting strategy, thresholds, etc.)
219///
220/// # Returns
221///
222/// LU decomposition result
223///
224/// # Examples
225///
226/// ```
227/// use scirs2_sparse::linalg::{lu_decomposition_with_options, LUOptions, PivotingStrategy};
228/// use scirs2_sparse::csr_array::CsrArray;
229///
230/// // Create a sparse matrix
231/// let rows = vec![0, 0, 1, 2];
232/// let cols = vec![0, 1, 1, 2];
233/// let data = vec![2.0, 1.0, 3.0, 4.0];
234/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
235///
236/// let options = LUOptions {
237///     pivoting: PivotingStrategy::ScaledPartial,
238///     zero_threshold: 1e-12,
239///     check_singular: true,
240/// };
241/// let lu_result = lu_decomposition_with_options(&matrix, Some(options)).unwrap();
242/// ```
243#[allow(dead_code)]
244pub fn lu_decomposition_with_options<T, S>(
245    matrix: &S,
246    options: Option<LUOptions>,
247) -> SparseResult<LUResult<T>>
248where
249    T: Float
250        + SparseElement
251        + Debug
252        + Copy
253        + Add<Output = T>
254        + Sub<Output = T>
255        + Mul<Output = T>
256        + Div<Output = T>,
257    S: SparseArray<T>,
258{
259    let opts = options.unwrap_or_default();
260    let (n, m) = matrix.shape();
261    if n != m {
262        return Err(SparseError::ValueError(
263            "Matrix must be square for LU decomposition".to_string(),
264        ));
265    }
266
267    // Convert to working format
268    let (row_indices, col_indices, values) = matrix.find();
269    let mut working_matrix = SparseWorkingMatrix::from_triplets(
270        row_indices.as_slice().unwrap(),
271        col_indices.as_slice().unwrap(),
272        values.as_slice().unwrap(),
273        n,
274    );
275
276    // Initialize permutations
277    let mut row_perm: Vec<usize> = (0..n).collect();
278    let mut col_perm: Vec<usize> = (0..n).collect();
279
280    // Compute row scaling factors for scaled partial pivoting
281    let mut row_scales = vec![T::sparse_one(); n];
282    if matches!(opts.pivoting, PivotingStrategy::ScaledPartial) {
283        for (i, scale) in row_scales.iter_mut().enumerate().take(n) {
284            let row_data = working_matrix.get_row(i);
285            let max_val = row_data
286                .values()
287                .map(|&v| v.abs())
288                .fold(T::sparse_zero(), |a, b| if a > b { a } else { b });
289            if max_val > T::sparse_zero() {
290                *scale = max_val;
291            }
292        }
293    }
294
295    // Gaussian elimination with enhanced pivoting
296    for k in 0..n - 1 {
297        // Find pivot using selected strategy
298        let (pivot_row, pivot_col) =
299            find_enhanced_pivot(&working_matrix, k, &row_perm, &col_perm, &row_scales, &opts)?;
300
301        // Apply row and column permutations
302        if pivot_row != k {
303            row_perm.swap(k, pivot_row);
304        }
305        if pivot_col != k
306            && matches!(
307                opts.pivoting,
308                PivotingStrategy::Complete | PivotingStrategy::Rook
309            )
310        {
311            col_perm.swap(k, pivot_col);
312            // When columns are swapped, we need to update all matrix elements
313            for &row_idx in row_perm.iter().take(n) {
314                let temp = working_matrix.get(row_idx, k);
315                working_matrix.set(row_idx, k, working_matrix.get(row_idx, pivot_col));
316                working_matrix.set(row_idx, pivot_col, temp);
317            }
318        }
319
320        let actual_pivot_row = row_perm[k];
321        let actual_pivot_col = col_perm[k];
322        let pivot_value = working_matrix.get(actual_pivot_row, actual_pivot_col);
323
324        // Check for numerical singularity
325        if opts.check_singular && pivot_value.abs() < T::from(opts.zero_threshold).unwrap() {
326            return Ok(LUResult {
327                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
328                u: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
329                p: Array1::from_vec(row_perm),
330                success: false,
331            });
332        }
333
334        // Eliminate below pivot
335        for &actual_row_i in row_perm.iter().take(n).skip(k + 1) {
336            let factor = working_matrix.get(actual_row_i, actual_pivot_col) / pivot_value;
337
338            if !SparseElement::is_zero(&factor) {
339                // Store multiplier in L
340                working_matrix.set(actual_row_i, actual_pivot_col, factor);
341
342                // Update row i
343                let pivot_row_data = working_matrix.get_row(actual_pivot_row);
344                for (col, &value) in &pivot_row_data {
345                    if *col > k {
346                        let old_val = working_matrix.get(actual_row_i, *col);
347                        working_matrix.set(actual_row_i, *col, old_val - factor * value);
348                    }
349                }
350            }
351        }
352    }
353
354    // Extract L and U matrices with proper permutation
355    let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
356        extract_lu_factors(&working_matrix, &row_perm, n);
357
358    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
359    let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
360
361    Ok(LUResult {
362        l,
363        u,
364        p: Array1::from_vec(row_perm),
365        success: true,
366    })
367}
368
369/// Compute sparse QR decomposition using Givens rotations
370///
371/// Computes the QR decomposition of a sparse matrix A = Q*R,
372/// where Q is orthogonal and R is upper triangular.
373///
374/// # Arguments
375///
376/// * `matrix` - The sparse matrix to decompose
377///
378/// # Returns
379///
380/// QR decomposition result
381///
382/// # Examples
383///
384/// ```
385/// use scirs2_sparse::linalg::qr_decomposition;
386/// use scirs2_sparse::csr_array::CsrArray;
387///
388/// let rows = vec![0, 1, 2];
389/// let cols = vec![0, 0, 1];
390/// let data = vec![1.0, 2.0, 3.0];
391/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
392///
393/// let qr_result = qr_decomposition(&matrix).unwrap();
394/// ```
395#[allow(dead_code)]
396pub fn qr_decomposition<T, S>(matrix: &S) -> SparseResult<QRResult<T>>
397where
398    T: Float
399        + SparseElement
400        + Debug
401        + Copy
402        + Add<Output = T>
403        + Sub<Output = T>
404        + Mul<Output = T>
405        + Div<Output = T>,
406    S: SparseArray<T>,
407{
408    let (m, n) = matrix.shape();
409
410    // Convert to dense for QR (sparse QR is complex)
411    let dense_matrix = matrix.to_array();
412
413    // Simple Gram-Schmidt QR decomposition
414    let mut q = Array2::zeros((m, n));
415    let mut r = Array2::zeros((n, n));
416
417    for j in 0..n {
418        // Copy column j
419        for i in 0..m {
420            q[[i, j]] = dense_matrix[[i, j]];
421        }
422
423        // Orthogonalize against previous columns
424        for k in 0..j {
425            let mut dot = T::sparse_zero();
426            for i in 0..m {
427                dot = dot + q[[i, k]] * dense_matrix[[i, j]];
428            }
429            r[[k, j]] = dot;
430
431            for i in 0..m {
432                q[[i, j]] = q[[i, j]] - dot * q[[i, k]];
433            }
434        }
435
436        // Normalize
437        let mut norm = T::sparse_zero();
438        for i in 0..m {
439            norm = norm + q[[i, j]] * q[[i, j]];
440        }
441        norm = norm.sqrt();
442        r[[j, j]] = norm;
443
444        if !SparseElement::is_zero(&norm) {
445            for i in 0..m {
446                q[[i, j]] = q[[i, j]] / norm;
447            }
448        }
449    }
450
451    // Convert back to sparse
452    let q_sparse = dense_to_sparse(&q)?;
453    let r_sparse = dense_to_sparse(&r)?;
454
455    Ok(QRResult {
456        q: q_sparse,
457        r: r_sparse,
458        success: true,
459    })
460}
461
462/// Compute sparse Cholesky decomposition
463///
464/// Computes the Cholesky decomposition of a symmetric positive definite matrix A = L*L^T,
465/// where L is lower triangular.
466///
467/// # Arguments
468///
469/// * `matrix` - The symmetric positive definite sparse matrix
470///
471/// # Returns
472///
473/// Cholesky decomposition result
474///
475/// # Examples
476///
477/// ```
478/// use scirs2_sparse::linalg::cholesky_decomposition;
479/// use scirs2_sparse::csr_array::CsrArray;
480///
481/// // Create a simple SPD matrix
482/// let rows = vec![0, 1, 1, 2, 2, 2];
483/// let cols = vec![0, 0, 1, 0, 1, 2];
484/// let data = vec![4.0, 2.0, 5.0, 1.0, 3.0, 6.0];
485/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
486///
487/// let chol_result = cholesky_decomposition(&matrix).unwrap();
488/// ```
489#[allow(dead_code)]
490pub fn cholesky_decomposition<T, S>(matrix: &S) -> SparseResult<CholeskyResult<T>>
491where
492    T: Float
493        + SparseElement
494        + Debug
495        + Copy
496        + Add<Output = T>
497        + Sub<Output = T>
498        + Mul<Output = T>
499        + Div<Output = T>,
500    S: SparseArray<T>,
501{
502    let (n, m) = matrix.shape();
503    if n != m {
504        return Err(SparseError::ValueError(
505            "Matrix must be square for Cholesky decomposition".to_string(),
506        ));
507    }
508
509    // Convert to working format
510    let (row_indices, col_indices, values) = matrix.find();
511    let mut working_matrix = SparseWorkingMatrix::from_triplets(
512        row_indices.as_slice().unwrap(),
513        col_indices.as_slice().unwrap(),
514        values.as_slice().unwrap(),
515        n,
516    );
517
518    // Cholesky decomposition algorithm
519    for k in 0..n {
520        // Compute diagonal element
521        let mut sum = T::sparse_zero();
522        for j in 0..k {
523            let l_kj = working_matrix.get(k, j);
524            sum = sum + l_kj * l_kj;
525        }
526
527        let a_kk = working_matrix.get(k, k);
528        let diag_val = a_kk - sum;
529
530        if diag_val <= T::sparse_zero() {
531            return Ok(CholeskyResult {
532                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
533                success: false,
534            });
535        }
536
537        let l_kk = diag_val.sqrt();
538        working_matrix.set(k, k, l_kk);
539
540        // Compute below-diagonal elements
541        for i in (k + 1)..n {
542            let mut sum = T::sparse_zero();
543            for j in 0..k {
544                sum = sum + working_matrix.get(i, j) * working_matrix.get(k, j);
545            }
546
547            let a_ik = working_matrix.get(i, k);
548            let l_ik = (a_ik - sum) / l_kk;
549            working_matrix.set(i, k, l_ik);
550        }
551    }
552
553    // Extract lower triangular _matrix
554    let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
555    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
556
557    Ok(CholeskyResult { l, success: true })
558}
559
560/// Compute pivoted Cholesky decomposition
561///
562/// Computes the pivoted Cholesky decomposition of a symmetric matrix A = P^T * L * L^T * P,
563/// where P is a permutation matrix and L is lower triangular. This version can handle
564/// indefinite matrices by determining the rank and producing a partial decomposition.
565///
566/// # Arguments
567///
568/// * `matrix` - The symmetric sparse matrix
569/// * `threshold` - Pivoting threshold for numerical stability (default: 1e-12)
570///
571/// # Returns
572///
573/// Pivoted Cholesky decomposition result with rank determination
574///
575/// # Examples
576///
577/// ```
578/// use scirs2_sparse::linalg::pivoted_cholesky_decomposition;
579/// use scirs2_sparse::csr_array::CsrArray;
580///
581/// // Create a symmetric indefinite matrix
582/// let rows = vec![0, 1, 1, 2, 2, 2];
583/// let cols = vec![0, 0, 1, 0, 1, 2];  
584/// let data = vec![1.0, 2.0, -1.0, 3.0, 1.0, 2.0];
585/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
586///
587/// let chol_result = pivoted_cholesky_decomposition(&matrix, Some(1e-12)).unwrap();
588/// ```
589#[allow(dead_code)]
590pub fn pivoted_cholesky_decomposition<T, S>(
591    matrix: &S,
592    threshold: Option<T>,
593) -> SparseResult<PivotedCholeskyResult<T>>
594where
595    T: Float
596        + SparseElement
597        + Debug
598        + Copy
599        + Add<Output = T>
600        + Sub<Output = T>
601        + Mul<Output = T>
602        + Div<Output = T>,
603    S: SparseArray<T>,
604{
605    let (n, m) = matrix.shape();
606    if n != m {
607        return Err(SparseError::ValueError(
608            "Matrix must be square for Cholesky decomposition".to_string(),
609        ));
610    }
611
612    let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
613
614    // Convert to working format
615    let (row_indices, col_indices, values) = matrix.find();
616    let mut working_matrix = SparseWorkingMatrix::from_triplets(
617        row_indices.as_slice().unwrap(),
618        col_indices.as_slice().unwrap(),
619        values.as_slice().unwrap(),
620        n,
621    );
622
623    // Initialize permutation
624    let mut perm: Vec<usize> = (0..n).collect();
625    let mut rank = 0;
626
627    // Pivoted Cholesky algorithm
628    for k in 0..n {
629        // Find the pivot: largest diagonal element among remaining
630        let mut max_diag = T::sparse_zero();
631        let mut pivot_idx = k;
632
633        for i in k..n {
634            let mut diag_val = working_matrix.get(perm[i], perm[i]);
635            for j in 0..k {
636                let l_ij = working_matrix.get(perm[i], perm[j]);
637                diag_val = diag_val - l_ij * l_ij;
638            }
639            if diag_val > max_diag {
640                max_diag = diag_val;
641                pivot_idx = i;
642            }
643        }
644
645        // Check if we should stop (matrix is not positive definite beyond this point)
646        if max_diag <= threshold {
647            break;
648        }
649
650        // Swap rows/columns in permutation
651        if pivot_idx != k {
652            perm.swap(k, pivot_idx);
653        }
654
655        // Compute L[k,k]
656        let l_kk = max_diag.sqrt();
657        working_matrix.set(perm[k], perm[k], l_kk);
658        rank += 1;
659
660        // Update column k below diagonal
661        for i in (k + 1)..n {
662            let mut sum = T::sparse_zero();
663            for j in 0..k {
664                sum = sum
665                    + working_matrix.get(perm[i], perm[j]) * working_matrix.get(perm[k], perm[j]);
666            }
667
668            let a_ik = working_matrix.get(perm[i], perm[k]);
669            let l_ik = (a_ik - sum) / l_kk;
670            working_matrix.set(perm[i], perm[k], l_ik);
671        }
672    }
673
674    // Extract lower triangular matrix with proper permutation
675    let mut l_rows = Vec::new();
676    let mut l_cols = Vec::new();
677    let mut l_vals = Vec::new();
678
679    for i in 0..rank {
680        for j in 0..=i {
681            let val = working_matrix.get(perm[i], perm[j]);
682            if val != T::sparse_zero() {
683                l_rows.push(i);
684                l_cols.push(j);
685                l_vals.push(val);
686            }
687        }
688    }
689
690    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, rank), false)?;
691    let p = Array1::from_vec(perm);
692
693    Ok(PivotedCholeskyResult {
694        l,
695        p,
696        rank,
697        success: true,
698    })
699}
700
701/// LDLT decomposition result for symmetric indefinite matrices
702#[derive(Debug, Clone)]
703pub struct LDLTResult<T>
704where
705    T: Float + SparseElement + Debug + Copy + 'static,
706{
707    /// Lower triangular factor L (unit diagonal)
708    pub l: CsrArray<T>,
709    /// Diagonal factor D
710    pub d: Array1<T>,
711    /// Permutation matrix (as permutation vector)
712    pub p: Array1<usize>,
713    /// Whether decomposition was successful
714    pub success: bool,
715}
716
717/// Compute LDLT decomposition for symmetric indefinite matrices
718///
719/// Computes the LDLT decomposition of a symmetric matrix A = P^T * L * D * L^T * P,
720/// where P is a permutation matrix, L is unit lower triangular, and D is diagonal.
721/// This method can handle indefinite matrices unlike Cholesky decomposition.
722///
723/// # Arguments
724///
725/// * `matrix` - The symmetric sparse matrix
726/// * `pivoting` - Whether to use pivoting for numerical stability (default: true)
727/// * `threshold` - Pivoting threshold for numerical stability (default: 1e-12)
728///
729/// # Returns
730///
731/// LDLT decomposition result
732///
733/// # Examples
734///
735/// ```
736/// use scirs2_sparse::linalg::ldlt_decomposition;
737/// use scirs2_sparse::csr_array::CsrArray;
738///
739/// // Create a symmetric indefinite matrix
740/// let rows = vec![0, 1, 1, 2, 2, 2];
741/// let cols = vec![0, 0, 1, 0, 1, 2];  
742/// let data = vec![1.0, 2.0, -1.0, 3.0, 1.0, 2.0];
743/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
744///
745/// let ldlt_result = ldlt_decomposition(&matrix, Some(true), Some(1e-12)).unwrap();
746/// ```
747#[allow(dead_code)]
748pub fn ldlt_decomposition<T, S>(
749    matrix: &S,
750    pivoting: Option<bool>,
751    threshold: Option<T>,
752) -> SparseResult<LDLTResult<T>>
753where
754    T: Float
755        + SparseElement
756        + Debug
757        + Copy
758        + Add<Output = T>
759        + Sub<Output = T>
760        + Mul<Output = T>
761        + Div<Output = T>,
762    S: SparseArray<T>,
763{
764    let (n, m) = matrix.shape();
765    if n != m {
766        return Err(SparseError::ValueError(
767            "Matrix must be square for LDLT decomposition".to_string(),
768        ));
769    }
770
771    let use_pivoting = pivoting.unwrap_or(true);
772    let threshold = threshold.unwrap_or_else(|| T::from(1e-12).unwrap());
773
774    // Convert to working format
775    let (row_indices, col_indices, values) = matrix.find();
776    let mut working_matrix = SparseWorkingMatrix::from_triplets(
777        row_indices.as_slice().unwrap(),
778        col_indices.as_slice().unwrap(),
779        values.as_slice().unwrap(),
780        n,
781    );
782
783    // Initialize permutation
784    let mut perm: Vec<usize> = (0..n).collect();
785    let mut d_values = vec![T::sparse_zero(); n];
786
787    // LDLT decomposition with optional pivoting
788    for k in 0..n {
789        // Find pivot if pivoting is enabled
790        if use_pivoting {
791            let pivot_idx = find_ldlt_pivot(&working_matrix, k, &perm, threshold);
792            if pivot_idx != k {
793                perm.swap(k, pivot_idx);
794            }
795        }
796
797        let actual_k = perm[k];
798
799        // Compute diagonal element D[k,k]
800        let mut diag_val = working_matrix.get(actual_k, actual_k);
801        for j in 0..k {
802            let l_kj = working_matrix.get(actual_k, perm[j]);
803            diag_val = diag_val - l_kj * l_kj * d_values[j];
804        }
805
806        d_values[k] = diag_val;
807
808        // Check for numerical issues
809        if diag_val.abs() < threshold {
810            return Ok(LDLTResult {
811                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
812                d: Array1::from_vec(d_values),
813                p: Array1::from_vec(perm),
814                success: false,
815            });
816        }
817
818        // Compute column k of L below the diagonal
819        for i in (k + 1)..n {
820            let actual_i = perm[i];
821            let mut l_ik = working_matrix.get(actual_i, actual_k);
822
823            for j in 0..k {
824                l_ik = l_ik
825                    - working_matrix.get(actual_i, perm[j])
826                        * working_matrix.get(actual_k, perm[j])
827                        * d_values[j];
828            }
829
830            l_ik = l_ik / diag_val;
831            working_matrix.set(actual_i, actual_k, l_ik);
832        }
833
834        // Set diagonal element of L to 1
835        working_matrix.set(actual_k, actual_k, T::sparse_one());
836    }
837
838    // Extract L matrix (unit lower triangular)
839    let (l_rows, l_cols, l_vals) = extract_unit_lower_triangular(&working_matrix, &perm, n);
840    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
841
842    Ok(LDLTResult {
843        l,
844        d: Array1::from_vec(d_values),
845        p: Array1::from_vec(perm),
846        success: true,
847    })
848}
849
850/// Find pivot for LDLT decomposition using Bunch-Kaufman strategy
851#[allow(dead_code)]
852fn find_ldlt_pivot<T>(
853    matrix: &SparseWorkingMatrix<T>,
854    k: usize,
855    perm: &[usize],
856    threshold: T,
857) -> usize
858where
859    T: Float + SparseElement + Debug + Copy,
860{
861    let n = matrix.n;
862    let mut max_val = T::sparse_zero();
863    let mut pivot_idx = k;
864
865    // Look for largest diagonal element among remaining rows
866    for (i, &actual_i) in perm.iter().enumerate().take(n).skip(k) {
867        let diag_val = matrix.get(actual_i, actual_i).abs();
868
869        if diag_val > max_val {
870            max_val = diag_val;
871            pivot_idx = i;
872        }
873    }
874
875    // Check if pivot is acceptable
876    if max_val >= threshold {
877        pivot_idx
878    } else {
879        k // Use current position if no good pivot found
880    }
881}
882
883/// Extract unit lower triangular matrix from working matrix
884#[allow(dead_code)]
885fn extract_unit_lower_triangular<T>(
886    matrix: &SparseWorkingMatrix<T>,
887    perm: &[usize],
888    n: usize,
889) -> (Vec<usize>, Vec<usize>, Vec<T>)
890where
891    T: Float + SparseElement + Debug + Copy,
892{
893    let mut rows = Vec::new();
894    let mut cols = Vec::new();
895    let mut vals = Vec::new();
896
897    for i in 0..n {
898        let actual_i = perm[i];
899
900        // Add diagonal element (always 1 for unit triangular)
901        rows.push(i);
902        cols.push(i);
903        vals.push(T::sparse_one());
904
905        // Add below-diagonal elements
906        for (j, &perm_j) in perm.iter().enumerate().take(i) {
907            let val = matrix.get(actual_i, perm_j);
908            if val != T::sparse_zero() {
909                rows.push(i);
910                cols.push(j);
911                vals.push(val);
912            }
913        }
914    }
915
916    (rows, cols, vals)
917}
918
919/// Compute incomplete LU decomposition (ILU)
920///
921/// Computes an approximate LU decomposition with controlled fill-in
922/// for use as a preconditioner in iterative methods.
923///
924/// # Arguments
925///
926/// * `matrix` - The sparse matrix to decompose
927/// * `options` - ILU options controlling fill-in and dropping
928///
929/// # Returns
930///
931/// Incomplete LU decomposition result
932#[allow(dead_code)]
933pub fn incomplete_lu<T, S>(matrix: &S, options: Option<ILUOptions>) -> SparseResult<LUResult<T>>
934where
935    T: Float
936        + SparseElement
937        + Debug
938        + Copy
939        + Add<Output = T>
940        + Sub<Output = T>
941        + Mul<Output = T>
942        + Div<Output = T>,
943    S: SparseArray<T>,
944{
945    let opts = options.unwrap_or_default();
946    let (n, m) = matrix.shape();
947
948    if n != m {
949        return Err(SparseError::ValueError(
950            "Matrix must be square for ILU decomposition".to_string(),
951        ));
952    }
953
954    // Convert to working format
955    let (row_indices, col_indices, values) = matrix.find();
956    let mut working_matrix = SparseWorkingMatrix::from_triplets(
957        row_indices.as_slice().unwrap(),
958        col_indices.as_slice().unwrap(),
959        values.as_slice().unwrap(),
960        n,
961    );
962
963    // ILU(0) algorithm - no fill-in beyond original sparsity pattern
964    for k in 0..n - 1 {
965        let pivot_val = working_matrix.get(k, k);
966
967        if pivot_val.abs() < T::from(1e-14).unwrap() {
968            continue; // Skip singular pivot
969        }
970
971        // Get all non-zero entries in column k below diagonal
972        let col_k_entries = working_matrix.get_column_below_diagonal(k);
973
974        for &row_i in &col_k_entries {
975            let factor = working_matrix.get(row_i, k) / pivot_val;
976
977            // Drop small factors
978            if factor.abs() < T::from(opts.drop_tol).unwrap() {
979                working_matrix.set(row_i, k, T::sparse_zero());
980                continue;
981            }
982
983            working_matrix.set(row_i, k, factor);
984
985            // Update row i (only existing non-zeros)
986            let row_k_entries = working_matrix.get_row_after_column(k, k);
987            for (col_j, &val_kj) in &row_k_entries {
988                if working_matrix.has_entry(row_i, *col_j) {
989                    let old_val = working_matrix.get(row_i, *col_j);
990                    let new_val = old_val - factor * val_kj;
991
992                    // Drop small values
993                    if new_val.abs() < T::from(opts.drop_tol).unwrap() {
994                        working_matrix.set(row_i, *col_j, T::sparse_zero());
995                    } else {
996                        working_matrix.set(row_i, *col_j, new_val);
997                    }
998                }
999            }
1000        }
1001    }
1002
1003    // Extract L and U factors
1004    let identity_p: Vec<usize> = (0..n).collect();
1005    let (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals) =
1006        extract_lu_factors(&working_matrix, &identity_p, n);
1007
1008    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1009    let u = CsrArray::from_triplets(&u_rows, &u_cols, &u_vals, (n, n), false)?;
1010
1011    Ok(LUResult {
1012        l,
1013        u,
1014        p: Array1::from_vec(identity_p),
1015        success: true,
1016    })
1017}
1018
1019/// Compute incomplete Cholesky decomposition (IC)
1020///
1021/// Computes an approximate Cholesky decomposition with controlled fill-in
1022/// for use as a preconditioner in iterative methods.
1023///
1024/// # Arguments
1025///
1026/// * `matrix` - The symmetric positive definite sparse matrix
1027/// * `options` - IC options controlling fill-in and dropping
1028///
1029/// # Returns
1030///
1031/// Incomplete Cholesky decomposition result
1032#[allow(dead_code)]
1033pub fn incomplete_cholesky<T, S>(
1034    matrix: &S,
1035    options: Option<ICOptions>,
1036) -> SparseResult<CholeskyResult<T>>
1037where
1038    T: Float
1039        + SparseElement
1040        + Debug
1041        + Copy
1042        + Add<Output = T>
1043        + Sub<Output = T>
1044        + Mul<Output = T>
1045        + Div<Output = T>,
1046    S: SparseArray<T>,
1047{
1048    let opts = options.unwrap_or_default();
1049    let (n, m) = matrix.shape();
1050
1051    if n != m {
1052        return Err(SparseError::ValueError(
1053            "Matrix must be square for IC decomposition".to_string(),
1054        ));
1055    }
1056
1057    // Convert to working format
1058    let (row_indices, col_indices, values) = matrix.find();
1059    let mut working_matrix = SparseWorkingMatrix::from_triplets(
1060        row_indices.as_slice().unwrap(),
1061        col_indices.as_slice().unwrap(),
1062        values.as_slice().unwrap(),
1063        n,
1064    );
1065
1066    // IC(0) algorithm - no fill-in beyond original sparsity pattern
1067    for k in 0..n {
1068        // Compute diagonal element
1069        let mut sum = T::sparse_zero();
1070        let row_k_before_k = working_matrix.get_row_before_column(k, k);
1071        for &val_kj in row_k_before_k.values() {
1072            sum = sum + val_kj * val_kj;
1073        }
1074
1075        let a_kk = working_matrix.get(k, k);
1076        let diag_val = a_kk - sum;
1077
1078        if diag_val <= T::sparse_zero() {
1079            return Ok(CholeskyResult {
1080                l: CsrArray::from_triplets(&[], &[], &[], (n, n), false)?,
1081                success: false,
1082            });
1083        }
1084
1085        let l_kk = diag_val.sqrt();
1086        working_matrix.set(k, k, l_kk);
1087
1088        // Compute below-diagonal elements (only existing entries)
1089        let col_k_below = working_matrix.get_column_below_diagonal(k);
1090        for &row_i in &col_k_below {
1091            let mut sum = T::sparse_zero();
1092            let row_i_before_k = working_matrix.get_row_before_column(row_i, k);
1093            let row_k_before_k = working_matrix.get_row_before_column(k, k);
1094
1095            // Compute dot product of L[i, :k] and L[k, :k]
1096            for (col_j, &val_ij) in &row_i_before_k {
1097                if let Some(&val_kj) = row_k_before_k.get(col_j) {
1098                    sum = sum + val_ij * val_kj;
1099                }
1100            }
1101
1102            let a_ik = working_matrix.get(row_i, k);
1103            let l_ik = (a_ik - sum) / l_kk;
1104
1105            // Drop small values
1106            if l_ik.abs() < T::from(opts.drop_tol).unwrap() {
1107                working_matrix.set(row_i, k, T::sparse_zero());
1108            } else {
1109                working_matrix.set(row_i, k, l_ik);
1110            }
1111        }
1112    }
1113
1114    // Extract lower triangular matrix
1115    let (l_rows, l_cols, l_vals) = extract_lower_triangular(&working_matrix, n);
1116    let l = CsrArray::from_triplets(&l_rows, &l_cols, &l_vals, (n, n), false)?;
1117
1118    Ok(CholeskyResult { l, success: true })
1119}
1120
1121/// Simple sparse working matrix for decomposition algorithms
1122struct SparseWorkingMatrix<T>
1123where
1124    T: Float + SparseElement + Debug + Copy,
1125{
1126    data: HashMap<(usize, usize), T>,
1127    n: usize,
1128}
1129
1130impl<T> SparseWorkingMatrix<T>
1131where
1132    T: Float
1133        + SparseElement
1134        + Debug
1135        + Copy
1136        + Add<Output = T>
1137        + Sub<Output = T>
1138        + Mul<Output = T>
1139        + Div<Output = T>,
1140{
1141    fn from_triplets(rows: &[usize], cols: &[usize], values: &[T], n: usize) -> Self {
1142        let mut data = HashMap::new();
1143
1144        for (i, (&row, &col)) in rows.iter().zip(cols.iter()).enumerate() {
1145            data.insert((row, col), values[i]);
1146        }
1147
1148        Self { data, n }
1149    }
1150
1151    fn get(&self, row: usize, col: usize) -> T {
1152        self.data
1153            .get(&(row, col))
1154            .copied()
1155            .unwrap_or(T::sparse_zero())
1156    }
1157
1158    fn set(&mut self, row: usize, col: usize, value: T) {
1159        if SparseElement::is_zero(&value) {
1160            self.data.remove(&(row, col));
1161        } else {
1162            self.data.insert((row, col), value);
1163        }
1164    }
1165
1166    fn has_entry(&self, row: usize, col: usize) -> bool {
1167        self.data.contains_key(&(row, col))
1168    }
1169
1170    fn get_row(&self, row: usize) -> HashMap<usize, T> {
1171        let mut result = HashMap::new();
1172        for (&(r, c), &value) in &self.data {
1173            if r == row {
1174                result.insert(c, value);
1175            }
1176        }
1177        result
1178    }
1179
1180    fn get_row_after_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1181        let mut result = HashMap::new();
1182        for (&(r, c), &value) in &self.data {
1183            if r == row && c > col {
1184                result.insert(c, value);
1185            }
1186        }
1187        result
1188    }
1189
1190    fn get_row_before_column(&self, row: usize, col: usize) -> HashMap<usize, T> {
1191        let mut result = HashMap::new();
1192        for (&(r, c), &value) in &self.data {
1193            if r == row && c < col {
1194                result.insert(c, value);
1195            }
1196        }
1197        result
1198    }
1199
1200    fn get_column_below_diagonal(&self, col: usize) -> Vec<usize> {
1201        let mut result = Vec::new();
1202        for &(r, c) in self.data.keys() {
1203            if c == col && r > col {
1204                result.push(r);
1205            }
1206        }
1207        result.sort();
1208        result
1209    }
1210}
1211
1212/// Find pivot for LU decomposition (backward compatibility)
1213#[allow(dead_code)]
1214fn find_pivot<T>(
1215    matrix: &SparseWorkingMatrix<T>,
1216    k: usize,
1217    p: &[usize],
1218    threshold: f64,
1219) -> SparseResult<usize>
1220where
1221    T: Float + SparseElement + Debug + Copy,
1222{
1223    // Use threshold pivoting for backward compatibility
1224    let opts = LUOptions {
1225        pivoting: PivotingStrategy::Threshold(threshold),
1226        zero_threshold: 1e-14,
1227        check_singular: true,
1228    };
1229
1230    let row_scales = vec![T::sparse_one(); matrix.n];
1231    let col_perm: Vec<usize> = (0..matrix.n).collect();
1232
1233    let (pivot_row, pivot_col) = find_enhanced_pivot(matrix, k, p, &col_perm, &row_scales, &opts)?;
1234    Ok(pivot_row)
1235}
1236
1237/// Enhanced pivoting function supporting multiple strategies
1238#[allow(dead_code)]
1239fn find_enhanced_pivot<T>(
1240    matrix: &SparseWorkingMatrix<T>,
1241    k: usize,
1242    row_perm: &[usize],
1243    col_perm: &[usize],
1244    row_scales: &[T],
1245    opts: &LUOptions,
1246) -> SparseResult<(usize, usize)>
1247where
1248    T: Float + SparseElement + Debug + Copy,
1249{
1250    let n = matrix.n;
1251
1252    match &opts.pivoting {
1253        PivotingStrategy::None => {
1254            // No pivoting - use diagonal element
1255            Ok((k, k))
1256        }
1257
1258        PivotingStrategy::Partial => {
1259            // Standard partial pivoting - find largest element in column k
1260            let mut max_val = T::sparse_zero();
1261            let mut pivot_row = k;
1262
1263            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1264                let i = k + idx;
1265                let val = matrix.get(actual_row, col_perm[k]).abs();
1266                if val > max_val {
1267                    max_val = val;
1268                    pivot_row = i;
1269                }
1270            }
1271
1272            Ok((pivot_row, k))
1273        }
1274
1275        PivotingStrategy::Threshold(threshold) => {
1276            // Threshold pivoting - use first element above threshold
1277            let threshold_val = T::from(*threshold).unwrap();
1278            let mut max_val = T::sparse_zero();
1279            let mut pivot_row = k;
1280
1281            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1282                let i = k + idx;
1283                let val = matrix.get(actual_row, col_perm[k]).abs();
1284                if val > max_val {
1285                    max_val = val;
1286                    pivot_row = i;
1287                }
1288                // Use first element above threshold for efficiency
1289                if val >= threshold_val {
1290                    pivot_row = i;
1291                    break;
1292                }
1293            }
1294
1295            Ok((pivot_row, k))
1296        }
1297
1298        PivotingStrategy::ScaledPartial => {
1299            // Scaled partial pivoting - account for row scaling
1300            let mut max_ratio = T::sparse_zero();
1301            let mut pivot_row = k;
1302
1303            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1304                let i = k + idx;
1305                let val = matrix.get(actual_row, col_perm[k]).abs();
1306                let scale = row_scales[actual_row];
1307
1308                let ratio = if scale > T::sparse_zero() {
1309                    val / scale
1310                } else {
1311                    val
1312                };
1313
1314                if ratio > max_ratio {
1315                    max_ratio = ratio;
1316                    pivot_row = i;
1317                }
1318            }
1319
1320            Ok((pivot_row, k))
1321        }
1322
1323        PivotingStrategy::Complete => {
1324            // Complete pivoting - find largest element in remaining submatrix
1325            let mut max_val = T::sparse_zero();
1326            let mut pivot_row = k;
1327            let mut pivot_col = k;
1328
1329            for (i_idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1330                let i = k + i_idx;
1331                for (j_idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1332                    let j = k + j_idx;
1333                    let val = matrix.get(actual_row, actual_col).abs();
1334                    if val > max_val {
1335                        max_val = val;
1336                        pivot_row = i;
1337                        pivot_col = j;
1338                    }
1339                }
1340            }
1341
1342            Ok((pivot_row, pivot_col))
1343        }
1344
1345        PivotingStrategy::Rook => {
1346            // Rook pivoting - alternating row and column searches
1347            let mut best_row = k;
1348            let mut best_col = k;
1349            let mut max_val = T::sparse_zero();
1350
1351            // Start with partial pivoting in column k
1352            for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1353                let i = k + idx;
1354                let val = matrix.get(actual_row, col_perm[k]).abs();
1355                if val > max_val {
1356                    max_val = val;
1357                    best_row = i;
1358                }
1359            }
1360
1361            // If we found a good pivot, check if we can improve by column pivoting
1362            if max_val > T::from(opts.zero_threshold).unwrap() {
1363                let actual_best_row = row_perm[best_row];
1364                let mut col_max = T::sparse_zero();
1365
1366                for (idx, &actual_col) in col_perm.iter().enumerate().skip(k).take(n - k) {
1367                    let j = k + idx;
1368                    let val = matrix.get(actual_best_row, actual_col).abs();
1369                    if val > col_max {
1370                        col_max = val;
1371                        best_col = j;
1372                    }
1373                }
1374
1375                // Use column pivot if it's significantly better
1376                let improvement_threshold = T::from(1.5).unwrap();
1377                if col_max > max_val * improvement_threshold {
1378                    // Recompute row pivot for the new column
1379                    max_val = T::sparse_zero();
1380                    for (idx, &actual_row) in row_perm.iter().enumerate().skip(k).take(n - k) {
1381                        let i = k + idx;
1382                        let val = matrix.get(actual_row, col_perm[best_col]).abs();
1383                        if val > max_val {
1384                            max_val = val;
1385                            best_row = i;
1386                        }
1387                    }
1388                }
1389            }
1390
1391            Ok((best_row, best_col))
1392        }
1393    }
1394}
1395
1396/// Extract L and U factors from working matrix
1397type LuFactors<T> = (
1398    Vec<usize>, // L row pointers
1399    Vec<usize>, // L column indices
1400    Vec<T>,     // L values
1401    Vec<usize>, // U row pointers
1402    Vec<usize>, // U column indices
1403    Vec<T>,     // U values
1404);
1405
1406#[allow(dead_code)]
1407fn extract_lu_factors<T>(matrix: &SparseWorkingMatrix<T>, p: &[usize], n: usize) -> LuFactors<T>
1408where
1409    T: Float + SparseElement + Debug + Copy,
1410{
1411    let mut l_rows = Vec::new();
1412    let mut l_cols = Vec::new();
1413    let mut l_vals = Vec::new();
1414    let mut u_rows = Vec::new();
1415    let mut u_cols = Vec::new();
1416    let mut u_vals = Vec::new();
1417
1418    #[allow(clippy::needless_range_loop)]
1419    for i in 0..n {
1420        let actual_row = p[i];
1421
1422        // Add diagonal 1 to L
1423        l_rows.push(i);
1424        l_cols.push(i);
1425        l_vals.push(T::sparse_one());
1426
1427        for j in 0..n {
1428            let val = matrix.get(actual_row, j);
1429            if !SparseElement::is_zero(&val) {
1430                if j < i {
1431                    // Below diagonal - goes to L
1432                    l_rows.push(i);
1433                    l_cols.push(j);
1434                    l_vals.push(val);
1435                } else {
1436                    // On or above diagonal - goes to U
1437                    u_rows.push(i);
1438                    u_cols.push(j);
1439                    u_vals.push(val);
1440                }
1441            }
1442        }
1443    }
1444
1445    (l_rows, l_cols, l_vals, u_rows, u_cols, u_vals)
1446}
1447
1448/// Extract lower triangular matrix
1449#[allow(dead_code)]
1450fn extract_lower_triangular<T>(
1451    matrix: &SparseWorkingMatrix<T>,
1452    n: usize,
1453) -> (Vec<usize>, Vec<usize>, Vec<T>)
1454where
1455    T: Float + SparseElement + Debug + Copy,
1456{
1457    let mut rows = Vec::new();
1458    let mut cols = Vec::new();
1459    let mut vals = Vec::new();
1460
1461    for i in 0..n {
1462        for j in 0..=i {
1463            let val = matrix.get(i, j);
1464            if !SparseElement::is_zero(&val) {
1465                rows.push(i);
1466                cols.push(j);
1467                vals.push(val);
1468            }
1469        }
1470    }
1471
1472    (rows, cols, vals)
1473}
1474
1475/// Convert dense matrix to sparse
1476#[allow(dead_code)]
1477fn dense_to_sparse<T>(matrix: &Array2<T>) -> SparseResult<CsrArray<T>>
1478where
1479    T: Float + SparseElement + Debug + Copy,
1480{
1481    let (m, n) = matrix.dim();
1482    let mut rows = Vec::new();
1483    let mut cols = Vec::new();
1484    let mut vals = Vec::new();
1485
1486    for i in 0..m {
1487        for j in 0..n {
1488            let val = matrix[[i, j]];
1489            if !SparseElement::is_zero(&val) {
1490                rows.push(i);
1491                cols.push(j);
1492                vals.push(val);
1493            }
1494        }
1495    }
1496
1497    CsrArray::from_triplets(&rows, &cols, &vals, (m, n), false)
1498}
1499
1500#[cfg(test)]
1501mod tests {
1502    use super::*;
1503    use crate::csr_array::CsrArray;
1504
1505    fn create_test_matrix() -> CsrArray<f64> {
1506        // Create a simple test matrix
1507        let rows = vec![0, 0, 1, 1, 2, 2];
1508        let cols = vec![0, 1, 0, 1, 1, 2];
1509        let data = vec![2.0, 1.0, 1.0, 3.0, 2.0, 4.0];
1510
1511        CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1512    }
1513
1514    fn create_spd_matrix() -> CsrArray<f64> {
1515        // Create a symmetric positive definite matrix
1516        let rows = vec![0, 1, 1, 2, 2, 2];
1517        let cols = vec![0, 0, 1, 0, 1, 2];
1518        let data = vec![4.0, 2.0, 5.0, 1.0, 3.0, 6.0];
1519
1520        CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap()
1521    }
1522
1523    #[test]
1524    fn test_lu_decomposition() {
1525        let matrix = create_test_matrix();
1526        let lu_result = lu_decomposition(&matrix, 0.1).unwrap();
1527
1528        assert!(lu_result.success);
1529        assert_eq!(lu_result.l.shape(), (3, 3));
1530        assert_eq!(lu_result.u.shape(), (3, 3));
1531        assert_eq!(lu_result.p.len(), 3);
1532    }
1533
1534    #[test]
1535    fn test_qr_decomposition() {
1536        let rows = vec![0, 1, 2];
1537        let cols = vec![0, 0, 1];
1538        let data = vec![1.0, 2.0, 3.0];
1539        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
1540
1541        let qr_result = qr_decomposition(&matrix).unwrap();
1542
1543        assert!(qr_result.success);
1544        assert_eq!(qr_result.q.shape(), (3, 2));
1545        assert_eq!(qr_result.r.shape(), (2, 2));
1546    }
1547
1548    #[test]
1549    fn test_cholesky_decomposition() {
1550        let matrix = create_spd_matrix();
1551        let chol_result = cholesky_decomposition(&matrix).unwrap();
1552
1553        assert!(chol_result.success);
1554        assert_eq!(chol_result.l.shape(), (3, 3));
1555    }
1556
1557    #[test]
1558    fn test_incomplete_lu() {
1559        let matrix = create_test_matrix();
1560        let options = ILUOptions {
1561            drop_tol: 1e-6,
1562            ..Default::default()
1563        };
1564
1565        let ilu_result = incomplete_lu(&matrix, Some(options)).unwrap();
1566
1567        assert!(ilu_result.success);
1568        assert_eq!(ilu_result.l.shape(), (3, 3));
1569        assert_eq!(ilu_result.u.shape(), (3, 3));
1570    }
1571
1572    #[test]
1573    fn test_incomplete_cholesky() {
1574        let matrix = create_spd_matrix();
1575        let options = ICOptions {
1576            drop_tol: 1e-6,
1577            ..Default::default()
1578        };
1579
1580        let ic_result = incomplete_cholesky(&matrix, Some(options)).unwrap();
1581
1582        assert!(ic_result.success);
1583        assert_eq!(ic_result.l.shape(), (3, 3));
1584    }
1585
1586    #[test]
1587    fn test_sparse_working_matrix() {
1588        let rows = vec![0, 1, 2];
1589        let cols = vec![0, 1, 2];
1590        let vals = vec![1.0, 2.0, 3.0];
1591
1592        let mut matrix = SparseWorkingMatrix::from_triplets(&rows, &cols, &vals, 3);
1593
1594        assert_eq!(matrix.get(0, 0), 1.0);
1595        assert_eq!(matrix.get(1, 1), 2.0);
1596        assert_eq!(matrix.get(2, 2), 3.0);
1597        assert_eq!(matrix.get(0, 1), 0.0);
1598
1599        matrix.set(0, 1, 5.0);
1600        assert_eq!(matrix.get(0, 1), 5.0);
1601
1602        matrix.set(0, 1, 0.0);
1603        assert_eq!(matrix.get(0, 1), 0.0);
1604        assert!(!matrix.has_entry(0, 1));
1605    }
1606
1607    #[test]
1608    fn test_dense_to_sparse_conversion() {
1609        let dense = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 2.0, 3.0]).unwrap();
1610        let sparse = dense_to_sparse(&dense).unwrap();
1611
1612        assert_eq!(sparse.nnz(), 3);
1613        assert_eq!(sparse.get(0, 0), 1.0);
1614        assert_eq!(sparse.get(0, 1), 0.0);
1615        assert_eq!(sparse.get(1, 0), 2.0);
1616        assert_eq!(sparse.get(1, 1), 3.0);
1617    }
1618}