1use super::CsrData;
4use crate::dtype::{DType, Element};
5use crate::error::{Error, Result};
6use crate::runtime::Runtime;
7use crate::sparse::{CscData, SparseStorage};
8use crate::tensor::Tensor;
9
10impl<R: Runtime<DType = DType>> CsrData<R> {
11 pub fn spmv(&self, x: &Tensor<R>) -> Result<Tensor<R>>
59 where
60 R::Client: crate::sparse::SparseOps<R>,
61 {
62 use crate::sparse::SparseOps;
63
64 let [nrows, ncols] = self.shape;
65 let dtype = self.dtype();
66 let device = self.values.device();
67
68 let x_len = x.numel();
70 if x_len != ncols {
71 return Err(Error::ShapeMismatch {
72 expected: vec![ncols],
73 got: vec![x_len],
74 });
75 }
76
77 if x.dtype() != dtype {
79 return Err(Error::DTypeMismatch {
80 lhs: dtype,
81 rhs: x.dtype(),
82 });
83 }
84
85 if self.is_empty() {
87 crate::dispatch_dtype!(dtype, T => {
88 let zeros: Vec<T> = vec![T::zero(); nrows];
89 return Ok(Tensor::from_slice(&zeros, &[nrows], device));
90 }, "spmv empty");
91 }
92
93 let client = R::default_client(device);
95
96 crate::dispatch_dtype!(dtype, T => {
98 return client.spmv_csr::<T>(
99 &self.row_ptrs,
100 &self.col_indices,
101 &self.values,
102 x,
103 self.shape,
104 );
105 }, "spmv");
106 }
107
108 pub fn spmm(&self, b: &Tensor<R>) -> Result<Tensor<R>>
158 where
159 R::Client: crate::sparse::SparseOps<R>,
160 {
161 use crate::sparse::SparseOps;
162
163 let [m, k] = self.shape;
164 let dtype = self.dtype();
165 let device = self.values.device();
166
167 if b.ndim() != 2 {
169 return Err(Error::Internal(format!(
170 "Expected 2D tensor for SpMM, got {}D",
171 b.ndim()
172 )));
173 }
174
175 let b_shape = b.shape();
176 let b_k = b_shape[0];
177 let n = b_shape[1];
178
179 if b_k != k {
181 return Err(Error::ShapeMismatch {
182 expected: vec![k],
183 got: vec![b_k],
184 });
185 }
186
187 if b.dtype() != dtype {
189 return Err(Error::DTypeMismatch {
190 lhs: dtype,
191 rhs: b.dtype(),
192 });
193 }
194
195 if self.is_empty() {
197 crate::dispatch_dtype!(dtype, T => {
198 let zeros: Vec<T> = vec![T::zero(); m * n];
199 return Ok(Tensor::from_slice(&zeros, &[m, n], device));
200 }, "spmm empty");
201 }
202
203 let client = R::default_client(device);
205
206 crate::dispatch_dtype!(dtype, T => {
208 return client.spmm_csr::<T>(
209 &self.row_ptrs,
210 &self.col_indices,
211 &self.values,
212 b,
213 self.shape,
214 );
215 }, "spmm");
216 }
217
218 pub fn transpose(&self) -> CscData<R> {
255 let [nrows, ncols] = self.shape;
256 CscData {
260 col_ptrs: self.row_ptrs.clone(),
261 row_indices: self.col_indices.clone(),
262 values: self.values.clone(),
263 shape: [ncols, nrows],
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::dtype::DType;
272 use crate::runtime::Runtime;
273 use crate::runtime::cpu::CpuRuntime;
274 use crate::sparse::{SparseFormat, SparseStorage};
275 use crate::tensor::Tensor;
276
277 #[test]
282 fn test_spmv_basic() {
283 let device = <CpuRuntime as Runtime>::Device::default();
284
285 let row_ptrs = vec![0i64, 2, 3, 5];
290 let col_indices = vec![0i64, 2, 2, 0, 1];
291 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
292
293 let csr =
294 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
295 .unwrap();
296
297 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
299
300 let y = csr.spmv(&x).unwrap();
305
306 assert_eq!(y.shape(), &[3]);
307 let y_data: Vec<f32> = y.to_vec();
308 assert_eq!(y_data, vec![7.0, 9.0, 14.0]);
309 }
310
311 #[test]
312 fn test_spmv_empty_matrix() {
313 let device = <CpuRuntime as Runtime>::Device::default();
314
315 let csr = CsrData::<CpuRuntime>::empty([3, 3], DType::F32, &device);
316 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
317
318 let y = csr.spmv(&x).unwrap();
319
320 assert_eq!(y.shape(), &[3]);
321 let y_data: Vec<f32> = y.to_vec();
322 assert_eq!(y_data, vec![0.0, 0.0, 0.0]);
323 }
324
325 #[test]
326 fn test_spmv_identity() {
327 let device = <CpuRuntime as Runtime>::Device::default();
328
329 let row_ptrs = vec![0i64, 1, 2, 3];
334 let col_indices = vec![0i64, 1, 2];
335 let values = vec![1.0f32, 1.0, 1.0];
336
337 let csr =
338 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
339 .unwrap();
340
341 let x = Tensor::<CpuRuntime>::from_slice(&[7.0f32, 8.0, 9.0], &[3], &device);
342 let y = csr.spmv(&x).unwrap();
343
344 let y_data: Vec<f32> = y.to_vec();
345 assert_eq!(y_data, vec![7.0, 8.0, 9.0]);
346 }
347
348 #[test]
349 fn test_spmv_non_square() {
350 let device = <CpuRuntime as Runtime>::Device::default();
351
352 let row_ptrs = vec![0i64, 3, 5];
356 let col_indices = vec![0i64, 1, 3, 1, 2];
357 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
358
359 let csr =
360 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 4], &device)
361 .unwrap();
362
363 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4], &device);
365
366 let y = csr.spmv(&x).unwrap();
370
371 assert_eq!(y.shape(), &[2]);
372 let y_data: Vec<f32> = y.to_vec();
373 assert_eq!(y_data, vec![17.0, 23.0]);
374 }
375
376 #[test]
377 fn test_spmv_shape_mismatch() {
378 let device = <CpuRuntime as Runtime>::Device::default();
379
380 let row_ptrs = vec![0i64, 2, 3, 5];
381 let col_indices = vec![0i64, 2, 2, 0, 1];
382 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
383
384 let csr =
385 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
386 .unwrap();
387
388 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
390
391 let result = csr.spmv(&x);
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn test_spmv_dtype_mismatch() {
397 let device = <CpuRuntime as Runtime>::Device::default();
398
399 let row_ptrs = vec![0i64, 2, 3, 5];
400 let col_indices = vec![0i64, 2, 2, 0, 1];
401 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0]; let csr =
404 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
405 .unwrap();
406
407 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
409
410 let result = csr.spmv(&x);
411 assert!(result.is_err());
412 }
413
414 #[test]
415 fn test_spmv_f64() {
416 let device = <CpuRuntime as Runtime>::Device::default();
417
418 let row_ptrs = vec![0i64, 2, 4];
422 let col_indices = vec![0i64, 1, 0, 1];
423 let values = vec![1.0f64, 2.0, 3.0, 4.0];
424
425 let csr =
426 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 2], &device)
427 .unwrap();
428
429 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 1.0], &[2], &device);
430
431 let y = csr.spmv(&x).unwrap();
435
436 assert_eq!(y.dtype(), DType::F64);
437 let y_data: Vec<f64> = y.to_vec();
438 assert_eq!(y_data, vec![3.0, 7.0]);
439 }
440
441 #[test]
442 fn test_spmv_single_element() {
443 let device = <CpuRuntime as Runtime>::Device::default();
444
445 let row_ptrs = vec![0i64, 0, 1, 1];
447 let col_indices = vec![2i64];
448 let values = vec![5.0f32];
449
450 let csr =
451 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
452 .unwrap();
453
454 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
455
456 let y = csr.spmv(&x).unwrap();
461
462 let y_data: Vec<f32> = y.to_vec();
463 assert_eq!(y_data, vec![0.0, 15.0, 0.0]);
464 }
465
466 #[test]
471 fn test_spmm_basic() {
472 let device = <CpuRuntime as Runtime>::Device::default();
473
474 let row_ptrs = vec![0i64, 2, 3];
478 let col_indices = vec![0i64, 2, 1];
479 let values = vec![1.0f32, 2.0, 3.0];
480
481 let csr =
482 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
483 .unwrap();
484
485 let b =
490 Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
491
492 let c = csr.spmm(&b).unwrap();
498
499 assert_eq!(c.shape(), &[2, 2]);
500 let c_data: Vec<f32> = c.to_vec();
501 assert_eq!(c_data, vec![11.0, 14.0, 9.0, 12.0]);
502 }
503
504 #[test]
505 fn test_spmm_empty_matrix() {
506 let device = <CpuRuntime as Runtime>::Device::default();
507
508 let csr = CsrData::<CpuRuntime>::empty([2, 3], DType::F32, &device);
509 let b =
510 Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
511
512 let c = csr.spmm(&b).unwrap();
513
514 assert_eq!(c.shape(), &[2, 2]);
515 let c_data: Vec<f32> = c.to_vec();
516 assert_eq!(c_data, vec![0.0, 0.0, 0.0, 0.0]);
517 }
518
519 #[test]
520 fn test_spmm_identity() {
521 let device = <CpuRuntime as Runtime>::Device::default();
522
523 let row_ptrs = vec![0i64, 1, 2, 3];
525 let col_indices = vec![0i64, 1, 2];
526 let values = vec![1.0f32, 1.0, 1.0];
527
528 let csr =
529 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
530 .unwrap();
531
532 let b =
534 Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
535
536 let c = csr.spmm(&b).unwrap();
538
539 let c_data: Vec<f32> = c.to_vec();
540 assert_eq!(c_data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
541 }
542
543 #[test]
544 fn test_spmm_shape_mismatch() {
545 let device = <CpuRuntime as Runtime>::Device::default();
546
547 let row_ptrs = vec![0i64, 2, 3];
549 let col_indices = vec![0i64, 2, 1];
550 let values = vec![1.0f32, 2.0, 3.0];
551
552 let csr =
553 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
554 .unwrap();
555
556 let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device);
558
559 let result = csr.spmm(&b);
560 assert!(result.is_err());
561 }
562
563 #[test]
564 fn test_spmm_not_2d() {
565 let device = <CpuRuntime as Runtime>::Device::default();
566
567 let row_ptrs = vec![0i64, 2, 3];
568 let col_indices = vec![0i64, 2, 1];
569 let values = vec![1.0f32, 2.0, 3.0];
570
571 let csr =
572 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
573 .unwrap();
574
575 let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
577
578 let result = csr.spmm(&b);
579 assert!(result.is_err());
580 }
581
582 #[test]
583 fn test_spmm_dtype_mismatch() {
584 let device = <CpuRuntime as Runtime>::Device::default();
585
586 let row_ptrs = vec![0i64, 2, 3];
587 let col_indices = vec![0i64, 2, 1];
588 let values = vec![1.0f32, 2.0, 3.0]; let csr =
591 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
592 .unwrap();
593
594 let b =
596 Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], &device);
597
598 let result = csr.spmm(&b);
599 assert!(result.is_err());
600 }
601
602 #[test]
603 fn test_spmm_f64() {
604 let device = <CpuRuntime as Runtime>::Device::default();
605
606 let row_ptrs = vec![0i64, 2, 4];
608 let col_indices = vec![0i64, 1, 0, 1];
609 let values = vec![1.0f64, 2.0, 3.0, 4.0];
610
611 let csr =
612 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 2], &device)
613 .unwrap();
614
615 let b = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 0.0, 0.0, 1.0], &[2, 2], &device);
617
618 let c = csr.spmm(&b).unwrap();
620
621 assert_eq!(c.dtype(), DType::F64);
622 let c_data: Vec<f64> = c.to_vec();
623 assert_eq!(c_data, vec![1.0, 2.0, 3.0, 4.0]);
624 }
625
626 #[test]
627 fn test_spmm_single_column() {
628 let device = <CpuRuntime as Runtime>::Device::default();
629
630 let row_ptrs = vec![0i64, 2, 3, 5];
632 let col_indices = vec![0i64, 2, 2, 0, 1];
633 let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
634
635 let csr =
636 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [3, 3], &device)
637 .unwrap();
638
639 let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3, 1], &device);
641
642 let c = csr.spmm(&b).unwrap();
644
645 assert_eq!(c.shape(), &[3, 1]);
646 let c_data: Vec<f32> = c.to_vec();
647 assert_eq!(c_data, vec![7.0, 9.0, 14.0]);
649 }
650
651 #[test]
656 fn test_csr_transpose() {
657 let device = <CpuRuntime as Runtime>::Device::default();
658
659 let row_ptrs = vec![0i64, 2, 3];
663 let col_indices = vec![0i64, 2, 1];
664 let values = vec![1.0f32, 2.0, 3.0];
665
666 let csr =
667 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
668 .unwrap();
669 let csc = csr.transpose();
670
671 assert_eq!(csc.shape(), [3, 2]);
673 assert_eq!(csc.nnz(), 3);
674 assert_eq!(csc.format(), SparseFormat::Csc);
675
676 let col_ptrs: Vec<i64> = csc.col_ptrs().to_vec();
678 let row_indices: Vec<i64> = csc.row_indices().to_vec();
679 let t_values: Vec<f32> = csc.values().to_vec();
680
681 assert_eq!(col_ptrs, vec![0, 2, 3]); assert_eq!(row_indices, vec![0, 2, 1]); assert_eq!(t_values, vec![1.0, 2.0, 3.0]); }
685
686 #[test]
687 fn test_csr_transpose_empty() {
688 let device = <CpuRuntime as Runtime>::Device::default();
689
690 let csr = CsrData::<CpuRuntime>::empty([3, 5], DType::F32, &device);
691 let csc = csr.transpose();
692
693 assert_eq!(csc.shape(), [5, 3]);
694 assert_eq!(csc.nnz(), 0);
695 assert_eq!(csc.format(), SparseFormat::Csc);
696 }
697
698 #[test]
699 fn test_csr_transpose_to_dense_matches() {
700 let device = <CpuRuntime as Runtime>::Device::default();
701
702 let row_ptrs = vec![0i64, 2, 3];
706 let col_indices = vec![0i64, 2, 1];
707 let values = vec![1.0f32, 2.0, 3.0];
708
709 let csr =
710 CsrData::<CpuRuntime>::from_slices(&row_ptrs, &col_indices, &values, [2, 3], &device)
711 .unwrap();
712
713 let csc = csr.transpose();
715
716 let csr_t = csc.to_csr().unwrap();
718 let coo_t = csr_t.to_coo().unwrap();
719
720 let t_rows: Vec<i64> = coo_t.row_indices().to_vec();
722 let t_cols: Vec<i64> = coo_t.col_indices().to_vec();
723 let t_vals: Vec<f32> = coo_t.values().to_vec();
724
725 let mut dense_t = vec![0.0f32; 6];
731 for i in 0..t_vals.len() {
732 let r = t_rows[i] as usize;
733 let c = t_cols[i] as usize;
734 dense_t[r * 2 + c] = t_vals[i];
735 }
736 assert_eq!(dense_t, vec![1.0, 0.0, 0.0, 3.0, 2.0, 0.0]);
737 }
738}