scirs2_sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format
2//!
3//! This module provides the CSR matrix format implementation, which is
4//! efficient for row operations, matrix-vector multiplication, and more.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use scirs2_core::GpuDataType;
9use std::cmp::PartialEq;
10
11/// Compressed Sparse Row (CSR) matrix
12///
13/// A sparse matrix format that compresses rows, making it efficient for
14/// row operations and matrix-vector multiplication.
15#[derive(Clone)]
16pub struct CsrMatrix<T> {
17    /// Number of rows
18    rows: usize,
19    /// Number of columns
20    cols: usize,
21    /// Row pointers (size rows+1)
22    pub indptr: Vec<usize>,
23    /// Column indices
24    pub indices: Vec<usize>,
25    /// Data values
26    pub data: Vec<T>,
27}
28
29impl<T> CsrMatrix<T>
30where
31    T: Clone + Copy + Zero + PartialEq + SparseElement,
32{
33    /// Get the value at the specified position
34    pub fn get(&self, row: usize, col: usize) -> T {
35        // Check bounds
36        if row >= self.rows || col >= self.cols {
37            return T::sparse_zero();
38        }
39
40        // Find the element in the CSR format
41        for j in self.indptr[row]..self.indptr[row + 1] {
42            if self.indices[j] == col {
43                return self.data[j];
44            }
45        }
46
47        // Element not found, return zero
48        T::sparse_zero()
49    }
50
51    /// Get the triplets (row indices, column indices, data)
52    pub fn get_triplets(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
53        let mut rows = Vec::new();
54        let mut cols = Vec::new();
55        let mut values = Vec::new();
56
57        for i in 0..self.rows {
58            for j in self.indptr[i]..self.indptr[i + 1] {
59                rows.push(i);
60                cols.push(self.indices[j]);
61                values.push(self.data[j]);
62            }
63        }
64
65        (rows, cols, values)
66    }
67    /// Create a new CSR matrix from raw data
68    ///
69    /// # Arguments
70    ///
71    /// * `data` - Vector of non-zero values
72    /// * `rowindices` - Vector of row indices for each non-zero value
73    /// * `colindices` - Vector of column indices for each non-zero value
74    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
75    ///
76    /// # Returns
77    ///
78    /// * A new CSR matrix
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use scirs2_sparse::csr::CsrMatrix;
84    ///
85    /// // Create a 3x3 sparse matrix with 5 non-zero elements
86    /// let rows = vec![0, 0, 1, 2, 2];
87    /// let cols = vec![0, 2, 2, 0, 1];
88    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
89    /// let shape = (3, 3);
90    ///
91    /// let matrix = CsrMatrix::new(data.clone(), rows, cols, shape).unwrap();
92    /// ```
93    pub fn new(
94        data: Vec<T>,
95        rowindices: Vec<usize>,
96        colindices: Vec<usize>,
97        shape: (usize, usize),
98    ) -> SparseResult<Self> {
99        // Validate input data
100        if data.len() != rowindices.len() || data.len() != colindices.len() {
101            return Err(SparseError::DimensionMismatch {
102                expected: data.len(),
103                found: std::cmp::min(rowindices.len(), colindices.len()),
104            });
105        }
106
107        let (rows, cols) = shape;
108
109        // Check indices are within bounds
110        if rowindices.iter().any(|&i| i >= rows) {
111            return Err(SparseError::ValueError(
112                "Row index out of bounds".to_string(),
113            ));
114        }
115
116        if colindices.iter().any(|&i| i >= cols) {
117            return Err(SparseError::ValueError(
118                "Column index out of bounds".to_string(),
119            ));
120        }
121
122        // Convert triplet format to CSR
123        // First, sort by row, then by column
124        let mut triplets: Vec<(usize, usize, T)> = rowindices
125            .into_iter()
126            .zip(colindices)
127            .zip(data)
128            .map(|((r, c), v)| (r, c, v))
129            .collect();
130        triplets.sort_by_key(|&(r, c_, _)| (r, c_));
131
132        // Create indptr, indices, and data arrays
133        let nnz = triplets.len();
134        let mut indptr = vec![0; rows + 1];
135        let mut indices = Vec::with_capacity(nnz);
136        let mut data_out = Vec::with_capacity(nnz);
137
138        // Count elements per row to build indptr
139        for &(r_, _, _) in &triplets {
140            indptr[r_ + 1] += 1;
141        }
142
143        // Compute cumulative sum for indptr
144        for i in 1..=rows {
145            indptr[i] += indptr[i - 1];
146        }
147
148        // Fill indices and data
149        for (_r, c, v) in triplets {
150            indices.push(c);
151            data_out.push(v);
152        }
153
154        Ok(CsrMatrix {
155            rows,
156            cols,
157            indptr,
158            indices,
159            data: data_out,
160        })
161    }
162
163    /// Create a CSR matrix from triplet format (COO-like construction)
164    ///
165    /// This is a convenience constructor that builds a CSR matrix from
166    /// separate row indices, column indices, and values vectors.
167    ///
168    /// # Arguments
169    ///
170    /// * `nrows` - Number of rows in the matrix
171    /// * `ncols` - Number of columns in the matrix
172    /// * `row_indices` - Vector of row indices for each non-zero value
173    /// * `col_indices` - Vector of column indices for each non-zero value
174    /// * `values` - Vector of non-zero values
175    ///
176    /// # Returns
177    ///
178    /// * `Ok(CsrMatrix)` - A new CSR matrix
179    /// * `Err(SparseError)` - If input is invalid
180    ///
181    /// # Examples
182    ///
183    /// ```
184    /// use scirs2_sparse::csr::CsrMatrix;
185    ///
186    /// let row_indices = vec![0, 0, 1, 2, 2];
187    /// let col_indices = vec![0, 2, 2, 0, 1];
188    /// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
189    ///
190    /// let matrix = CsrMatrix::from_triplets(3, 3, row_indices, col_indices, values).unwrap();
191    /// assert_eq!(matrix.nnz(), 5);
192    /// ```
193    pub fn from_triplets(
194        nrows: usize,
195        ncols: usize,
196        row_indices: Vec<usize>,
197        col_indices: Vec<usize>,
198        values: Vec<T>,
199    ) -> SparseResult<Self> {
200        Self::new(values, row_indices, col_indices, (nrows, ncols))
201    }
202
203    /// Create a CSR matrix from triplet tuples
204    ///
205    /// This constructor accepts a slice of (row, col, value) tuples,
206    /// which is convenient for constructing matrices from coordinate lists.
207    ///
208    /// # Arguments
209    ///
210    /// * `nrows` - Number of rows in the matrix
211    /// * `ncols` - Number of columns in the matrix
212    /// * `triplets` - Slice of (row_index, col_index, value) tuples
213    ///
214    /// # Returns
215    ///
216    /// * `Ok(CsrMatrix)` - A new CSR matrix
217    /// * `Err(SparseError)` - If input is invalid
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use scirs2_sparse::csr::CsrMatrix;
223    ///
224    /// let triplets = vec![
225    ///     (0, 0, 1.0),
226    ///     (0, 2, 2.0),
227    ///     (1, 2, 3.0),
228    ///     (2, 0, 4.0),
229    ///     (2, 1, 5.0),
230    /// ];
231    ///
232    /// let matrix = CsrMatrix::try_from_triplets(3, 3, &triplets).unwrap();
233    /// assert_eq!(matrix.nnz(), 5);
234    /// assert_eq!(matrix.get(0, 0), 1.0);
235    /// assert_eq!(matrix.get(2, 1), 5.0);
236    /// ```
237    pub fn try_from_triplets(
238        nrows: usize,
239        ncols: usize,
240        triplets: &[(usize, usize, T)],
241    ) -> SparseResult<Self> {
242        let mut row_indices = Vec::with_capacity(triplets.len());
243        let mut col_indices = Vec::with_capacity(triplets.len());
244        let mut values = Vec::with_capacity(triplets.len());
245
246        for &(r, c, v) in triplets {
247            row_indices.push(r);
248            col_indices.push(c);
249            values.push(v);
250        }
251
252        Self::from_triplets(nrows, ncols, row_indices, col_indices, values)
253    }
254
255    /// Create a new CSR matrix from raw CSR format
256    ///
257    /// # Arguments
258    ///
259    /// * `data` - Vector of non-zero values
260    /// * `indptr` - Vector of row pointers (size rows+1)
261    /// * `indices` - Vector of column indices
262    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
263    ///
264    /// # Returns
265    ///
266    /// * A new CSR matrix
267    pub fn from_raw_csr(
268        data: Vec<T>,
269        indptr: Vec<usize>,
270        indices: Vec<usize>,
271        shape: (usize, usize),
272    ) -> SparseResult<Self> {
273        let (rows, cols) = shape;
274
275        // Validate input data
276        if indptr.len() != rows + 1 {
277            return Err(SparseError::DimensionMismatch {
278                expected: rows + 1,
279                found: indptr.len(),
280            });
281        }
282
283        if data.len() != indices.len() {
284            return Err(SparseError::DimensionMismatch {
285                expected: data.len(),
286                found: indices.len(),
287            });
288        }
289
290        // Check if indptr is monotonically increasing
291        for i in 1..indptr.len() {
292            if indptr[i] < indptr[i - 1] {
293                return Err(SparseError::ValueError(
294                    "Row pointer array must be monotonically increasing".to_string(),
295                ));
296            }
297        }
298
299        // Check if the last indptr entry matches the data length
300        if indptr[rows] != data.len() {
301            return Err(SparseError::ValueError(
302                "Last row pointer entry must match data length".to_string(),
303            ));
304        }
305
306        // Check if indices are within bounds
307        if indices.iter().any(|&i| i >= cols) {
308            return Err(SparseError::ValueError(
309                "Column index out of bounds".to_string(),
310            ));
311        }
312
313        Ok(CsrMatrix {
314            rows,
315            cols,
316            indptr,
317            indices,
318            data,
319        })
320    }
321
322    /// Create a new empty CSR matrix
323    ///
324    /// # Arguments
325    ///
326    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
327    ///
328    /// # Returns
329    ///
330    /// * A new empty CSR matrix
331    pub fn empty(shape: (usize, usize)) -> Self {
332        let (rows, cols) = shape;
333        let indptr = vec![0; rows + 1];
334
335        CsrMatrix {
336            rows,
337            cols,
338            indptr,
339            indices: Vec::new(),
340            data: Vec::new(),
341        }
342    }
343
344    /// Get the number of rows in the matrix
345    pub fn rows(&self) -> usize {
346        self.rows
347    }
348
349    /// Get the number of columns in the matrix
350    pub fn cols(&self) -> usize {
351        self.cols
352    }
353
354    /// Get the shape (dimensions) of the matrix
355    pub fn shape(&self) -> (usize, usize) {
356        (self.rows, self.cols)
357    }
358
359    /// Get the number of non-zero elements in the matrix
360    pub fn nnz(&self) -> usize {
361        self.data.len()
362    }
363
364    /// Convert to dense matrix (as Vec<Vec<T>>)
365    pub fn to_dense(&self) -> Vec<Vec<T>>
366    where
367        T: Zero + Copy + SparseElement,
368    {
369        let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
370
371        for (row_idx, row) in result.iter_mut().enumerate() {
372            for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
373                let col_idx = self.indices[j];
374                row[col_idx] = self.data[j];
375            }
376        }
377
378        result
379    }
380
381    /// Transpose the matrix
382    pub fn transpose(&self) -> Self {
383        // Compute the number of non-zeros per column
384        let mut col_counts = vec![0; self.cols];
385        for &col in &self.indices {
386            col_counts[col] += 1;
387        }
388
389        // Compute column pointers (cumulative sum)
390        let mut col_ptrs = vec![0; self.cols + 1];
391        for i in 0..self.cols {
392            col_ptrs[i + 1] = col_ptrs[i] + col_counts[i];
393        }
394
395        // Fill the transposed matrix
396        let nnz = self.nnz();
397        let mut indices_t = vec![0; nnz];
398        let mut data_t = vec![T::sparse_zero(); nnz];
399        let mut col_counts = vec![0; self.cols];
400
401        for row in 0..self.rows {
402            for j in self.indptr[row]..self.indptr[row + 1] {
403                let col = self.indices[j];
404                let dest = col_ptrs[col] + col_counts[col];
405
406                indices_t[dest] = row;
407                data_t[dest] = self.data[j];
408                col_counts[col] += 1;
409            }
410        }
411
412        CsrMatrix {
413            rows: self.cols,
414            cols: self.rows,
415            indptr: col_ptrs,
416            indices: indices_t,
417            data: data_t,
418        }
419    }
420}
421
422impl<
423        T: Clone
424            + Copy
425            + std::ops::AddAssign
426            + std::ops::MulAssign
427            + std::cmp::PartialEq
428            + std::fmt::Debug
429            + scirs2_core::numeric::Zero
430            + std::ops::Add<Output = T>
431            + std::ops::Mul<Output = T>
432            + SparseElement,
433    > CsrMatrix<T>
434{
435    /// Check if matrix is symmetric
436    ///
437    /// # Returns
438    ///
439    /// * `true` if the matrix is symmetric, `false` otherwise
440    pub fn is_symmetric(&self) -> bool {
441        if self.rows != self.cols {
442            return false;
443        }
444
445        // Create a transposed matrix
446        let transposed = self.transpose();
447
448        // Compare the sparsity patterns and values
449        if self.nnz() != transposed.nnz() {
450            return false;
451        }
452
453        // Compare row by row
454        for row in 0..self.rows {
455            let self_start = self.indptr[row];
456            let self_end = self.indptr[row + 1];
457            let trans_start = transposed.indptr[row];
458            let trans_end = transposed.indptr[row + 1];
459
460            if self_end - self_start != trans_end - trans_start {
461                return false;
462            }
463
464            // Create sorted columns and values for this row
465            let mut self_entries: Vec<(usize, &T)> = (self_start..self_end)
466                .map(|j| (self.indices[j], &self.data[j]))
467                .collect();
468            self_entries.sort_by_key(|(col_, _)| *col_);
469
470            let mut trans_entries: Vec<(usize, &T)> = (trans_start..trans_end)
471                .map(|j| (transposed.indices[j], &transposed.data[j]))
472                .collect();
473            trans_entries.sort_by_key(|(col_, _)| *col_);
474
475            // Compare columns and values
476            for i in 0..self_entries.len() {
477                if self_entries[i].0 != trans_entries[i].0
478                    || self_entries[i].1 != trans_entries[i].1
479                {
480                    return false;
481                }
482            }
483        }
484
485        true
486    }
487
488    /// Matrix-matrix multiplication
489    ///
490    /// # Arguments
491    ///
492    /// * `other` - Matrix to multiply with
493    ///
494    /// # Returns
495    ///
496    /// * Result containing the product matrix
497    pub fn matmul(&self, other: &CsrMatrix<T>) -> SparseResult<CsrMatrix<T>> {
498        if self.cols != other.rows {
499            return Err(SparseError::DimensionMismatch {
500                expected: self.cols,
501                found: other.rows,
502            });
503        }
504
505        // For simplicity, we'll implement this using dense operations
506        // In a real implementation, you'd use a more efficient sparse algorithm
507        let a_dense = self.to_dense();
508        let b_dense = other.to_dense();
509
510        let m = self.rows;
511        let n = other.cols;
512        let k = self.cols;
513
514        let mut c_dense = vec![vec![T::sparse_zero(); n]; m];
515
516        for (i, c_row) in c_dense.iter_mut().enumerate().take(m) {
517            for (j, val) in c_row.iter_mut().enumerate().take(n) {
518                for (l, &a_val) in a_dense[i].iter().enumerate().take(k) {
519                    let prod = a_val * b_dense[l][j];
520                    *val += prod;
521                }
522            }
523        }
524
525        // Convert back to CSR format
526        let mut rowindices = Vec::new();
527        let mut colindices = Vec::new();
528        let mut values = Vec::new();
529
530        for (i, row) in c_dense.iter().enumerate() {
531            for (j, val) in row.iter().enumerate() {
532                if *val != T::sparse_zero() {
533                    rowindices.push(i);
534                    colindices.push(j);
535                    values.push(*val);
536                }
537            }
538        }
539
540        CsrMatrix::new(values, rowindices, colindices, (m, n))
541    }
542
543    /// Get row range for iterating over elements in a row
544    ///
545    /// # Arguments
546    ///
547    /// * `row` - Row index
548    ///
549    /// # Returns
550    ///
551    /// * Range of indices in the data and indices arrays for this row
552    pub fn row_range(&self, row: usize) -> std::ops::Range<usize> {
553        assert!(row < self.rows, "Row index out of bounds");
554        self.indptr[row]..self.indptr[row + 1]
555    }
556
557    /// Get column indices array
558    pub fn colindices(&self) -> &[usize] {
559        &self.indices
560    }
561}
562
563impl CsrMatrix<f64> {
564    /// Matrix-vector multiplication
565    ///
566    /// # Arguments
567    ///
568    /// * `vec` - Vector to multiply with
569    ///
570    /// # Returns
571    ///
572    /// * Result of matrix-vector multiplication
573    pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
574        if vec.len() != self.cols {
575            return Err(SparseError::DimensionMismatch {
576                expected: self.cols,
577                found: vec.len(),
578            });
579        }
580
581        let mut result = vec![0.0; self.rows];
582
583        for (row_idx, result_val) in result.iter_mut().enumerate() {
584            for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
585                let col_idx = self.indices[j];
586                *result_val += self.data[j] * vec[col_idx];
587            }
588        }
589
590        Ok(result)
591    }
592
593    /// GPU-accelerated matrix-vector multiplication
594    ///
595    /// This method automatically uses GPU acceleration when beneficial,
596    /// falling back to optimized CPU implementation when appropriate.
597    ///
598    /// # Arguments
599    ///
600    /// * `vec` - Vector to multiply with
601    ///
602    /// # Returns
603    ///
604    /// * Result of matrix-vector multiplication
605    ///
606    /// # Examples
607    ///
608    /// ```
609    /// use scirs2_sparse::csr::CsrMatrix;
610    ///
611    /// let rows = vec![0, 0, 1, 2, 2];
612    /// let cols = vec![0, 2, 2, 0, 1];
613    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
614    /// let shape = (3, 3);
615    ///
616    /// let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
617    /// let vec = vec![1.0, 2.0, 3.0];
618    /// let result = matrix.gpu_dot(&vec).unwrap();
619    /// ```
620    #[allow(dead_code)]
621    pub fn gpu_dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
622        // Use the GpuSpMV implementation
623        let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::new()?;
624        gpu_spmv.spmv(
625            self.rows,
626            self.cols,
627            &self.indptr,
628            &self.indices,
629            &self.data,
630            vec,
631        )
632    }
633
634    /// GPU-accelerated matrix-vector multiplication with backend selection
635    ///
636    /// # Arguments
637    ///
638    /// * `vec` - Vector to multiply with
639    /// * `backend` - Preferred GPU backend
640    ///
641    /// # Returns
642    ///
643    /// * Result of matrix-vector multiplication
644    #[allow(dead_code)]
645    pub fn gpu_dot_with_backend(
646        &self,
647        vec: &[f64],
648        backend: scirs2_core::gpu::GpuBackend,
649    ) -> SparseResult<Vec<f64>> {
650        // Use the GpuSpMV implementation with specified backend
651        let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::with_backend(backend)?;
652        gpu_spmv.spmv(
653            self.rows,
654            self.cols,
655            &self.indptr,
656            &self.indices,
657            &self.data,
658            vec,
659        )
660    }
661}
662
663impl<T> CsrMatrix<T>
664where
665    T: scirs2_core::numeric::Float
666        + std::fmt::Debug
667        + Copy
668        + Default
669        + GpuDataType
670        + Send
671        + Sync
672        + SparseElement
673        + std::ops::AddAssign
674        + std::ops::Mul<Output = T>
675        + 'static,
676{
677    /// GPU-accelerated matrix-vector multiplication for generic floating-point types
678    ///
679    /// # Arguments
680    ///
681    /// * `vec` - Vector to multiply with
682    ///
683    /// # Returns
684    ///
685    /// * Result of matrix-vector multiplication
686    #[allow(dead_code)]
687    pub fn gpu_dot_generic(&self, vec: &[T]) -> SparseResult<Vec<T>>
688where {
689        // GPU operations fall back to CPU for stability
690        if vec.len() != self.cols {
691            return Err(SparseError::DimensionMismatch {
692                expected: self.cols,
693                found: vec.len(),
694            });
695        }
696
697        let mut result = vec![T::sparse_zero(); self.rows];
698
699        for (row_idx, result_val) in result.iter_mut().enumerate() {
700            let start = self.indptr[row_idx];
701            let end = self.indptr[row_idx + 1];
702
703            for idx in start..end {
704                let col = self.indices[idx];
705                *result_val += self.data[idx] * vec[col];
706            }
707        }
708
709        Ok(result)
710    }
711
712    /// Check if this matrix should benefit from GPU acceleration
713    ///
714    /// # Returns
715    ///
716    /// * `true` if GPU acceleration is likely to provide benefits
717    pub fn should_use_gpu(&self) -> bool {
718        // Use GPU for matrices with significant computation (> 10k non-zeros)
719        // and reasonable sparsity (< 50% dense)
720        let nnz_threshold = 10000;
721        let density = self.nnz() as f64 / (self.rows * self.cols) as f64;
722
723        self.nnz() > nnz_threshold && density < 0.5
724    }
725
726    /// Get GPU backend information
727    ///
728    /// # Returns
729    ///
730    /// * Information about available GPU backends
731    #[allow(dead_code)]
732    pub fn gpu_backend_info() -> SparseResult<(crate::gpu_ops::GpuBackend, String)> {
733        // GPU operations fall back to CPU for stability
734        Ok((crate::gpu_ops::GpuBackend::Cpu, "CPU Fallback".to_string()))
735    }
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741    use approx::assert_relative_eq;
742
743    #[test]
744    fn test_csr_create() {
745        // Create a 3x3 sparse matrix with 5 non-zero elements
746        let rows = vec![0, 0, 1, 2, 2];
747        let cols = vec![0, 2, 2, 0, 1];
748        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
749        let shape = (3, 3);
750
751        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
752
753        assert_eq!(matrix.shape(), (3, 3));
754        assert_eq!(matrix.nnz(), 5);
755    }
756
757    #[test]
758    fn test_csr_to_dense() {
759        // Create a 3x3 sparse matrix with 5 non-zero elements
760        let rows = vec![0, 0, 1, 2, 2];
761        let cols = vec![0, 2, 2, 0, 1];
762        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
763        let shape = (3, 3);
764
765        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
766        let dense = matrix.to_dense();
767
768        let expected = vec![
769            vec![1.0, 0.0, 2.0],
770            vec![0.0, 0.0, 3.0],
771            vec![4.0, 5.0, 0.0],
772        ];
773
774        assert_eq!(dense, expected);
775    }
776
777    #[test]
778    fn test_csr_dot() {
779        // Create a 3x3 sparse matrix with 5 non-zero elements
780        let rows = vec![0, 0, 1, 2, 2];
781        let cols = vec![0, 2, 2, 0, 1];
782        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
783        let shape = (3, 3);
784
785        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
786
787        // Matrix:
788        // [1 0 2]
789        // [0 0 3]
790        // [4 5 0]
791
792        let vec = vec![1.0, 2.0, 3.0];
793        let result = matrix.dot(&vec).unwrap();
794
795        // Expected:
796        // 1*1 + 0*2 + 2*3 = 7
797        // 0*1 + 0*2 + 3*3 = 9
798        // 4*1 + 5*2 + 0*3 = 14
799        let expected = [7.0, 9.0, 14.0];
800
801        assert_eq!(result.len(), expected.len());
802        for (a, b) in result.iter().zip(expected.iter()) {
803            assert_relative_eq!(a, b, epsilon = 1e-10);
804        }
805    }
806
807    #[test]
808    fn test_csr_transpose() {
809        // Create a 3x3 sparse matrix with 5 non-zero elements
810        let rows = vec![0, 0, 1, 2, 2];
811        let cols = vec![0, 2, 2, 0, 1];
812        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
813        let shape = (3, 3);
814
815        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
816        let transposed = matrix.transpose();
817
818        assert_eq!(transposed.shape(), (3, 3));
819        assert_eq!(transposed.nnz(), 5);
820
821        let dense = transposed.to_dense();
822        let expected = vec![
823            vec![1.0, 0.0, 4.0],
824            vec![0.0, 0.0, 5.0],
825            vec![2.0, 3.0, 0.0],
826        ];
827
828        assert_eq!(dense, expected);
829    }
830
831    #[test]
832    fn test_gpu_dot() {
833        // Create a 3x3 sparse matrix
834        let rows = vec![0, 0, 1, 2, 2];
835        let cols = vec![0, 2, 2, 0, 1];
836        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
837        let shape = (3, 3);
838
839        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
840        let vec = vec![1.0, 2.0, 3.0];
841
842        // Test GPU-accelerated SpMV (skip gracefully if GPU is unavailable)
843        match matrix.gpu_dot(&vec) {
844            Ok(result) => {
845                let expected = [7.0, 9.0, 14.0];
846                assert_eq!(result.len(), expected.len());
847                for (a, b) in result.iter().zip(expected.iter()) {
848                    assert_relative_eq!(a, b, epsilon = 1e-10);
849                }
850            }
851            Err(crate::error::SparseError::ComputationError(_))
852            | Err(crate::error::SparseError::OperationNotSupported(_)) => {
853                // Acceptable when no GPU is available in CI/local machines
854            }
855            Err(e) => panic!("Unexpected error in GPU SpMV: {:?}", e),
856        }
857    }
858
859    #[test]
860    fn test_should_use_gpu() {
861        // Small matrix - should not use GPU
862        let small_matrix = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 2)).unwrap();
863        assert!(
864            !small_matrix.should_use_gpu(),
865            "Small matrix should not use GPU"
866        );
867
868        // Large sparse matrix - should use GPU
869        let large_data = vec![1.0; 15000];
870        let large_rows: Vec<usize> = (0..15000).collect();
871        let large_cols: Vec<usize> = (0..15000).collect();
872        let large_matrix =
873            CsrMatrix::new(large_data, large_rows, large_cols, (15000, 15000)).unwrap();
874        assert!(
875            large_matrix.should_use_gpu(),
876            "Large sparse matrix should use GPU"
877        );
878    }
879
880    #[test]
881    fn test_gpu_backend_info() {
882        let backend_info = CsrMatrix::<f64>::gpu_backend_info();
883        assert!(
884            backend_info.is_ok(),
885            "Should be able to get GPU backend info"
886        );
887
888        if let Ok((backend, name)) = backend_info {
889            assert!(!name.is_empty(), "Backend name should not be empty");
890            // Backend should be one of the supported types
891            match backend {
892                crate::gpu_ops::GpuBackend::Cuda
893                | crate::gpu_ops::GpuBackend::OpenCL
894                | crate::gpu_ops::GpuBackend::Metal
895                | crate::gpu_ops::GpuBackend::Cpu
896                | crate::gpu_ops::GpuBackend::Rocm
897                | crate::gpu_ops::GpuBackend::Wgpu => {}
898            }
899        }
900    }
901
902    #[test]
903    fn test_gpu_dot_generic_f32() {
904        // Test with f32 type
905        let rows = vec![0, 0, 1, 2, 2];
906        let cols = vec![0, 2, 2, 0, 1];
907        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
908        let shape = (3, 3);
909
910        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
911        let vec = vec![1.0f32, 2.0, 3.0];
912
913        match matrix.gpu_dot_generic(&vec) {
914            Ok(result) => {
915                let expected = [7.0f32, 9.0, 14.0];
916                assert_eq!(result.len(), expected.len());
917                for (a, b) in result.iter().zip(expected.iter()) {
918                    assert_relative_eq!(a, b, epsilon = 1e-6);
919                }
920            }
921            Err(crate::error::SparseError::ComputationError(_))
922            | Err(crate::error::SparseError::OperationNotSupported(_)) => {}
923            Err(e) => panic!("Unexpected error in generic GPU SpMV: {:?}", e),
924        }
925    }
926}