Skip to main content

scirs2_sparse/
sparse_utils.rs

1//! Sparse matrix utility operations
2//!
3//! This module provides fundamental sparse matrix operations including:
4//!
5//! - **Norms**: 1-norm, infinity-norm, Frobenius norm
6//! - **SpGEMM**: Sparse matrix-matrix multiplication
7//! - **Arithmetic**: Sparse addition, subtraction, scaling
8//! - **Kronecker product**: Sparse Kronecker (tensor) product
9//! - **Reordering**: Reverse Cuthill-McKee bandwidth reduction
10//! - **Condition number estimate**: Cheap 1-norm-based condition estimate
11
12use crate::csr::CsrMatrix;
13use crate::error::{SparseError, SparseResult};
14use scirs2_core::numeric::{Float, NumAssign, SparseElement};
15use std::collections::VecDeque;
16use std::fmt::Debug;
17use std::iter::Sum;
18
19// ---------------------------------------------------------------------------
20// Sparse matrix norms
21// ---------------------------------------------------------------------------
22
23/// Type of matrix norm to compute.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SparseNorm {
26    /// 1-norm: maximum absolute column sum.
27    One,
28    /// Infinity norm: maximum absolute row sum.
29    Inf,
30    /// Frobenius norm: sqrt of sum of squared elements.
31    Frobenius,
32}
33
34/// Compute a matrix norm of a sparse CSR matrix.
35///
36/// - `One`: max over columns of sum of absolute values (||A||_1)
37/// - `Inf`: max over rows of sum of absolute values (||A||_inf)
38/// - `Frobenius`: sqrt(sum(|a_ij|^2))
39pub fn sparse_matrix_norm<F>(a: &CsrMatrix<F>, norm_type: SparseNorm) -> SparseResult<F>
40where
41    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
42{
43    let (m, n_cols) = a.shape();
44    match norm_type {
45        SparseNorm::Inf => {
46            let mut max_row_sum = F::sparse_zero();
47            for i in 0..m {
48                let range = a.row_range(i);
49                let vals = &a.data[range];
50                let row_sum: F = vals.iter().map(|v| v.abs()).sum();
51                if row_sum > max_row_sum {
52                    max_row_sum = row_sum;
53                }
54            }
55            Ok(max_row_sum)
56        }
57        SparseNorm::One => {
58            let mut col_sums = vec![F::sparse_zero(); n_cols];
59            for i in 0..m {
60                let range = a.row_range(i);
61                let indices = &a.indices[range.clone()];
62                let vals = &a.data[range];
63                for (idx, &col) in indices.iter().enumerate() {
64                    col_sums[col] += vals[idx].abs();
65                }
66            }
67            let max_col =
68                col_sums
69                    .iter()
70                    .copied()
71                    .fold(F::sparse_zero(), |acc, x| if x > acc { x } else { acc });
72            Ok(max_col)
73        }
74        SparseNorm::Frobenius => {
75            let mut sum_sq = F::sparse_zero();
76            for val in &a.data {
77                sum_sq += *val * *val;
78            }
79            Ok(sum_sq.sqrt())
80        }
81    }
82}
83
84// ---------------------------------------------------------------------------
85// Sparse matrix-matrix multiplication (SpGEMM)
86// ---------------------------------------------------------------------------
87
88/// Sparse matrix-matrix multiplication: C = A * B.
89///
90/// Both A and B are in CSR format. The result is also in CSR format.
91/// Uses a symbolic + numeric two-phase approach for efficiency.
92pub fn spgemm<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
93where
94    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
95{
96    let (m, ka) = a.shape();
97    let (kb, n) = b.shape();
98    if ka != kb {
99        return Err(SparseError::ShapeMismatch {
100            expected: (m, ka),
101            found: (kb, n),
102        });
103    }
104
105    // Dense accumulator approach (scatter-gather): for each row of C,
106    // scatter contributions into a dense workspace, then gather non-zeros.
107    let mut values = vec![F::sparse_zero(); n];
108    let mut active = vec![false; n];
109    let mut rows_out = Vec::new();
110    let mut cols_out = Vec::new();
111    let mut data_out = Vec::new();
112
113    for i in 0..m {
114        let a_range = a.row_range(i);
115        let a_cols = &a.indices[a_range.clone()];
116        let a_vals = &a.data[a_range];
117
118        // Scatter: accumulate row i of C = sum_k a_ik * b_row_k
119        let mut col_list: Vec<usize> = Vec::new();
120        for (a_idx, &k_col) in a_cols.iter().enumerate() {
121            let a_ik = a_vals[a_idx];
122            let b_range = b.row_range(k_col);
123            let b_cols = &b.indices[b_range.clone()];
124            let b_vals = &b.data[b_range];
125
126            for (b_idx, &j) in b_cols.iter().enumerate() {
127                values[j] += a_ik * b_vals[b_idx];
128                if !active[j] {
129                    active[j] = true;
130                    col_list.push(j);
131                }
132            }
133        }
134
135        // Gather: collect non-zero entries in sorted column order
136        col_list.sort_unstable();
137        for &j in &col_list {
138            let val = values[j];
139            if val.abs() > F::epsilon() * F::from(0.01).unwrap_or(F::sparse_zero()) {
140                rows_out.push(i);
141                cols_out.push(j);
142                data_out.push(val);
143            }
144            // Reset workspace
145            values[j] = F::sparse_zero();
146            active[j] = false;
147        }
148    }
149
150    CsrMatrix::new(data_out, rows_out, cols_out, (m, n))
151}
152
153// ---------------------------------------------------------------------------
154// Sparse matrix addition / subtraction
155// ---------------------------------------------------------------------------
156
157/// Sparse matrix addition: C = alpha * A + beta * B.
158///
159/// Both A and B must have the same shape. The result is in CSR format.
160pub fn sparse_add<F>(
161    a: &CsrMatrix<F>,
162    b: &CsrMatrix<F>,
163    alpha: F,
164    beta: F,
165) -> SparseResult<CsrMatrix<F>>
166where
167    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
168{
169    let (ma, na) = a.shape();
170    let (mb, nb) = b.shape();
171    if ma != mb || na != nb {
172        return Err(SparseError::ShapeMismatch {
173            expected: (ma, na),
174            found: (mb, nb),
175        });
176    }
177
178    let mut rows_out = Vec::new();
179    let mut cols_out = Vec::new();
180    let mut data_out = Vec::new();
181
182    let mut b_vals = vec![F::sparse_zero(); na]; // workspace for one row of B
183    let mut b_flags = vec![false; na];
184
185    for i in 0..ma {
186        // Load row i of B into workspace
187        let b_range = b.row_range(i);
188        let b_cols = &b.indices[b_range.clone()];
189        let b_data = &b.data[b_range];
190        for (idx, &col) in b_cols.iter().enumerate() {
191            b_vals[col] = b_data[idx];
192            b_flags[col] = true;
193        }
194
195        // Process row i of A
196        let a_range = a.row_range(i);
197        let a_cols = &a.indices[a_range.clone()];
198        let a_data = &a.data[a_range];
199        let mut used_cols: Vec<usize> = Vec::new();
200
201        for (idx, &col) in a_cols.iter().enumerate() {
202            let val = alpha * a_data[idx]
203                + if b_flags[col] {
204                    beta * b_vals[col]
205                } else {
206                    F::sparse_zero()
207                };
208            if val.abs() > F::epsilon() {
209                rows_out.push(i);
210                cols_out.push(col);
211                data_out.push(val);
212            }
213            if b_flags[col] {
214                b_flags[col] = false;
215                b_vals[col] = F::sparse_zero();
216            }
217            used_cols.push(col);
218        }
219
220        // Remaining entries from B not in A
221        for (idx, &col) in b_cols.iter().enumerate() {
222            if b_flags[col] {
223                let val = beta * b_data[idx];
224                if val.abs() > F::epsilon() {
225                    rows_out.push(i);
226                    cols_out.push(col);
227                    data_out.push(val);
228                }
229                b_flags[col] = false;
230                b_vals[col] = F::sparse_zero();
231            }
232        }
233    }
234
235    CsrMatrix::new(data_out, rows_out, cols_out, (ma, na))
236}
237
238/// Sparse matrix subtraction: C = A - B.
239pub fn sparse_sub<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
240where
241    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
242{
243    sparse_add(a, b, F::sparse_one(), -F::sparse_one())
244}
245
246/// Scale a sparse matrix: C = alpha * A.
247pub fn sparse_scale<F>(a: &CsrMatrix<F>, alpha: F) -> SparseResult<CsrMatrix<F>>
248where
249    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
250{
251    let (m, n) = a.shape();
252    let (rows_in, cols_in, data_in) = a.get_triplets();
253    let data_out: Vec<F> = data_in.iter().map(|&v| alpha * v).collect();
254    CsrMatrix::new(data_out, rows_in, cols_in, (m, n))
255}
256
257// ---------------------------------------------------------------------------
258// Sparse Kronecker product
259// ---------------------------------------------------------------------------
260
261/// Compute the Kronecker product C = A kron B.
262///
263/// If A is (m x n) and B is (p x q), then C is (mp x nq).
264pub fn sparse_kronecker<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
265where
266    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
267{
268    let (ma, na) = a.shape();
269    let (mb, nb) = b.shape();
270    let out_rows = ma * mb;
271    let out_cols = na * nb;
272
273    let mut rows_out = Vec::new();
274    let mut cols_out = Vec::new();
275    let mut data_out = Vec::new();
276
277    for ia in 0..ma {
278        let a_range = a.row_range(ia);
279        let a_cols = &a.indices[a_range.clone()];
280        let a_vals = &a.data[a_range];
281
282        for ib in 0..mb {
283            let b_range = b.row_range(ib);
284            let b_cols = &b.indices[b_range.clone()];
285            let b_vals = &b.data[b_range];
286
287            let out_row = ia * mb + ib;
288
289            for (a_idx, &ja) in a_cols.iter().enumerate() {
290                for (b_idx, &jb) in b_cols.iter().enumerate() {
291                    let out_col = ja * nb + jb;
292                    let val = a_vals[a_idx] * b_vals[b_idx];
293                    if val.abs() > F::epsilon() {
294                        rows_out.push(out_row);
295                        cols_out.push(out_col);
296                        data_out.push(val);
297                    }
298                }
299            }
300        }
301    }
302
303    CsrMatrix::new(data_out, rows_out, cols_out, (out_rows, out_cols))
304}
305
306// ---------------------------------------------------------------------------
307// Reverse Cuthill-McKee (RCM) reordering
308// ---------------------------------------------------------------------------
309
310/// Result of the Reverse Cuthill-McKee algorithm.
311#[derive(Debug, Clone)]
312pub struct RcmResult {
313    /// The permutation vector: `new_index[i]` = old_index.
314    pub permutation: Vec<usize>,
315    /// The inverse permutation: `old_index[i]` = new_index.
316    pub inverse_permutation: Vec<usize>,
317    /// Original bandwidth of the matrix.
318    pub original_bandwidth: usize,
319    /// Bandwidth after reordering.
320    pub new_bandwidth: usize,
321}
322
323/// Compute the bandwidth of a sparse matrix (max |i - j| for non-zero a_ij).
324fn compute_bandwidth<F>(a: &CsrMatrix<F>) -> usize
325where
326    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
327{
328    let m = a.rows();
329    let mut bw = 0usize;
330    for i in 0..m {
331        let range = a.row_range(i);
332        for &col in &a.indices[range] {
333            let diff = i.abs_diff(col);
334            if diff > bw {
335                bw = diff;
336            }
337        }
338    }
339    bw
340}
341
342/// Compute the degree of node `i` in the adjacency graph (number of off-diagonal entries in row i).
343fn node_degree<F>(a: &CsrMatrix<F>, i: usize) -> usize
344where
345    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
346{
347    let range = a.row_range(i);
348    a.indices[range].iter().filter(|&&col| col != i).count()
349}
350
351/// Find a pseudo-peripheral node (good starting node for RCM).
352fn find_pseudo_peripheral<F>(a: &CsrMatrix<F>) -> usize
353where
354    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
355{
356    let n = a.rows();
357    if n == 0 {
358        return 0;
359    }
360
361    // Start from the node with minimum degree
362    let mut start = 0;
363    let mut min_deg = usize::MAX;
364    for i in 0..n {
365        let deg = node_degree(a, i);
366        if deg < min_deg {
367            min_deg = deg;
368            start = i;
369        }
370    }
371
372    // BFS to find a peripheral node
373    for _ in 0..5 {
374        let levels = bfs_levels(a, start);
375        let max_level = levels.iter().copied().max().unwrap_or(0);
376        if max_level == 0 {
377            break;
378        }
379        // Among the nodes at the last level, pick the one with minimum degree
380        let mut best = start;
381        let mut best_deg = usize::MAX;
382        for i in 0..n {
383            if levels[i] == max_level {
384                let deg = node_degree(a, i);
385                if deg < best_deg {
386                    best_deg = deg;
387                    best = i;
388                }
389            }
390        }
391        if best == start {
392            break;
393        }
394        start = best;
395    }
396
397    start
398}
399
400/// BFS from `start`, returning level numbers for each node.
401fn bfs_levels<F>(a: &CsrMatrix<F>, start: usize) -> Vec<usize>
402where
403    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
404{
405    let n = a.rows();
406    let mut levels = vec![usize::MAX; n];
407    let mut queue = VecDeque::new();
408    levels[start] = 0;
409    queue.push_back(start);
410
411    while let Some(node) = queue.pop_front() {
412        let range = a.row_range(node);
413        for &neighbor in &a.indices[range] {
414            if levels[neighbor] == usize::MAX {
415                levels[neighbor] = levels[node] + 1;
416                queue.push_back(neighbor);
417            }
418        }
419    }
420
421    levels
422}
423
424/// Compute the Reverse Cuthill-McKee permutation of a sparse matrix.
425///
426/// The RCM algorithm reduces the bandwidth of a sparse matrix by
427/// reordering its rows and columns. This can significantly improve
428/// the performance of direct solvers and incomplete factorizations.
429///
430/// # Arguments
431///
432/// * `a` - Square sparse matrix in CSR format
433///
434/// # Returns
435///
436/// An `RcmResult` containing the permutation and bandwidth information.
437pub fn reverse_cuthill_mckee<F>(a: &CsrMatrix<F>) -> SparseResult<RcmResult>
438where
439    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
440{
441    let (m, n) = a.shape();
442    if m != n {
443        return Err(SparseError::ValueError(
444            "RCM requires a square matrix".to_string(),
445        ));
446    }
447
448    let original_bandwidth = compute_bandwidth(a);
449
450    if m == 0 {
451        return Ok(RcmResult {
452            permutation: Vec::new(),
453            inverse_permutation: Vec::new(),
454            original_bandwidth: 0,
455            new_bandwidth: 0,
456        });
457    }
458
459    // Cuthill-McKee ordering
460    let mut visited = vec![false; m];
461    let mut cm_order = Vec::with_capacity(m);
462
463    // Handle potentially disconnected graphs
464    while cm_order.len() < m {
465        // Find starting node for next component
466        let start = if cm_order.is_empty() {
467            find_pseudo_peripheral(a)
468        } else {
469            // Find first unvisited node
470            let mut s = 0;
471            for i in 0..m {
472                if !visited[i] {
473                    s = i;
474                    break;
475                }
476            }
477            s
478        };
479
480        if visited[start] {
481            break;
482        }
483
484        visited[start] = true;
485        cm_order.push(start);
486        let mut queue_start = cm_order.len() - 1;
487
488        while queue_start < cm_order.len() {
489            let node = cm_order[queue_start];
490            queue_start += 1;
491
492            // Get neighbors sorted by degree (ascending)
493            let range = a.row_range(node);
494            let mut neighbors: Vec<usize> = a.indices[range]
495                .iter()
496                .copied()
497                .filter(|&nb| !visited[nb])
498                .collect();
499            neighbors.sort_by_key(|&nb| node_degree(a, nb));
500
501            for nb in neighbors {
502                if !visited[nb] {
503                    visited[nb] = true;
504                    cm_order.push(nb);
505                }
506            }
507        }
508    }
509
510    // Reverse the ordering for RCM
511    cm_order.reverse();
512
513    // Compute inverse permutation
514    let mut inv_perm = vec![0usize; m];
515    for (new_idx, &old_idx) in cm_order.iter().enumerate() {
516        inv_perm[old_idx] = new_idx;
517    }
518
519    // Compute new bandwidth
520    let mut new_bw = 0usize;
521    for i in 0..m {
522        let range = a.row_range(i);
523        let new_i = inv_perm[i];
524        for &col in &a.indices[range] {
525            let new_j = inv_perm[col];
526            let diff = new_i.abs_diff(new_j);
527            if diff > new_bw {
528                new_bw = diff;
529            }
530        }
531    }
532
533    Ok(RcmResult {
534        permutation: cm_order,
535        inverse_permutation: inv_perm,
536        original_bandwidth,
537        new_bandwidth: new_bw,
538    })
539}
540
541/// Apply a permutation to a sparse matrix: P * A * P^T.
542///
543/// This reorders both rows and columns using the given permutation vector
544/// where `perm[new_i] = old_i`.
545pub fn permute_matrix<F>(
546    a: &CsrMatrix<F>,
547    perm: &[usize],
548    inv_perm: &[usize],
549) -> SparseResult<CsrMatrix<F>>
550where
551    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
552{
553    let (m, n) = a.shape();
554    if perm.len() != m || inv_perm.len() != n {
555        return Err(SparseError::ValueError(
556            "Permutation size mismatch".to_string(),
557        ));
558    }
559
560    let mut rows_out = Vec::new();
561    let mut cols_out = Vec::new();
562    let mut data_out = Vec::new();
563
564    for new_i in 0..m {
565        let old_i = perm[new_i];
566        let range = a.row_range(old_i);
567        let old_cols = &a.indices[range.clone()];
568        let vals = &a.data[range];
569
570        for (idx, &old_j) in old_cols.iter().enumerate() {
571            let new_j = inv_perm[old_j];
572            rows_out.push(new_i);
573            cols_out.push(new_j);
574            data_out.push(vals[idx]);
575        }
576    }
577
578    CsrMatrix::new(data_out, rows_out, cols_out, (m, n))
579}
580
581// ---------------------------------------------------------------------------
582// Condition number estimate
583// ---------------------------------------------------------------------------
584
585/// Estimate the 1-norm condition number of a sparse matrix.
586///
587/// Uses Hager's algorithm (1-norm estimation) combined with a simple
588/// triangular solve estimate. This is much cheaper than computing the
589/// actual condition number via SVD.
590///
591/// Returns `None` if the matrix appears singular.
592pub fn condest_1norm<F>(a: &CsrMatrix<F>) -> SparseResult<Option<F>>
593where
594    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
595{
596    let (m, n) = a.shape();
597    if m != n || m == 0 {
598        return Err(SparseError::ValueError(
599            "condest requires a non-empty square matrix".to_string(),
600        ));
601    }
602
603    let a_norm = sparse_matrix_norm(a, SparseNorm::One)?;
604    if a_norm < F::epsilon() {
605        return Ok(None); // Essentially zero matrix
606    }
607
608    // Estimate ||A^{-1}||_1 using Hager's algorithm
609    // Start with x = (1/n, 1/n, ..., 1/n)
610    let inv_n = F::sparse_one()
611        / F::from(n as f64)
612            .ok_or_else(|| SparseError::ValueError("Failed to convert n".to_string()))?;
613
614    let mut x = vec![inv_n; n];
615    let max_iter = 5;
616    let mut gamma = F::sparse_zero();
617
618    for _ in 0..max_iter {
619        // y = A^{-1} x  (approximate via iterative refinement)
620        let y = approximate_solve(a, &x)?;
621
622        // gamma = ||y||_1
623        let new_gamma: F = y.iter().map(|v| v.abs()).sum();
624        if new_gamma <= gamma {
625            break;
626        }
627        gamma = new_gamma;
628
629        // z = A^{-T} sign(y)
630        let sign_y: Vec<F> = y
631            .iter()
632            .map(|&v| {
633                if v >= F::sparse_zero() {
634                    F::sparse_one()
635                } else {
636                    -F::sparse_one()
637                }
638            })
639            .collect();
640
641        let at = a.transpose();
642        let z = approximate_solve(&at, &sign_y)?;
643
644        // Find the index of maximum |z_j|
645        let mut max_abs = F::sparse_zero();
646        let mut max_idx = 0;
647        for (j, &zj) in z.iter().enumerate() {
648            if zj.abs() > max_abs {
649                max_abs = zj.abs();
650                max_idx = j;
651            }
652        }
653
654        // Check if we can improve: ||z||_inf <= z^T x
655        let ztx: F = z.iter().zip(x.iter()).map(|(&zi, &xi)| zi * xi).sum();
656        if max_abs <= ztx {
657            break;
658        }
659
660        // x = e_{max_idx}
661        for xi in x.iter_mut() {
662            *xi = F::sparse_zero();
663        }
664        x[max_idx] = F::sparse_one();
665    }
666
667    if gamma < F::epsilon() {
668        return Ok(None);
669    }
670
671    Ok(Some(a_norm * gamma))
672}
673
674/// Approximate solve of A * x = b using a few Jacobi iterations.
675/// This is used internally for condition number estimation and does NOT
676/// need to be highly accurate.
677fn approximate_solve<F>(a: &CsrMatrix<F>, b: &[F]) -> SparseResult<Vec<F>>
678where
679    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
680{
681    let n = b.len();
682    let (m, _) = a.shape();
683    if m != n {
684        return Err(SparseError::DimensionMismatch {
685            expected: m,
686            found: n,
687        });
688    }
689
690    // Extract diagonal
691    let mut diag = vec![F::sparse_one(); n];
692    for i in 0..n {
693        let d = a.get(i, i);
694        if d.abs() > F::epsilon() {
695            diag[i] = d;
696        }
697    }
698
699    // A few Jacobi iterations: x_{k+1} = D^{-1} (b - (A - D) x_k)
700    let mut x = vec![F::sparse_zero(); n];
701    for _ in 0..10 {
702        let mut x_new = vec![F::sparse_zero(); n];
703        for i in 0..n {
704            let range = a.row_range(i);
705            let cols = &a.indices[range.clone()];
706            let vals = &a.data[range];
707            let mut sum = b[i];
708            for (idx, &col) in cols.iter().enumerate() {
709                if col != i {
710                    sum -= vals[idx] * x[col];
711                }
712            }
713            x_new[i] = sum / diag[i];
714        }
715        x = x_new;
716    }
717
718    Ok(x)
719}
720
721// ---------------------------------------------------------------------------
722// Sparse matrix transpose
723// ---------------------------------------------------------------------------
724
725/// Compute the transpose of a sparse CSR matrix (returns a new CSR matrix).
726pub fn sparse_transpose<F>(a: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
727where
728    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
729{
730    Ok(a.transpose())
731}
732
733// ---------------------------------------------------------------------------
734// Sparse diagonal extraction
735// ---------------------------------------------------------------------------
736
737/// Extract the diagonal of a sparse CSR matrix as a dense vector.
738pub fn sparse_extract_diagonal<F>(a: &CsrMatrix<F>) -> Vec<F>
739where
740    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
741{
742    let n = a.rows().min(a.cols());
743    let mut diag = vec![F::sparse_zero(); n];
744    for i in 0..n {
745        diag[i] = a.get(i, i);
746    }
747    diag
748}
749
750/// Compute the trace of a sparse matrix (sum of diagonal elements).
751pub fn sparse_matrix_trace<F>(a: &CsrMatrix<F>) -> F
752where
753    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
754{
755    let diag = sparse_extract_diagonal(a);
756    diag.iter().copied().sum()
757}
758
759// ---------------------------------------------------------------------------
760// Sparse identity matrix construction
761// ---------------------------------------------------------------------------
762
763/// Create an n x n sparse identity matrix in CSR format.
764pub fn sparse_identity<F>(n: usize) -> SparseResult<CsrMatrix<F>>
765where
766    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
767{
768    let rows: Vec<usize> = (0..n).collect();
769    let cols: Vec<usize> = (0..n).collect();
770    let data: Vec<F> = vec![F::sparse_one(); n];
771    CsrMatrix::new(data, rows, cols, (n, n))
772}
773
774// ---------------------------------------------------------------------------
775// Tests
776// ---------------------------------------------------------------------------
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781
782    fn build_test_matrix() -> CsrMatrix<f64> {
783        // 3x3 matrix:
784        // [1  2  0]
785        // [3  4  5]
786        // [0  6  7]
787        let rows = vec![0, 0, 1, 1, 1, 2, 2];
788        let cols = vec![0, 1, 0, 1, 2, 1, 2];
789        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
790        CsrMatrix::new(data, rows, cols, (3, 3)).expect("valid matrix")
791    }
792
793    fn build_identity(n: usize) -> CsrMatrix<f64> {
794        let rows: Vec<usize> = (0..n).collect();
795        let cols: Vec<usize> = (0..n).collect();
796        let data = vec![1.0; n];
797        CsrMatrix::new(data, rows, cols, (n, n)).expect("valid identity")
798    }
799
800    fn build_tridiag(n: usize) -> CsrMatrix<f64> {
801        let mut rows = Vec::new();
802        let mut cols = Vec::new();
803        let mut data = Vec::new();
804        for i in 0..n {
805            if i > 0 {
806                rows.push(i);
807                cols.push(i - 1);
808                data.push(-1.0);
809            }
810            rows.push(i);
811            cols.push(i);
812            data.push(2.0);
813            if i + 1 < n {
814                rows.push(i);
815                cols.push(i + 1);
816                data.push(-1.0);
817            }
818        }
819        CsrMatrix::new(data, rows, cols, (n, n)).expect("valid matrix")
820    }
821
822    #[test]
823    fn test_frobenius_norm() {
824        let a = build_test_matrix();
825        let nrm = sparse_matrix_norm(&a, SparseNorm::Frobenius).expect("frobenius norm");
826        // sqrt(1 + 4 + 9 + 16 + 25 + 36 + 49) = sqrt(140)
827        let expected = (140.0_f64).sqrt();
828        assert!(
829            (nrm - expected).abs() < 1e-10,
830            "Expected {expected}, got {nrm}"
831        );
832    }
833
834    #[test]
835    fn test_one_norm() {
836        let a = build_test_matrix();
837        let nrm = sparse_matrix_norm(&a, SparseNorm::One).expect("1-norm");
838        // Column sums: col0=|1|+|3|=4, col1=|2|+|4|+|6|=12, col2=|5|+|7|=12
839        assert!((nrm - 12.0).abs() < 1e-10, "Expected 12.0, got {nrm}");
840    }
841
842    #[test]
843    fn test_inf_norm() {
844        let a = build_test_matrix();
845        let nrm = sparse_matrix_norm(&a, SparseNorm::Inf).expect("inf-norm");
846        // Row sums: row0=1+2=3, row1=3+4+5=12, row2=6+7=13
847        assert!((nrm - 13.0).abs() < 1e-10, "Expected 13.0, got {nrm}");
848    }
849
850    #[test]
851    fn test_spgemm_identity() {
852        let a = build_test_matrix();
853        let eye = build_identity(3);
854        let c = spgemm(&a, &eye).expect("spgemm A*I");
855        for i in 0..3 {
856            for j in 0..3 {
857                assert!(
858                    (c.get(i, j) - a.get(i, j)).abs() < 1e-10,
859                    "Mismatch at ({i},{j})"
860                );
861            }
862        }
863    }
864
865    #[test]
866    fn test_spgemm_square() {
867        // A = [[1,2],[3,4]]
868        // A^2 = [[7,10],[15,22]]
869        let rows = vec![0, 0, 1, 1];
870        let cols = vec![0, 1, 0, 1];
871        let data = vec![1.0, 2.0, 3.0, 4.0];
872        let a = CsrMatrix::new(data, rows, cols, (2, 2)).expect("valid matrix");
873        let c = spgemm(&a, &a).expect("spgemm A*A");
874        assert!((c.get(0, 0) - 7.0).abs() < 1e-10);
875        assert!((c.get(0, 1) - 10.0).abs() < 1e-10);
876        assert!((c.get(1, 0) - 15.0).abs() < 1e-10);
877        assert!((c.get(1, 1) - 22.0).abs() < 1e-10);
878    }
879
880    #[test]
881    fn test_spgemm_dimension_mismatch() {
882        let a = CsrMatrix::new(vec![1.0], vec![0], vec![0], (1, 2)).expect("valid");
883        let b = CsrMatrix::new(vec![1.0], vec![0], vec![0], (3, 1)).expect("valid");
884        assert!(spgemm(&a, &b).is_err());
885    }
886
887    #[test]
888    fn test_sparse_add() {
889        let a = build_identity(3);
890        let b = build_identity(3);
891        let c = sparse_add(&a, &b, 2.0, 3.0).expect("sparse add");
892        // C = 2I + 3I = 5I
893        for i in 0..3 {
894            assert!((c.get(i, i) - 5.0).abs() < 1e-10);
895        }
896    }
897
898    #[test]
899    fn test_sparse_sub() {
900        let a = build_test_matrix();
901        let c = sparse_sub(&a, &a).expect("sparse sub");
902        // A - A should be zero
903        for i in 0..3 {
904            for j in 0..3 {
905                assert!(c.get(i, j).abs() < 1e-10);
906            }
907        }
908    }
909
910    #[test]
911    fn test_sparse_scale() {
912        let a = build_identity(3);
913        let c = sparse_scale(&a, 5.0).expect("sparse scale");
914        for i in 0..3 {
915            assert!((c.get(i, i) - 5.0).abs() < 1e-10);
916        }
917    }
918
919    #[test]
920    fn test_sparse_kronecker_identity() {
921        let i2 = build_identity(2);
922        let i3 = build_identity(3);
923        let c = sparse_kronecker(&i2, &i3).expect("kronecker I2 x I3");
924        // Result should be I6
925        let (m, n) = c.shape();
926        assert_eq!(m, 6);
927        assert_eq!(n, 6);
928        for i in 0..6 {
929            for j in 0..6 {
930                let expected = if i == j { 1.0 } else { 0.0 };
931                assert!(
932                    (c.get(i, j) - expected).abs() < 1e-10,
933                    "Kronecker I2xI3 mismatch at ({i},{j})"
934                );
935            }
936        }
937    }
938
939    #[test]
940    fn test_sparse_kronecker_small() {
941        // A = [[1,2],[3,4]], B = [[5,6],[7,8]]
942        // A kron B = [[5,6,10,12],[7,8,14,16],[15,18,20,24],[21,24,28,32]]
943        let a = CsrMatrix::new(
944            vec![1.0, 2.0, 3.0, 4.0],
945            vec![0, 0, 1, 1],
946            vec![0, 1, 0, 1],
947            (2, 2),
948        )
949        .expect("valid");
950        let b = CsrMatrix::new(
951            vec![5.0, 6.0, 7.0, 8.0],
952            vec![0, 0, 1, 1],
953            vec![0, 1, 0, 1],
954            (2, 2),
955        )
956        .expect("valid");
957        let c = sparse_kronecker(&a, &b).expect("kronecker");
958        assert!((c.get(0, 0) - 5.0).abs() < 1e-10);
959        assert!((c.get(0, 2) - 10.0).abs() < 1e-10);
960        assert!((c.get(3, 3) - 32.0).abs() < 1e-10);
961    }
962
963    #[test]
964    fn test_rcm_tridiagonal() {
965        let n = 10;
966        let a = build_tridiag(n);
967        let result = reverse_cuthill_mckee(&a).expect("rcm");
968        assert_eq!(result.permutation.len(), n);
969        assert_eq!(result.inverse_permutation.len(), n);
970        // Tridiagonal already has bandwidth 1, RCM shouldn't make it worse
971        assert!(result.new_bandwidth <= result.original_bandwidth + 1);
972    }
973
974    #[test]
975    fn test_rcm_sparse_matrix() {
976        // A banded matrix with wider bandwidth
977        let n = 8;
978        let mut rows = Vec::new();
979        let mut cols = Vec::new();
980        let mut data = Vec::new();
981        for i in 0..n {
982            rows.push(i);
983            cols.push(i);
984            data.push(4.0);
985            if i + 1 < n {
986                rows.push(i);
987                cols.push(i + 1);
988                data.push(-1.0);
989                rows.push(i + 1);
990                cols.push(i);
991                data.push(-1.0);
992            }
993            if i + 3 < n {
994                rows.push(i);
995                cols.push(i + 3);
996                data.push(-0.5);
997                rows.push(i + 3);
998                cols.push(i);
999                data.push(-0.5);
1000            }
1001        }
1002        let a = CsrMatrix::new(data, rows, cols, (n, n)).expect("valid matrix");
1003        let result = reverse_cuthill_mckee(&a).expect("rcm");
1004        // Just verify it produces a valid permutation
1005        let mut sorted_perm = result.permutation.clone();
1006        sorted_perm.sort();
1007        let expected: Vec<usize> = (0..n).collect();
1008        assert_eq!(sorted_perm, expected);
1009    }
1010
1011    #[test]
1012    fn test_rcm_identity() {
1013        let eye = build_identity(5);
1014        let result = reverse_cuthill_mckee(&eye).expect("rcm identity");
1015        assert_eq!(result.original_bandwidth, 0);
1016        assert_eq!(result.new_bandwidth, 0);
1017    }
1018
1019    #[test]
1020    fn test_rcm_error_non_square() {
1021        let a = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 3)).expect("valid");
1022        assert!(reverse_cuthill_mckee(&a).is_err());
1023    }
1024
1025    #[test]
1026    fn test_permute_matrix() {
1027        // A = [[1,2],[3,4]]
1028        // Permutation [1,0] => P*A*P^T = [[4,3],[2,1]]
1029        let a = CsrMatrix::new(
1030            vec![1.0, 2.0, 3.0, 4.0],
1031            vec![0, 0, 1, 1],
1032            vec![0, 1, 0, 1],
1033            (2, 2),
1034        )
1035        .expect("valid");
1036        let perm = vec![1, 0];
1037        let inv_perm = vec![1, 0];
1038        let b = permute_matrix(&a, &perm, &inv_perm).expect("permute");
1039        assert!((b.get(0, 0) - 4.0).abs() < 1e-10);
1040        assert!((b.get(0, 1) - 3.0).abs() < 1e-10);
1041        assert!((b.get(1, 0) - 2.0).abs() < 1e-10);
1042        assert!((b.get(1, 1) - 1.0).abs() < 1e-10);
1043    }
1044
1045    #[test]
1046    fn test_condest_identity() {
1047        let eye = build_identity(5);
1048        let cond = condest_1norm(&eye).expect("condest");
1049        // Condition number of identity is 1
1050        if let Some(c) = cond {
1051            assert!((c - 1.0).abs() < 1.0, "Expected cond(I) ~ 1, got {c}");
1052        }
1053    }
1054
1055    #[test]
1056    fn test_condest_diagonal() {
1057        // diag(1, 100) => cond_1 = 100
1058        let a = CsrMatrix::new(vec![1.0, 100.0], vec![0, 1], vec![0, 1], (2, 2)).expect("valid");
1059        let cond = condest_1norm(&a).expect("condest");
1060        if let Some(c) = cond {
1061            // Should be around 100
1062            assert!(c > 10.0 && c < 1000.0, "Expected cond ~ 100, got {c}");
1063        }
1064    }
1065
1066    #[test]
1067    fn test_condest_error_non_square() {
1068        let a = CsrMatrix::new(vec![1.0], vec![0], vec![0], (1, 2)).expect("valid");
1069        assert!(condest_1norm(&a).is_err());
1070    }
1071
1072    #[test]
1073    fn test_sparse_extract_diagonal() {
1074        let a = build_test_matrix();
1075        let diag = sparse_extract_diagonal(&a);
1076        assert_eq!(diag.len(), 3);
1077        assert!((diag[0] - 1.0).abs() < 1e-10);
1078        assert!((diag[1] - 4.0).abs() < 1e-10);
1079        assert!((diag[2] - 7.0).abs() < 1e-10);
1080    }
1081
1082    #[test]
1083    fn test_sparse_matrix_trace() {
1084        let a = build_test_matrix();
1085        let tr = sparse_matrix_trace(&a);
1086        assert!((tr - 12.0).abs() < 1e-10); // 1 + 4 + 7 = 12
1087    }
1088
1089    #[test]
1090    fn test_sparse_identity() {
1091        let eye: CsrMatrix<f64> = sparse_identity(4).expect("sparse identity");
1092        for i in 0..4 {
1093            for j in 0..4 {
1094                let expected = if i == j { 1.0 } else { 0.0 };
1095                assert!(
1096                    (eye.get(i, j) - expected).abs() < 1e-10,
1097                    "Identity mismatch at ({i},{j})"
1098                );
1099            }
1100        }
1101    }
1102
1103    #[test]
1104    fn test_sparse_transpose() {
1105        let a = build_test_matrix();
1106        let at = sparse_transpose(&a).expect("transpose");
1107        for i in 0..3 {
1108            for j in 0..3 {
1109                assert!(
1110                    (at.get(i, j) - a.get(j, i)).abs() < 1e-10,
1111                    "Transpose mismatch at ({i},{j})"
1112                );
1113            }
1114        }
1115    }
1116
1117    #[test]
1118    fn test_compute_bandwidth() {
1119        let a = build_tridiag(5);
1120        assert_eq!(compute_bandwidth(&a), 1);
1121    }
1122}