1#![warn(missing_debug_implementations)]
13
14use core::ffi::c_void;
15use std::marker::PhantomData;
16
17use baracuda_cusparse_sys::{
18 cudaDataType, cusparse, cusparseDiagType_t, cusparseDnMatDescr_t, cusparseDnVecDescr_t,
19 cusparseFillMode_t, cusparseHandle_t, cusparseIndexBase_t, cusparseIndexType_t,
20 cusparseOperation_t, cusparseOrder_t, cusparseSpGEMMDescr_t, cusparseSpMatAttribute_t,
21 cusparseSpMatDescr_t, cusparseSpSMDescr_t, cusparseSpSVDescr_t, cusparseStatus_t,
22};
23use baracuda_driver::{DeviceBuffer, Stream};
24use baracuda_types::{Complex32, Complex64};
25
26pub use baracuda_cusparse_sys::{
27 cusparseCsr2CscAlg_t as Csr2CscAlg, cusparseSDDMMAlg_t as SDDMMAlg,
28 cusparseSpGEMMAlg_t as SpGEMMAlg, cusparseSpMMAlg_t as SpMMAlg, cusparseSpMVAlg_t as SpMVAlg,
29 cusparseSpSMAlg_t as SpSMAlg, cusparseSpSVAlg_t as SpSVAlg,
30};
31
32pub type Error = baracuda_core::Error<cusparseStatus_t>;
34pub type Result<T, E = Error> = core::result::Result<T, E>;
36
37#[inline]
38fn check(status: cusparseStatus_t) -> Result<()> {
39 Error::check(status)
40}
41
42pub trait SparseScalar: sealed::Sealed + Copy + 'static {
46 fn data_type() -> cudaDataType;
48}
49
50impl SparseScalar for f32 {
51 fn data_type() -> cudaDataType {
52 cudaDataType::R_32F
53 }
54}
55impl SparseScalar for f64 {
56 fn data_type() -> cudaDataType {
57 cudaDataType::R_64F
58 }
59}
60impl SparseScalar for Complex32 {
61 fn data_type() -> cudaDataType {
62 cudaDataType::C_32F
63 }
64}
65impl SparseScalar for Complex64 {
66 fn data_type() -> cudaDataType {
67 cudaDataType::C_64F
68 }
69}
70
71mod sealed {
72 use baracuda_types::{Complex32, Complex64};
73 pub trait Sealed {}
74 impl Sealed for f32 {}
75 impl Sealed for f64 {}
76 impl Sealed for Complex32 {}
77 impl Sealed for Complex64 {}
78}
79
80#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
84pub enum Op {
85 #[default]
86 N,
87 T,
88 C,
89}
90
91impl Op {
92 fn raw(self) -> cusparseOperation_t {
93 match self {
94 Op::N => cusparseOperation_t::N,
95 Op::T => cusparseOperation_t::T,
96 Op::C => cusparseOperation_t::C,
97 }
98 }
99}
100
101#[derive(Copy, Clone, Debug, Eq, PartialEq)]
103pub enum Order {
104 Row,
105 Col,
106}
107
108impl Order {
109 fn raw(self) -> cusparseOrder_t {
110 match self {
111 Order::Row => cusparseOrder_t::Row,
112 Order::Col => cusparseOrder_t::Col,
113 }
114 }
115}
116
117#[derive(Copy, Clone, Debug, Eq, PartialEq)]
118pub enum Fill {
119 Lower,
120 Upper,
121}
122
123impl Fill {
124 fn raw(self) -> cusparseFillMode_t {
125 match self {
126 Fill::Lower => cusparseFillMode_t::Lower,
127 Fill::Upper => cusparseFillMode_t::Upper,
128 }
129 }
130}
131
132#[derive(Copy, Clone, Debug, Eq, PartialEq)]
133pub enum Diag {
134 NonUnit,
135 Unit,
136}
137
138impl Diag {
139 fn raw(self) -> cusparseDiagType_t {
140 match self {
141 Diag::NonUnit => cusparseDiagType_t::NonUnit,
142 Diag::Unit => cusparseDiagType_t::Unit,
143 }
144 }
145}
146
147pub struct Handle {
151 handle: cusparseHandle_t,
152}
153
154unsafe impl Send for Handle {}
155
156impl core::fmt::Debug for Handle {
157 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
158 f.debug_struct("cusparse::Handle")
159 .field("handle", &self.handle)
160 .finish()
161 }
162}
163
164impl Handle {
165 pub fn new() -> Result<Self> {
166 let c = cusparse()?;
167 let cu = c.cusparse_create()?;
168 let mut h: cusparseHandle_t = core::ptr::null_mut();
169 check(unsafe { cu(&mut h) })?;
170 Ok(Self { handle: h })
171 }
172
173 pub fn set_stream(&self, stream: &Stream) -> Result<()> {
174 let c = cusparse()?;
175 let cu = c.cusparse_set_stream()?;
176 check(unsafe { cu(self.handle, stream.as_raw() as _) })
177 }
178
179 pub fn version(&self) -> Result<i32> {
180 let c = cusparse()?;
181 let cu = c.cusparse_get_version()?;
182 let mut v: core::ffi::c_int = 0;
183 check(unsafe { cu(self.handle, &mut v) })?;
184 Ok(v)
185 }
186
187 #[inline]
188 pub fn as_raw(&self) -> cusparseHandle_t {
189 self.handle
190 }
191}
192
193impl Drop for Handle {
194 fn drop(&mut self) {
195 if let Ok(c) = cusparse() {
196 if let Ok(cu) = c.cusparse_destroy() {
197 let _ = unsafe { cu(self.handle) };
198 }
199 }
200 }
201}
202
203pub struct SpMat<'buf, T> {
209 descr: cusparseSpMatDescr_t,
210 _markers: PhantomData<&'buf mut T>,
211}
212
213unsafe impl<T> Send for SpMat<'_, T> {}
214
215impl<T> core::fmt::Debug for SpMat<'_, T> {
216 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
217 f.debug_struct("SpMat")
218 .field("descr", &self.descr)
219 .finish_non_exhaustive()
220 }
221}
222
223impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> SpMat<'buf, T> {
224 pub fn csr(
229 rows: i64,
230 cols: i64,
231 nnz: i64,
232 row_offsets: &'buf mut DeviceBuffer<i32>,
233 col_indices: &'buf mut DeviceBuffer<i32>,
234 values: &'buf mut DeviceBuffer<T>,
235 ) -> Result<Self> {
236 let c = cusparse()?;
237 let cu = c.cusparse_create_csr()?;
238 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
239 check(unsafe {
240 cu(
241 &mut descr,
242 rows,
243 cols,
244 nnz,
245 row_offsets.as_raw().0 as *mut c_void,
246 col_indices.as_raw().0 as *mut c_void,
247 values.as_raw().0 as *mut c_void,
248 cusparseIndexType_t::I32I,
249 cusparseIndexType_t::I32I,
250 cusparseIndexBase_t::Zero,
251 T::data_type(),
252 )
253 })?;
254 Ok(Self {
255 descr,
256 _markers: PhantomData,
257 })
258 }
259
260 pub fn csc(
262 rows: i64,
263 cols: i64,
264 nnz: i64,
265 col_offsets: &'buf mut DeviceBuffer<i32>,
266 row_indices: &'buf mut DeviceBuffer<i32>,
267 values: &'buf mut DeviceBuffer<T>,
268 ) -> Result<Self> {
269 let c = cusparse()?;
270 let cu = c.cusparse_create_csc()?;
271 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
272 check(unsafe {
273 cu(
274 &mut descr,
275 rows,
276 cols,
277 nnz,
278 col_offsets.as_raw().0 as *mut c_void,
279 row_indices.as_raw().0 as *mut c_void,
280 values.as_raw().0 as *mut c_void,
281 cusparseIndexType_t::I32I,
282 cusparseIndexType_t::I32I,
283 cusparseIndexBase_t::Zero,
284 T::data_type(),
285 )
286 })?;
287 Ok(Self {
288 descr,
289 _markers: PhantomData,
290 })
291 }
292
293 #[allow(clippy::too_many_arguments)]
295 pub fn bsr(
296 brows: i64,
297 bcols: i64,
298 bnnz: i64,
299 row_block_dim: i64,
300 col_block_dim: i64,
301 order: Order,
302 row_offsets: &'buf mut DeviceBuffer<i32>,
303 col_indices: &'buf mut DeviceBuffer<i32>,
304 values: &'buf mut DeviceBuffer<T>,
305 ) -> Result<Self> {
306 let c = cusparse()?;
307 let cu = c.cusparse_create_bsr()?;
308 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
309 check(unsafe {
310 cu(
311 &mut descr,
312 brows,
313 bcols,
314 bnnz,
315 row_block_dim,
316 col_block_dim,
317 row_offsets.as_raw().0 as *mut c_void,
318 col_indices.as_raw().0 as *mut c_void,
319 values.as_raw().0 as *mut c_void,
320 cusparseIndexType_t::I32I,
321 cusparseIndexType_t::I32I,
322 cusparseIndexBase_t::Zero,
323 T::data_type(),
324 order.raw(),
325 )
326 })?;
327 Ok(Self {
328 descr,
329 _markers: PhantomData,
330 })
331 }
332
333 pub fn coo(
335 rows: i64,
336 cols: i64,
337 nnz: i64,
338 row_indices: &'buf mut DeviceBuffer<i32>,
339 col_indices: &'buf mut DeviceBuffer<i32>,
340 values: &'buf mut DeviceBuffer<T>,
341 ) -> Result<Self> {
342 let c = cusparse()?;
343 let cu = c.cusparse_create_coo()?;
344 let mut descr: cusparseSpMatDescr_t = core::ptr::null_mut();
345 check(unsafe {
346 cu(
347 &mut descr,
348 rows,
349 cols,
350 nnz,
351 row_indices.as_raw().0 as *mut c_void,
352 col_indices.as_raw().0 as *mut c_void,
353 values.as_raw().0 as *mut c_void,
354 cusparseIndexType_t::I32I,
355 cusparseIndexBase_t::Zero,
356 T::data_type(),
357 )
358 })?;
359 Ok(Self {
360 descr,
361 _markers: PhantomData,
362 })
363 }
364}
365
366impl<T> SpMat<'_, T> {
367 pub fn shape(&self) -> Result<(i64, i64, i64)> {
369 let c = cusparse()?;
370 let cu = c.cusparse_sp_mat_get_size()?;
371 let (mut r, mut col, mut nz) = (0i64, 0i64, 0i64);
372 check(unsafe { cu(self.descr, &mut r, &mut col, &mut nz) })?;
373 Ok((r, col, nz))
374 }
375
376 pub fn set_fill(&self, fill: Fill) -> Result<()> {
378 let c = cusparse()?;
379 let cu = c.cusparse_sp_mat_set_attribute()?;
380 let raw = fill.raw();
381 check(unsafe {
382 cu(
383 self.descr,
384 cusparseSpMatAttribute_t::FillMode,
385 &raw as *const _ as *const c_void,
386 core::mem::size_of::<cusparseFillMode_t>(),
387 )
388 })
389 }
390
391 pub fn set_diag(&self, diag: Diag) -> Result<()> {
393 let c = cusparse()?;
394 let cu = c.cusparse_sp_mat_set_attribute()?;
395 let raw = diag.raw();
396 check(unsafe {
397 cu(
398 self.descr,
399 cusparseSpMatAttribute_t::DiagType,
400 &raw as *const _ as *const c_void,
401 core::mem::size_of::<cusparseDiagType_t>(),
402 )
403 })
404 }
405
406 #[inline]
407 pub fn as_raw(&self) -> cusparseSpMatDescr_t {
408 self.descr
409 }
410}
411
412impl<T> Drop for SpMat<'_, T> {
413 fn drop(&mut self) {
414 if let Ok(c) = cusparse() {
415 if let Ok(cu) = c.cusparse_destroy_sp_mat() {
416 let _ = unsafe { cu(self.descr) };
417 }
418 }
419 }
420}
421
422pub struct DnVec<'buf, T> {
425 descr: cusparseDnVecDescr_t,
426 _marker: PhantomData<&'buf mut T>,
427}
428
429unsafe impl<T> Send for DnVec<'_, T> {}
430
431impl<T> core::fmt::Debug for DnVec<'_, T> {
432 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
433 f.debug_struct("DnVec")
434 .field("descr", &self.descr)
435 .finish_non_exhaustive()
436 }
437}
438
439impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnVec<'buf, T> {
440 pub fn new(values: &'buf mut DeviceBuffer<T>) -> Result<Self> {
441 let c = cusparse()?;
442 let cu = c.cusparse_create_dn_vec()?;
443 let mut descr: cusparseDnVecDescr_t = core::ptr::null_mut();
444 check(unsafe {
445 cu(
446 &mut descr,
447 values.len() as i64,
448 values.as_raw().0 as *mut c_void,
449 T::data_type(),
450 )
451 })?;
452 Ok(Self {
453 descr,
454 _marker: PhantomData,
455 })
456 }
457}
458
459impl<T> DnVec<'_, T> {
460 #[inline]
461 pub fn as_raw(&self) -> cusparseDnVecDescr_t {
462 self.descr
463 }
464}
465
466impl<T> Drop for DnVec<'_, T> {
467 fn drop(&mut self) {
468 if let Ok(c) = cusparse() {
469 if let Ok(cu) = c.cusparse_destroy_dn_vec() {
470 let _ = unsafe { cu(self.descr) };
471 }
472 }
473 }
474}
475
476pub struct DnMat<'buf, T> {
477 descr: cusparseDnMatDescr_t,
478 _marker: PhantomData<&'buf mut T>,
479}
480
481unsafe impl<T> Send for DnMat<'_, T> {}
482
483impl<T> core::fmt::Debug for DnMat<'_, T> {
484 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
485 f.debug_struct("DnMat")
486 .field("descr", &self.descr)
487 .finish_non_exhaustive()
488 }
489}
490
491impl<'buf, T: SparseScalar + baracuda_types::DeviceRepr> DnMat<'buf, T> {
492 pub fn new(
493 rows: i64,
494 cols: i64,
495 ld: i64,
496 order: Order,
497 values: &'buf mut DeviceBuffer<T>,
498 ) -> Result<Self> {
499 let c = cusparse()?;
500 let cu = c.cusparse_create_dn_mat()?;
501 let mut descr: cusparseDnMatDescr_t = core::ptr::null_mut();
502 check(unsafe {
503 cu(
504 &mut descr,
505 rows,
506 cols,
507 ld,
508 values.as_raw().0 as *mut c_void,
509 T::data_type(),
510 order.raw(),
511 )
512 })?;
513 Ok(Self {
514 descr,
515 _marker: PhantomData,
516 })
517 }
518}
519
520impl<T> DnMat<'_, T> {
521 #[inline]
522 pub fn as_raw(&self) -> cusparseDnMatDescr_t {
523 self.descr
524 }
525}
526
527impl<T> Drop for DnMat<'_, T> {
528 fn drop(&mut self) {
529 if let Ok(c) = cusparse() {
530 if let Ok(cu) = c.cusparse_destroy_dn_mat() {
531 let _ = unsafe { cu(self.descr) };
532 }
533 }
534 }
535}
536
537#[allow(clippy::too_many_arguments)]
541pub fn spmv_buffer_size<T: SparseScalar>(
542 handle: &Handle,
543 op: Op,
544 alpha: &T,
545 a: &SpMat<'_, T>,
546 x: &DnVec<'_, T>,
547 beta: &T,
548 y: &DnVec<'_, T>,
549 alg: SpMVAlg,
550) -> Result<usize> {
551 let c = cusparse()?;
552 let cu = c.cusparse_spmv_buffer_size()?;
553 let mut size: usize = 0;
554 check(unsafe {
555 cu(
556 handle.as_raw(),
557 op.raw(),
558 alpha as *const T as *const c_void,
559 a.descr,
560 x.descr,
561 beta as *const T as *const c_void,
562 y.descr,
563 T::data_type(),
564 alg,
565 &mut size,
566 )
567 })?;
568 Ok(size)
569}
570
571#[allow(clippy::too_many_arguments)]
573pub fn spmv<T: SparseScalar>(
574 handle: &Handle,
575 op: Op,
576 alpha: &T,
577 a: &SpMat<'_, T>,
578 x: &DnVec<'_, T>,
579 beta: &T,
580 y: &mut DnVec<'_, T>,
581 alg: SpMVAlg,
582 workspace: &mut DeviceBuffer<u8>,
583) -> Result<()> {
584 let c = cusparse()?;
585 let cu = c.cusparse_spmv()?;
586 check(unsafe {
587 cu(
588 handle.as_raw(),
589 op.raw(),
590 alpha as *const T as *const c_void,
591 a.descr,
592 x.descr,
593 beta as *const T as *const c_void,
594 y.descr,
595 T::data_type(),
596 alg,
597 workspace.as_raw().0 as *mut c_void,
598 )
599 })
600}
601
602#[allow(clippy::too_many_arguments)]
606pub fn spmm_buffer_size<T: SparseScalar>(
607 handle: &Handle,
608 op_a: Op,
609 op_b: Op,
610 alpha: &T,
611 a: &SpMat<'_, T>,
612 b: &DnMat<'_, T>,
613 beta: &T,
614 c: &DnMat<'_, T>,
615 alg: SpMMAlg,
616) -> Result<usize> {
617 let c_api = cusparse()?;
618 let cu = c_api.cusparse_spmm_buffer_size()?;
619 let mut size = 0usize;
620 check(unsafe {
621 cu(
622 handle.as_raw(),
623 op_a.raw(),
624 op_b.raw(),
625 alpha as *const T as *const c_void,
626 a.descr,
627 b.descr,
628 beta as *const T as *const c_void,
629 c.descr,
630 T::data_type(),
631 alg,
632 &mut size,
633 )
634 })?;
635 Ok(size)
636}
637
638#[allow(clippy::too_many_arguments)]
639pub fn spmm<T: SparseScalar>(
640 handle: &Handle,
641 op_a: Op,
642 op_b: Op,
643 alpha: &T,
644 a: &SpMat<'_, T>,
645 b: &DnMat<'_, T>,
646 beta: &T,
647 c: &mut DnMat<'_, T>,
648 alg: SpMMAlg,
649 workspace: &mut DeviceBuffer<u8>,
650) -> Result<()> {
651 let c_api = cusparse()?;
652 let cu = c_api.cusparse_spmm()?;
653 check(unsafe {
654 cu(
655 handle.as_raw(),
656 op_a.raw(),
657 op_b.raw(),
658 alpha as *const T as *const c_void,
659 a.descr,
660 b.descr,
661 beta as *const T as *const c_void,
662 c.descr,
663 T::data_type(),
664 alg,
665 workspace.as_raw().0 as *mut c_void,
666 )
667 })
668}
669
670#[derive(Debug)]
674pub struct SpGEMMPlan {
675 raw: cusparseSpGEMMDescr_t,
676}
677
678impl SpGEMMPlan {
679 pub fn new() -> Result<Self> {
680 let c = cusparse()?;
681 let cu = c.cusparse_spgemm_create_descr()?;
682 let mut d: cusparseSpGEMMDescr_t = core::ptr::null_mut();
683 check(unsafe { cu(&mut d) })?;
684 Ok(Self { raw: d })
685 }
686}
687
688impl Drop for SpGEMMPlan {
689 fn drop(&mut self) {
690 if let Ok(c) = cusparse() {
691 if let Ok(cu) = c.cusparse_spgemm_destroy_descr() {
692 let _ = unsafe { cu(self.raw) };
693 }
694 }
695 }
696}
697
698#[allow(clippy::too_many_arguments)]
701pub unsafe fn spgemm_work_estimation<T: SparseScalar>(
702 handle: &Handle,
703 op_a: Op,
704 op_b: Op,
705 alpha: &T,
706 a: &SpMat<'_, T>,
707 b: &SpMat<'_, T>,
708 beta: &T,
709 c: &mut SpMat<'_, T>,
710 alg: SpGEMMAlg,
711 plan: &SpGEMMPlan,
712 size1: &mut usize,
713 buffer1: *mut c_void,
714) -> Result<()> {
715 let c_api = cusparse()?;
716 let cu = c_api.cusparse_spgemm_work_estimation()?;
717 check(cu(
718 handle.as_raw(),
719 op_a.raw(),
720 op_b.raw(),
721 alpha as *const T as *const c_void,
722 a.descr,
723 b.descr,
724 beta as *const T as *const c_void,
725 c.descr,
726 T::data_type(),
727 alg,
728 plan.raw,
729 size1,
730 buffer1,
731 ))
732}
733
734#[allow(clippy::too_many_arguments)]
736pub unsafe fn spgemm_compute<T: SparseScalar>(
737 handle: &Handle,
738 op_a: Op,
739 op_b: Op,
740 alpha: &T,
741 a: &SpMat<'_, T>,
742 b: &SpMat<'_, T>,
743 beta: &T,
744 c: &mut SpMat<'_, T>,
745 alg: SpGEMMAlg,
746 plan: &SpGEMMPlan,
747 size2: &mut usize,
748 buffer2: *mut c_void,
749) -> Result<()> {
750 let c_api = cusparse()?;
751 let cu = c_api.cusparse_spgemm_compute()?;
752 check(cu(
753 handle.as_raw(),
754 op_a.raw(),
755 op_b.raw(),
756 alpha as *const T as *const c_void,
757 a.descr,
758 b.descr,
759 beta as *const T as *const c_void,
760 c.descr,
761 T::data_type(),
762 alg,
763 plan.raw,
764 size2,
765 buffer2,
766 ))
767}
768
769#[allow(clippy::too_many_arguments)]
771pub fn spgemm_copy<T: SparseScalar>(
772 handle: &Handle,
773 op_a: Op,
774 op_b: Op,
775 alpha: &T,
776 a: &SpMat<'_, T>,
777 b: &SpMat<'_, T>,
778 beta: &T,
779 c: &mut SpMat<'_, T>,
780 alg: SpGEMMAlg,
781 plan: &SpGEMMPlan,
782) -> Result<()> {
783 let c_api = cusparse()?;
784 let cu = c_api.cusparse_spgemm_copy()?;
785 check(unsafe {
786 cu(
787 handle.as_raw(),
788 op_a.raw(),
789 op_b.raw(),
790 alpha as *const T as *const c_void,
791 a.descr,
792 b.descr,
793 beta as *const T as *const c_void,
794 c.descr,
795 T::data_type(),
796 alg,
797 plan.raw,
798 )
799 })
800}
801
802#[derive(Debug)]
805pub struct SpSVPlan {
806 raw: cusparseSpSVDescr_t,
807}
808
809impl SpSVPlan {
810 pub fn new() -> Result<Self> {
811 let c = cusparse()?;
812 let cu = c.cusparse_spsv_create_descr()?;
813 let mut d: cusparseSpSVDescr_t = core::ptr::null_mut();
814 check(unsafe { cu(&mut d) })?;
815 Ok(Self { raw: d })
816 }
817}
818
819impl Drop for SpSVPlan {
820 fn drop(&mut self) {
821 if let Ok(c) = cusparse() {
822 if let Ok(cu) = c.cusparse_spsv_destroy_descr() {
823 let _ = unsafe { cu(self.raw) };
824 }
825 }
826 }
827}
828
829#[allow(clippy::too_many_arguments)]
830pub fn spsv_buffer_size<T: SparseScalar>(
831 handle: &Handle,
832 op: Op,
833 alpha: &T,
834 a: &SpMat<'_, T>,
835 x: &DnVec<'_, T>,
836 y: &DnVec<'_, T>,
837 alg: SpSVAlg,
838 plan: &SpSVPlan,
839) -> Result<usize> {
840 let c = cusparse()?;
841 let cu = c.cusparse_spsv_buffer_size()?;
842 let mut size = 0usize;
843 check(unsafe {
844 cu(
845 handle.as_raw(),
846 op.raw(),
847 alpha as *const T as *const c_void,
848 a.descr,
849 x.descr,
850 y.descr,
851 T::data_type(),
852 alg,
853 plan.raw,
854 &mut size,
855 )
856 })?;
857 Ok(size)
858}
859
860#[allow(clippy::too_many_arguments)]
861pub fn spsv_analysis<T: SparseScalar>(
862 handle: &Handle,
863 op: Op,
864 alpha: &T,
865 a: &SpMat<'_, T>,
866 x: &DnVec<'_, T>,
867 y: &DnVec<'_, T>,
868 alg: SpSVAlg,
869 plan: &SpSVPlan,
870 workspace: &mut DeviceBuffer<u8>,
871) -> Result<()> {
872 let c = cusparse()?;
873 let cu = c.cusparse_spsv_analysis()?;
874 check(unsafe {
875 cu(
876 handle.as_raw(),
877 op.raw(),
878 alpha as *const T as *const c_void,
879 a.descr,
880 x.descr,
881 y.descr,
882 T::data_type(),
883 alg,
884 plan.raw,
885 workspace.as_raw().0 as *mut c_void,
886 )
887 })
888}
889
890#[allow(clippy::too_many_arguments)]
891pub fn spsv_solve<T: SparseScalar>(
892 handle: &Handle,
893 op: Op,
894 alpha: &T,
895 a: &SpMat<'_, T>,
896 x: &DnVec<'_, T>,
897 y: &mut DnVec<'_, T>,
898 alg: SpSVAlg,
899 plan: &SpSVPlan,
900) -> Result<()> {
901 let c = cusparse()?;
902 let cu = c.cusparse_spsv_solve()?;
903 check(unsafe {
904 cu(
905 handle.as_raw(),
906 op.raw(),
907 alpha as *const T as *const c_void,
908 a.descr,
909 x.descr,
910 y.descr,
911 T::data_type(),
912 alg,
913 plan.raw,
914 )
915 })
916}
917
918#[derive(Debug)]
919pub struct SpSMPlan {
920 raw: cusparseSpSMDescr_t,
921}
922
923impl SpSMPlan {
924 pub fn new() -> Result<Self> {
925 let c = cusparse()?;
926 let cu = c.cusparse_spsm_create_descr()?;
927 let mut d: cusparseSpSMDescr_t = core::ptr::null_mut();
928 check(unsafe { cu(&mut d) })?;
929 Ok(Self { raw: d })
930 }
931}
932
933impl Drop for SpSMPlan {
934 fn drop(&mut self) {
935 if let Ok(c) = cusparse() {
936 if let Ok(cu) = c.cusparse_spsm_destroy_descr() {
937 let _ = unsafe { cu(self.raw) };
938 }
939 }
940 }
941}
942
943#[allow(clippy::too_many_arguments)]
944pub fn spsm_buffer_size<T: SparseScalar>(
945 handle: &Handle,
946 op_a: Op,
947 op_b: Op,
948 alpha: &T,
949 a: &SpMat<'_, T>,
950 b: &DnMat<'_, T>,
951 c: &DnMat<'_, T>,
952 alg: SpSMAlg,
953 plan: &SpSMPlan,
954) -> Result<usize> {
955 let c_api = cusparse()?;
956 let cu = c_api.cusparse_spsm_buffer_size()?;
957 let mut size = 0usize;
958 check(unsafe {
959 cu(
960 handle.as_raw(),
961 op_a.raw(),
962 op_b.raw(),
963 alpha as *const T as *const c_void,
964 a.descr,
965 b.descr,
966 c.descr,
967 T::data_type(),
968 alg,
969 plan.raw,
970 &mut size,
971 )
972 })?;
973 Ok(size)
974}
975
976#[allow(clippy::too_many_arguments)]
977pub fn spsm_analysis<T: SparseScalar>(
978 handle: &Handle,
979 op_a: Op,
980 op_b: Op,
981 alpha: &T,
982 a: &SpMat<'_, T>,
983 b: &DnMat<'_, T>,
984 c: &DnMat<'_, T>,
985 alg: SpSMAlg,
986 plan: &SpSMPlan,
987 workspace: &mut DeviceBuffer<u8>,
988) -> Result<()> {
989 let c_api = cusparse()?;
990 let cu = c_api.cusparse_spsm_analysis()?;
991 check(unsafe {
992 cu(
993 handle.as_raw(),
994 op_a.raw(),
995 op_b.raw(),
996 alpha as *const T as *const c_void,
997 a.descr,
998 b.descr,
999 c.descr,
1000 T::data_type(),
1001 alg,
1002 plan.raw,
1003 workspace.as_raw().0 as *mut c_void,
1004 )
1005 })
1006}
1007
1008#[allow(clippy::too_many_arguments)]
1009pub fn spsm_solve<T: SparseScalar>(
1010 handle: &Handle,
1011 op_a: Op,
1012 op_b: Op,
1013 alpha: &T,
1014 a: &SpMat<'_, T>,
1015 b: &DnMat<'_, T>,
1016 c: &mut DnMat<'_, T>,
1017 alg: SpSMAlg,
1018 plan: &SpSMPlan,
1019) -> Result<()> {
1020 let c_api = cusparse()?;
1021 let cu = c_api.cusparse_spsm_solve()?;
1022 check(unsafe {
1023 cu(
1024 handle.as_raw(),
1025 op_a.raw(),
1026 op_b.raw(),
1027 alpha as *const T as *const c_void,
1028 a.descr,
1029 b.descr,
1030 c.descr,
1031 T::data_type(),
1032 alg,
1033 plan.raw,
1034 )
1035 })
1036}
1037
1038#[allow(clippy::too_many_arguments)]
1041pub fn sddmm_buffer_size<T: SparseScalar>(
1042 handle: &Handle,
1043 op_a: Op,
1044 op_b: Op,
1045 alpha: &T,
1046 a: &DnMat<'_, T>,
1047 b: &DnMat<'_, T>,
1048 beta: &T,
1049 c: &SpMat<'_, T>,
1050 alg: SDDMMAlg,
1051) -> Result<usize> {
1052 let c_api = cusparse()?;
1053 let cu = c_api.cusparse_sddmm_buffer_size()?;
1054 let mut size = 0usize;
1055 check(unsafe {
1056 cu(
1057 handle.as_raw(),
1058 op_a.raw(),
1059 op_b.raw(),
1060 alpha as *const T as *const c_void,
1061 a.descr,
1062 b.descr,
1063 beta as *const T as *const c_void,
1064 c.descr,
1065 T::data_type(),
1066 alg,
1067 &mut size,
1068 )
1069 })?;
1070 Ok(size)
1071}
1072
1073#[allow(clippy::too_many_arguments)]
1074pub fn sddmm<T: SparseScalar>(
1075 handle: &Handle,
1076 op_a: Op,
1077 op_b: Op,
1078 alpha: &T,
1079 a: &DnMat<'_, T>,
1080 b: &DnMat<'_, T>,
1081 beta: &T,
1082 c: &mut SpMat<'_, T>,
1083 alg: SDDMMAlg,
1084 workspace: &mut DeviceBuffer<u8>,
1085) -> Result<()> {
1086 let c_api = cusparse()?;
1087 let cu = c_api.cusparse_sddmm()?;
1088 check(unsafe {
1089 cu(
1090 handle.as_raw(),
1091 op_a.raw(),
1092 op_b.raw(),
1093 alpha as *const T as *const c_void,
1094 a.descr,
1095 b.descr,
1096 beta as *const T as *const c_void,
1097 c.descr,
1098 T::data_type(),
1099 alg,
1100 workspace.as_raw().0 as *mut c_void,
1101 )
1102 })
1103}
1104
1105pub fn sparse_to_dense_buffer_size<T: SparseScalar>(
1108 handle: &Handle,
1109 sp: &SpMat<'_, T>,
1110 dn: &DnMat<'_, T>,
1111) -> Result<usize> {
1112 let c = cusparse()?;
1113 let cu = c.cusparse_sparse_to_dense_buffer_size()?;
1114 let mut size = 0usize;
1115 check(unsafe { cu(handle.as_raw(), sp.descr, dn.descr, 0, &mut size) })?;
1116 Ok(size)
1117}
1118
1119pub fn sparse_to_dense<T: SparseScalar>(
1120 handle: &Handle,
1121 sp: &SpMat<'_, T>,
1122 dn: &mut DnMat<'_, T>,
1123 workspace: &mut DeviceBuffer<u8>,
1124) -> Result<()> {
1125 let c = cusparse()?;
1126 let cu = c.cusparse_sparse_to_dense()?;
1127 check(unsafe {
1128 cu(
1129 handle.as_raw(),
1130 sp.descr,
1131 dn.descr,
1132 0,
1133 workspace.as_raw().0 as *mut c_void,
1134 )
1135 })
1136}
1137
1138pub fn dense_to_sparse_buffer_size<T: SparseScalar>(
1139 handle: &Handle,
1140 dn: &DnMat<'_, T>,
1141 sp: &SpMat<'_, T>,
1142) -> Result<usize> {
1143 let c = cusparse()?;
1144 let cu = c.cusparse_dense_to_sparse_buffer_size()?;
1145 let mut size = 0usize;
1146 check(unsafe { cu(handle.as_raw(), dn.descr, sp.descr, 0, &mut size) })?;
1147 Ok(size)
1148}
1149
1150pub fn dense_to_sparse_analysis<T: SparseScalar>(
1151 handle: &Handle,
1152 dn: &DnMat<'_, T>,
1153 sp: &SpMat<'_, T>,
1154 workspace: &mut DeviceBuffer<u8>,
1155) -> Result<()> {
1156 let c = cusparse()?;
1157 let cu = c.cusparse_dense_to_sparse_analysis()?;
1158 check(unsafe {
1159 cu(
1160 handle.as_raw(),
1161 dn.descr,
1162 sp.descr,
1163 0,
1164 workspace.as_raw().0 as *mut c_void,
1165 )
1166 })
1167}
1168
1169pub fn dense_to_sparse_convert<T: SparseScalar>(
1170 handle: &Handle,
1171 dn: &DnMat<'_, T>,
1172 sp: &mut SpMat<'_, T>,
1173 workspace: &mut DeviceBuffer<u8>,
1174) -> Result<()> {
1175 let c = cusparse()?;
1176 let cu = c.cusparse_dense_to_sparse_convert()?;
1177 check(unsafe {
1178 cu(
1179 handle.as_raw(),
1180 dn.descr,
1181 sp.descr,
1182 0,
1183 workspace.as_raw().0 as *mut c_void,
1184 )
1185 })
1186}
1187
1188pub fn axpby<T: SparseScalar>(
1191 handle: &Handle,
1192 alpha: &T,
1193 x: &DnVec<'_, T>,
1194 beta: &T,
1195 y: &mut DnVec<'_, T>,
1196) -> Result<()> {
1197 let c = cusparse()?;
1198 let cu = c.cusparse_axpby()?;
1199 check(unsafe {
1200 cu(
1201 handle.as_raw(),
1202 alpha as *const T as *const c_void,
1203 x.descr,
1204 beta as *const T as *const c_void,
1205 y.descr,
1206 )
1207 })
1208}
1209
1210pub fn gather<T: SparseScalar>(
1211 handle: &Handle,
1212 y: &DnVec<'_, T>,
1213 x: &mut DnVec<'_, T>,
1214) -> Result<()> {
1215 let c = cusparse()?;
1216 let cu = c.cusparse_gather()?;
1217 check(unsafe { cu(handle.as_raw(), y.descr, x.descr) })
1218}
1219
1220pub fn scatter<T: SparseScalar>(
1221 handle: &Handle,
1222 x: &DnVec<'_, T>,
1223 y: &mut DnVec<'_, T>,
1224) -> Result<()> {
1225 let c = cusparse()?;
1226 let cu = c.cusparse_scatter()?;
1227 check(unsafe { cu(handle.as_raw(), x.descr, y.descr) })
1228}
1229
1230pub fn rot<T: SparseScalar>(
1231 handle: &Handle,
1232 c_cos: &T,
1233 s_sin: &T,
1234 x: &mut DnVec<'_, T>,
1235 y: &mut DnVec<'_, T>,
1236) -> Result<()> {
1237 let c_api = cusparse()?;
1238 let cu = c_api.cusparse_rot()?;
1239 check(unsafe {
1240 cu(
1241 handle.as_raw(),
1242 c_cos as *const T as *const c_void,
1243 s_sin as *const T as *const c_void,
1244 x.descr,
1245 y.descr,
1246 )
1247 })
1248}
1249
1250pub type CsrMatrix<'buf> = SpMat<'buf, f32>;
1254pub type DenseVector<'buf, T> = DnVec<'buf, T>;