1#![warn(missing_debug_implementations)]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5#![allow(clippy::not_unsafe_ptr_arg_deref)]
12#![allow(clippy::type_complexity)]
16
17use std::ffi::CString;
18use std::marker::PhantomData;
19
20pub use aocl_error::{Error, Result};
21use aocl_sparse_sys as sys;
22use aocl_types::sealed::Sealed;
23pub use aocl_types::{Complex32, Complex64, Trans};
24
25pub mod complex;
26
27fn trans_raw(t: Trans) -> sys::aoclsparse_operation {
28 match t {
29 Trans::No => sys::aoclsparse_operation__aoclsparse_operation_none,
30 Trans::T => sys::aoclsparse_operation__aoclsparse_operation_transpose,
31 Trans::C => sys::aoclsparse_operation__aoclsparse_operation_conjugate_transpose,
32 }
33}
34
35fn check_status(component: &'static str, status: sys::aoclsparse_status) -> Result<()> {
36 if status == sys::aoclsparse_status__aoclsparse_status_success {
37 return Ok(());
38 }
39 let message = match status {
40 s if s == sys::aoclsparse_status__aoclsparse_status_not_implemented => "not implemented",
41 s if s == sys::aoclsparse_status__aoclsparse_status_invalid_pointer => "invalid pointer",
42 s if s == sys::aoclsparse_status__aoclsparse_status_invalid_size => "invalid size",
43 s if s == sys::aoclsparse_status__aoclsparse_status_internal_error => "internal error",
44 s if s == sys::aoclsparse_status__aoclsparse_status_invalid_value => "invalid value",
45 s if s == sys::aoclsparse_status__aoclsparse_status_invalid_index_value => {
46 "invalid index value"
47 }
48 s if s == sys::aoclsparse_status__aoclsparse_status_maxit => "max iterations reached",
49 s if s == sys::aoclsparse_status__aoclsparse_status_user_stop => "user stop",
50 s if s == sys::aoclsparse_status__aoclsparse_status_wrong_type => "wrong type",
51 s if s == sys::aoclsparse_status__aoclsparse_status_memory_error => "memory error",
52 _ => "unknown sparse status",
53 }
54 .to_string();
55 Err(Error::Status {
56 component,
57 code: status as i64,
58 message,
59 })
60}
61
62pub struct MatDescr {
64 raw: sys::aoclsparse_mat_descr,
65}
66
67impl MatDescr {
68 pub fn new() -> Result<Self> {
70 let mut raw: sys::aoclsparse_mat_descr = std::ptr::null_mut();
71 let status = unsafe { sys::aoclsparse_create_mat_descr(&mut raw) };
72 check_status("sparse", status)?;
73 if raw.is_null() {
74 return Err(Error::AllocationFailed("sparse"));
75 }
76 Ok(MatDescr { raw })
77 }
78
79 pub fn as_raw(&self) -> sys::aoclsparse_mat_descr {
85 self.raw
86 }
87
88 pub fn set_type(&mut self, ty: MatType) -> Result<()> {
91 let status = unsafe { sys::aoclsparse_set_mat_type(self.raw, ty.raw()) };
92 check_status("sparse", status)
93 }
94
95 pub fn set_index_base(&mut self, base: IndexBase) -> Result<()> {
97 let status = unsafe { sys::aoclsparse_set_mat_index_base(self.raw, base.raw()) };
98 check_status("sparse", status)
99 }
100
101 pub fn set_fill_mode(&mut self, fill: FillMode) -> Result<()> {
104 let status = unsafe { sys::aoclsparse_set_mat_fill_mode(self.raw, fill.raw()) };
105 check_status("sparse", status)
106 }
107
108 pub fn set_diag_type(&mut self, diag: DiagType) -> Result<()> {
111 let status = unsafe { sys::aoclsparse_set_mat_diag_type(self.raw, diag.raw()) };
112 check_status("sparse", status)
113 }
114
115 pub fn ty(&self) -> MatType {
117 let raw = unsafe { sys::aoclsparse_get_mat_type(self.raw) };
118 MatType::from_raw(raw).unwrap_or(MatType::General)
119 }
120
121 pub fn index_base(&self) -> IndexBase {
123 let raw = unsafe { sys::aoclsparse_get_mat_index_base(self.raw) };
124 if raw == sys::aoclsparse_index_base__aoclsparse_index_base_one {
125 IndexBase::One
126 } else {
127 IndexBase::Zero
128 }
129 }
130
131 pub fn fill_mode(&self) -> FillMode {
133 let raw = unsafe { sys::aoclsparse_get_mat_fill_mode(self.raw) };
134 FillMode::from_raw(raw).unwrap_or(FillMode::Lower)
135 }
136
137 pub fn diag_type(&self) -> DiagType {
139 let raw = unsafe { sys::aoclsparse_get_mat_diag_type(self.raw) };
140 DiagType::from_raw(raw).unwrap_or(DiagType::NonUnit)
141 }
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
146pub enum MatType {
147 General,
148 Symmetric,
149 Hermitian,
150 Triangular,
151}
152
153impl MatType {
154 fn raw(self) -> sys::aoclsparse_matrix_type {
155 match self {
156 MatType::General => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_general,
157 MatType::Symmetric => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric,
158 MatType::Hermitian => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_hermitian,
159 MatType::Triangular => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_triangular,
160 }
161 }
162
163 fn from_raw(raw: sys::aoclsparse_matrix_type) -> Option<Self> {
164 Some(match raw {
165 r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_general => {
166 MatType::General
167 }
168 r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric => {
169 MatType::Symmetric
170 }
171 r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_hermitian => {
172 MatType::Hermitian
173 }
174 r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_triangular => {
175 MatType::Triangular
176 }
177 _ => return None,
178 })
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
184pub enum FillMode {
185 Lower,
186 Upper,
187}
188
189impl FillMode {
190 fn raw(self) -> sys::aoclsparse_fill_mode {
191 match self {
192 FillMode::Lower => sys::aoclsparse_fill_mode__aoclsparse_fill_mode_lower,
193 FillMode::Upper => sys::aoclsparse_fill_mode__aoclsparse_fill_mode_upper,
194 }
195 }
196 fn from_raw(raw: sys::aoclsparse_fill_mode) -> Option<Self> {
197 Some(match raw {
198 r if r == sys::aoclsparse_fill_mode__aoclsparse_fill_mode_lower => FillMode::Lower,
199 r if r == sys::aoclsparse_fill_mode__aoclsparse_fill_mode_upper => FillMode::Upper,
200 _ => return None,
201 })
202 }
203}
204
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
208pub enum DiagType {
209 Unit,
210 NonUnit,
211}
212
213impl DiagType {
214 fn raw(self) -> sys::aoclsparse_diag_type {
215 match self {
216 DiagType::Unit => sys::aoclsparse_diag_type__aoclsparse_diag_type_unit,
217 DiagType::NonUnit => sys::aoclsparse_diag_type__aoclsparse_diag_type_non_unit,
218 }
219 }
220 fn from_raw(raw: sys::aoclsparse_diag_type) -> Option<Self> {
221 Some(match raw {
222 r if r == sys::aoclsparse_diag_type__aoclsparse_diag_type_unit => DiagType::Unit,
223 r if r == sys::aoclsparse_diag_type__aoclsparse_diag_type_non_unit => DiagType::NonUnit,
224 _ => return None,
225 })
226 }
227}
228
229pub fn copy_mat_descr(src: &MatDescr) -> Result<MatDescr> {
231 let dest = MatDescr::new()?;
232 let status = unsafe { sys::aoclsparse_copy_mat_descr(dest.raw, src.raw) };
233 check_status("sparse", status)?;
234 Ok(dest)
235}
236
237pub fn optimize<T: Scalar>(mat: &mut SparseMatrix<T>) -> Result<()> {
241 let status = unsafe { sys::aoclsparse_optimize(mat.as_raw()) };
242 check_status("sparse", status)
243}
244
245pub fn version() -> &'static str {
247 unsafe {
248 let p = sys::aoclsparse_get_version();
249 if p.is_null() {
250 return "";
251 }
252 std::ffi::CStr::from_ptr(p).to_str().unwrap_or("")
253 }
254}
255
256impl Drop for MatDescr {
257 fn drop(&mut self) {
258 if !self.raw.is_null() {
259 unsafe {
260 let _ = sys::aoclsparse_destroy_mat_descr(self.raw);
261 }
262 self.raw = std::ptr::null_mut();
263 }
264 }
265}
266
267impl std::fmt::Debug for MatDescr {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 f.debug_struct("MatDescr").finish_non_exhaustive()
270 }
271}
272
273pub trait Scalar: Copy + Sized + Sealed {
275 #[allow(clippy::too_many_arguments)]
277 fn csrmv(
278 op: Trans,
279 alpha: Self,
280 m: usize,
281 n: usize,
282 csr_val: &[Self],
283 csr_col_ind: &[sys::aoclsparse_int],
284 csr_row_ptr: &[sys::aoclsparse_int],
285 descr: &MatDescr,
286 x: &[Self],
287 beta: Self,
288 y: &mut [Self],
289 ) -> Result<()>;
290
291 fn axpyi(alpha: Self, x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()>;
293
294 fn gthr(y: &[Self], indx: &[sys::aoclsparse_int], x: &mut [Self]) -> Result<()>;
296
297 fn sctr(x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()>;
299
300 #[allow(clippy::too_many_arguments)]
302 fn csrsv(
303 op: Trans,
304 alpha: Self,
305 m: usize,
306 csr_val: &[Self],
307 csr_col_ind: &[sys::aoclsparse_int],
308 csr_row_ptr: &[sys::aoclsparse_int],
309 descr: &MatDescr,
310 x: &[Self],
311 y: &mut [Self],
312 ) -> Result<()>;
313
314 #[allow(clippy::too_many_arguments)]
316 fn csr_to_dense(
317 m: usize,
318 n: usize,
319 descr: &MatDescr,
320 csr_val: &[Self],
321 csr_row_ptr: &[sys::aoclsparse_int],
322 csr_col_ind: &[sys::aoclsparse_int],
323 a: &mut [Self],
324 ld: usize,
325 order: Order,
326 ) -> Result<()>;
327
328 #[allow(clippy::too_many_arguments)]
330 fn csr_to_csc(
331 m: usize,
332 n: usize,
333 descr: &MatDescr,
334 base_csc: IndexBase,
335 csr_row_ptr: &[sys::aoclsparse_int],
336 csr_col_ind: &[sys::aoclsparse_int],
337 csr_val: &[Self],
338 csc_row_ind: &mut [sys::aoclsparse_int],
339 csc_col_ptr: &mut [sys::aoclsparse_int],
340 csc_val: &mut [Self],
341 ) -> Result<()>;
342
343 #[allow(clippy::too_many_arguments)]
346 fn ellmv(
347 op: Trans,
348 alpha: Self,
349 m: usize,
350 n: usize,
351 ell_val: &[Self],
352 ell_col_ind: &[sys::aoclsparse_int],
353 ell_width: usize,
354 descr: &MatDescr,
355 x: &[Self],
356 beta: Self,
357 y: &mut [Self],
358 ) -> Result<()>;
359
360 #[allow(clippy::too_many_arguments)]
364 fn bsrmv(
365 op: Trans,
366 alpha: Self,
367 mb: usize,
368 nb: usize,
369 bsr_dim: usize,
370 bsr_val: &[Self],
371 bsr_col_ind: &[sys::aoclsparse_int],
372 bsr_row_ptr: &[sys::aoclsparse_int],
373 descr: &MatDescr,
374 x: &[Self],
375 beta: Self,
376 y: &mut [Self],
377 ) -> Result<()>;
378
379 #[allow(clippy::too_many_arguments)]
383 fn create_csr(
384 base: IndexBase,
385 m: usize,
386 n: usize,
387 nnz: usize,
388 row_ptr: *mut sys::aoclsparse_int,
389 col_idx: *mut sys::aoclsparse_int,
390 val: *mut Self,
391 ) -> Result<sys::aoclsparse_matrix>;
392
393 fn export_csr(
396 mat: sys::aoclsparse_matrix,
397 ) -> Result<(
398 IndexBase,
399 usize,
400 usize,
401 usize,
402 *mut sys::aoclsparse_int,
403 *mut sys::aoclsparse_int,
404 *mut Self,
405 )>;
406
407 fn ilu_smoother(
409 op: Trans,
410 a: sys::aoclsparse_matrix,
411 descr: &MatDescr,
412 x: &mut [Self],
413 b: &[Self],
414 ) -> Result<()>;
415
416 fn itsol_init(handle: &mut sys::aoclsparse_itsol_handle) -> Result<()>;
418
419 #[allow(clippy::too_many_arguments)]
421 fn itsol_solve(
422 handle: sys::aoclsparse_itsol_handle,
423 n: usize,
424 mat: sys::aoclsparse_matrix,
425 descr: &MatDescr,
426 b: &[Self],
427 x: &mut [Self],
428 rinfo: &mut [Self; 100],
429 ) -> Result<()>;
430
431 #[allow(clippy::too_many_arguments)]
439 unsafe fn csr2m_ffi(
440 op_a: sys::aoclsparse_operation,
441 descr_a: sys::aoclsparse_mat_descr,
442 a: sys::aoclsparse_matrix,
443 op_b: sys::aoclsparse_operation,
444 descr_b: sys::aoclsparse_mat_descr,
445 b: sys::aoclsparse_matrix,
446 request: sys::aoclsparse_request,
447 out: *mut sys::aoclsparse_matrix,
448 ) -> sys::aoclsparse_status;
449
450 #[allow(clippy::too_many_arguments)]
453 fn csrmm(
454 op: Trans,
455 alpha: Self,
456 a: sys::aoclsparse_matrix,
457 descr: &MatDescr,
458 order: Order,
459 b: &[Self],
460 n: usize,
461 ldb: usize,
462 beta: Self,
463 c: &mut [Self],
464 ldc: usize,
465 ) -> Result<()>;
466
467 #[allow(clippy::too_many_arguments)]
470 fn spmmd(
471 op: Trans,
472 a: sys::aoclsparse_matrix,
473 b: sys::aoclsparse_matrix,
474 layout: Order,
475 c: &mut [Self],
476 ldc: usize,
477 ) -> Result<()>;
478
479 #[allow(clippy::too_many_arguments)]
482 fn sp2md(
483 op_a: Trans,
484 descr_a: &MatDescr,
485 a: sys::aoclsparse_matrix,
486 op_b: Trans,
487 descr_b: &MatDescr,
488 b: sys::aoclsparse_matrix,
489 alpha: Self,
490 beta: Self,
491 c: &mut [Self],
492 layout: Order,
493 ldc: usize,
494 ) -> Result<()>;
495
496 unsafe fn add_ffi(
502 op: sys::aoclsparse_operation,
503 a: sys::aoclsparse_matrix,
504 alpha: Self,
505 b: sys::aoclsparse_matrix,
506 out: *mut sys::aoclsparse_matrix,
507 ) -> sys::aoclsparse_status;
508
509 #[allow(clippy::too_many_arguments)]
512 fn sorv(
513 sor_type: SorType,
514 descr: &MatDescr,
515 a: sys::aoclsparse_matrix,
516 omega: Self,
517 alpha: Self,
518 x: &mut [Self],
519 b: &[Self],
520 ) -> Result<()>;
521}
522
523#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
525pub enum SorType {
526 Forward,
528 Backward,
530 Symmetric,
532}
533
534impl SorType {
535 pub(crate) fn raw(self) -> sys::aoclsparse_sor_type {
536 match self {
537 SorType::Forward => sys::aoclsparse_sor_type__aoclsparse_sor_forward,
538 SorType::Backward => sys::aoclsparse_sor_type__aoclsparse_sor_backward,
539 SorType::Symmetric => sys::aoclsparse_sor_type__aoclsparse_sor_symmetric,
540 }
541 }
542}
543
544macro_rules! impl_scalar {
545 (
546 $t:ty,
547 csrmv = $csrmv:ident,
548 axpyi = $axpyi:ident,
549 gthr = $gthr:ident,
550 sctr = $sctr:ident,
551 csrsv = $csrsv:ident,
552 csr2dense = $csr2dense:ident,
553 csr2csc = $csr2csc:ident,
554 ellmv = $ellmv:ident,
555 bsrmv = $bsrmv:ident,
556 create_csr = $create_csr:ident,
557 export_csr = $export_csr:ident,
558 ilu_smoother = $ilu_smoother:ident,
559 itsol_init = $itsol_init:ident,
560 itsol_solve = $itsol_solve:ident,
561 csr2m = $csr2m:ident,
562 csrmm = $csrmm:ident,
563 spmmd = $spmmd:ident,
564 sp2md = $sp2md:ident,
565 add = $add:ident,
566 sorv = $sorv:ident
567 ) => {
568 impl Scalar for $t {
569 fn csrmv(
570 op: Trans,
571 alpha: Self,
572 m: usize,
573 n: usize,
574 csr_val: &[Self],
575 csr_col_ind: &[sys::aoclsparse_int],
576 csr_row_ptr: &[sys::aoclsparse_int],
577 descr: &MatDescr,
578 x: &[Self],
579 beta: Self,
580 y: &mut [Self],
581 ) -> Result<()> {
582 if csr_row_ptr.len() != m + 1 {
583 return Err(Error::InvalidArgument(format!(
584 "csrmv: csr_row_ptr length {} != m+1 = {}",
585 csr_row_ptr.len(),
586 m + 1
587 )));
588 }
589 let nnz = csr_val.len();
590 if csr_col_ind.len() != nnz {
591 return Err(Error::InvalidArgument(format!(
592 "csrmv: csr_col_ind length {} != csr_val length {}",
593 csr_col_ind.len(),
594 nnz
595 )));
596 }
597 let (x_len, y_len) = match op {
598 Trans::No => (n, m),
599 Trans::T | Trans::C => (m, n),
600 };
601 if x.len() < x_len {
602 return Err(Error::InvalidArgument(format!(
603 "csrmv: x length {} < expected {x_len}",
604 x.len()
605 )));
606 }
607 if y.len() < y_len {
608 return Err(Error::InvalidArgument(format!(
609 "csrmv: y length {} < expected {y_len}",
610 y.len()
611 )));
612 }
613
614 let status = unsafe {
615 sys::$csrmv(
616 trans_raw(op),
617 &alpha,
618 m as sys::aoclsparse_int,
619 n as sys::aoclsparse_int,
620 nnz as sys::aoclsparse_int,
621 csr_val.as_ptr(),
622 csr_col_ind.as_ptr(),
623 csr_row_ptr.as_ptr(),
624 descr.as_raw(),
625 x.as_ptr(),
626 &beta,
627 y.as_mut_ptr(),
628 )
629 };
630 check_status("sparse", status)
631 }
632
633 fn axpyi(
634 alpha: Self,
635 x: &[Self],
636 indx: &[sys::aoclsparse_int],
637 y: &mut [Self],
638 ) -> Result<()> {
639 let status = unsafe {
640 sys::$axpyi(
641 x.len() as sys::aoclsparse_int,
642 alpha,
643 x.as_ptr(),
644 indx.as_ptr(),
645 y.as_mut_ptr(),
646 )
647 };
648 check_status("sparse", status)
649 }
650
651 fn gthr(y: &[Self], indx: &[sys::aoclsparse_int], x: &mut [Self]) -> Result<()> {
652 let status = unsafe {
653 sys::$gthr(
654 x.len() as sys::aoclsparse_int,
655 y.as_ptr(),
656 x.as_mut_ptr(),
657 indx.as_ptr(),
658 )
659 };
660 check_status("sparse", status)
661 }
662
663 fn sctr(x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()> {
664 let status = unsafe {
665 sys::$sctr(
666 x.len() as sys::aoclsparse_int,
667 x.as_ptr(),
668 indx.as_ptr(),
669 y.as_mut_ptr(),
670 )
671 };
672 check_status("sparse", status)
673 }
674
675 #[allow(clippy::too_many_arguments)]
676 fn csrsv(
677 op: Trans,
678 alpha: Self,
679 m: usize,
680 csr_val: &[Self],
681 csr_col_ind: &[sys::aoclsparse_int],
682 csr_row_ptr: &[sys::aoclsparse_int],
683 descr: &MatDescr,
684 x: &[Self],
685 y: &mut [Self],
686 ) -> Result<()> {
687 if csr_row_ptr.len() != m + 1 {
688 return Err(Error::InvalidArgument(format!(
689 "csrsv: csr_row_ptr length {} != m+1 = {}",
690 csr_row_ptr.len(),
691 m + 1
692 )));
693 }
694 if x.len() < m || y.len() < m {
695 return Err(Error::InvalidArgument(format!(
696 "csrsv: x.len()={}, y.len()={}, m={m}",
697 x.len(),
698 y.len()
699 )));
700 }
701 let status = unsafe {
702 sys::$csrsv(
703 trans_raw(op),
704 &alpha,
705 m as sys::aoclsparse_int,
706 csr_val.as_ptr(),
707 csr_col_ind.as_ptr(),
708 csr_row_ptr.as_ptr(),
709 descr.as_raw(),
710 x.as_ptr(),
711 y.as_mut_ptr(),
712 )
713 };
714 check_status("sparse", status)
715 }
716
717 #[allow(clippy::too_many_arguments)]
718 fn csr_to_dense(
719 m: usize,
720 n: usize,
721 descr: &MatDescr,
722 csr_val: &[Self],
723 csr_row_ptr: &[sys::aoclsparse_int],
724 csr_col_ind: &[sys::aoclsparse_int],
725 a: &mut [Self],
726 ld: usize,
727 order: Order,
728 ) -> Result<()> {
729 let status = unsafe {
730 sys::$csr2dense(
731 m as sys::aoclsparse_int,
732 n as sys::aoclsparse_int,
733 descr.as_raw(),
734 csr_val.as_ptr(),
735 csr_row_ptr.as_ptr(),
736 csr_col_ind.as_ptr(),
737 a.as_mut_ptr(),
738 ld as sys::aoclsparse_int,
739 order.raw(),
740 )
741 };
742 check_status("sparse", status)
743 }
744
745 #[allow(clippy::too_many_arguments)]
746 fn csr_to_csc(
747 m: usize,
748 n: usize,
749 descr: &MatDescr,
750 base_csc: IndexBase,
751 csr_row_ptr: &[sys::aoclsparse_int],
752 csr_col_ind: &[sys::aoclsparse_int],
753 csr_val: &[Self],
754 csc_row_ind: &mut [sys::aoclsparse_int],
755 csc_col_ptr: &mut [sys::aoclsparse_int],
756 csc_val: &mut [Self],
757 ) -> Result<()> {
758 let nnz = csr_val.len();
759 let status = unsafe {
760 sys::$csr2csc(
761 m as sys::aoclsparse_int,
762 n as sys::aoclsparse_int,
763 nnz as sys::aoclsparse_int,
764 descr.as_raw(),
765 base_csc.raw(),
766 csr_row_ptr.as_ptr(),
767 csr_col_ind.as_ptr(),
768 csr_val.as_ptr(),
769 csc_row_ind.as_mut_ptr(),
770 csc_col_ptr.as_mut_ptr(),
771 csc_val.as_mut_ptr(),
772 )
773 };
774 check_status("sparse", status)
775 }
776
777 #[allow(clippy::too_many_arguments)]
778 fn ellmv(
779 op: Trans,
780 alpha: Self,
781 m: usize,
782 n: usize,
783 ell_val: &[Self],
784 ell_col_ind: &[sys::aoclsparse_int],
785 ell_width: usize,
786 descr: &MatDescr,
787 x: &[Self],
788 beta: Self,
789 y: &mut [Self],
790 ) -> Result<()> {
791 let nnz = ell_val.len();
792 if ell_col_ind.len() != nnz {
793 return Err(Error::InvalidArgument(format!(
794 "ellmv: ell_col_ind length {} != ell_val length {nnz}",
795 ell_col_ind.len()
796 )));
797 }
798 let needed = m.checked_mul(ell_width).ok_or_else(|| {
799 Error::InvalidArgument("ellmv: m * ell_width overflows".into())
800 })?;
801 if nnz < needed {
802 return Err(Error::InvalidArgument(format!(
803 "ellmv: ell_val length {nnz} < m*ell_width = {needed}"
804 )));
805 }
806 let (x_len, y_len) = match op {
807 Trans::No => (n, m),
808 Trans::T | Trans::C => (m, n),
809 };
810 if x.len() < x_len || y.len() < y_len {
811 return Err(Error::InvalidArgument(format!(
812 "ellmv: x.len()={}, y.len()={}, expected ({x_len}, {y_len})",
813 x.len(),
814 y.len()
815 )));
816 }
817 let status = unsafe {
818 sys::$ellmv(
819 trans_raw(op),
820 &alpha,
821 m as sys::aoclsparse_int,
822 n as sys::aoclsparse_int,
823 nnz as sys::aoclsparse_int,
824 ell_val.as_ptr(),
825 ell_col_ind.as_ptr(),
826 ell_width as sys::aoclsparse_int,
827 descr.as_raw(),
828 x.as_ptr(),
829 &beta,
830 y.as_mut_ptr(),
831 )
832 };
833 check_status("sparse", status)
834 }
835
836 #[allow(clippy::too_many_arguments)]
837 fn bsrmv(
838 op: Trans,
839 alpha: Self,
840 mb: usize,
841 nb: usize,
842 bsr_dim: usize,
843 bsr_val: &[Self],
844 bsr_col_ind: &[sys::aoclsparse_int],
845 bsr_row_ptr: &[sys::aoclsparse_int],
846 descr: &MatDescr,
847 x: &[Self],
848 beta: Self,
849 y: &mut [Self],
850 ) -> Result<()> {
851 if bsr_row_ptr.len() != mb + 1 {
852 return Err(Error::InvalidArgument(format!(
853 "bsrmv: bsr_row_ptr length {} != mb+1 = {}",
854 bsr_row_ptr.len(),
855 mb + 1
856 )));
857 }
858 let block_area = bsr_dim.checked_mul(bsr_dim).ok_or_else(|| {
859 Error::InvalidArgument("bsrmv: bsr_dim*bsr_dim overflows".into())
860 })?;
861 let nnzb = bsr_col_ind.len();
862 if bsr_val.len() < nnzb * block_area {
863 return Err(Error::InvalidArgument(format!(
864 "bsrmv: bsr_val length {} < nnzb*bsr_dim^2 = {}",
865 bsr_val.len(),
866 nnzb * block_area
867 )));
868 }
869 let (x_len, y_len) = match op {
870 Trans::No => (nb * bsr_dim, mb * bsr_dim),
871 Trans::T | Trans::C => (mb * bsr_dim, nb * bsr_dim),
872 };
873 if x.len() < x_len || y.len() < y_len {
874 return Err(Error::InvalidArgument(format!(
875 "bsrmv: x.len()={}, y.len()={}, expected ({x_len}, {y_len})",
876 x.len(),
877 y.len()
878 )));
879 }
880 let status = unsafe {
881 sys::$bsrmv(
882 trans_raw(op),
883 &alpha,
884 mb as sys::aoclsparse_int,
885 nb as sys::aoclsparse_int,
886 bsr_dim as sys::aoclsparse_int,
887 bsr_val.as_ptr(),
888 bsr_col_ind.as_ptr(),
889 bsr_row_ptr.as_ptr(),
890 descr.as_raw(),
891 x.as_ptr(),
892 &beta,
893 y.as_mut_ptr(),
894 )
895 };
896 check_status("sparse", status)
897 }
898
899 fn create_csr(
900 base: IndexBase,
901 m: usize,
902 n: usize,
903 nnz: usize,
904 row_ptr: *mut sys::aoclsparse_int,
905 col_idx: *mut sys::aoclsparse_int,
906 val: *mut Self,
907 ) -> Result<sys::aoclsparse_matrix> {
908 let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
909 let status = unsafe {
910 sys::$create_csr(
911 &mut raw,
912 base.raw(),
913 m as sys::aoclsparse_int,
914 n as sys::aoclsparse_int,
915 nnz as sys::aoclsparse_int,
916 row_ptr,
917 col_idx,
918 val,
919 )
920 };
921 check_status("sparse", status)?;
922 if raw.is_null() {
923 return Err(Error::AllocationFailed("sparse"));
924 }
925 Ok(raw)
926 }
927
928 fn export_csr(
929 mat: sys::aoclsparse_matrix,
930 ) -> Result<(
931 IndexBase,
932 usize,
933 usize,
934 usize,
935 *mut sys::aoclsparse_int,
936 *mut sys::aoclsparse_int,
937 *mut Self,
938 )> {
939 let mut base: sys::aoclsparse_index_base = 0;
940 let mut m: sys::aoclsparse_int = 0;
941 let mut n: sys::aoclsparse_int = 0;
942 let mut nnz: sys::aoclsparse_int = 0;
943 let mut row_ptr: *mut sys::aoclsparse_int = std::ptr::null_mut();
944 let mut col_ind: *mut sys::aoclsparse_int = std::ptr::null_mut();
945 let mut val: *mut Self = std::ptr::null_mut();
946 let status = unsafe {
947 sys::$export_csr(
948 mat,
949 &mut base,
950 &mut m,
951 &mut n,
952 &mut nnz,
953 &mut row_ptr,
954 &mut col_ind,
955 &mut val,
956 )
957 };
958 check_status("sparse", status)?;
959 let base_e = if base == sys::aoclsparse_index_base__aoclsparse_index_base_one {
960 IndexBase::One
961 } else {
962 IndexBase::Zero
963 };
964 Ok((
965 base_e,
966 m as usize,
967 n as usize,
968 nnz as usize,
969 row_ptr,
970 col_ind,
971 val,
972 ))
973 }
974
975 fn ilu_smoother(
976 op: Trans,
977 a: sys::aoclsparse_matrix,
978 descr: &MatDescr,
979 x: &mut [Self],
980 b: &[Self],
981 ) -> Result<()> {
982 let mut precond_csr_val: *mut Self = std::ptr::null_mut();
983 let status = unsafe {
984 sys::$ilu_smoother(
985 trans_raw(op),
986 a,
987 descr.as_raw(),
988 &mut precond_csr_val,
989 std::ptr::null(),
990 x.as_mut_ptr(),
991 b.as_ptr(),
992 )
993 };
994 check_status("sparse", status)
995 }
996
997 fn itsol_init(handle: &mut sys::aoclsparse_itsol_handle) -> Result<()> {
998 let status = unsafe { sys::$itsol_init(handle) };
999 check_status("sparse", status)
1000 }
1001
1002 fn itsol_solve(
1003 handle: sys::aoclsparse_itsol_handle,
1004 n: usize,
1005 mat: sys::aoclsparse_matrix,
1006 descr: &MatDescr,
1007 b: &[Self],
1008 x: &mut [Self],
1009 rinfo: &mut [Self; 100],
1010 ) -> Result<()> {
1011 if b.len() < n || x.len() < n {
1012 return Err(Error::InvalidArgument(format!(
1013 "itsol_solve: b.len()={}, x.len()={}, n={n}",
1014 b.len(),
1015 x.len()
1016 )));
1017 }
1018 let status = unsafe {
1019 sys::$itsol_solve(
1020 handle,
1021 n as sys::aoclsparse_int,
1022 mat,
1023 descr.as_raw(),
1024 b.as_ptr(),
1025 x.as_mut_ptr(),
1026 rinfo.as_mut_ptr(),
1027 None,
1028 None,
1029 std::ptr::null_mut(),
1030 )
1031 };
1032 check_status("sparse", status)
1033 }
1034
1035 unsafe fn csr2m_ffi(
1036 op_a: sys::aoclsparse_operation,
1037 descr_a: sys::aoclsparse_mat_descr,
1038 a: sys::aoclsparse_matrix,
1039 op_b: sys::aoclsparse_operation,
1040 descr_b: sys::aoclsparse_mat_descr,
1041 b: sys::aoclsparse_matrix,
1042 request: sys::aoclsparse_request,
1043 out: *mut sys::aoclsparse_matrix,
1044 ) -> sys::aoclsparse_status {
1045 sys::$csr2m(op_a, descr_a, a, op_b, descr_b, b, request, out)
1046 }
1047
1048 #[allow(clippy::too_many_arguments)]
1049 fn csrmm(
1050 op: Trans,
1051 alpha: Self,
1052 a: sys::aoclsparse_matrix,
1053 descr: &MatDescr,
1054 order: Order,
1055 b: &[Self],
1056 n: usize,
1057 ldb: usize,
1058 beta: Self,
1059 c: &mut [Self],
1060 ldc: usize,
1061 ) -> Result<()> {
1062 let status = unsafe {
1063 sys::$csrmm(
1064 trans_raw(op),
1065 alpha,
1066 a,
1067 descr.as_raw(),
1068 order.raw(),
1069 b.as_ptr(),
1070 n as sys::aoclsparse_int,
1071 ldb as sys::aoclsparse_int,
1072 beta,
1073 c.as_mut_ptr(),
1074 ldc as sys::aoclsparse_int,
1075 )
1076 };
1077 check_status("sparse", status)
1078 }
1079
1080 fn spmmd(
1081 op: Trans,
1082 a: sys::aoclsparse_matrix,
1083 b: sys::aoclsparse_matrix,
1084 layout: Order,
1085 c: &mut [Self],
1086 ldc: usize,
1087 ) -> Result<()> {
1088 let status = unsafe {
1089 sys::$spmmd(
1090 trans_raw(op),
1091 a,
1092 b,
1093 layout.raw(),
1094 c.as_mut_ptr(),
1095 ldc as sys::aoclsparse_int,
1096 )
1097 };
1098 check_status("sparse", status)
1099 }
1100
1101 #[allow(clippy::too_many_arguments)]
1102 fn sp2md(
1103 op_a: Trans,
1104 descr_a: &MatDescr,
1105 a: sys::aoclsparse_matrix,
1106 op_b: Trans,
1107 descr_b: &MatDescr,
1108 b: sys::aoclsparse_matrix,
1109 alpha: Self,
1110 beta: Self,
1111 c: &mut [Self],
1112 layout: Order,
1113 ldc: usize,
1114 ) -> Result<()> {
1115 let status = unsafe {
1116 sys::$sp2md(
1117 trans_raw(op_a),
1118 descr_a.as_raw(),
1119 a,
1120 trans_raw(op_b),
1121 descr_b.as_raw(),
1122 b,
1123 alpha,
1124 beta,
1125 c.as_mut_ptr(),
1126 layout.raw(),
1127 ldc as sys::aoclsparse_int,
1128 )
1129 };
1130 check_status("sparse", status)
1131 }
1132
1133 unsafe fn add_ffi(
1134 op: sys::aoclsparse_operation,
1135 a: sys::aoclsparse_matrix,
1136 alpha: Self,
1137 b: sys::aoclsparse_matrix,
1138 out: *mut sys::aoclsparse_matrix,
1139 ) -> sys::aoclsparse_status {
1140 sys::$add(op, a, alpha, b, out)
1141 }
1142
1143 #[allow(clippy::too_many_arguments)]
1144 fn sorv(
1145 sor_type: SorType,
1146 descr: &MatDescr,
1147 a: sys::aoclsparse_matrix,
1148 omega: Self,
1149 alpha: Self,
1150 x: &mut [Self],
1151 b: &[Self],
1152 ) -> Result<()> {
1153 let status = unsafe {
1154 sys::$sorv(
1155 sor_type.raw(),
1156 descr.as_raw(),
1157 a,
1158 omega,
1159 alpha,
1160 x.as_mut_ptr(),
1161 b.as_ptr(),
1162 )
1163 };
1164 check_status("sparse", status)
1165 }
1166 }
1167 };
1168}
1169
1170impl_scalar!(
1171 f32,
1172 csrmv = aoclsparse_scsrmv,
1173 axpyi = aoclsparse_saxpyi,
1174 gthr = aoclsparse_sgthr,
1175 sctr = aoclsparse_ssctr,
1176 csrsv = aoclsparse_scsrsv,
1177 csr2dense = aoclsparse_scsr2dense,
1178 csr2csc = aoclsparse_scsr2csc,
1179 ellmv = aoclsparse_sellmv,
1180 bsrmv = aoclsparse_sbsrmv,
1181 create_csr = aoclsparse_create_scsr,
1182 export_csr = aoclsparse_export_scsr,
1183 ilu_smoother = aoclsparse_silu_smoother,
1184 itsol_init = aoclsparse_itsol_s_init,
1185 itsol_solve = aoclsparse_itsol_s_solve,
1186 csr2m = aoclsparse_scsr2m,
1187 csrmm = aoclsparse_scsrmm,
1188 spmmd = aoclsparse_sspmmd,
1189 sp2md = aoclsparse_ssp2md,
1190 add = aoclsparse_sadd,
1191 sorv = aoclsparse_ssorv
1192);
1193impl_scalar!(
1194 f64,
1195 csrmv = aoclsparse_dcsrmv,
1196 axpyi = aoclsparse_daxpyi,
1197 gthr = aoclsparse_dgthr,
1198 sctr = aoclsparse_dsctr,
1199 csrsv = aoclsparse_dcsrsv,
1200 csr2dense = aoclsparse_dcsr2dense,
1201 csr2csc = aoclsparse_dcsr2csc,
1202 ellmv = aoclsparse_dellmv,
1203 bsrmv = aoclsparse_dbsrmv,
1204 create_csr = aoclsparse_create_dcsr,
1205 export_csr = aoclsparse_export_dcsr,
1206 ilu_smoother = aoclsparse_dilu_smoother,
1207 itsol_init = aoclsparse_itsol_d_init,
1208 itsol_solve = aoclsparse_itsol_d_solve,
1209 csr2m = aoclsparse_dcsr2m,
1210 csrmm = aoclsparse_dcsrmm,
1211 spmmd = aoclsparse_dspmmd,
1212 sp2md = aoclsparse_dsp2md,
1213 add = aoclsparse_dadd,
1214 sorv = aoclsparse_dsorv
1215);
1216
1217#[allow(clippy::too_many_arguments)]
1219pub fn csrmv<T: Scalar>(
1220 alpha: T,
1221 m: usize,
1222 n: usize,
1223 csr_val: &[T],
1224 csr_col_ind: &[sys::aoclsparse_int],
1225 csr_row_ptr: &[sys::aoclsparse_int],
1226 descr: &MatDescr,
1227 x: &[T],
1228 beta: T,
1229 y: &mut [T],
1230) -> Result<()> {
1231 T::csrmv(
1232 Trans::No,
1233 alpha,
1234 m,
1235 n,
1236 csr_val,
1237 csr_col_ind,
1238 csr_row_ptr,
1239 descr,
1240 x,
1241 beta,
1242 y,
1243 )
1244}
1245
1246pub fn axpyi<T: Scalar>(
1255 alpha: T,
1256 x: &[T],
1257 indx: &[sys::aoclsparse_int],
1258 y: &mut [T],
1259) -> Result<()> {
1260 if x.len() != indx.len() {
1261 return Err(Error::InvalidArgument(format!(
1262 "axpyi: x.len()={}, indx.len()={}",
1263 x.len(),
1264 indx.len()
1265 )));
1266 }
1267 T::axpyi(alpha, x, indx, y)
1268}
1269
1270pub fn gthr<T: Scalar>(y: &[T], indx: &[sys::aoclsparse_int], x: &mut [T]) -> Result<()> {
1272 if x.len() != indx.len() {
1273 return Err(Error::InvalidArgument(format!(
1274 "gthr: x.len()={}, indx.len()={}",
1275 x.len(),
1276 indx.len()
1277 )));
1278 }
1279 T::gthr(y, indx, x)
1280}
1281
1282pub fn sctr<T: Scalar>(x: &[T], indx: &[sys::aoclsparse_int], y: &mut [T]) -> Result<()> {
1284 if x.len() != indx.len() {
1285 return Err(Error::InvalidArgument(format!(
1286 "sctr: x.len()={}, indx.len()={}",
1287 x.len(),
1288 indx.len()
1289 )));
1290 }
1291 T::sctr(x, indx, y)
1292}
1293
1294#[allow(clippy::too_many_arguments)]
1302pub fn csrsv<T: Scalar>(
1303 op: Trans,
1304 alpha: T,
1305 m: usize,
1306 csr_val: &[T],
1307 csr_col_ind: &[sys::aoclsparse_int],
1308 csr_row_ptr: &[sys::aoclsparse_int],
1309 descr: &MatDescr,
1310 x: &[T],
1311 y: &mut [T],
1312) -> Result<()> {
1313 T::csrsv(op, alpha, m, csr_val, csr_col_ind, csr_row_ptr, descr, x, y)
1314}
1315
1316#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1322pub enum Order {
1323 RowMajor,
1324 ColMajor,
1325}
1326
1327impl Order {
1328 pub(crate) fn raw(self) -> sys::aoclsparse_order {
1329 match self {
1330 Order::RowMajor => sys::aoclsparse_order__aoclsparse_order_row,
1331 Order::ColMajor => sys::aoclsparse_order__aoclsparse_order_column,
1332 }
1333 }
1334}
1335
1336#[allow(clippy::too_many_arguments)]
1338pub fn csr_to_dense<T: Scalar>(
1339 m: usize,
1340 n: usize,
1341 descr: &MatDescr,
1342 csr_val: &[T],
1343 csr_row_ptr: &[sys::aoclsparse_int],
1344 csr_col_ind: &[sys::aoclsparse_int],
1345 a: &mut [T],
1346 ld: usize,
1347 order: Order,
1348) -> Result<()> {
1349 if csr_row_ptr.len() != m + 1 {
1350 return Err(Error::InvalidArgument(format!(
1351 "csr_to_dense: csr_row_ptr length {} != m+1 = {}",
1352 csr_row_ptr.len(),
1353 m + 1
1354 )));
1355 }
1356 let needed = match order {
1357 Order::RowMajor => m.saturating_sub(1) * ld + n,
1358 Order::ColMajor => n.saturating_sub(1) * ld + m,
1359 };
1360 if a.len() < needed {
1361 return Err(Error::InvalidArgument(format!(
1362 "csr_to_dense: A length {} < needed {needed}",
1363 a.len()
1364 )));
1365 }
1366 T::csr_to_dense(m, n, descr, csr_val, csr_row_ptr, csr_col_ind, a, ld, order)
1367}
1368
1369#[allow(clippy::too_many_arguments)]
1371pub fn csr_to_csc<T: Scalar>(
1372 m: usize,
1373 n: usize,
1374 descr: &MatDescr,
1375 base_csc: IndexBase,
1376 csr_row_ptr: &[sys::aoclsparse_int],
1377 csr_col_ind: &[sys::aoclsparse_int],
1378 csr_val: &[T],
1379 csc_row_ind: &mut [sys::aoclsparse_int],
1380 csc_col_ptr: &mut [sys::aoclsparse_int],
1381 csc_val: &mut [T],
1382) -> Result<()> {
1383 let nnz = csr_val.len();
1384 if csr_col_ind.len() != nnz || csc_row_ind.len() < nnz || csc_val.len() < nnz {
1385 return Err(Error::InvalidArgument(format!(
1386 "csr_to_csc: nnz mismatch (csr_val={}, csr_col_ind={}, csc_row_ind={}, csc_val={})",
1387 nnz,
1388 csr_col_ind.len(),
1389 csc_row_ind.len(),
1390 csc_val.len()
1391 )));
1392 }
1393 if csr_row_ptr.len() != m + 1 || csc_col_ptr.len() != n + 1 {
1394 return Err(Error::InvalidArgument(format!(
1395 "csr_to_csc: row_ptr length {} != m+1 = {} or col_ptr length {} != n+1 = {}",
1396 csr_row_ptr.len(),
1397 m + 1,
1398 csc_col_ptr.len(),
1399 n + 1
1400 )));
1401 }
1402 T::csr_to_csc(
1403 m,
1404 n,
1405 descr,
1406 base_csc,
1407 csr_row_ptr,
1408 csr_col_ind,
1409 csr_val,
1410 csc_row_ind,
1411 csc_col_ptr,
1412 csc_val,
1413 )
1414}
1415
1416#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1418pub enum IndexBase {
1419 Zero,
1421 One,
1423}
1424
1425impl IndexBase {
1426 fn raw(self) -> sys::aoclsparse_index_base {
1427 match self {
1428 IndexBase::Zero => sys::aoclsparse_index_base__aoclsparse_index_base_zero,
1429 IndexBase::One => sys::aoclsparse_index_base__aoclsparse_index_base_one,
1430 }
1431 }
1432}
1433
1434#[allow(clippy::too_many_arguments)]
1445pub fn ellmv<T: Scalar>(
1446 op: Trans,
1447 alpha: T,
1448 m: usize,
1449 n: usize,
1450 ell_val: &[T],
1451 ell_col_ind: &[sys::aoclsparse_int],
1452 ell_width: usize,
1453 descr: &MatDescr,
1454 x: &[T],
1455 beta: T,
1456 y: &mut [T],
1457) -> Result<()> {
1458 T::ellmv(
1459 op,
1460 alpha,
1461 m,
1462 n,
1463 ell_val,
1464 ell_col_ind,
1465 ell_width,
1466 descr,
1467 x,
1468 beta,
1469 y,
1470 )
1471}
1472
1473#[allow(clippy::too_many_arguments)]
1480pub fn bsrmv<T: Scalar>(
1481 op: Trans,
1482 alpha: T,
1483 mb: usize,
1484 nb: usize,
1485 bsr_dim: usize,
1486 bsr_val: &[T],
1487 bsr_col_ind: &[sys::aoclsparse_int],
1488 bsr_row_ptr: &[sys::aoclsparse_int],
1489 descr: &MatDescr,
1490 x: &[T],
1491 beta: T,
1492 y: &mut [T],
1493) -> Result<()> {
1494 T::bsrmv(
1495 op,
1496 alpha,
1497 mb,
1498 nb,
1499 bsr_dim,
1500 bsr_val,
1501 bsr_col_ind,
1502 bsr_row_ptr,
1503 descr,
1504 x,
1505 beta,
1506 y,
1507 )
1508}
1509
1510#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1516pub enum Stage {
1517 NnzCount,
1519 Finalize,
1521 FullComputation,
1523}
1524
1525impl Stage {
1526 fn raw(self) -> sys::aoclsparse_request {
1527 match self {
1528 Stage::NnzCount => sys::aoclsparse_request__aoclsparse_stage_nnz_count,
1529 Stage::Finalize => sys::aoclsparse_request__aoclsparse_stage_finalize,
1530 Stage::FullComputation => sys::aoclsparse_request__aoclsparse_stage_full_computation,
1531 }
1532 }
1533}
1534
1535enum CsrStorage<T: Scalar> {
1536 Owned {
1538 _row_ptr: Vec<sys::aoclsparse_int>,
1539 _col_ind: Vec<sys::aoclsparse_int>,
1540 _val: Vec<T>,
1541 },
1542 LibraryOwned,
1544}
1545
1546pub struct SparseMatrix<T: Scalar> {
1552 raw: sys::aoclsparse_matrix,
1553 #[allow(dead_code)] storage: CsrStorage<T>,
1555 base: IndexBase,
1556 m: usize,
1557 n: usize,
1558 nnz: usize,
1559}
1560
1561impl<T: Scalar> SparseMatrix<T> {
1562 pub fn from_csr(
1565 base: IndexBase,
1566 m: usize,
1567 n: usize,
1568 row_ptr: &[sys::aoclsparse_int],
1569 col_ind: &[sys::aoclsparse_int],
1570 val: &[T],
1571 ) -> Result<Self> {
1572 if row_ptr.len() != m + 1 {
1573 return Err(Error::InvalidArgument(format!(
1574 "from_csr: row_ptr length {} != m+1 = {}",
1575 row_ptr.len(),
1576 m + 1
1577 )));
1578 }
1579 let nnz = val.len();
1580 if col_ind.len() != nnz {
1581 return Err(Error::InvalidArgument(format!(
1582 "from_csr: col_ind length {} != val length {nnz}",
1583 col_ind.len()
1584 )));
1585 }
1586 let mut row_ptr = row_ptr.to_vec();
1587 let mut col_ind = col_ind.to_vec();
1588 let mut val = val.to_vec();
1589 let raw = T::create_csr(
1590 base,
1591 m,
1592 n,
1593 nnz,
1594 row_ptr.as_mut_ptr(),
1595 col_ind.as_mut_ptr(),
1596 val.as_mut_ptr(),
1597 )?;
1598 Ok(Self {
1599 raw,
1600 storage: CsrStorage::Owned {
1601 _row_ptr: row_ptr,
1602 _col_ind: col_ind,
1603 _val: val,
1604 },
1605 base,
1606 m,
1607 n,
1608 nnz,
1609 })
1610 }
1611
1612 pub unsafe fn from_library_owned(raw: sys::aoclsparse_matrix) -> Result<Self> {
1620 if raw.is_null() {
1621 return Err(Error::AllocationFailed("sparse"));
1622 }
1623 let (base, m, n, nnz, _, _, _) = T::export_csr(raw)?;
1624 Ok(Self {
1625 raw,
1626 storage: CsrStorage::LibraryOwned,
1627 base,
1628 m,
1629 n,
1630 nnz,
1631 })
1632 }
1633
1634 pub fn dims(&self) -> (usize, usize) {
1636 (self.m, self.n)
1637 }
1638
1639 pub fn nnz(&self) -> usize {
1641 self.nnz
1642 }
1643
1644 pub fn base(&self) -> IndexBase {
1646 self.base
1647 }
1648
1649 pub fn as_raw(&self) -> sys::aoclsparse_matrix {
1655 self.raw
1656 }
1657
1658 pub fn export_csr(
1660 &self,
1661 ) -> Result<(
1662 IndexBase,
1663 Vec<sys::aoclsparse_int>,
1664 Vec<sys::aoclsparse_int>,
1665 Vec<T>,
1666 )> {
1667 let (base, m, _, nnz, row_ptr, col_ind, val) = T::export_csr(self.raw)?;
1668 let row_ptr = unsafe { std::slice::from_raw_parts(row_ptr, m + 1).to_vec() };
1669 let col_ind = unsafe { std::slice::from_raw_parts(col_ind, nnz).to_vec() };
1670 let val = unsafe { std::slice::from_raw_parts(val, nnz).to_vec() };
1671 Ok((base, row_ptr, col_ind, val))
1672 }
1673
1674 pub fn set_mv_hint(
1680 &mut self,
1681 op: Trans,
1682 descr: &MatDescr,
1683 expected_calls: usize,
1684 ) -> Result<()> {
1685 let status = unsafe {
1686 sys::aoclsparse_set_mv_hint(
1687 self.raw,
1688 trans_raw(op),
1689 descr.as_raw(),
1690 expected_calls as sys::aoclsparse_int,
1691 )
1692 };
1693 check_status("sparse", status)
1694 }
1695
1696 pub fn set_sv_hint(
1698 &mut self,
1699 op: Trans,
1700 descr: &MatDescr,
1701 expected_calls: usize,
1702 ) -> Result<()> {
1703 let status = unsafe {
1704 sys::aoclsparse_set_sv_hint(
1705 self.raw,
1706 trans_raw(op),
1707 descr.as_raw(),
1708 expected_calls as sys::aoclsparse_int,
1709 )
1710 };
1711 check_status("sparse", status)
1712 }
1713
1714 pub fn set_mm_hint(
1716 &mut self,
1717 op: Trans,
1718 descr: &MatDescr,
1719 expected_calls: usize,
1720 ) -> Result<()> {
1721 let status = unsafe {
1722 sys::aoclsparse_set_mm_hint(
1723 self.raw,
1724 trans_raw(op),
1725 descr.as_raw(),
1726 expected_calls as sys::aoclsparse_int,
1727 )
1728 };
1729 check_status("sparse", status)
1730 }
1731
1732 pub fn set_2m_hint(
1734 &mut self,
1735 op: Trans,
1736 descr: &MatDescr,
1737 expected_calls: usize,
1738 ) -> Result<()> {
1739 let status = unsafe {
1740 sys::aoclsparse_set_2m_hint(
1741 self.raw,
1742 trans_raw(op),
1743 descr.as_raw(),
1744 expected_calls as sys::aoclsparse_int,
1745 )
1746 };
1747 check_status("sparse", status)
1748 }
1749
1750 pub fn set_sm_hint(
1753 &mut self,
1754 op: Trans,
1755 descr: &MatDescr,
1756 order: Order,
1757 expected_calls: usize,
1758 ) -> Result<()> {
1759 let status = unsafe {
1760 sys::aoclsparse_set_sm_hint(
1761 self.raw,
1762 trans_raw(op),
1763 descr.as_raw(),
1764 order.raw(),
1765 expected_calls as sys::aoclsparse_int,
1766 )
1767 };
1768 check_status("sparse", status)
1769 }
1770
1771 pub fn set_lu_smoother_hint(
1773 &mut self,
1774 op: Trans,
1775 descr: &MatDescr,
1776 expected_calls: usize,
1777 ) -> Result<()> {
1778 let status = unsafe {
1779 sys::aoclsparse_set_lu_smoother_hint(
1780 self.raw,
1781 trans_raw(op),
1782 descr.as_raw(),
1783 expected_calls as sys::aoclsparse_int,
1784 )
1785 };
1786 check_status("sparse", status)
1787 }
1788
1789 pub fn set_symgs_hint(
1791 &mut self,
1792 op: Trans,
1793 descr: &MatDescr,
1794 expected_calls: usize,
1795 ) -> Result<()> {
1796 let status = unsafe {
1797 sys::aoclsparse_set_symgs_hint(
1798 self.raw,
1799 trans_raw(op),
1800 descr.as_raw(),
1801 expected_calls as sys::aoclsparse_int,
1802 )
1803 };
1804 check_status("sparse", status)
1805 }
1806
1807 pub fn set_dotmv_hint(
1809 &mut self,
1810 op: Trans,
1811 descr: &MatDescr,
1812 expected_calls: usize,
1813 ) -> Result<()> {
1814 let status = unsafe {
1815 sys::aoclsparse_set_dotmv_hint(
1816 self.raw,
1817 trans_raw(op),
1818 descr.as_raw(),
1819 expected_calls as sys::aoclsparse_int,
1820 )
1821 };
1822 check_status("sparse", status)
1823 }
1824
1825 pub fn set_sorv_hint(
1827 &mut self,
1828 sor_type: SorType,
1829 descr: &MatDescr,
1830 expected_calls: usize,
1831 ) -> Result<()> {
1832 let status = unsafe {
1833 sys::aoclsparse_set_sorv_hint(
1834 self.raw,
1835 descr.as_raw(),
1836 sor_type.raw(),
1837 expected_calls as sys::aoclsparse_int,
1838 )
1839 };
1840 check_status("sparse", status)
1841 }
1842}
1843
1844impl<T: Scalar> Drop for SparseMatrix<T> {
1845 fn drop(&mut self) {
1846 if !self.raw.is_null() {
1847 unsafe {
1848 let _ = sys::aoclsparse_destroy(&mut self.raw);
1849 }
1850 self.raw = std::ptr::null_mut();
1851 }
1852 }
1853}
1854
1855impl<T: Scalar> std::fmt::Debug for SparseMatrix<T> {
1856 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1857 f.debug_struct("SparseMatrix")
1858 .field("m", &self.m)
1859 .field("n", &self.n)
1860 .field("nnz", &self.nnz)
1861 .field("base", &self.base)
1862 .finish()
1863 }
1864}
1865
1866#[allow(clippy::too_many_arguments)]
1878pub fn csr2m<T: Scalar>(
1879 op_a: Trans,
1880 descr_a: &MatDescr,
1881 a: &SparseMatrix<T>,
1882 op_b: Trans,
1883 descr_b: &MatDescr,
1884 b: &SparseMatrix<T>,
1885 stage: Stage,
1886) -> Result<SparseMatrix<T>> {
1887 let mut c_raw: sys::aoclsparse_matrix = std::ptr::null_mut();
1888 let status = unsafe {
1889 T::csr2m_ffi(
1892 trans_raw(op_a),
1893 descr_a.as_raw(),
1894 a.raw,
1895 trans_raw(op_b),
1896 descr_b.as_raw(),
1897 b.raw,
1898 stage.raw(),
1899 &mut c_raw,
1900 )
1901 };
1902 check_status("sparse", status)?;
1903 unsafe { SparseMatrix::from_library_owned(c_raw) }
1904}
1905
1906#[allow(clippy::too_many_arguments)]
1916pub fn csrmm<T: Scalar>(
1917 op: Trans,
1918 alpha: T,
1919 a: &SparseMatrix<T>,
1920 descr: &MatDescr,
1921 order: Order,
1922 b: &[T],
1923 n: usize,
1924 ldb: usize,
1925 beta: T,
1926 c: &mut [T],
1927 ldc: usize,
1928) -> Result<()> {
1929 T::csrmm(op, alpha, a.as_raw(), descr, order, b, n, ldb, beta, c, ldc)
1930}
1931
1932pub fn spmmd<T: Scalar>(
1935 op: Trans,
1936 a: &SparseMatrix<T>,
1937 b: &SparseMatrix<T>,
1938 layout: Order,
1939 c: &mut [T],
1940 ldc: usize,
1941) -> Result<()> {
1942 T::spmmd(op, a.as_raw(), b.as_raw(), layout, c, ldc)
1943}
1944
1945#[allow(clippy::too_many_arguments)]
1948pub fn sp2md<T: Scalar>(
1949 op_a: Trans,
1950 descr_a: &MatDescr,
1951 a: &SparseMatrix<T>,
1952 op_b: Trans,
1953 descr_b: &MatDescr,
1954 b: &SparseMatrix<T>,
1955 alpha: T,
1956 beta: T,
1957 c: &mut [T],
1958 layout: Order,
1959 ldc: usize,
1960) -> Result<()> {
1961 T::sp2md(
1962 op_a,
1963 descr_a,
1964 a.as_raw(),
1965 op_b,
1966 descr_b,
1967 b.as_raw(),
1968 alpha,
1969 beta,
1970 c,
1971 layout,
1972 ldc,
1973 )
1974}
1975
1976pub fn add<T: Scalar>(
1982 op: Trans,
1983 a: &SparseMatrix<T>,
1984 alpha: T,
1985 b: &SparseMatrix<T>,
1986) -> Result<SparseMatrix<T>> {
1987 let mut c_raw: sys::aoclsparse_matrix = std::ptr::null_mut();
1988 let status = unsafe { T::add_ffi(trans_raw(op), a.as_raw(), alpha, b.as_raw(), &mut c_raw) };
1989 check_status("sparse", status)?;
1990 unsafe { SparseMatrix::from_library_owned(c_raw) }
1991}
1992
1993pub fn sorv<T: Scalar>(
1997 sor_type: SorType,
1998 descr: &MatDescr,
1999 a: &SparseMatrix<T>,
2000 omega: T,
2001 alpha: T,
2002 x: &mut [T],
2003 b: &[T],
2004) -> Result<()> {
2005 if x.len() < a.dims().1 || b.len() < a.dims().0 {
2006 return Err(Error::InvalidArgument(format!(
2007 "sorv: x.len()={}, b.len()={}, dims=({}, {})",
2008 x.len(),
2009 b.len(),
2010 a.dims().0,
2011 a.dims().1
2012 )));
2013 }
2014 T::sorv(sor_type, descr, a.as_raw(), omega, alpha, x, b)
2015}
2016
2017#[allow(clippy::too_many_arguments)]
2029pub fn mv_f64(
2030 op: Trans,
2031 alpha: f64,
2032 a: &SparseMatrix<f64>,
2033 descr: &MatDescr,
2034 x: &[f64],
2035 beta: f64,
2036 y: &mut [f64],
2037) -> Result<()> {
2038 let status = unsafe {
2039 sys::aoclsparse_dmv(
2040 trans_raw(op),
2041 &alpha,
2042 a.as_raw(),
2043 descr.as_raw(),
2044 x.as_ptr(),
2045 &beta,
2046 y.as_mut_ptr(),
2047 )
2048 };
2049 check_status("sparse", status)
2050}
2051
2052#[allow(clippy::too_many_arguments)]
2054pub fn mv_f32(
2055 op: Trans,
2056 alpha: f32,
2057 a: &SparseMatrix<f32>,
2058 descr: &MatDescr,
2059 x: &[f32],
2060 beta: f32,
2061 y: &mut [f32],
2062) -> Result<()> {
2063 let status = unsafe {
2064 sys::aoclsparse_smv(
2065 trans_raw(op),
2066 &alpha,
2067 a.as_raw(),
2068 descr.as_raw(),
2069 x.as_ptr(),
2070 &beta,
2071 y.as_mut_ptr(),
2072 )
2073 };
2074 check_status("sparse", status)
2075}
2076
2077pub fn trsv_f64(
2081 op: Trans,
2082 alpha: f64,
2083 a: &SparseMatrix<f64>,
2084 descr: &MatDescr,
2085 b: &[f64],
2086 x: &mut [f64],
2087) -> Result<()> {
2088 let status = unsafe {
2089 sys::aoclsparse_dtrsv(
2090 trans_raw(op),
2091 alpha,
2092 a.as_raw(),
2093 descr.as_raw(),
2094 b.as_ptr(),
2095 x.as_mut_ptr(),
2096 )
2097 };
2098 check_status("sparse", status)
2099}
2100
2101pub fn trsv_f32(
2103 op: Trans,
2104 alpha: f32,
2105 a: &SparseMatrix<f32>,
2106 descr: &MatDescr,
2107 b: &[f32],
2108 x: &mut [f32],
2109) -> Result<()> {
2110 let status = unsafe {
2111 sys::aoclsparse_strsv(
2112 trans_raw(op),
2113 alpha,
2114 a.as_raw(),
2115 descr.as_raw(),
2116 b.as_ptr(),
2117 x.as_mut_ptr(),
2118 )
2119 };
2120 check_status("sparse", status)
2121}
2122
2123#[allow(clippy::too_many_arguments)]
2126pub fn trsm_f64(
2127 op: Trans,
2128 alpha: f64,
2129 a: &SparseMatrix<f64>,
2130 descr: &MatDescr,
2131 order: Order,
2132 b: &[f64],
2133 n_rhs: usize,
2134 ldb: usize,
2135 x: &mut [f64],
2136 ldx: usize,
2137) -> Result<()> {
2138 let status = unsafe {
2139 sys::aoclsparse_dtrsm(
2140 trans_raw(op),
2141 alpha,
2142 a.as_raw(),
2143 descr.as_raw(),
2144 order.raw(),
2145 b.as_ptr(),
2146 n_rhs as sys::aoclsparse_int,
2147 ldb as sys::aoclsparse_int,
2148 x.as_mut_ptr(),
2149 ldx as sys::aoclsparse_int,
2150 )
2151 };
2152 check_status("sparse", status)
2153}
2154
2155#[allow(clippy::too_many_arguments)]
2157pub fn trsm_f32(
2158 op: Trans,
2159 alpha: f32,
2160 a: &SparseMatrix<f32>,
2161 descr: &MatDescr,
2162 order: Order,
2163 b: &[f32],
2164 n_rhs: usize,
2165 ldb: usize,
2166 x: &mut [f32],
2167 ldx: usize,
2168) -> Result<()> {
2169 let status = unsafe {
2170 sys::aoclsparse_strsm(
2171 trans_raw(op),
2172 alpha,
2173 a.as_raw(),
2174 descr.as_raw(),
2175 order.raw(),
2176 b.as_ptr(),
2177 n_rhs as sys::aoclsparse_int,
2178 ldb as sys::aoclsparse_int,
2179 x.as_mut_ptr(),
2180 ldx as sys::aoclsparse_int,
2181 )
2182 };
2183 check_status("sparse", status)
2184}
2185
2186pub fn doti_f64(x: &[f64], indx: &[sys::aoclsparse_int], y: &[f64]) -> Result<f64> {
2189 if x.len() != indx.len() {
2190 return Err(Error::InvalidArgument(format!(
2191 "doti: x.len()={} != indx.len()={}",
2192 x.len(),
2193 indx.len()
2194 )));
2195 }
2196 let r = unsafe {
2197 sys::aoclsparse_ddoti(
2198 x.len() as sys::aoclsparse_int,
2199 x.as_ptr(),
2200 indx.as_ptr(),
2201 y.as_ptr(),
2202 )
2203 };
2204 Ok(r)
2205}
2206
2207pub fn doti_f32(x: &[f32], indx: &[sys::aoclsparse_int], y: &[f32]) -> Result<f32> {
2209 if x.len() != indx.len() {
2210 return Err(Error::InvalidArgument(format!(
2211 "doti: x.len()={} != indx.len()={}",
2212 x.len(),
2213 indx.len()
2214 )));
2215 }
2216 let r = unsafe {
2217 sys::aoclsparse_sdoti(
2218 x.len() as sys::aoclsparse_int,
2219 x.as_ptr(),
2220 indx.as_ptr(),
2221 y.as_ptr(),
2222 )
2223 };
2224 Ok(r)
2225}
2226
2227#[allow(clippy::too_many_arguments)]
2236pub fn csr2ell_f64(
2237 m: usize,
2238 descr: &MatDescr,
2239 csr_row_ptr: &[sys::aoclsparse_int],
2240 csr_col_ind: &[sys::aoclsparse_int],
2241 csr_val: &[f64],
2242 ell_col_ind: &mut [sys::aoclsparse_int],
2243 ell_val: &mut [f64],
2244 ell_width: usize,
2245) -> Result<()> {
2246 let status = unsafe {
2247 sys::aoclsparse_dcsr2ell(
2248 m as sys::aoclsparse_int,
2249 descr.as_raw(),
2250 csr_row_ptr.as_ptr(),
2251 csr_col_ind.as_ptr(),
2252 csr_val.as_ptr(),
2253 ell_col_ind.as_mut_ptr(),
2254 ell_val.as_mut_ptr(),
2255 ell_width as sys::aoclsparse_int,
2256 )
2257 };
2258 check_status("sparse", status)
2259}
2260
2261#[allow(clippy::too_many_arguments)]
2263pub fn csr2ell_f32(
2264 m: usize,
2265 descr: &MatDescr,
2266 csr_row_ptr: &[sys::aoclsparse_int],
2267 csr_col_ind: &[sys::aoclsparse_int],
2268 csr_val: &[f32],
2269 ell_col_ind: &mut [sys::aoclsparse_int],
2270 ell_val: &mut [f32],
2271 ell_width: usize,
2272) -> Result<()> {
2273 let status = unsafe {
2274 sys::aoclsparse_scsr2ell(
2275 m as sys::aoclsparse_int,
2276 descr.as_raw(),
2277 csr_row_ptr.as_ptr(),
2278 csr_col_ind.as_ptr(),
2279 csr_val.as_ptr(),
2280 ell_col_ind.as_mut_ptr(),
2281 ell_val.as_mut_ptr(),
2282 ell_width as sys::aoclsparse_int,
2283 )
2284 };
2285 check_status("sparse", status)
2286}
2287
2288#[allow(clippy::too_many_arguments)]
2292pub fn csr2dia_f64(
2293 m: usize,
2294 n: usize,
2295 descr: &MatDescr,
2296 csr_row_ptr: &[sys::aoclsparse_int],
2297 csr_col_ind: &[sys::aoclsparse_int],
2298 csr_val: &[f64],
2299 dia_num_diag: usize,
2300 dia_offset: &mut [sys::aoclsparse_int],
2301 dia_val: &mut [f64],
2302) -> Result<()> {
2303 let status = unsafe {
2304 sys::aoclsparse_dcsr2dia(
2305 m as sys::aoclsparse_int,
2306 n as sys::aoclsparse_int,
2307 descr.as_raw(),
2308 csr_row_ptr.as_ptr(),
2309 csr_col_ind.as_ptr(),
2310 csr_val.as_ptr(),
2311 dia_num_diag as sys::aoclsparse_int,
2312 dia_offset.as_mut_ptr(),
2313 dia_val.as_mut_ptr(),
2314 )
2315 };
2316 check_status("sparse", status)
2317}
2318
2319#[allow(clippy::too_many_arguments)]
2321pub fn csr2dia_f32(
2322 m: usize,
2323 n: usize,
2324 descr: &MatDescr,
2325 csr_row_ptr: &[sys::aoclsparse_int],
2326 csr_col_ind: &[sys::aoclsparse_int],
2327 csr_val: &[f32],
2328 dia_num_diag: usize,
2329 dia_offset: &mut [sys::aoclsparse_int],
2330 dia_val: &mut [f32],
2331) -> Result<()> {
2332 let status = unsafe {
2333 sys::aoclsparse_scsr2dia(
2334 m as sys::aoclsparse_int,
2335 n as sys::aoclsparse_int,
2336 descr.as_raw(),
2337 csr_row_ptr.as_ptr(),
2338 csr_col_ind.as_ptr(),
2339 csr_val.as_ptr(),
2340 dia_num_diag as sys::aoclsparse_int,
2341 dia_offset.as_mut_ptr(),
2342 dia_val.as_mut_ptr(),
2343 )
2344 };
2345 check_status("sparse", status)
2346}
2347
2348pub fn csr2bsr_nnz(
2351 m: usize,
2352 n: usize,
2353 descr: &MatDescr,
2354 csr_row_ptr: &[sys::aoclsparse_int],
2355 csr_col_ind: &[sys::aoclsparse_int],
2356 block_dim: usize,
2357 bsr_row_ptr: &mut [sys::aoclsparse_int],
2358) -> Result<usize> {
2359 let mut bsr_nnz: sys::aoclsparse_int = 0;
2360 let status = unsafe {
2361 sys::aoclsparse_csr2bsr_nnz(
2362 m as sys::aoclsparse_int,
2363 n as sys::aoclsparse_int,
2364 descr.as_raw(),
2365 csr_row_ptr.as_ptr(),
2366 csr_col_ind.as_ptr(),
2367 block_dim as sys::aoclsparse_int,
2368 bsr_row_ptr.as_mut_ptr(),
2369 &mut bsr_nnz,
2370 )
2371 };
2372 check_status("sparse", status)?;
2373 Ok(bsr_nnz as usize)
2374}
2375
2376#[allow(clippy::too_many_arguments)]
2378pub fn csr2bsr_f64(
2379 m: usize,
2380 n: usize,
2381 descr: &MatDescr,
2382 csr_val: &[f64],
2383 csr_row_ptr: &[sys::aoclsparse_int],
2384 csr_col_ind: &[sys::aoclsparse_int],
2385 block_dim: usize,
2386 bsr_val: &mut [f64],
2387 bsr_row_ptr: &mut [sys::aoclsparse_int],
2388 bsr_col_ind: &mut [sys::aoclsparse_int],
2389) -> Result<()> {
2390 let status = unsafe {
2391 sys::aoclsparse_dcsr2bsr(
2392 m as sys::aoclsparse_int,
2393 n as sys::aoclsparse_int,
2394 descr.as_raw(),
2395 csr_val.as_ptr(),
2396 csr_row_ptr.as_ptr(),
2397 csr_col_ind.as_ptr(),
2398 block_dim as sys::aoclsparse_int,
2399 bsr_val.as_mut_ptr(),
2400 bsr_row_ptr.as_mut_ptr(),
2401 bsr_col_ind.as_mut_ptr(),
2402 )
2403 };
2404 check_status("sparse", status)
2405}
2406
2407#[allow(clippy::too_many_arguments)]
2409pub fn csr2bsr_f32(
2410 m: usize,
2411 n: usize,
2412 descr: &MatDescr,
2413 csr_val: &[f32],
2414 csr_row_ptr: &[sys::aoclsparse_int],
2415 csr_col_ind: &[sys::aoclsparse_int],
2416 block_dim: usize,
2417 bsr_val: &mut [f32],
2418 bsr_row_ptr: &mut [sys::aoclsparse_int],
2419 bsr_col_ind: &mut [sys::aoclsparse_int],
2420) -> Result<()> {
2421 let status = unsafe {
2422 sys::aoclsparse_scsr2bsr(
2423 m as sys::aoclsparse_int,
2424 n as sys::aoclsparse_int,
2425 descr.as_raw(),
2426 csr_val.as_ptr(),
2427 csr_row_ptr.as_ptr(),
2428 csr_col_ind.as_ptr(),
2429 block_dim as sys::aoclsparse_int,
2430 bsr_val.as_mut_ptr(),
2431 bsr_row_ptr.as_mut_ptr(),
2432 bsr_col_ind.as_mut_ptr(),
2433 )
2434 };
2435 check_status("sparse", status)
2436}
2437
2438#[allow(clippy::too_many_arguments)]
2442pub fn blkcsrmv_f64(
2443 op: Trans,
2444 alpha: f64,
2445 m: usize,
2446 n: usize,
2447 masks: &[u8],
2448 csr_val: &[f64],
2449 csr_col_ind: &[sys::aoclsparse_int],
2450 csr_row_ptr: &[sys::aoclsparse_int],
2451 descr: &MatDescr,
2452 x: &[f64],
2453 beta: f64,
2454 y: &mut [f64],
2455 n_rows_blk: usize,
2456) -> Result<()> {
2457 let nnz = csr_val.len();
2458 let status = unsafe {
2459 sys::aoclsparse_dblkcsrmv(
2460 trans_raw(op),
2461 &alpha,
2462 m as sys::aoclsparse_int,
2463 n as sys::aoclsparse_int,
2464 nnz as sys::aoclsparse_int,
2465 masks.as_ptr(),
2466 csr_val.as_ptr(),
2467 csr_col_ind.as_ptr(),
2468 csr_row_ptr.as_ptr(),
2469 descr.as_raw(),
2470 x.as_ptr(),
2471 &beta,
2472 y.as_mut_ptr(),
2473 n_rows_blk as sys::aoclsparse_int,
2474 )
2475 };
2476 check_status("sparse", status)
2477}
2478
2479pub fn symgs_f64(
2482 op: Trans,
2483 a: &SparseMatrix<f64>,
2484 descr: &MatDescr,
2485 alpha: f64,
2486 b: &[f64],
2487 x: &mut [f64],
2488) -> Result<()> {
2489 let status = unsafe {
2490 sys::aoclsparse_dsymgs(
2491 trans_raw(op),
2492 a.as_raw(),
2493 descr.as_raw(),
2494 alpha,
2495 b.as_ptr(),
2496 x.as_mut_ptr(),
2497 )
2498 };
2499 check_status("sparse", status)
2500}
2501
2502pub fn symgs_f32(
2504 op: Trans,
2505 a: &SparseMatrix<f32>,
2506 descr: &MatDescr,
2507 alpha: f32,
2508 b: &[f32],
2509 x: &mut [f32],
2510) -> Result<()> {
2511 let status = unsafe {
2512 sys::aoclsparse_ssymgs(
2513 trans_raw(op),
2514 a.as_raw(),
2515 descr.as_raw(),
2516 alpha,
2517 b.as_ptr(),
2518 x.as_mut_ptr(),
2519 )
2520 };
2521 check_status("sparse", status)
2522}
2523
2524#[allow(clippy::too_many_arguments)]
2527pub fn symgs_mv_f64(
2528 op: Trans,
2529 a: &SparseMatrix<f64>,
2530 descr: &MatDescr,
2531 alpha: f64,
2532 b: &[f64],
2533 x: &mut [f64],
2534 y: &mut [f64],
2535) -> Result<()> {
2536 let status = unsafe {
2537 sys::aoclsparse_dsymgs_mv(
2538 trans_raw(op),
2539 a.as_raw(),
2540 descr.as_raw(),
2541 alpha,
2542 b.as_ptr(),
2543 x.as_mut_ptr(),
2544 y.as_mut_ptr(),
2545 )
2546 };
2547 check_status("sparse", status)
2548}
2549
2550#[allow(clippy::too_many_arguments)]
2552pub fn symgs_mv_f32(
2553 op: Trans,
2554 a: &SparseMatrix<f32>,
2555 descr: &MatDescr,
2556 alpha: f32,
2557 b: &[f32],
2558 x: &mut [f32],
2559 y: &mut [f32],
2560) -> Result<()> {
2561 let status = unsafe {
2562 sys::aoclsparse_ssymgs_mv(
2563 trans_raw(op),
2564 a.as_raw(),
2565 descr.as_raw(),
2566 alpha,
2567 b.as_ptr(),
2568 x.as_mut_ptr(),
2569 y.as_mut_ptr(),
2570 )
2571 };
2572 check_status("sparse", status)
2573}
2574
2575pub fn set_value_f64(
2578 a: &mut SparseMatrix<f64>,
2579 row_idx: i32,
2580 col_idx: i32,
2581 val: f64,
2582) -> Result<()> {
2583 let status = unsafe {
2584 sys::aoclsparse_dset_value(
2585 a.as_raw(),
2586 row_idx as sys::aoclsparse_int,
2587 col_idx as sys::aoclsparse_int,
2588 val,
2589 )
2590 };
2591 check_status("sparse", status)
2592}
2593
2594pub fn set_value_f32(
2596 a: &mut SparseMatrix<f32>,
2597 row_idx: i32,
2598 col_idx: i32,
2599 val: f32,
2600) -> Result<()> {
2601 let status = unsafe {
2602 sys::aoclsparse_sset_value(
2603 a.as_raw(),
2604 row_idx as sys::aoclsparse_int,
2605 col_idx as sys::aoclsparse_int,
2606 val,
2607 )
2608 };
2609 check_status("sparse", status)
2610}
2611
2612pub fn update_values_f64(a: &mut SparseMatrix<f64>, val: &mut [f64]) -> Result<()> {
2616 let status = unsafe {
2617 sys::aoclsparse_dupdate_values(
2618 a.as_raw(),
2619 val.len() as sys::aoclsparse_int,
2620 val.as_mut_ptr(),
2621 )
2622 };
2623 check_status("sparse", status)
2624}
2625
2626pub fn update_values_f32(a: &mut SparseMatrix<f32>, val: &mut [f32]) -> Result<()> {
2628 let status = unsafe {
2629 sys::aoclsparse_supdate_values(
2630 a.as_raw(),
2631 val.len() as sys::aoclsparse_int,
2632 val.as_mut_ptr(),
2633 )
2634 };
2635 check_status("sparse", status)
2636}
2637
2638#[allow(clippy::too_many_arguments)]
2647pub fn dotmv_f64(
2648 op: Trans,
2649 alpha: f64,
2650 a: &SparseMatrix<f64>,
2651 descr: &MatDescr,
2652 x: &[f64],
2653 beta: f64,
2654 y: &mut [f64],
2655 d: &mut f64,
2656) -> Result<()> {
2657 let status = unsafe {
2658 sys::aoclsparse_ddotmv(
2659 trans_raw(op),
2660 alpha,
2661 a.as_raw(),
2662 descr.as_raw(),
2663 x.as_ptr(),
2664 beta,
2665 y.as_mut_ptr(),
2666 d,
2667 )
2668 };
2669 check_status("sparse", status)
2670}
2671
2672#[allow(clippy::too_many_arguments)]
2674pub fn dotmv_f32(
2675 op: Trans,
2676 alpha: f32,
2677 a: &SparseMatrix<f32>,
2678 descr: &MatDescr,
2679 x: &[f32],
2680 beta: f32,
2681 y: &mut [f32],
2682 d: &mut f32,
2683) -> Result<()> {
2684 let status = unsafe {
2685 sys::aoclsparse_sdotmv(
2686 trans_raw(op),
2687 alpha,
2688 a.as_raw(),
2689 descr.as_raw(),
2690 x.as_ptr(),
2691 beta,
2692 y.as_mut_ptr(),
2693 d,
2694 )
2695 };
2696 check_status("sparse", status)
2697}
2698
2699#[allow(clippy::too_many_arguments)]
2703pub fn syrkd_f64(
2704 op_a: Trans,
2705 a: &SparseMatrix<f64>,
2706 alpha: f64,
2707 beta: f64,
2708 c: &mut [f64],
2709 order_c: Order,
2710 ldc: usize,
2711) -> Result<()> {
2712 let status = unsafe {
2713 sys::aoclsparse_dsyrkd(
2714 trans_raw(op_a),
2715 a.as_raw(),
2716 alpha,
2717 beta,
2718 c.as_mut_ptr(),
2719 order_c.raw(),
2720 ldc as sys::aoclsparse_int,
2721 )
2722 };
2723 check_status("sparse", status)
2724}
2725
2726#[allow(clippy::too_many_arguments)]
2728pub fn syrkd_f32(
2729 op_a: Trans,
2730 a: &SparseMatrix<f32>,
2731 alpha: f32,
2732 beta: f32,
2733 c: &mut [f32],
2734 order_c: Order,
2735 ldc: usize,
2736) -> Result<()> {
2737 let status = unsafe {
2738 sys::aoclsparse_ssyrkd(
2739 trans_raw(op_a),
2740 a.as_raw(),
2741 alpha,
2742 beta,
2743 c.as_mut_ptr(),
2744 order_c.raw(),
2745 ldc as sys::aoclsparse_int,
2746 )
2747 };
2748 check_status("sparse", status)
2749}
2750
2751#[allow(clippy::too_many_arguments)]
2754pub fn syprd_f64(
2755 op_a: Trans,
2756 a: &SparseMatrix<f64>,
2757 b: &[f64],
2758 order_b: Order,
2759 ldb: usize,
2760 alpha: f64,
2761 beta: f64,
2762 c: &mut [f64],
2763 order_c: Order,
2764 ldc: usize,
2765) -> Result<()> {
2766 let status = unsafe {
2767 sys::aoclsparse_dsyprd(
2768 trans_raw(op_a),
2769 a.as_raw(),
2770 b.as_ptr(),
2771 order_b.raw(),
2772 ldb as sys::aoclsparse_int,
2773 alpha,
2774 beta,
2775 c.as_mut_ptr(),
2776 order_c.raw(),
2777 ldc as sys::aoclsparse_int,
2778 )
2779 };
2780 check_status("sparse", status)
2781}
2782
2783#[allow(clippy::too_many_arguments)]
2785pub fn syprd_f32(
2786 op_a: Trans,
2787 a: &SparseMatrix<f32>,
2788 b: &[f32],
2789 order_b: Order,
2790 ldb: usize,
2791 alpha: f32,
2792 beta: f32,
2793 c: &mut [f32],
2794 order_c: Order,
2795 ldc: usize,
2796) -> Result<()> {
2797 let status = unsafe {
2798 sys::aoclsparse_ssyprd(
2799 trans_raw(op_a),
2800 a.as_raw(),
2801 b.as_ptr(),
2802 order_b.raw(),
2803 ldb as sys::aoclsparse_int,
2804 alpha,
2805 beta,
2806 c.as_mut_ptr(),
2807 order_c.raw(),
2808 ldc as sys::aoclsparse_int,
2809 )
2810 };
2811 check_status("sparse", status)
2812}
2813
2814pub fn roti_f64(
2817 x: &mut [f64],
2818 indx: &[sys::aoclsparse_int],
2819 y: &mut [f64],
2820 c: f64,
2821 s: f64,
2822) -> Result<()> {
2823 let status = unsafe {
2824 sys::aoclsparse_droti(
2825 x.len() as sys::aoclsparse_int,
2826 x.as_mut_ptr(),
2827 indx.as_ptr(),
2828 y.as_mut_ptr(),
2829 c,
2830 s,
2831 )
2832 };
2833 check_status("sparse", status)
2834}
2835
2836pub fn roti_f32(
2838 x: &mut [f32],
2839 indx: &[sys::aoclsparse_int],
2840 y: &mut [f32],
2841 c: f32,
2842 s: f32,
2843) -> Result<()> {
2844 let status = unsafe {
2845 sys::aoclsparse_sroti(
2846 x.len() as sys::aoclsparse_int,
2847 x.as_mut_ptr(),
2848 indx.as_ptr(),
2849 y.as_mut_ptr(),
2850 c,
2851 s,
2852 )
2853 };
2854 check_status("sparse", status)
2855}
2856
2857pub fn gthrs_f64(y: &[f64], x: &mut [f64], stride: i32) -> Result<()> {
2859 let status = unsafe {
2860 sys::aoclsparse_dgthrs(
2861 x.len() as sys::aoclsparse_int,
2862 y.as_ptr(),
2863 x.as_mut_ptr(),
2864 stride as sys::aoclsparse_int,
2865 )
2866 };
2867 check_status("sparse", status)
2868}
2869pub fn gthrs_f32(y: &[f32], x: &mut [f32], stride: i32) -> Result<()> {
2871 let status = unsafe {
2872 sys::aoclsparse_sgthrs(
2873 x.len() as sys::aoclsparse_int,
2874 y.as_ptr(),
2875 x.as_mut_ptr(),
2876 stride as sys::aoclsparse_int,
2877 )
2878 };
2879 check_status("sparse", status)
2880}
2881
2882pub fn sctrs_f64(x: &[f64], y: &mut [f64], stride: i32) -> Result<()> {
2884 let status = unsafe {
2885 sys::aoclsparse_dsctrs(
2886 x.len() as sys::aoclsparse_int,
2887 x.as_ptr(),
2888 stride as sys::aoclsparse_int,
2889 y.as_mut_ptr(),
2890 )
2891 };
2892 check_status("sparse", status)
2893}
2894pub fn sctrs_f32(x: &[f32], y: &mut [f32], stride: i32) -> Result<()> {
2896 let status = unsafe {
2897 sys::aoclsparse_ssctrs(
2898 x.len() as sys::aoclsparse_int,
2899 x.as_ptr(),
2900 stride as sys::aoclsparse_int,
2901 y.as_mut_ptr(),
2902 )
2903 };
2904 check_status("sparse", status)
2905}
2906
2907pub fn gthrz_f64(y: &mut [f64], indx: &[sys::aoclsparse_int], x: &mut [f64]) -> Result<()> {
2910 let status = unsafe {
2911 sys::aoclsparse_dgthrz(
2912 x.len() as sys::aoclsparse_int,
2913 y.as_mut_ptr(),
2914 x.as_mut_ptr(),
2915 indx.as_ptr(),
2916 )
2917 };
2918 check_status("sparse", status)
2919}
2920pub fn gthrz_f32(y: &mut [f32], indx: &[sys::aoclsparse_int], x: &mut [f32]) -> Result<()> {
2922 let status = unsafe {
2923 sys::aoclsparse_sgthrz(
2924 x.len() as sys::aoclsparse_int,
2925 y.as_mut_ptr(),
2926 x.as_mut_ptr(),
2927 indx.as_ptr(),
2928 )
2929 };
2930 check_status("sparse", status)
2931}
2932
2933#[allow(clippy::too_many_arguments)]
2937pub fn sorv_f32(
2938 sor_type: SorType,
2939 descr: &MatDescr,
2940 a: &SparseMatrix<f32>,
2941 omega: f32,
2942 alpha: f32,
2943 x: &mut [f32],
2944 b: &[f32],
2945) -> Result<()> {
2946 let status = unsafe {
2947 sys::aoclsparse_ssorv(
2948 sor_type.raw(),
2949 descr.as_raw(),
2950 a.as_raw(),
2951 omega,
2952 alpha,
2953 x.as_mut_ptr(),
2954 b.as_ptr(),
2955 )
2956 };
2957 check_status("sparse", status)
2958}
2959
2960#[allow(clippy::too_many_arguments)]
2967pub fn create_tcsr_f64(
2968 base: IndexBase,
2969 m: usize,
2970 n: usize,
2971 nnz: usize,
2972 row_ptr_l: &mut [sys::aoclsparse_int],
2973 row_ptr_u: &mut [sys::aoclsparse_int],
2974 col_idx_l: &mut [sys::aoclsparse_int],
2975 col_idx_u: &mut [sys::aoclsparse_int],
2976 val_l: &mut [f64],
2977 val_u: &mut [f64],
2978) -> Result<sys::aoclsparse_matrix> {
2979 let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
2980 let status = unsafe {
2981 sys::aoclsparse_create_dtcsr(
2982 &mut raw,
2983 base.raw(),
2984 m as sys::aoclsparse_int,
2985 n as sys::aoclsparse_int,
2986 nnz as sys::aoclsparse_int,
2987 row_ptr_l.as_mut_ptr(),
2988 row_ptr_u.as_mut_ptr(),
2989 col_idx_l.as_mut_ptr(),
2990 col_idx_u.as_mut_ptr(),
2991 val_l.as_mut_ptr(),
2992 val_u.as_mut_ptr(),
2993 )
2994 };
2995 check_status("sparse", status)?;
2996 if raw.is_null() {
2997 return Err(Error::AllocationFailed("sparse"));
2998 }
2999 Ok(raw)
3000}
3001
3002#[allow(clippy::too_many_arguments)]
3004pub fn create_tcsr_f32(
3005 base: IndexBase,
3006 m: usize,
3007 n: usize,
3008 nnz: usize,
3009 row_ptr_l: &mut [sys::aoclsparse_int],
3010 row_ptr_u: &mut [sys::aoclsparse_int],
3011 col_idx_l: &mut [sys::aoclsparse_int],
3012 col_idx_u: &mut [sys::aoclsparse_int],
3013 val_l: &mut [f32],
3014 val_u: &mut [f32],
3015) -> Result<sys::aoclsparse_matrix> {
3016 let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
3017 let status = unsafe {
3018 sys::aoclsparse_create_stcsr(
3019 &mut raw,
3020 base.raw(),
3021 m as sys::aoclsparse_int,
3022 n as sys::aoclsparse_int,
3023 nnz as sys::aoclsparse_int,
3024 row_ptr_l.as_mut_ptr(),
3025 row_ptr_u.as_mut_ptr(),
3026 col_idx_l.as_mut_ptr(),
3027 col_idx_u.as_mut_ptr(),
3028 val_l.as_mut_ptr(),
3029 val_u.as_mut_ptr(),
3030 )
3031 };
3032 check_status("sparse", status)?;
3033 if raw.is_null() {
3034 return Err(Error::AllocationFailed("sparse"));
3035 }
3036 Ok(raw)
3037}
3038
3039#[allow(clippy::too_many_arguments)]
3045pub fn create_csc_f64(
3046 base: IndexBase,
3047 m: usize,
3048 n: usize,
3049 nnz: usize,
3050 col_ptr: &mut [sys::aoclsparse_int],
3051 row_idx: &mut [sys::aoclsparse_int],
3052 val: &mut [f64],
3053) -> Result<sys::aoclsparse_matrix> {
3054 let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
3055 let status = unsafe {
3056 sys::aoclsparse_create_dcsc(
3057 &mut raw,
3058 base.raw(),
3059 m as sys::aoclsparse_int,
3060 n as sys::aoclsparse_int,
3061 nnz as sys::aoclsparse_int,
3062 col_ptr.as_mut_ptr(),
3063 row_idx.as_mut_ptr(),
3064 val.as_mut_ptr(),
3065 )
3066 };
3067 check_status("sparse", status)?;
3068 if raw.is_null() {
3069 return Err(Error::AllocationFailed("sparse"));
3070 }
3071 Ok(raw)
3072}
3073#[allow(clippy::too_many_arguments)]
3075pub fn create_csc_f32(
3076 base: IndexBase,
3077 m: usize,
3078 n: usize,
3079 nnz: usize,
3080 col_ptr: &mut [sys::aoclsparse_int],
3081 row_idx: &mut [sys::aoclsparse_int],
3082 val: &mut [f32],
3083) -> Result<sys::aoclsparse_matrix> {
3084 let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
3085 let status = unsafe {
3086 sys::aoclsparse_create_scsc(
3087 &mut raw,
3088 base.raw(),
3089 m as sys::aoclsparse_int,
3090 n as sys::aoclsparse_int,
3091 nnz as sys::aoclsparse_int,
3092 col_ptr.as_mut_ptr(),
3093 row_idx.as_mut_ptr(),
3094 val.as_mut_ptr(),
3095 )
3096 };
3097 check_status("sparse", status)?;
3098 if raw.is_null() {
3099 return Err(Error::AllocationFailed("sparse"));
3100 }
3101 Ok(raw)
3102}
3103
3104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3106pub enum MemoryUsage {
3107 Minimal,
3109 Unrestricted,
3111}
3112
3113impl MemoryUsage {
3114 fn raw(self) -> sys::aoclsparse_memory_usage {
3115 match self {
3116 MemoryUsage::Minimal => sys::aoclsparse_memory_usage__aoclsparse_memory_usage_minimal,
3117 MemoryUsage::Unrestricted => {
3118 sys::aoclsparse_memory_usage__aoclsparse_memory_usage_unrestricted
3119 }
3120 }
3121 }
3122}
3123
3124pub fn set_memory_hint<T: Scalar>(mat: &mut SparseMatrix<T>, policy: MemoryUsage) -> Result<()> {
3127 let status = unsafe { sys::aoclsparse_set_memory_hint(mat.as_raw(), policy.raw()) };
3128 check_status("sparse", status)
3129}
3130
3131pub fn ilu_smoother<T: Scalar>(
3136 op: Trans,
3137 a: &SparseMatrix<T>,
3138 descr: &MatDescr,
3139 x: &mut [T],
3140 b: &[T],
3141) -> Result<()> {
3142 if x.len() < a.n || b.len() < a.m {
3143 return Err(Error::InvalidArgument(format!(
3144 "ilu_smoother: x.len()={}, b.len()={}, dims=({}, {})",
3145 x.len(),
3146 b.len(),
3147 a.m,
3148 a.n
3149 )));
3150 }
3151 T::ilu_smoother(op, a.raw, descr, x, b)
3152}
3153
3154pub struct IterSolver<T: Scalar> {
3164 handle: sys::aoclsparse_itsol_handle,
3165 _marker: PhantomData<T>,
3166}
3167
3168impl<T: Scalar> IterSolver<T> {
3169 pub fn new() -> Result<Self> {
3171 let mut handle: sys::aoclsparse_itsol_handle = std::ptr::null_mut();
3172 T::itsol_init(&mut handle)?;
3173 if handle.is_null() {
3174 return Err(Error::AllocationFailed("sparse"));
3175 }
3176 Ok(Self {
3177 handle,
3178 _marker: PhantomData,
3179 })
3180 }
3181
3182 pub fn set_option(&mut self, name: &str, value: &str) -> Result<()> {
3187 let c_name = CString::new(name)
3188 .map_err(|_| Error::InvalidArgument("set_option: name has interior NUL".into()))?;
3189 let c_value = CString::new(value)
3190 .map_err(|_| Error::InvalidArgument("set_option: value has interior NUL".into()))?;
3191 let status = unsafe {
3192 sys::aoclsparse_itsol_option_set(self.handle, c_name.as_ptr(), c_value.as_ptr())
3193 };
3194 check_status("sparse", status)
3195 }
3196
3197 pub fn solve(
3202 &mut self,
3203 mat: &SparseMatrix<T>,
3204 descr: &MatDescr,
3205 b: &[T],
3206 x: &mut [T],
3207 ) -> Result<Box<[T; 100]>>
3208 where
3209 T: Default,
3210 {
3211 let n = mat.n;
3212 if mat.m != mat.n {
3213 return Err(Error::InvalidArgument(format!(
3214 "iterative solve requires square matrix; got ({}, {})",
3215 mat.m, mat.n
3216 )));
3217 }
3218 let mut rinfo: Box<[T; 100]> = Box::new([T::default(); 100]);
3219 T::itsol_solve(self.handle, n, mat.raw, descr, b, x, &mut rinfo)?;
3220 Ok(rinfo)
3221 }
3222}
3223
3224impl<T: Scalar> Drop for IterSolver<T> {
3225 fn drop(&mut self) {
3226 if !self.handle.is_null() {
3227 unsafe {
3228 sys::aoclsparse_itsol_destroy(&mut self.handle);
3229 }
3230 self.handle = std::ptr::null_mut();
3231 }
3232 }
3233}
3234
3235impl<T: Scalar> std::fmt::Debug for IterSolver<T> {
3236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3237 f.debug_struct("IterSolver").finish_non_exhaustive()
3238 }
3239}
3240
3241#[cfg(test)]
3242mod tests {
3243 use super::*;
3244
3245 #[test]
3246 fn csrmv_2x2_identity_f64() {
3247 let val = [1.0_f64, 1.0];
3248 let col: [sys::aoclsparse_int; 2] = [0, 1];
3249 let rowptr: [sys::aoclsparse_int; 3] = [0, 1, 2];
3250 let x = [3.0_f64, 4.0];
3251 let mut y = [0.0_f64; 2];
3252 let descr = MatDescr::new().unwrap();
3253 csrmv(1.0_f64, 2, 2, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap();
3254 assert!((y[0] - 3.0).abs() < 1e-12);
3255 assert!((y[1] - 4.0).abs() < 1e-12);
3256 }
3257
3258 #[test]
3259 fn csrmv_simple_2x3() {
3260 let val = [1.0_f64, 2.0, 3.0];
3261 let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3262 let rowptr: [sys::aoclsparse_int; 3] = [0, 2, 3];
3263 let x = [1.0_f64; 3];
3264 let mut y = [0.0_f64; 2];
3265 let descr = MatDescr::new().unwrap();
3266 csrmv(1.0_f64, 2, 3, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap();
3267 assert!((y[0] - 3.0).abs() < 1e-12, "got {}", y[0]);
3268 assert!((y[1] - 3.0).abs() < 1e-12, "got {}", y[1]);
3269 }
3270
3271 #[test]
3272 fn dim_mismatch_is_error() {
3273 let val = [1.0_f64];
3274 let col: [sys::aoclsparse_int; 1] = [0];
3275 let rowptr: [sys::aoclsparse_int; 2] = [0, 1];
3276 let x = [1.0_f64];
3277 let mut y = [0.0_f64; 2];
3278 let descr = MatDescr::new().unwrap();
3279 let err = csrmv(1.0_f64, 2, 1, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap_err();
3280 matches!(err, Error::InvalidArgument(_));
3281 }
3282
3283 #[test]
3284 fn axpyi_scatter() {
3285 let mut y = [10.0_f64, 20.0, 30.0, 40.0];
3288 let x = [1.0_f64, 2.0];
3289 let indx: [sys::aoclsparse_int; 2] = [0, 2];
3290 axpyi(3.0_f64, &x, &indx, &mut y).unwrap();
3291 assert_eq!(y, [13.0, 20.0, 36.0, 40.0]);
3292 }
3293
3294 #[test]
3295 fn gthr_scatter_round_trip() {
3296 let y = [10.0_f64, 20.0, 30.0, 40.0];
3298 let indx: [sys::aoclsparse_int; 2] = [1, 3];
3299 let mut x = [0.0_f64; 2];
3300 gthr(&y, &indx, &mut x).unwrap();
3301 assert_eq!(x, [20.0, 40.0]);
3302
3303 let mut y2 = [0.0_f64; 4];
3304 sctr(&x, &indx, &mut y2).unwrap();
3305 assert_eq!(y2, [0.0, 20.0, 0.0, 40.0]);
3306 }
3307
3308 #[test]
3309 fn add_identity_plus_identity_is_2_diag() {
3310 let val = [1.0_f64, 1.0];
3312 let col: [sys::aoclsparse_int; 2] = [0, 1];
3313 let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
3314 let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
3315 let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
3316 let c = add(Trans::No, &a, 1.0, &b).unwrap();
3317 let (_, _, _, val_c) = c.export_csr().unwrap();
3318 assert_eq!(val_c.len(), 2);
3320 for v in &val_c {
3321 assert!((v - 2.0).abs() < 1e-12, "got {v}, want 2.0");
3322 }
3323 }
3324
3325 #[test]
3326 fn csrmm_2x2_identity_against_2x3_dense() {
3327 let val = [1.0_f64, 1.0];
3330 let col: [sys::aoclsparse_int; 2] = [0, 1];
3331 let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
3332 let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
3333 let descr = MatDescr::new().unwrap();
3334 let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
3335 let mut c = [0.0_f64; 6];
3336 csrmm(
3337 Trans::No,
3338 1.0,
3339 &a,
3340 &descr,
3341 Order::RowMajor,
3342 &b,
3343 3,
3344 3,
3345 0.0,
3346 &mut c,
3347 3,
3348 )
3349 .unwrap();
3350 for (got, want) in c.iter().zip(b.iter()) {
3351 assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
3352 }
3353 }
3354
3355 #[test]
3356 fn spmmd_identity_squared_yields_identity_dense() {
3357 let val = [1.0_f64; 3];
3359 let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3360 let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
3361 let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3362 let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3363 let mut c = [0.0_f64; 9];
3364 spmmd(Trans::No, &a, &b, Order::RowMajor, &mut c, 3).unwrap();
3365 let expected = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
3367 for (got, want) in c.iter().zip(expected.iter()) {
3368 assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
3369 }
3370 }
3371
3372 #[test]
3373 fn ellmv_2x3_f64() {
3374 let val: [f64; 4] = [1.0, 2.0, 3.0, 4.0];
3377 let col: [sys::aoclsparse_int; 4] = [0, 2, 0, 1];
3378 let descr = MatDescr::new().unwrap();
3379 let x = [10.0_f64, 20.0, 30.0];
3380 let mut y = [0.0_f64; 2];
3381 ellmv(
3382 Trans::No,
3383 1.0_f64,
3384 2,
3385 3,
3386 &val,
3387 &col,
3388 2,
3389 &descr,
3390 &x,
3391 0.0,
3392 &mut y,
3393 )
3394 .unwrap();
3395 assert!((y[0] - 70.0).abs() < 1e-12, "got {}", y[0]);
3397 assert!((y[1] - 110.0).abs() < 1e-12, "got {}", y[1]);
3398 }
3399
3400 #[test]
3401 fn bsrmv_2x2_blocks_f64() {
3402 let val: [f64; 8] = [1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
3406 let col: [sys::aoclsparse_int; 2] = [0, 1];
3407 let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
3408 let descr = MatDescr::new().unwrap();
3409 let x = [1.0_f64, 2.0, 3.0, 4.0];
3410 let mut y = [0.0_f64; 4];
3411 bsrmv(
3412 Trans::No,
3413 1.0_f64,
3414 2,
3415 2,
3416 2,
3417 &val,
3418 &col,
3419 &rp,
3420 &descr,
3421 &x,
3422 0.0,
3423 &mut y,
3424 )
3425 .unwrap();
3426 assert!((y[0] - 1.0).abs() < 1e-12);
3428 assert!((y[1] - 2.0).abs() < 1e-12);
3429 assert!((y[2] - 6.0).abs() < 1e-12);
3430 assert!((y[3] - 8.0).abs() < 1e-12);
3431 }
3432
3433 #[test]
3434 fn sparse_matrix_round_trip() {
3435 let val = [1.0_f64, 2.0, 3.0];
3437 let col: [sys::aoclsparse_int; 3] = [0, 2, 1];
3438 let rp: [sys::aoclsparse_int; 3] = [0, 2, 3];
3439 let mat = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 3, &rp, &col, &val).unwrap();
3440 assert_eq!(mat.dims(), (2, 3));
3441 assert_eq!(mat.nnz(), 3);
3442 assert_eq!(mat.base(), IndexBase::Zero);
3443 let (base, rp2, col2, val2) = mat.export_csr().unwrap();
3444 assert_eq!(base, IndexBase::Zero);
3445 assert_eq!(rp2, [0, 2, 3]);
3446 assert_eq!(col2, [0, 2, 1]);
3447 assert_eq!(val2, [1.0, 2.0, 3.0]);
3448 }
3449
3450 #[test]
3451 fn csr2m_identity_squared_is_identity() {
3452 let val = [1.0_f64; 3];
3454 let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3455 let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
3456 let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3457 let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3458 let descr = MatDescr::new().unwrap();
3459 let c = csr2m(
3460 Trans::No,
3461 &descr,
3462 &a,
3463 Trans::No,
3464 &descr,
3465 &b,
3466 Stage::FullComputation,
3467 )
3468 .unwrap();
3469 assert_eq!(c.dims(), (3, 3));
3470 let (_, rp_c, col_c, val_c) = c.export_csr().unwrap();
3471 assert_eq!(rp_c, [0, 1, 2, 3]);
3472 assert_eq!(col_c, [0, 1, 2]);
3473 for v in &val_c {
3474 assert!((v - 1.0).abs() < 1e-12);
3475 }
3476 }
3477
3478 #[test]
3479 fn iter_solver_cg_diagonal_3x3() {
3480 let val = [2.0_f64, 2.0, 2.0];
3482 let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3483 let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
3484 let mat = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3485
3486 let descr = MatDescr::new().unwrap();
3487 unsafe {
3488 sys::aoclsparse_set_mat_type(
3489 descr.as_raw(),
3490 sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric,
3491 );
3492 }
3493
3494 let b = [4.0_f64, 6.0, 10.0];
3495 let mut x = [0.0_f64; 3];
3496 let mut solver = IterSolver::<f64>::new().unwrap();
3497 solver.set_option("iterative method", "cg").unwrap();
3498 solver.set_option("cg rel tolerance", "1e-10").unwrap();
3499 solver.set_option("cg iteration limit", "200").unwrap();
3500 solver.solve(&mat, &descr, &b, &mut x).unwrap();
3501 assert!((x[0] - 2.0).abs() < 1e-6, "x[0] = {}", x[0]);
3502 assert!((x[1] - 3.0).abs() < 1e-6, "x[1] = {}", x[1]);
3503 assert!((x[2] - 5.0).abs() < 1e-6, "x[2] = {}", x[2]);
3504 }
3505
3506 #[test]
3507 fn csr_to_dense_round_trip() {
3508 let val = [1.0_f64, 2.0, 3.0];
3510 let col: [sys::aoclsparse_int; 3] = [0, 2, 1];
3511 let rp: [sys::aoclsparse_int; 3] = [0, 2, 3];
3512 let descr = MatDescr::new().unwrap();
3513 let mut dense = [0.0_f64; 6];
3514 csr_to_dense::<f64>(
3515 2,
3516 3,
3517 &descr,
3518 &val,
3519 &rp,
3520 &col,
3521 &mut dense,
3522 3,
3523 Order::RowMajor,
3524 )
3525 .unwrap();
3526 assert_eq!(dense, [1.0, 0.0, 2.0, 0.0, 3.0, 0.0]);
3527 }
3528}