oxiblas_sparse/
convert.rs

1//! Format conversion utilities for sparse matrices.
2//!
3//! Provides efficient conversions between:
4//! - CSR ↔ CSC
5//! - COO → CSR/CSC
6//! - CSR ↔ DIA
7//! - CSR ↔ ELL
8//! - CSR ↔ BSR
9//! - CSR ↔ BSC
10//! - CSR ↔ HYB
11//! - CSR ↔ SELL
12//! - Sparse ↔ Dense
13//!
14//! # Automatic Format Selection
15//!
16//! Use [`analyze_sparsity_pattern`] to determine the optimal format for a matrix.
17
18use crate::bsc::BscMatrix;
19use crate::bsr::BsrMatrix;
20use crate::coo::CooMatrix;
21use crate::csc::CscMatrix;
22use crate::csr::CsrMatrix;
23use crate::dia::DiaMatrix;
24use crate::ell::EllMatrix;
25use crate::hyb::{HybMatrix, HybWidthStrategy};
26use crate::sell::{SellMatrix, SliceSize};
27use oxiblas_core::scalar::{Field, Real, Scalar};
28
29/// Converts a CSR matrix to CSC format.
30///
31/// Time complexity: O(nnz)
32/// Space complexity: O(nnz) for the output
33pub fn csr_to_csc<T: Scalar + Clone>(csr: &CsrMatrix<T>) -> CscMatrix<T> {
34    let nrows = csr.nrows();
35    let ncols = csr.ncols();
36    let nnz = csr.nnz();
37
38    if nnz == 0 {
39        return CscMatrix::zeros(nrows, ncols);
40    }
41
42    // Count entries per column
43    let mut col_counts = vec![0usize; ncols];
44    for &col in csr.col_indices() {
45        col_counts[col] += 1;
46    }
47
48    // Build column pointers
49    let mut col_ptrs = vec![0usize; ncols + 1];
50    for i in 0..ncols {
51        col_ptrs[i + 1] = col_ptrs[i] + col_counts[i];
52    }
53
54    // Fill in values and row indices
55    let mut row_indices = vec![0usize; nnz];
56    let mut values = vec![T::zero(); nnz];
57    let mut write_pos = col_ptrs.clone();
58
59    for row in 0..nrows {
60        let start = csr.row_ptrs()[row];
61        let end = csr.row_ptrs()[row + 1];
62
63        for i in start..end {
64            let col = csr.col_indices()[i];
65            let pos = write_pos[col];
66
67            row_indices[pos] = row;
68            values[pos] = csr.values()[i].clone();
69
70            write_pos[col] += 1;
71        }
72    }
73
74    // SAFETY: We've constructed valid CSC data
75    unsafe { CscMatrix::new_unchecked(nrows, ncols, col_ptrs, row_indices, values) }
76}
77
78/// Converts a CSC matrix to CSR format.
79///
80/// Time complexity: O(nnz)
81/// Space complexity: O(nnz) for the output
82pub fn csc_to_csr<T: Scalar + Clone>(csc: &CscMatrix<T>) -> CsrMatrix<T> {
83    let nrows = csc.nrows();
84    let ncols = csc.ncols();
85    let nnz = csc.nnz();
86
87    if nnz == 0 {
88        return CsrMatrix::zeros(nrows, ncols);
89    }
90
91    // Count entries per row
92    let mut row_counts = vec![0usize; nrows];
93    for &row in csc.row_indices() {
94        row_counts[row] += 1;
95    }
96
97    // Build row pointers
98    let mut row_ptrs = vec![0usize; nrows + 1];
99    for i in 0..nrows {
100        row_ptrs[i + 1] = row_ptrs[i] + row_counts[i];
101    }
102
103    // Fill in values and column indices
104    let mut col_indices = vec![0usize; nnz];
105    let mut values = vec![T::zero(); nnz];
106    let mut write_pos = row_ptrs.clone();
107
108    for col in 0..ncols {
109        let start = csc.col_ptrs()[col];
110        let end = csc.col_ptrs()[col + 1];
111
112        for i in start..end {
113            let row = csc.row_indices()[i];
114            let pos = write_pos[row];
115
116            col_indices[pos] = col;
117            values[pos] = csc.values()[i].clone();
118
119            write_pos[row] += 1;
120        }
121    }
122
123    // SAFETY: We've constructed valid CSR data
124    unsafe { CsrMatrix::new_unchecked(nrows, ncols, row_ptrs, col_indices, values) }
125}
126
127/// Converts a COO matrix to CSR format, summing duplicate entries.
128///
129/// Time complexity: O(nnz log nnz) due to sorting
130/// Space complexity: O(nnz) for the output
131pub fn coo_to_csr<T: Scalar<Real = T> + Clone + Field + Real>(coo: &CooMatrix<T>) -> CsrMatrix<T> {
132    let nrows = coo.nrows();
133    let ncols = coo.ncols();
134
135    if coo.is_empty() {
136        return CsrMatrix::zeros(nrows, ncols);
137    }
138
139    // Sort entries by (row, col)
140    let mut indices: Vec<usize> = (0..coo.len()).collect();
141    indices.sort_by_key(|&i| (coo.row_indices()[i], coo.col_indices()[i]));
142
143    // Build CSR data, summing duplicates
144    let mut row_ptrs = Vec::with_capacity(nrows + 1);
145    let mut col_indices = Vec::with_capacity(coo.len());
146    let mut values: Vec<T> = Vec::with_capacity(coo.len());
147
148    row_ptrs.push(0);
149    let mut current_row = 0;
150
151    for &idx in &indices {
152        let row = coo.row_indices()[idx];
153        let col = coo.col_indices()[idx];
154        let val = coo.values()[idx].clone();
155
156        // Fill in empty rows
157        while current_row < row {
158            row_ptrs.push(values.len());
159            current_row += 1;
160        }
161
162        // Check for duplicate
163        if !values.is_empty() && col_indices.last() == Some(&col) && current_row == row {
164            // Same position as last entry, accumulate
165            let last = values.len() - 1;
166            values[last] = values[last].clone() + val;
167        } else {
168            // Skip zeros after accumulation
169            if !values.is_empty() {
170                let last = values.len() - 1;
171                if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
172                    values.pop();
173                    col_indices.pop();
174                }
175            }
176            // New entry
177            if Scalar::abs(val.clone()) > <T as Scalar>::epsilon() {
178                col_indices.push(col);
179                values.push(val);
180            }
181        }
182    }
183
184    // Clean up last entry if it became zero
185    if !values.is_empty() {
186        let last = values.len() - 1;
187        if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
188            values.pop();
189            col_indices.pop();
190        }
191    }
192
193    // Fill remaining row pointers
194    while current_row < nrows {
195        row_ptrs.push(values.len());
196        current_row += 1;
197    }
198    row_ptrs.push(values.len());
199
200    // SAFETY: We've constructed valid CSR data
201    unsafe { CsrMatrix::new_unchecked(nrows, ncols, row_ptrs, col_indices, values) }
202}
203
204/// Converts a COO matrix to CSC format, summing duplicate entries.
205///
206/// Time complexity: O(nnz log nnz) due to sorting
207/// Space complexity: O(nnz) for the output
208pub fn coo_to_csc<T: Scalar<Real = T> + Clone + Field + Real>(coo: &CooMatrix<T>) -> CscMatrix<T> {
209    let nrows = coo.nrows();
210    let ncols = coo.ncols();
211
212    if coo.is_empty() {
213        return CscMatrix::zeros(nrows, ncols);
214    }
215
216    // Sort entries by (col, row)
217    let mut indices: Vec<usize> = (0..coo.len()).collect();
218    indices.sort_by_key(|&i| (coo.col_indices()[i], coo.row_indices()[i]));
219
220    // Build CSC data, summing duplicates
221    let mut col_ptrs = Vec::with_capacity(ncols + 1);
222    let mut row_indices = Vec::with_capacity(coo.len());
223    let mut values: Vec<T> = Vec::with_capacity(coo.len());
224
225    col_ptrs.push(0);
226    let mut current_col = 0;
227
228    for &idx in &indices {
229        let row = coo.row_indices()[idx];
230        let col = coo.col_indices()[idx];
231        let val = coo.values()[idx].clone();
232
233        // Fill in empty columns
234        while current_col < col {
235            col_ptrs.push(values.len());
236            current_col += 1;
237        }
238
239        // Check for duplicate
240        if !values.is_empty() && row_indices.last() == Some(&row) && current_col == col {
241            // Same position as last entry, accumulate
242            let last = values.len() - 1;
243            values[last] = values[last].clone() + val;
244        } else {
245            // Skip zeros after accumulation
246            if !values.is_empty() {
247                let last = values.len() - 1;
248                if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
249                    values.pop();
250                    row_indices.pop();
251                }
252            }
253            // New entry
254            if Scalar::abs(val.clone()) > <T as Scalar>::epsilon() {
255                row_indices.push(row);
256                values.push(val);
257            }
258        }
259    }
260
261    // Clean up last entry if it became zero
262    if !values.is_empty() {
263        let last = values.len() - 1;
264        if Scalar::abs(values[last].clone()) <= <T as Scalar>::epsilon() {
265            values.pop();
266            row_indices.pop();
267        }
268    }
269
270    // Fill remaining column pointers
271    while current_col < ncols {
272        col_ptrs.push(values.len());
273        current_col += 1;
274    }
275    col_ptrs.push(values.len());
276
277    // SAFETY: We've constructed valid CSC data
278    unsafe { CscMatrix::new_unchecked(nrows, ncols, col_ptrs, row_indices, values) }
279}
280
281/// Converts a CSR matrix to COO format.
282pub fn csr_to_coo<T: Scalar + Clone>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
283    let nrows = csr.nrows();
284    let ncols = csr.ncols();
285    let nnz = csr.nnz();
286
287    let mut row_indices = Vec::with_capacity(nnz);
288    let mut col_indices = Vec::with_capacity(nnz);
289    let mut values = Vec::with_capacity(nnz);
290
291    for row in 0..nrows {
292        let start = csr.row_ptrs()[row];
293        let end = csr.row_ptrs()[row + 1];
294
295        for i in start..end {
296            row_indices.push(row);
297            col_indices.push(csr.col_indices()[i]);
298            values.push(csr.values()[i].clone());
299        }
300    }
301
302    // SAFETY: Valid COO data derived from valid CSR
303    unsafe { CooMatrix::new_unchecked(nrows, ncols, row_indices, col_indices, values) }
304}
305
306/// Converts a CSC matrix to COO format.
307pub fn csc_to_coo<T: Scalar + Clone>(csc: &CscMatrix<T>) -> CooMatrix<T> {
308    let nrows = csc.nrows();
309    let ncols = csc.ncols();
310    let nnz = csc.nnz();
311
312    let mut row_indices = Vec::with_capacity(nnz);
313    let mut col_indices = Vec::with_capacity(nnz);
314    let mut values = Vec::with_capacity(nnz);
315
316    for col in 0..ncols {
317        let start = csc.col_ptrs()[col];
318        let end = csc.col_ptrs()[col + 1];
319
320        for i in start..end {
321            row_indices.push(csc.row_indices()[i]);
322            col_indices.push(col);
323            values.push(csc.values()[i].clone());
324        }
325    }
326
327    // SAFETY: Valid COO data derived from valid CSC
328    unsafe { CooMatrix::new_unchecked(nrows, ncols, row_indices, col_indices, values) }
329}
330
331// ============================================================================
332// DIA Conversions
333// ============================================================================
334
335/// Converts a CSR matrix to DIA format.
336///
337/// # Arguments
338///
339/// * `csr` - Source CSR matrix
340/// * `offsets` - Optional list of diagonal offsets to extract. If None, all non-empty diagonals are extracted.
341///
342/// Time complexity: O(nnz)
343pub fn csr_to_dia<T: Scalar + Clone + Field>(
344    csr: &CsrMatrix<T>,
345    offsets: Option<Vec<isize>>,
346) -> DiaMatrix<T> {
347    let (nrows, ncols) = csr.shape();
348    let eps = <T as Scalar>::epsilon();
349
350    // Find all non-empty diagonals if not specified
351    let offsets = offsets.unwrap_or_else(|| {
352        let mut found = std::collections::HashSet::new();
353        for (row, col, val) in csr.iter() {
354            if Scalar::abs(val.clone()) > eps {
355                found.insert(col as isize - row as isize);
356            }
357        }
358        let mut offsets: Vec<_> = found.into_iter().collect();
359        offsets.sort();
360        offsets
361    });
362
363    if offsets.is_empty() {
364        return DiaMatrix::zeros(nrows, ncols);
365    }
366
367    let diag_len = nrows.min(ncols);
368    let mut data = Vec::with_capacity(offsets.len());
369
370    for &offset in &offsets {
371        let mut diag = vec![T::zero(); diag_len];
372
373        // Fill diagonal from CSR
374        // Element A[row, col] where col = row + offset goes to data index (row + offset)
375        // This matches DiaMatrix::data_index which uses (row as isize + offset) as usize
376        for (row, col, val) in csr.iter() {
377            let expected_col = (row as isize + offset) as usize;
378            if col == expected_col && row < nrows && col < ncols {
379                // data_index = row + offset (accounting for padding)
380                let idx = (row as isize + offset) as usize;
381                if idx < diag_len {
382                    diag[idx] = val.clone();
383                }
384            }
385        }
386
387        data.push(diag);
388    }
389
390    // Safety: we constructed valid DIA data
391    unsafe { DiaMatrix::new_unchecked(nrows, ncols, offsets, data) }
392}
393
394/// Converts a DIA matrix to CSR format.
395///
396/// Time complexity: O(nrows * ndiag)
397pub fn dia_to_csr<T: Scalar + Clone + Field>(dia: &DiaMatrix<T>) -> CsrMatrix<T> {
398    dia.to_csr()
399}
400
401// ============================================================================
402// ELL Conversions
403// ============================================================================
404
405/// Converts a CSR matrix to ELL format.
406///
407/// # Arguments
408///
409/// * `csr` - Source CSR matrix
410/// * `max_width` - Optional maximum width (if None, uses actual max non-zeros per row)
411///
412/// Time complexity: O(nnz)
413pub fn csr_to_ell<T: Scalar + Clone + Field>(
414    csr: &CsrMatrix<T>,
415    max_width: Option<usize>,
416) -> Result<EllMatrix<T>, crate::ell::EllError> {
417    EllMatrix::from_csr(csr, max_width)
418}
419
420/// Converts an ELL matrix to CSR format.
421///
422/// Time complexity: O(nrows * width)
423pub fn ell_to_csr<T: Scalar + Clone + Field>(ell: &EllMatrix<T>) -> CsrMatrix<T> {
424    ell.to_csr()
425}
426
427// ============================================================================
428// BSR Conversions
429// ============================================================================
430
431/// Converts a CSR matrix to BSR format.
432///
433/// # Arguments
434///
435/// * `csr` - Source CSR matrix
436/// * `block_rows` - Block row size
437/// * `block_cols` - Block column size
438///
439/// Time complexity: O(nnz)
440pub fn csr_to_bsr<T: Scalar + Clone + Field>(
441    csr: &CsrMatrix<T>,
442    block_rows: usize,
443    block_cols: usize,
444) -> BsrMatrix<T> {
445    BsrMatrix::from_csr(csr, block_rows, block_cols)
446}
447
448/// Converts a BSR matrix to CSR format.
449///
450/// Time complexity: O(nblocks * block_size)
451pub fn bsr_to_csr<T: Scalar + Clone + Field>(bsr: &BsrMatrix<T>) -> CsrMatrix<T> {
452    bsr.to_csr()
453}
454
455// ============================================================================
456// Cross-format conversions
457// ============================================================================
458
459/// Converts a DIA matrix to ELL format.
460pub fn dia_to_ell<T: Scalar + Clone + Field>(
461    dia: &DiaMatrix<T>,
462    max_width: Option<usize>,
463) -> Result<EllMatrix<T>, crate::ell::EllError> {
464    let csr = dia.to_csr();
465    EllMatrix::from_csr(&csr, max_width)
466}
467
468/// Converts an ELL matrix to DIA format.
469pub fn ell_to_dia<T: Scalar + Clone + Field>(
470    ell: &EllMatrix<T>,
471    offsets: Option<Vec<isize>>,
472) -> DiaMatrix<T> {
473    let csr = ell.to_csr();
474    csr_to_dia(&csr, offsets)
475}
476
477/// Converts a DIA matrix to BSR format.
478pub fn dia_to_bsr<T: Scalar + Clone + Field>(
479    dia: &DiaMatrix<T>,
480    block_rows: usize,
481    block_cols: usize,
482) -> BsrMatrix<T> {
483    let csr = dia.to_csr();
484    BsrMatrix::from_csr(&csr, block_rows, block_cols)
485}
486
487/// Converts a BSR matrix to DIA format.
488pub fn bsr_to_dia<T: Scalar + Clone + Field>(
489    bsr: &BsrMatrix<T>,
490    offsets: Option<Vec<isize>>,
491) -> DiaMatrix<T> {
492    let csr = bsr.to_csr();
493    csr_to_dia(&csr, offsets)
494}
495
496/// Converts an ELL matrix to BSR format.
497pub fn ell_to_bsr<T: Scalar + Clone + Field>(
498    ell: &EllMatrix<T>,
499    block_rows: usize,
500    block_cols: usize,
501) -> BsrMatrix<T> {
502    let csr = ell.to_csr();
503    BsrMatrix::from_csr(&csr, block_rows, block_cols)
504}
505
506/// Converts a BSR matrix to ELL format.
507pub fn bsr_to_ell<T: Scalar + Clone + Field>(
508    bsr: &BsrMatrix<T>,
509    max_width: Option<usize>,
510) -> Result<EllMatrix<T>, crate::ell::EllError> {
511    let csr = bsr.to_csr();
512    EllMatrix::from_csr(&csr, max_width)
513}
514
515// ============================================================================
516// BSC Conversions
517// ============================================================================
518
519/// Converts a CSR matrix to BSC format.
520///
521/// # Arguments
522///
523/// * `csr` - Source CSR matrix
524/// * `block_rows` - Block row size
525/// * `block_cols` - Block column size
526pub fn csr_to_bsc<T: Scalar + Clone + Field>(
527    csr: &CsrMatrix<T>,
528    block_rows: usize,
529    block_cols: usize,
530) -> BscMatrix<T> {
531    let bsr = BsrMatrix::from_csr(csr, block_rows, block_cols);
532    BscMatrix::from_bsr(&bsr)
533}
534
535/// Converts a BSC matrix to CSR format.
536pub fn bsc_to_csr<T: Scalar + Clone + Field>(bsc: &BscMatrix<T>) -> CsrMatrix<T> {
537    let bsr = bsc.to_bsr();
538    bsr.to_csr()
539}
540
541/// Converts a BSC matrix to BSR format.
542pub fn bsc_to_bsr<T: Scalar + Clone + Field>(bsc: &BscMatrix<T>) -> BsrMatrix<T> {
543    bsc.to_bsr()
544}
545
546/// Converts a BSR matrix to BSC format.
547pub fn bsr_to_bsc<T: Scalar + Clone + Field>(bsr: &BsrMatrix<T>) -> BscMatrix<T> {
548    BscMatrix::from_bsr(bsr)
549}
550
551// ============================================================================
552// HYB Conversions
553// ============================================================================
554
555/// Converts a CSR matrix to HYB format.
556///
557/// # Arguments
558///
559/// * `csr` - Source CSR matrix
560/// * `strategy` - Strategy for determining ELL width
561pub fn csr_to_hyb<T: Scalar + Clone + Field>(
562    csr: &CsrMatrix<T>,
563    strategy: HybWidthStrategy,
564) -> HybMatrix<T> {
565    HybMatrix::from_csr(csr, strategy)
566}
567
568/// Converts a HYB matrix to CSR format.
569pub fn hyb_to_csr<T: Scalar + Clone + Field>(hyb: &HybMatrix<T>) -> CsrMatrix<T> {
570    hyb.to_csr()
571}
572
573/// Converts an ELL matrix to HYB format (no COO overflow).
574pub fn ell_to_hyb<T: Scalar + Clone + Field>(ell: &EllMatrix<T>) -> HybMatrix<T> {
575    HybMatrix::from_ell(ell)
576}
577
578/// Converts a HYB matrix to ELL format.
579pub fn hyb_to_ell<T: Scalar + Clone + Field>(hyb: &HybMatrix<T>) -> EllMatrix<T> {
580    hyb.to_ell()
581}
582
583// ============================================================================
584// SELL Conversions
585// ============================================================================
586
587/// Converts a CSR matrix to SELL (Sliced ELLPACK) format.
588///
589/// # Arguments
590///
591/// * `csr` - Source CSR matrix
592/// * `slice_size` - Size of each slice (typically 32 or 64 for GPU)
593pub fn csr_to_sell<T: Scalar + Clone + Field>(
594    csr: &CsrMatrix<T>,
595    slice_size: SliceSize,
596) -> SellMatrix<T> {
597    SellMatrix::from_csr(csr, slice_size)
598}
599
600/// Converts a SELL matrix to CSR format.
601pub fn sell_to_csr<T: Scalar + Clone + Field>(sell: &SellMatrix<T>) -> CsrMatrix<T> {
602    sell.to_csr()
603}
604
605// ============================================================================
606// Format Detection and Analysis
607// ============================================================================
608
609/// Recommended sparse matrix format based on sparsity analysis.
610#[derive(Debug, Clone, Copy, PartialEq, Eq)]
611pub enum RecommendedFormat {
612    /// CSR: General purpose, good for row-wise operations.
613    Csr,
614    /// CSC: Good for column-wise operations and direct solvers.
615    Csc,
616    /// DIA: Optimal for banded/diagonal matrices.
617    Dia,
618    /// ELL: Good for matrices with uniform row lengths.
619    Ell,
620    /// HYB: Good for matrices with mostly uniform rows but some outliers.
621    Hyb,
622    /// SELL: Good for GPU computation with variable row lengths.
623    Sell,
624    /// BSR: Good for block-structured matrices.
625    Bsr,
626    /// BSC: Good for column-oriented block-structured matrices.
627    Bsc,
628}
629
630/// Analysis of a sparse matrix's sparsity pattern.
631#[derive(Debug, Clone)]
632pub struct SparsityAnalysis {
633    /// Number of rows.
634    pub nrows: usize,
635    /// Number of columns.
636    pub ncols: usize,
637    /// Number of non-zeros.
638    pub nnz: usize,
639    /// Density (nnz / (nrows * ncols)).
640    pub density: f64,
641    /// Maximum row length.
642    pub max_row_length: usize,
643    /// Minimum row length.
644    pub min_row_length: usize,
645    /// Average row length.
646    pub avg_row_length: f64,
647    /// Standard deviation of row lengths.
648    pub row_length_stddev: f64,
649    /// Number of distinct diagonals with entries.
650    pub num_diagonals: usize,
651    /// True if matrix appears to have block structure.
652    pub has_block_structure: bool,
653    /// Detected block size (if any).
654    pub detected_block_size: Option<(usize, usize)>,
655    /// Recommended format for this matrix.
656    pub recommended_format: RecommendedFormat,
657}
658
659/// Analyzes the sparsity pattern of a CSR matrix and recommends a format.
660///
661/// # Returns
662///
663/// A `SparsityAnalysis` containing statistics and a recommended format.
664pub fn analyze_sparsity_pattern<T: Scalar + Clone + Field>(csr: &CsrMatrix<T>) -> SparsityAnalysis {
665    let (nrows, ncols) = csr.shape();
666    let nnz = csr.nnz();
667    let eps = <T as Scalar>::epsilon();
668
669    if nrows == 0 || ncols == 0 {
670        return SparsityAnalysis {
671            nrows,
672            ncols,
673            nnz,
674            density: 0.0,
675            max_row_length: 0,
676            min_row_length: 0,
677            avg_row_length: 0.0,
678            row_length_stddev: 0.0,
679            num_diagonals: 0,
680            has_block_structure: false,
681            detected_block_size: None,
682            recommended_format: RecommendedFormat::Csr,
683        };
684    }
685
686    // Compute row lengths
687    let mut row_lengths = Vec::with_capacity(nrows);
688    for row in 0..nrows {
689        let mut count = 0;
690        for (_, val) in csr.row_iter(row) {
691            if Scalar::abs(val.clone()) > eps {
692                count += 1;
693            }
694        }
695        row_lengths.push(count);
696    }
697
698    let max_row_length = row_lengths.iter().max().copied().unwrap_or(0);
699    let min_row_length = row_lengths.iter().min().copied().unwrap_or(0);
700    let avg_row_length = if nrows > 0 {
701        row_lengths.iter().sum::<usize>() as f64 / nrows as f64
702    } else {
703        0.0
704    };
705
706    // Compute standard deviation
707    let variance: f64 = row_lengths
708        .iter()
709        .map(|&x| {
710            let diff = x as f64 - avg_row_length;
711            diff * diff
712        })
713        .sum::<f64>()
714        / nrows.max(1) as f64;
715    let row_length_stddev = variance.sqrt();
716
717    // Count distinct diagonals
718    let mut diagonals = std::collections::HashSet::new();
719    for (row, col, val) in csr.iter() {
720        if Scalar::abs(val.clone()) > eps {
721            diagonals.insert(col as isize - row as isize);
722        }
723    }
724    let num_diagonals = diagonals.len();
725
726    // Check for block structure (simple heuristic)
727    let (has_block_structure, detected_block_size) = detect_block_structure(csr);
728
729    let density = if nrows * ncols > 0 {
730        nnz as f64 / (nrows * ncols) as f64
731    } else {
732        0.0
733    };
734
735    // Determine recommended format
736    let recommended_format = determine_recommended_format(
737        nrows,
738        ncols,
739        nnz,
740        max_row_length,
741        min_row_length,
742        row_length_stddev,
743        num_diagonals,
744        has_block_structure,
745    );
746
747    SparsityAnalysis {
748        nrows,
749        ncols,
750        nnz,
751        density,
752        max_row_length,
753        min_row_length,
754        avg_row_length,
755        row_length_stddev,
756        num_diagonals,
757        has_block_structure,
758        detected_block_size,
759        recommended_format,
760    }
761}
762
763/// Detects if a matrix has block structure.
764fn detect_block_structure<T: Scalar + Clone + Field>(
765    csr: &CsrMatrix<T>,
766) -> (bool, Option<(usize, usize)>) {
767    let (nrows, ncols) = csr.shape();
768    let eps = <T as Scalar>::epsilon();
769
770    if nrows < 4 || ncols < 4 {
771        return (false, None);
772    }
773
774    // Try common block sizes
775    for block_size in [2, 3, 4, 6, 8] {
776        if nrows % block_size != 0 || ncols % block_size != 0 {
777            continue;
778        }
779
780        let _num_block_rows = nrows / block_size;
781        let _num_block_cols = ncols / block_size;
782
783        // Check if entries align with blocks
784        let block_aligned = true;
785        let mut blocks_found = std::collections::HashSet::new();
786
787        for (row, col, val) in csr.iter() {
788            if Scalar::abs(val.clone()) > eps {
789                let block_row = row / block_size;
790                let block_col = col / block_size;
791                blocks_found.insert((block_row, block_col));
792            }
793        }
794
795        // Verify that within each block, we have dense or near-dense entries
796        let mut dense_blocks = 0;
797        for &(br, bc) in &blocks_found {
798            let mut count = 0;
799            for i in 0..block_size {
800                for j in 0..block_size {
801                    let row = br * block_size + i;
802                    let col = bc * block_size + j;
803                    if let Some(val) = csr.get(row, col) {
804                        if Scalar::abs(val.clone()) > eps {
805                            count += 1;
806                        }
807                    }
808                }
809            }
810            // Consider block dense if > 50% full
811            if count * 2 >= block_size * block_size {
812                dense_blocks += 1;
813            }
814        }
815
816        // Consider it block-structured if > 70% of found blocks are dense
817        if !blocks_found.is_empty() && dense_blocks * 10 >= blocks_found.len() * 7 {
818            return (true, Some((block_size, block_size)));
819        }
820        if !block_aligned {
821            // Just to avoid warnings, this is always true
822            continue;
823        }
824    }
825
826    (false, None)
827}
828
829/// Determines the recommended format based on matrix characteristics.
830fn determine_recommended_format(
831    nrows: usize,
832    ncols: usize,
833    nnz: usize,
834    max_row_length: usize,
835    min_row_length: usize,
836    row_length_stddev: f64,
837    num_diagonals: usize,
838    has_block_structure: bool,
839) -> RecommendedFormat {
840    // Empty or very small matrix
841    if nnz == 0 || nrows <= 10 || ncols <= 10 {
842        return RecommendedFormat::Csr;
843    }
844
845    let avg_row_length = nnz as f64 / nrows.max(1) as f64;
846
847    // Block structure
848    if has_block_structure {
849        return RecommendedFormat::Bsr;
850    }
851
852    // Diagonal/banded structure
853    // If number of diagonals is small relative to matrix size
854    if num_diagonals <= 10 && num_diagonals * 2 <= nrows.max(1) {
855        return RecommendedFormat::Dia;
856    }
857
858    // Uniform row lengths (low variance)
859    let coefficient_of_variation = row_length_stddev / avg_row_length.max(1.0);
860
861    if coefficient_of_variation < 0.3 {
862        // Very uniform - ELL is efficient
863        return RecommendedFormat::Ell;
864    }
865
866    if coefficient_of_variation < 0.8 {
867        // Moderately uniform but with some variation - HYB is good
868        return RecommendedFormat::Hyb;
869    }
870
871    // High variance in row lengths
872    if max_row_length > min_row_length * 10 {
873        // Very irregular - SELL handles this well for GPU
874        return RecommendedFormat::Sell;
875    }
876
877    // Default to CSR
878    RecommendedFormat::Csr
879}
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884    use crate::bsr::DenseBlock;
885
886    #[test]
887    fn test_csr_to_csc() {
888        // [1 0 2]
889        // [0 3 0]
890        // [4 0 5]
891        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
892        let col_indices = vec![0, 2, 1, 0, 2];
893        let row_ptrs = vec![0, 2, 3, 5];
894
895        let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
896        let csc = csr_to_csc(&csr);
897
898        assert_eq!(csc.nnz(), 5);
899        assert_eq!(csc.get(0, 0), Some(&1.0));
900        assert_eq!(csc.get(0, 2), Some(&2.0));
901        assert_eq!(csc.get(1, 1), Some(&3.0));
902        assert_eq!(csc.get(2, 0), Some(&4.0));
903        assert_eq!(csc.get(2, 2), Some(&5.0));
904    }
905
906    #[test]
907    fn test_csc_to_csr() {
908        // [1 0 4]
909        // [0 3 0]
910        // [2 0 5]
911        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
912        let row_indices = vec![0, 2, 1, 0, 2];
913        let col_ptrs = vec![0, 2, 3, 5];
914
915        let csc = CscMatrix::new(3, 3, col_ptrs, row_indices, values).unwrap();
916        let csr = csc_to_csr(&csc);
917
918        assert_eq!(csr.nnz(), 5);
919        assert_eq!(csr.get(0, 0), Some(&1.0));
920        assert_eq!(csr.get(0, 2), Some(&4.0));
921        assert_eq!(csr.get(1, 1), Some(&3.0));
922        assert_eq!(csr.get(2, 0), Some(&2.0));
923        assert_eq!(csr.get(2, 2), Some(&5.0));
924    }
925
926    #[test]
927    fn test_coo_to_csr() {
928        let row_indices = vec![0, 1, 2, 0, 2];
929        let col_indices = vec![0, 1, 0, 2, 2];
930        let values = vec![1.0f64, 3.0, 4.0, 2.0, 5.0];
931
932        let coo = CooMatrix::new(3, 3, row_indices, col_indices, values).unwrap();
933        let csr = coo_to_csr(&coo);
934
935        assert_eq!(csr.nnz(), 5);
936        assert_eq!(csr.get(0, 0), Some(&1.0));
937        assert_eq!(csr.get(0, 2), Some(&2.0));
938        assert_eq!(csr.get(1, 1), Some(&3.0));
939        assert_eq!(csr.get(2, 0), Some(&4.0));
940        assert_eq!(csr.get(2, 2), Some(&5.0));
941    }
942
943    #[test]
944    fn test_coo_to_csr_duplicates() {
945        // Duplicate entries at (0,0)
946        let row_indices = vec![0, 0, 1];
947        let col_indices = vec![0, 0, 1];
948        let values = vec![1.0f64, 2.0, 3.0];
949
950        let coo = CooMatrix::new(2, 2, row_indices, col_indices, values).unwrap();
951        let csr = coo_to_csr(&coo);
952
953        assert_eq!(csr.nnz(), 2);
954        assert_eq!(csr.get(0, 0), Some(&3.0)); // 1 + 2
955        assert_eq!(csr.get(1, 1), Some(&3.0));
956    }
957
958    #[test]
959    fn test_coo_to_csc() {
960        let row_indices = vec![0, 1, 2, 0, 2];
961        let col_indices = vec![0, 1, 0, 2, 2];
962        let values = vec![1.0f64, 3.0, 4.0, 2.0, 5.0];
963
964        let coo = CooMatrix::new(3, 3, row_indices, col_indices, values).unwrap();
965        let csc = coo_to_csc(&coo);
966
967        assert_eq!(csc.nnz(), 5);
968        assert_eq!(csc.get(0, 0), Some(&1.0));
969        assert_eq!(csc.get(0, 2), Some(&2.0));
970        assert_eq!(csc.get(1, 1), Some(&3.0));
971        assert_eq!(csc.get(2, 0), Some(&4.0));
972        assert_eq!(csc.get(2, 2), Some(&5.0));
973    }
974
975    #[test]
976    fn test_roundtrip_csr_csc_csr() {
977        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
978        let col_indices = vec![0, 2, 1, 0, 2];
979        let row_ptrs = vec![0, 2, 3, 5];
980
981        let csr1 = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
982        let csc = csr_to_csc(&csr1);
983        let csr2 = csc_to_csr(&csc);
984
985        assert_eq!(csr1.nnz(), csr2.nnz());
986        for row in 0..3 {
987            for col in 0..3 {
988                assert_eq!(csr1.get(row, col), csr2.get(row, col));
989            }
990        }
991    }
992
993    #[test]
994    fn test_csr_to_coo() {
995        let values = vec![1.0f64, 2.0, 3.0];
996        let col_indices = vec![0, 1, 2];
997        let row_ptrs = vec![0, 1, 2, 3];
998
999        let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
1000        let coo = csr_to_coo(&csr);
1001
1002        assert_eq!(coo.len(), 3);
1003        let entries: Vec<_> = coo.iter().map(|(r, c, v)| (r, c, *v)).collect();
1004        assert_eq!(entries, vec![(0, 0, 1.0), (1, 1, 2.0), (2, 2, 3.0)]);
1005    }
1006
1007    #[test]
1008    fn test_empty_matrix_conversion() {
1009        let csr: CsrMatrix<f64> = CsrMatrix::zeros(5, 3);
1010        let csc = csr_to_csc(&csr);
1011
1012        assert_eq!(csc.nrows(), 5);
1013        assert_eq!(csc.ncols(), 3);
1014        assert_eq!(csc.nnz(), 0);
1015    }
1016
1017    // ========================================================================
1018    // DIA conversion tests
1019    // ========================================================================
1020
1021    #[test]
1022    fn test_csr_to_dia_tridiagonal() {
1023        // Tridiagonal matrix:
1024        // [4 1 0]
1025        // [2 5 1]
1026        // [0 3 6]
1027        let values = vec![4.0f64, 1.0, 2.0, 5.0, 1.0, 3.0, 6.0];
1028        let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
1029        let row_ptrs = vec![0, 2, 5, 7];
1030
1031        let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
1032        let dia = csr_to_dia(&csr, None);
1033
1034        assert_eq!(dia.ndiag(), 3);
1035        assert_eq!(dia.get(0, 0), Some(&4.0));
1036        assert_eq!(dia.get(0, 1), Some(&1.0));
1037        assert_eq!(dia.get(1, 0), Some(&2.0));
1038        assert_eq!(dia.get(1, 1), Some(&5.0));
1039        assert_eq!(dia.get(2, 2), Some(&6.0));
1040    }
1041
1042    #[test]
1043    fn test_dia_to_csr() {
1044        let offsets = vec![-1, 0, 1];
1045        let data = vec![
1046            vec![2.0, 3.0, 0.0],
1047            vec![4.0, 5.0, 6.0],
1048            vec![0.0, 1.0, 1.0],
1049        ];
1050
1051        let dia = DiaMatrix::new(3, 3, offsets, data).unwrap();
1052        let csr = dia_to_csr(&dia);
1053
1054        assert_eq!(csr.nrows(), 3);
1055        assert_eq!(csr.get(0, 0), Some(&4.0));
1056        assert_eq!(csr.get(1, 0), Some(&2.0));
1057    }
1058
1059    #[test]
1060    fn test_csr_dia_roundtrip() {
1061        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
1062        let col_indices = vec![0, 1, 1, 0, 2];
1063        let row_ptrs = vec![0, 2, 3, 5];
1064
1065        let csr1 = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
1066        let dia = csr_to_dia(&csr1, None);
1067        let csr2 = dia_to_csr(&dia);
1068
1069        for row in 0..3 {
1070            for col in 0..3 {
1071                let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1072                let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1073                assert!((v1 - v2).abs() < 1e-10);
1074            }
1075        }
1076    }
1077
1078    // ========================================================================
1079    // ELL conversion tests
1080    // ========================================================================
1081
1082    #[test]
1083    fn test_csr_to_ell() {
1084        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
1085        let col_indices = vec![0, 1, 1, 2, 0, 3];
1086        let row_ptrs = vec![0, 2, 4, 6];
1087
1088        let csr = CsrMatrix::new(3, 4, row_ptrs, col_indices, values).unwrap();
1089        let ell = csr_to_ell(&csr, None).unwrap();
1090
1091        assert_eq!(ell.width(), 2);
1092        assert_eq!(ell.get(0, 0), Some(&1.0));
1093        assert_eq!(ell.get(1, 2), Some(&4.0));
1094    }
1095
1096    #[test]
1097    fn test_ell_to_csr() {
1098        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1099        let indices = vec![vec![0, 1], vec![1, 2]];
1100
1101        let ell = EllMatrix::new(2, 3, 2, data, indices).unwrap();
1102        let csr = ell_to_csr(&ell);
1103
1104        assert_eq!(csr.nrows(), 2);
1105        assert_eq!(csr.get(0, 0), Some(&1.0));
1106        assert_eq!(csr.get(1, 2), Some(&4.0));
1107    }
1108
1109    #[test]
1110    fn test_csr_ell_roundtrip() {
1111        let values = vec![1.0f64, 2.0, 3.0, 4.0];
1112        let col_indices = vec![0, 1, 1, 2];
1113        let row_ptrs = vec![0, 2, 4];
1114
1115        let csr1 = CsrMatrix::new(2, 3, row_ptrs, col_indices, values).unwrap();
1116        let ell = csr_to_ell(&csr1, None).unwrap();
1117        let csr2 = ell_to_csr(&ell);
1118
1119        for row in 0..2 {
1120            for col in 0..3 {
1121                let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1122                let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1123                assert!((v1 - v2).abs() < 1e-10);
1124            }
1125        }
1126    }
1127
1128    // ========================================================================
1129    // BSR conversion tests
1130    // ========================================================================
1131
1132    #[test]
1133    fn test_csr_to_bsr() {
1134        // 4x4 matrix with 2x2 block structure
1135        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1136        let col_indices = vec![0, 1, 0, 1, 2, 3, 2, 3];
1137        let row_ptrs = vec![0, 2, 4, 6, 8];
1138
1139        let csr = CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap();
1140        let bsr = csr_to_bsr(&csr, 2, 2);
1141
1142        assert_eq!(bsr.nblocks(), 2);
1143        assert_eq!(bsr.get(0, 0), Some(1.0));
1144        assert_eq!(bsr.get(3, 3), Some(8.0));
1145    }
1146
1147    #[test]
1148    fn test_bsr_to_csr() {
1149        let block1 = DenseBlock::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1150        let block2 = DenseBlock::new(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1151
1152        let bsr =
1153            BsrMatrix::new(4, 4, 2, 2, vec![0, 1, 2], vec![0, 1], vec![block1, block2]).unwrap();
1154
1155        let csr = bsr_to_csr(&bsr);
1156
1157        assert_eq!(csr.nrows(), 4);
1158        assert_eq!(csr.get(0, 0), Some(&1.0));
1159        assert_eq!(csr.get(2, 2), Some(&5.0));
1160    }
1161
1162    #[test]
1163    fn test_csr_bsr_roundtrip() {
1164        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1165        let col_indices = vec![0, 1, 0, 1, 2, 3, 2, 3];
1166        let row_ptrs = vec![0, 2, 4, 6, 8];
1167
1168        let csr1 = CsrMatrix::new(4, 4, row_ptrs, col_indices, values).unwrap();
1169        let bsr = csr_to_bsr(&csr1, 2, 2);
1170        let csr2 = bsr_to_csr(&bsr);
1171
1172        for row in 0..4 {
1173            for col in 0..4 {
1174                let v1 = csr1.get(row, col).cloned().unwrap_or(0.0);
1175                let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
1176                assert!((v1 - v2).abs() < 1e-10);
1177            }
1178        }
1179    }
1180
1181    // ========================================================================
1182    // Cross-format conversion tests
1183    // ========================================================================
1184
1185    #[test]
1186    fn test_dia_to_ell() {
1187        let offsets = vec![0];
1188        let data = vec![vec![1.0, 2.0, 3.0]];
1189
1190        let dia = DiaMatrix::new(3, 3, offsets, data).unwrap();
1191        let ell = dia_to_ell(&dia, None).unwrap();
1192
1193        assert_eq!(ell.width(), 1);
1194        assert_eq!(ell.get(0, 0), Some(&1.0));
1195        assert_eq!(ell.get(1, 1), Some(&2.0));
1196    }
1197
1198    #[test]
1199    fn test_dia_to_bsr() {
1200        let offsets = vec![0];
1201        let data = vec![vec![1.0, 2.0, 3.0, 4.0]];
1202
1203        let dia = DiaMatrix::new(4, 4, offsets, data).unwrap();
1204        let bsr = dia_to_bsr(&dia, 2, 2);
1205
1206        assert_eq!(bsr.get(0, 0), Some(1.0));
1207        assert_eq!(bsr.get(1, 1), Some(2.0));
1208    }
1209
1210    #[test]
1211    fn test_ell_to_bsr() {
1212        let data = vec![
1213            vec![1.0, 2.0],
1214            vec![3.0, 4.0],
1215            vec![5.0, 6.0],
1216            vec![7.0, 8.0],
1217        ];
1218        let indices = vec![vec![0, 1], vec![0, 1], vec![2, 3], vec![2, 3]];
1219
1220        let ell = EllMatrix::new(4, 4, 2, data, indices).unwrap();
1221        let bsr = ell_to_bsr(&ell, 2, 2);
1222
1223        assert_eq!(bsr.nrows(), 4);
1224        assert_eq!(bsr.get(0, 0), Some(1.0));
1225        assert_eq!(bsr.get(3, 3), Some(8.0));
1226    }
1227}