Skip to main content

numr/sparse/csr/
matmul.rs

1//! CSR matrix multiplication: spmv, spmm
2
3use super::CsrData;
4use crate::dtype::{DType, Element};
5use crate::error::{Error, Result};
6use crate::runtime::Runtime;
7use crate::sparse::{CscData, SparseStorage};
8use crate::tensor::Tensor;
9
10impl<R: Runtime<DType = DType>> CsrData<R> {
11    /// Sparse matrix-vector multiplication: y = A * x
12    ///
13    /// Computes the product of this sparse matrix with a dense vector.
14    ///
15    /// # Arguments
16    ///
17    /// * `x` - Dense vector of length `ncols` (or shape `` `[ncols]` `` or `` `[ncols, 1]` ``)
18    ///
19    /// # Returns
20    ///
21    /// Dense vector of length `` `nrows` ``
22    ///
23    /// # Errors
24    ///
25    /// Returns error if:
26    /// - `x` length doesn't match matrix ncols
27    /// - dtype mismatch between matrix and vector
28    ///
29    /// # Algorithm
30    ///
31    /// For each row i:
32    /// ```text
33    /// `` `y[i] = sum(values[j] * x[col_indices[j]]) for j in row_ptrs[i]..row_ptrs[i+1]` ``
34    /// ```
35    ///
36    /// # Performance
37    ///
38    /// - O(nnz) time complexity
39    /// - CSR format provides optimal memory access pattern for SpMV
40    /// - Each row's non-zeros are contiguous in memory
41    ///
42    /// # Example
43    ///
44    /// ```
45    /// # use numr::prelude::*;
46    /// # #[cfg(feature = "sparse")]
47    /// # {
48    /// # use numr::sparse::SparseTensor;
49    /// # let device = CpuDevice::new();
50    /// # let sp = SparseTensor::<CpuRuntime>::from_coo_slices(&[0, 0, 1], &[0, 1, 0], &[1.0f32, 2.0, 3.0], [2, 2], &device)?.to_csr()?;
51    /// # if let numr::sparse::SparseTensor::Csr(csr) = sp {
52    /// let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
53    /// let y = csr.spmv(&x)?;  // y = [1*1 + 2*2, 3*1] = [5, 3]
54    /// # }
55    /// # }
56    /// # Ok::<(), numr::error::Error>(())
57    /// ```
58    pub fn spmv(&self, x: &Tensor<R>) -> Result<Tensor<R>>
59    where
60        R::Client: crate::sparse::SparseOps<R>,
61    {
62        use crate::sparse::SparseOps;
63
64        let [nrows, ncols] = self.shape;
65        let dtype = self.dtype();
66        let device = self.values.device();
67
68        // Validate vector length
69        let x_len = x.numel();
70        if x_len != ncols {
71            return Err(Error::ShapeMismatch {
72                expected: vec![ncols],
73                got: vec![x_len],
74            });
75        }
76
77        // Validate dtype match
78        if x.dtype() != dtype {
79            return Err(Error::DTypeMismatch {
80                lhs: dtype,
81                rhs: x.dtype(),
82            });
83        }
84
85        // Handle empty matrix case
86        if self.is_empty() {
87            crate::dispatch_dtype!(dtype, T => {
88                let zeros: Vec<T> = vec![T::zero(); nrows];
89                return Ok(Tensor::from_slice(&zeros, &[nrows], device));
90            }, "spmv empty");
91        }
92
93        // Get runtime client to dispatch to backend-specific implementation
94        let client = R::default_client(device);
95
96        // Dispatch on dtype to call backend spmv_csr
97        crate::dispatch_dtype!(dtype, T => {
98            return client.spmv_csr::<T>(
99                &self.row_ptrs,
100                &self.col_indices,
101                &self.values,
102                x,
103                self.shape,
104            );
105        }, "spmv");
106    }
107
108    /// Sparse matrix-dense matrix multiplication: C = A * B
109    ///
110    /// Computes the product of this sparse matrix with a dense matrix.
111    ///
112    /// # Arguments
113    ///
114    /// * `b` - Dense matrix of shape `` `[K, N]` `` where K == ncols of sparse matrix
115    ///
116    /// # Returns
117    ///
118    /// Dense matrix of shape `` `[M, N]` `` where M == nrows of sparse matrix
119    ///
120    /// # Errors
121    ///
122    /// Returns error if:
123    /// - `b` first dimension doesn't match matrix ncols
124    /// - `b` is not 2D
125    /// - dtype mismatch between matrix and input
126    ///
127    /// # Algorithm
128    ///
129    /// For each row i of A and each column n of B:
130    /// ```text
131    /// `` `C[i, n] = sum(A.values[j] * B[A.col_indices[j], n])` ``
132    ///           for j in `` `row_ptrs[i]..row_ptrs[i+1]` ``
133    /// ```
134    ///
135    /// # Performance
136    ///
137    /// - `O(nnz * N)` time complexity
138    /// - CSR format provides good memory access for row-wise traversal
139    ///
140    /// # Example
141    ///
142    /// ```
143    /// # use numr::prelude::*;
144    /// # #[cfg(feature = "sparse")]
145    /// # {
146    /// # use numr::sparse::SparseTensor;
147    /// # let device = CpuDevice::new();
148    /// // A: `[2, 3]` sparse, B: `[3, 2]` dense -> C: `[2, 2]` dense
149    /// # let sp = SparseTensor::<CpuRuntime>::from_coo_slices(&[0, 0, 1], &[0, 1, 2], &[1.0f32, 2.0, 3.0], [2, 3], &device)?.to_csr()?;
150    /// # if let numr::sparse::SparseTensor::Csr(csr) = sp {
151    /// # let b = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0f32], &[3, 2], &device);
152    /// let c = csr.spmm(&b)?;
153    /// # }
154    /// # }
155    /// # Ok::<(), numr::error::Error>(())
156    /// ```
157    pub fn spmm(&self, b: &Tensor<R>) -> Result<Tensor<R>>
158    where
159        R::Client: crate::sparse::SparseOps<R>,
160    {
161        use crate::sparse::SparseOps;
162
163        let [m, k] = self.shape;
164        let dtype = self.dtype();
165        let device = self.values.device();
166
167        // Validate B is 2D
168        if b.ndim() != 2 {
169            return Err(Error::Internal(format!(
170                "Expected 2D tensor for SpMM, got {}D",
171                b.ndim()
172            )));
173        }
174
175        let b_shape = b.shape();
176        let b_k = b_shape[0];
177        let n = b_shape[1];
178
179        // Validate dimensions match
180        if b_k != k {
181            return Err(Error::ShapeMismatch {
182                expected: vec![k],
183                got: vec![b_k],
184            });
185        }
186
187        // Validate dtype match
188        if b.dtype() != dtype {
189            return Err(Error::DTypeMismatch {
190                lhs: dtype,
191                rhs: b.dtype(),
192            });
193        }
194
195        // Handle empty matrix case
196        if self.is_empty() {
197            crate::dispatch_dtype!(dtype, T => {
198                let zeros: Vec<T> = vec![T::zero(); m * n];
199                return Ok(Tensor::from_slice(&zeros, &[m, n], device));
200            }, "spmm empty");
201        }
202
203        // Get runtime client to dispatch to backend-specific implementation
204        let client = R::default_client(device);
205
206        // Dispatch on dtype to call backend spmm_csr
207        crate::dispatch_dtype!(dtype, T => {
208            return client.spmm_csr::<T>(
209                &self.row_ptrs,
210                &self.col_indices,
211                &self.values,
212                b,
213                self.shape,
214            );
215        }, "spmm");
216    }
217
218    /// Transpose the sparse matrix: B = A^T
219    ///
220    /// Returns the transpose as a CSC matrix. This is an `O(1)` operation
221    /// that reinterprets the CSR structure as CSC:
222    /// - `row_ptrs` become `col_ptrs`
223    /// - `col_indices` become `row_indices`
224    /// - `values` remain the same
225    /// - `shape` is swapped
226    ///
227    /// # Returns
228    ///
229    /// CSC matrix representing the transpose
230    ///
231    /// # Performance
232    ///
233    /// `O(1)` - structural reinterpretation, no data copying beyond cloning tensors.
234    ///
235    /// # Example
236    ///
237    /// ```
238    /// # use numr::prelude::*;
239    /// # #[cfg(feature = "sparse")]
240    /// # {
241    /// # use numr::sparse::SparseTensor;
242    /// # let device = CpuDevice::new();
243    /// // A `[2, 3]` in CSR:
244    /// // `[1, 0, 2]`
245    /// // `[0, 3, 0]`
246    /// # let sp = SparseTensor::<CpuRuntime>::from_coo_slices(&[0, 0, 1], &[0, 2, 1], &[1.0f32, 2.0, 3.0], [2, 3], &device)?.to_csr()?;
247    /// # if let numr::sparse::SparseTensor::Csr(a) = sp {
248    /// let a_t = a.transpose();
249    /// // A^T `[3, 2]` in CSC (same underlying data)
250    /// # }
251    /// # }
252    /// # Ok::<(), numr::error::Error>(())
253    /// ```
254    pub fn transpose(&self) -> CscData<R> {
255        let [nrows, ncols] = self.shape;
256        // CSR row_ptrs -> CSC col_ptrs
257        // CSR col_indices -> CSC row_indices
258        // Shape [nrows, ncols] -> [ncols, nrows]
259        CscData {
260            col_ptrs: self.row_ptrs.clone(),
261            row_indices: self.col_indices.clone(),
262            values: self.values.clone(),
263            shape: [ncols, nrows],
264        }
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::dtype::DType;
272    use crate::runtime::Runtime;
273    use crate::runtime::cpu::CpuRuntime;
274    use crate::sparse::{SparseFormat, SparseStorage};
275    use crate::tensor::Tensor;
276
277    // =========================================================================
278    // SpMV tests
279    // =========================================================================
280
281    #[test]
282    fn test_spmv_basic() {
283        let device = <CpuRuntime as Runtime>::Device::default();
284
285        // Matrix:
286        // [1, 0, 2]
287        // [0, 0, 3]
288        // [4, 5, 0]
289        let row_ptrs = vec![0i64, 2, 3, 5];
290        let col_indices = vec![0i64, 2, 2, 0, 1];
291        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
292
293        let csr =
294            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
295                .unwrap();
296
297        // x = [1, 2, 3]
298        let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
299
300        // y = A * x
301        // y[0] = 1*1 + 2*3 = 7
302        // y[1] = 3*3 = 9
303        // y[2] = 4*1 + 5*2 = 14
304        let y = csr.spmv(&x).unwrap();
305
306        assert_eq!(y.shape(), &[3]);
307        let y_data: Vec<f32> = y.to_vec();
308        assert_eq!(y_data, vec![7.0, 9.0, 14.0]);
309    }
310
311    #[test]
312    fn test_spmv_empty_matrix() {
313        let device = <CpuRuntime as Runtime>::Device::default();
314
315        let csr = CsrData::<CpuRuntime>::empty([3, 3], DType::F32, &device);
316        let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
317
318        let y = csr.spmv(&x).unwrap();
319
320        assert_eq!(y.shape(), &[3]);
321        let y_data: Vec<f32> = y.to_vec();
322        assert_eq!(y_data, vec![0.0, 0.0, 0.0]);
323    }
324
325    #[test]
326    fn test_spmv_identity() {
327        let device = <CpuRuntime as Runtime>::Device::default();
328
329        // Identity matrix:
330        // [1, 0, 0]
331        // [0, 1, 0]
332        // [0, 0, 1]
333        let row_ptrs = vec![0i64, 1, 2, 3];
334        let col_indices = vec![0i64, 1, 2];
335        let values = vec![1.0f32, 1.0, 1.0];
336
337        let csr =
338            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
339                .unwrap();
340
341        let x = Tensor::<CpuRuntime>::from_slice(&[7.0f32, 8.0, 9.0], &[3], &device);
342        let y = csr.spmv(&x).unwrap();
343
344        let y_data: Vec<f32> = y.to_vec();
345        assert_eq!(y_data, vec![7.0, 8.0, 9.0]);
346    }
347
348    #[test]
349    fn test_spmv_non_square() {
350        let device = <CpuRuntime as Runtime>::Device::default();
351
352        // Matrix [2, 4]:
353        // [1, 2, 0, 3]
354        // [0, 4, 5, 0]
355        let row_ptrs = vec![0i64, 3, 5];
356        let col_indices = vec![0i64, 1, 3, 1, 2];
357        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
358
359        let csr =
360            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 4], &device)
361                .unwrap();
362
363        // x = [1, 2, 3, 4]
364        let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4], &device);
365
366        // y = A * x
367        // y[0] = 1*1 + 2*2 + 3*4 = 17
368        // y[1] = 4*2 + 5*3 = 23
369        let y = csr.spmv(&x).unwrap();
370
371        assert_eq!(y.shape(), &[2]);
372        let y_data: Vec<f32> = y.to_vec();
373        assert_eq!(y_data, vec![17.0, 23.0]);
374    }
375
376    #[test]
377    fn test_spmv_shape_mismatch() {
378        let device = <CpuRuntime as Runtime>::Device::default();
379
380        let row_ptrs = vec![0i64, 2, 3, 5];
381        let col_indices = vec![0i64, 2, 2, 0, 1];
382        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
383
384        let csr =
385            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
386                .unwrap();
387
388        // Wrong vector length (2 instead of 3)
389        let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
390
391        let result = csr.spmv(&x);
392        assert!(result.is_err());
393    }
394
395    #[test]
396    fn test_spmv_dtype_mismatch() {
397        let device = <CpuRuntime as Runtime>::Device::default();
398
399        let row_ptrs = vec![0i64, 2, 3, 5];
400        let col_indices = vec![0i64, 2, 2, 0, 1];
401        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; // F32
402
403        let csr =
404            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
405                .unwrap();
406
407        // F64 vector
408        let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
409
410        let result = csr.spmv(&x);
411        assert!(result.is_err());
412    }
413
414    #[test]
415    fn test_spmv_f64() {
416        let device = <CpuRuntime as Runtime>::Device::default();
417
418        // Matrix:
419        // [1, 2]
420        // [3, 4]
421        let row_ptrs = vec![0i64, 2, 4];
422        let col_indices = vec![0i64, 1, 0, 1];
423        let values = vec![1.0f64, 2.0, 3.0, 4.0];
424
425        let csr =
426            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 2], &device)
427                .unwrap();
428
429        let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0], &[2], &device);
430
431        // y = A * x
432        // y[0] = 1 + 2 = 3
433        // y[1] = 3 + 4 = 7
434        let y = csr.spmv(&x).unwrap();
435
436        assert_eq!(y.dtype(), DType::F64);
437        let y_data: Vec<f64> = y.to_vec();
438        assert_eq!(y_data, vec![3.0, 7.0]);
439    }
440
441    #[test]
442    fn test_spmv_single_element() {
443        let device = <CpuRuntime as Runtime>::Device::default();
444
445        // Single element at (1, 2) with value 5
446        let row_ptrs = vec![0i64, 0, 1, 1];
447        let col_indices = vec![2i64];
448        let values = vec![5.0f32];
449
450        let csr =
451            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
452                .unwrap();
453
454        let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
455
456        // y = A * x
457        // y[0] = 0
458        // y[1] = 5 * 3 = 15
459        // y[2] = 0
460        let y = csr.spmv(&x).unwrap();
461
462        let y_data: Vec<f32> = y.to_vec();
463        assert_eq!(y_data, vec![0.0, 15.0, 0.0]);
464    }
465
466    // =========================================================================
467    // SpMM tests
468    // =========================================================================
469
470    #[test]
471    fn test_spmm_basic() {
472        let device = <CpuRuntime as Runtime>::Device::default();
473
474        // Sparse A [2, 3]:
475        // [1, 0, 2]
476        // [0, 3, 0]
477        let row_ptrs = vec![0i64, 2, 3];
478        let col_indices = vec![0i64, 2, 1];
479        let values = vec![1.0f32, 2.0, 3.0];
480
481        let csr =
482            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
483                .unwrap();
484
485        // Dense B [3, 2]:
486        // [1, 2]
487        // [3, 4]
488        // [5, 6]
489        let b =
490            Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
491
492        // C = A * B [2, 2]:
493        // C[0,0] = 1*1 + 2*5 = 11
494        // C[0,1] = 1*2 + 2*6 = 14
495        // C[1,0] = 3*3 = 9
496        // C[1,1] = 3*4 = 12
497        let c = csr.spmm(&b).unwrap();
498
499        assert_eq!(c.shape(), &[2, 2]);
500        let c_data: Vec<f32> = c.to_vec();
501        assert_eq!(c_data, vec![11.0, 14.0, 9.0, 12.0]);
502    }
503
504    #[test]
505    fn test_spmm_empty_matrix() {
506        let device = <CpuRuntime as Runtime>::Device::default();
507
508        let csr = CsrData::<CpuRuntime>::empty([2, 3], DType::F32, &device);
509        let b =
510            Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
511
512        let c = csr.spmm(&b).unwrap();
513
514        assert_eq!(c.shape(), &[2, 2]);
515        let c_data: Vec<f32> = c.to_vec();
516        assert_eq!(c_data, vec![0.0, 0.0, 0.0, 0.0]);
517    }
518
519    #[test]
520    fn test_spmm_identity() {
521        let device = <CpuRuntime as Runtime>::Device::default();
522
523        // Identity matrix [3, 3]
524        let row_ptrs = vec![0i64, 1, 2, 3];
525        let col_indices = vec![0i64, 1, 2];
526        let values = vec![1.0f32, 1.0, 1.0];
527
528        let csr =
529            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
530                .unwrap();
531
532        // B [3, 2]
533        let b =
534            Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
535
536        // I * B = B
537        let c = csr.spmm(&b).unwrap();
538
539        let c_data: Vec<f32> = c.to_vec();
540        assert_eq!(c_data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
541    }
542
543    #[test]
544    fn test_spmm_shape_mismatch() {
545        let device = <CpuRuntime as Runtime>::Device::default();
546
547        // A [2, 3]
548        let row_ptrs = vec![0i64, 2, 3];
549        let col_indices = vec![0i64, 2, 1];
550        let values = vec![1.0f32, 2.0, 3.0];
551
552        let csr =
553            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
554                .unwrap();
555
556        // B [2, 2] - wrong dimension (should be [3, ...])
557        let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device);
558
559        let result = csr.spmm(&b);
560        assert!(result.is_err());
561    }
562
563    #[test]
564    fn test_spmm_not_2d() {
565        let device = <CpuRuntime as Runtime>::Device::default();
566
567        let row_ptrs = vec![0i64, 2, 3];
568        let col_indices = vec![0i64, 2, 1];
569        let values = vec![1.0f32, 2.0, 3.0];
570
571        let csr =
572            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
573                .unwrap();
574
575        // 1D tensor instead of 2D
576        let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
577
578        let result = csr.spmm(&b);
579        assert!(result.is_err());
580    }
581
582    #[test]
583    fn test_spmm_dtype_mismatch() {
584        let device = <CpuRuntime as Runtime>::Device::default();
585
586        let row_ptrs = vec![0i64, 2, 3];
587        let col_indices = vec![0i64, 2, 1];
588        let values = vec![1.0f32, 2.0, 3.0]; // F32
589
590        let csr =
591            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
592                .unwrap();
593
594        // F64 matrix
595        let b =
596            Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
597
598        let result = csr.spmm(&b);
599        assert!(result.is_err());
600    }
601
602    #[test]
603    fn test_spmm_f64() {
604        let device = <CpuRuntime as Runtime>::Device::default();
605
606        // A [2, 2]
607        let row_ptrs = vec![0i64, 2, 4];
608        let col_indices = vec![0i64, 1, 0, 1];
609        let values = vec![1.0f64, 2.0, 3.0, 4.0];
610
611        let csr =
612            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 2], &device)
613                .unwrap();
614
615        // B [2, 2]
616        let b = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 0.0, 0.0, 1.0], &[2, 2], &device);
617
618        // C = A * I = A
619        let c = csr.spmm(&b).unwrap();
620
621        assert_eq!(c.dtype(), DType::F64);
622        let c_data: Vec<f64> = c.to_vec();
623        assert_eq!(c_data, vec![1.0, 2.0, 3.0, 4.0]);
624    }
625
626    #[test]
627    fn test_spmm_single_column() {
628        let device = <CpuRuntime as Runtime>::Device::default();
629
630        // A [3, 3]
631        let row_ptrs = vec![0i64, 2, 3, 5];
632        let col_indices = vec![0i64, 2, 2, 0, 1];
633        let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
634
635        let csr =
636            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
637                .unwrap();
638
639        // B [3, 1] - single column (like a vector reshaped)
640        let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3, 1], &device);
641
642        // Should match spmv result
643        let c = csr.spmm(&b).unwrap();
644
645        assert_eq!(c.shape(), &[3, 1]);
646        let c_data: Vec<f32> = c.to_vec();
647        // Same as spmv: [7, 9, 14]
648        assert_eq!(c_data, vec![7.0, 9.0, 14.0]);
649    }
650
651    // =========================================================================
652    // Transpose tests
653    // =========================================================================
654
655    #[test]
656    fn test_csr_transpose() {
657        let device = <CpuRuntime as Runtime>::Device::default();
658
659        // Matrix [2, 3]:
660        // [1, 0, 2]
661        // [0, 3, 0]
662        let row_ptrs = vec![0i64, 2, 3];
663        let col_indices = vec![0i64, 2, 1];
664        let values = vec![1.0f32, 2.0, 3.0];
665
666        let csr =
667            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
668                .unwrap();
669        let csc = csr.transpose();
670
671        // Transposed [3, 2] as CSC
672        assert_eq!(csc.shape(), [3, 2]);
673        assert_eq!(csc.nnz(), 3);
674        assert_eq!(csc.format(), SparseFormat::Csc);
675
676        // CSR row_ptrs become CSC col_ptrs
677        let col_ptrs: Vec<i64> = csc.col_ptrs().to_vec();
678        let row_indices: Vec<i64> = csc.row_indices().to_vec();
679        let t_values: Vec<f32> = csc.values().to_vec();
680
681        assert_eq!(col_ptrs, vec![0, 2, 3]); // Same as original row_ptrs
682        assert_eq!(row_indices, vec![0, 2, 1]); // Same as original col_indices
683        assert_eq!(t_values, vec![1.0, 2.0, 3.0]); // Values unchanged
684    }
685
686    #[test]
687    fn test_csr_transpose_empty() {
688        let device = <CpuRuntime as Runtime>::Device::default();
689
690        let csr = CsrData::<CpuRuntime>::empty([3, 5], DType::F32, &device);
691        let csc = csr.transpose();
692
693        assert_eq!(csc.shape(), [5, 3]);
694        assert_eq!(csc.nnz(), 0);
695        assert_eq!(csc.format(), SparseFormat::Csc);
696    }
697
698    #[test]
699    fn test_csr_transpose_to_dense_matches() {
700        let device = <CpuRuntime as Runtime>::Device::default();
701
702        // Matrix [2, 3]:
703        // [1, 0, 2]
704        // [0, 3, 0]
705        let row_ptrs = vec![0i64, 2, 3];
706        let col_indices = vec![0i64, 2, 1];
707        let values = vec![1.0f32, 2.0, 3.0];
708
709        let csr =
710            CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
711                .unwrap();
712
713        // Convert to dense, then transpose CSC to dense
714        let csc = csr.transpose();
715
716        // Convert CSC transpose to CSR to use to_dense via COO
717        let csr_t = csc.to_csr().unwrap();
718        let coo_t = csr_t.to_coo().unwrap();
719
720        // Build dense from COO
721        let t_rows: Vec<i64> = coo_t.row_indices().to_vec();
722        let t_cols: Vec<i64> = coo_t.col_indices().to_vec();
723        let t_vals: Vec<f32> = coo_t.values().to_vec();
724
725        // Transposed [3, 2]:
726        // [1, 0]
727        // [0, 3]
728        // [2, 0]
729        // Check that values are in correct positions
730        let mut dense_t = vec![0.0f32; 6];
731        for i in 0..t_vals.len() {
732            let r = t_rows[i] as usize;
733            let c = t_cols[i] as usize;
734            dense_t[r * 2 + c] = t_vals[i];
735        }
736        assert_eq!(dense_t, vec![1.0, 0.0, 0.0, 3.0, 2.0, 0.0]);
737    }
738}