Skip to main content

aocl_sparse/
lib.rs

1//! Safe wrappers for AOCL-Sparse.
2
3#![warn(missing_debug_implementations)]
4#![cfg_attr(docsrs, feature(doc_cfg))]
5// The Scalar / ComplexScalar traits in this crate are sealed (require
6// `aocl_types::sealed::Sealed`), so they cannot be implemented from
7// outside this crate. The raw-pointer arguments on a few trait methods
8// are FFI plumbing called only by the safe wrappers in this same
9// crate; clippy's lint targets the case where a downstream
10// implementer might deref a hostile pointer, which can't happen here.
11#![allow(clippy::not_unsafe_ptr_arg_deref)]
12// Some `(IndexBase, Vec, Vec, Vec)` return tuples are intentional —
13// they mirror what the C API hands back. Aliasing them would obscure
14// rather than clarify the underlying data layout.
15#![allow(clippy::type_complexity)]
16
17use std::ffi::CString;
18use std::marker::PhantomData;
19
20pub use aocl_error::{Error, Result};
21use aocl_sparse_sys as sys;
22use aocl_types::sealed::Sealed;
23pub use aocl_types::{Complex32, Complex64, Trans};
24
25pub mod complex;
26
27fn trans_raw(t: Trans) -> sys::aoclsparse_operation {
28    match t {
29        Trans::No => sys::aoclsparse_operation__aoclsparse_operation_none,
30        Trans::T => sys::aoclsparse_operation__aoclsparse_operation_transpose,
31        Trans::C => sys::aoclsparse_operation__aoclsparse_operation_conjugate_transpose,
32    }
33}
34
35fn check_status(component: &'static str, status: sys::aoclsparse_status) -> Result<()> {
36    if status == sys::aoclsparse_status__aoclsparse_status_success {
37        return Ok(());
38    }
39    let message = match status {
40        s if s == sys::aoclsparse_status__aoclsparse_status_not_implemented => "not implemented",
41        s if s == sys::aoclsparse_status__aoclsparse_status_invalid_pointer => "invalid pointer",
42        s if s == sys::aoclsparse_status__aoclsparse_status_invalid_size => "invalid size",
43        s if s == sys::aoclsparse_status__aoclsparse_status_internal_error => "internal error",
44        s if s == sys::aoclsparse_status__aoclsparse_status_invalid_value => "invalid value",
45        s if s == sys::aoclsparse_status__aoclsparse_status_invalid_index_value => {
46            "invalid index value"
47        }
48        s if s == sys::aoclsparse_status__aoclsparse_status_maxit => "max iterations reached",
49        s if s == sys::aoclsparse_status__aoclsparse_status_user_stop => "user stop",
50        s if s == sys::aoclsparse_status__aoclsparse_status_wrong_type => "wrong type",
51        s if s == sys::aoclsparse_status__aoclsparse_status_memory_error => "memory error",
52        _ => "unknown sparse status",
53    }
54    .to_string();
55    Err(Error::Status {
56        component,
57        code: status as i64,
58        message,
59    })
60}
61
62/// RAII wrapper for `aoclsparse_mat_descr`.
63pub struct MatDescr {
64    raw: sys::aoclsparse_mat_descr,
65}
66
67impl MatDescr {
68    /// Create a fresh descriptor with library defaults.
69    pub fn new() -> Result<Self> {
70        let mut raw: sys::aoclsparse_mat_descr = std::ptr::null_mut();
71        let status = unsafe { sys::aoclsparse_create_mat_descr(&mut raw) };
72        check_status("sparse", status)?;
73        if raw.is_null() {
74            return Err(Error::AllocationFailed("sparse"));
75        }
76        Ok(MatDescr { raw })
77    }
78
79    /// Borrow the underlying handle for raw FFI calls.
80    ///
81    /// # Safety
82    /// The returned pointer is valid only for the lifetime of `self`.
83    /// Do not call `aoclsparse_destroy_mat_descr` on it.
84    pub fn as_raw(&self) -> sys::aoclsparse_mat_descr {
85        self.raw
86    }
87
88    /// Set the matrix-type hint (general / symmetric / hermitian /
89    /// triangular). Affects which fast paths the library can take.
90    pub fn set_type(&mut self, ty: MatType) -> Result<()> {
91        let status = unsafe { sys::aoclsparse_set_mat_type(self.raw, ty.raw()) };
92        check_status("sparse", status)
93    }
94
95    /// Set the index base (zero- or one-based) for column / row arrays.
96    pub fn set_index_base(&mut self, base: IndexBase) -> Result<()> {
97        let status = unsafe { sys::aoclsparse_set_mat_index_base(self.raw, base.raw()) };
98        check_status("sparse", status)
99    }
100
101    /// For triangular / symmetric matrices, declare which triangle is
102    /// stored (upper or lower).
103    pub fn set_fill_mode(&mut self, fill: FillMode) -> Result<()> {
104        let status = unsafe { sys::aoclsparse_set_mat_fill_mode(self.raw, fill.raw()) };
105        check_status("sparse", status)
106    }
107
108    /// For triangular matrices, declare whether the diagonal is unit
109    /// (implicit) or non-unit (explicitly stored).
110    pub fn set_diag_type(&mut self, diag: DiagType) -> Result<()> {
111        let status = unsafe { sys::aoclsparse_set_mat_diag_type(self.raw, diag.raw()) };
112        check_status("sparse", status)
113    }
114
115    /// Read back the matrix-type hint.
116    pub fn ty(&self) -> MatType {
117        let raw = unsafe { sys::aoclsparse_get_mat_type(self.raw) };
118        MatType::from_raw(raw).unwrap_or(MatType::General)
119    }
120
121    /// Read back the index base.
122    pub fn index_base(&self) -> IndexBase {
123        let raw = unsafe { sys::aoclsparse_get_mat_index_base(self.raw) };
124        if raw == sys::aoclsparse_index_base__aoclsparse_index_base_one {
125            IndexBase::One
126        } else {
127            IndexBase::Zero
128        }
129    }
130
131    /// Read back the fill mode.
132    pub fn fill_mode(&self) -> FillMode {
133        let raw = unsafe { sys::aoclsparse_get_mat_fill_mode(self.raw) };
134        FillMode::from_raw(raw).unwrap_or(FillMode::Lower)
135    }
136
137    /// Read back the diagonal-type hint.
138    pub fn diag_type(&self) -> DiagType {
139        let raw = unsafe { sys::aoclsparse_get_mat_diag_type(self.raw) };
140        DiagType::from_raw(raw).unwrap_or(DiagType::NonUnit)
141    }
142}
143
144/// Storage type of a sparse matrix's nonzero pattern.
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
146pub enum MatType {
147    General,
148    Symmetric,
149    Hermitian,
150    Triangular,
151}
152
153impl MatType {
154    fn raw(self) -> sys::aoclsparse_matrix_type {
155        match self {
156            MatType::General => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_general,
157            MatType::Symmetric => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric,
158            MatType::Hermitian => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_hermitian,
159            MatType::Triangular => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_triangular,
160        }
161    }
162
163    fn from_raw(raw: sys::aoclsparse_matrix_type) -> Option<Self> {
164        Some(match raw {
165            r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_general => {
166                MatType::General
167            }
168            r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric => {
169                MatType::Symmetric
170            }
171            r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_hermitian => {
172                MatType::Hermitian
173            }
174            r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_triangular => {
175                MatType::Triangular
176            }
177            _ => return None,
178        })
179    }
180}
181
182/// Which triangle of a symmetric / triangular matrix is stored.
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
184pub enum FillMode {
185    Lower,
186    Upper,
187}
188
189impl FillMode {
190    fn raw(self) -> sys::aoclsparse_fill_mode {
191        match self {
192            FillMode::Lower => sys::aoclsparse_fill_mode__aoclsparse_fill_mode_lower,
193            FillMode::Upper => sys::aoclsparse_fill_mode__aoclsparse_fill_mode_upper,
194        }
195    }
196    fn from_raw(raw: sys::aoclsparse_fill_mode) -> Option<Self> {
197        Some(match raw {
198            r if r == sys::aoclsparse_fill_mode__aoclsparse_fill_mode_lower => FillMode::Lower,
199            r if r == sys::aoclsparse_fill_mode__aoclsparse_fill_mode_upper => FillMode::Upper,
200            _ => return None,
201        })
202    }
203}
204
205/// Whether the diagonal of a triangular matrix is implicitly unit
206/// (`Unit`) or explicitly stored (`NonUnit`).
207#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
208pub enum DiagType {
209    Unit,
210    NonUnit,
211}
212
213impl DiagType {
214    fn raw(self) -> sys::aoclsparse_diag_type {
215        match self {
216            DiagType::Unit => sys::aoclsparse_diag_type__aoclsparse_diag_type_unit,
217            DiagType::NonUnit => sys::aoclsparse_diag_type__aoclsparse_diag_type_non_unit,
218        }
219    }
220    fn from_raw(raw: sys::aoclsparse_diag_type) -> Option<Self> {
221        Some(match raw {
222            r if r == sys::aoclsparse_diag_type__aoclsparse_diag_type_unit => DiagType::Unit,
223            r if r == sys::aoclsparse_diag_type__aoclsparse_diag_type_non_unit => DiagType::NonUnit,
224            _ => return None,
225        })
226    }
227}
228
229/// Make a deep copy of a `MatDescr` with the same options.
230pub fn copy_mat_descr(src: &MatDescr) -> Result<MatDescr> {
231    let dest = MatDescr::new()?;
232    let status = unsafe { sys::aoclsparse_copy_mat_descr(dest.raw, src.raw) };
233    check_status("sparse", status)?;
234    Ok(dest)
235}
236
237/// Run AOCL-Sparse's analysis / optimization step on a matrix handle.
238/// Use this after declaring hints (e.g. `set_mv_hint`) to let the
239/// library choose specialised kernels for repeated operations.
240pub fn optimize<T: Scalar>(mat: &mut SparseMatrix<T>) -> Result<()> {
241    let status = unsafe { sys::aoclsparse_optimize(mat.as_raw()) };
242    check_status("sparse", status)
243}
244
245/// AOCL-Sparse library version string (e.g. `"AOCL-Sparse 5.1.0"`).
246pub fn version() -> &'static str {
247    unsafe {
248        let p = sys::aoclsparse_get_version();
249        if p.is_null() {
250            return "";
251        }
252        std::ffi::CStr::from_ptr(p).to_str().unwrap_or("")
253    }
254}
255
256impl Drop for MatDescr {
257    fn drop(&mut self) {
258        if !self.raw.is_null() {
259            unsafe {
260                let _ = sys::aoclsparse_destroy_mat_descr(self.raw);
261            }
262            self.raw = std::ptr::null_mut();
263        }
264    }
265}
266
267impl std::fmt::Debug for MatDescr {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        f.debug_struct("MatDescr").finish_non_exhaustive()
270    }
271}
272
273/// Scalar element type usable with the wrapped sparse routines.
274pub trait Scalar: Copy + Sized + Sealed {
275    /// `y := α · op(A) · x + β · y` where `A` is in CSR format.
276    #[allow(clippy::too_many_arguments)]
277    fn csrmv(
278        op: Trans,
279        alpha: Self,
280        m: usize,
281        n: usize,
282        csr_val: &[Self],
283        csr_col_ind: &[sys::aoclsparse_int],
284        csr_row_ptr: &[sys::aoclsparse_int],
285        descr: &MatDescr,
286        x: &[Self],
287        beta: Self,
288        y: &mut [Self],
289    ) -> Result<()>;
290
291    /// Sparse `y[indx] := α·x + y[indx]`.
292    fn axpyi(alpha: Self, x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()>;
293
294    /// Sparse gather: `x[i] := y[indx[i]]`.
295    fn gthr(y: &[Self], indx: &[sys::aoclsparse_int], x: &mut [Self]) -> Result<()>;
296
297    /// Sparse scatter: `y[indx[i]] := x[i]`.
298    fn sctr(x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()>;
299
300    /// Solve `op(A) · y = α · x` where `A` is sparse triangular (CSR).
301    #[allow(clippy::too_many_arguments)]
302    fn csrsv(
303        op: Trans,
304        alpha: Self,
305        m: usize,
306        csr_val: &[Self],
307        csr_col_ind: &[sys::aoclsparse_int],
308        csr_row_ptr: &[sys::aoclsparse_int],
309        descr: &MatDescr,
310        x: &[Self],
311        y: &mut [Self],
312    ) -> Result<()>;
313
314    /// Convert a CSR matrix to a dense `m × n` matrix.
315    #[allow(clippy::too_many_arguments)]
316    fn csr_to_dense(
317        m: usize,
318        n: usize,
319        descr: &MatDescr,
320        csr_val: &[Self],
321        csr_row_ptr: &[sys::aoclsparse_int],
322        csr_col_ind: &[sys::aoclsparse_int],
323        a: &mut [Self],
324        ld: usize,
325        order: Order,
326    ) -> Result<()>;
327
328    /// Convert CSR → CSC.
329    #[allow(clippy::too_many_arguments)]
330    fn csr_to_csc(
331        m: usize,
332        n: usize,
333        descr: &MatDescr,
334        base_csc: IndexBase,
335        csr_row_ptr: &[sys::aoclsparse_int],
336        csr_col_ind: &[sys::aoclsparse_int],
337        csr_val: &[Self],
338        csc_row_ind: &mut [sys::aoclsparse_int],
339        csc_col_ptr: &mut [sys::aoclsparse_int],
340        csc_val: &mut [Self],
341    ) -> Result<()>;
342
343    /// `y := α · op(A) · x + β · y` where `A` is in ELLPACK format.
344    /// Library only supports `op = aoclsparse_operation_none`.
345    #[allow(clippy::too_many_arguments)]
346    fn ellmv(
347        op: Trans,
348        alpha: Self,
349        m: usize,
350        n: usize,
351        ell_val: &[Self],
352        ell_col_ind: &[sys::aoclsparse_int],
353        ell_width: usize,
354        descr: &MatDescr,
355        x: &[Self],
356        beta: Self,
357        y: &mut [Self],
358    ) -> Result<()>;
359
360    /// `y := α · op(A) · x + β · y` where `A` is in BSR format.
361    /// `mb`/`nb` are block-row / block-column counts and `bsr_dim` is the
362    /// block edge length. Library only supports `op = aoclsparse_operation_none`.
363    #[allow(clippy::too_many_arguments)]
364    fn bsrmv(
365        op: Trans,
366        alpha: Self,
367        mb: usize,
368        nb: usize,
369        bsr_dim: usize,
370        bsr_val: &[Self],
371        bsr_col_ind: &[sys::aoclsparse_int],
372        bsr_row_ptr: &[sys::aoclsparse_int],
373        descr: &MatDescr,
374        x: &[Self],
375        beta: Self,
376        y: &mut [Self],
377    ) -> Result<()>;
378
379    /// Wrap raw CSR pointers in a fresh `aoclsparse_matrix` handle. The
380    /// library does not copy the underlying arrays — keep them alive for as
381    /// long as the returned handle is in use.
382    #[allow(clippy::too_many_arguments)]
383    fn create_csr(
384        base: IndexBase,
385        m: usize,
386        n: usize,
387        nnz: usize,
388        row_ptr: *mut sys::aoclsparse_int,
389        col_idx: *mut sys::aoclsparse_int,
390        val: *mut Self,
391    ) -> Result<sys::aoclsparse_matrix>;
392
393    /// Read out a library-owned CSR matrix's metadata and array pointers.
394    /// Pointers reference internal storage; do not modify or free them.
395    fn export_csr(
396        mat: sys::aoclsparse_matrix,
397    ) -> Result<(
398        IndexBase,
399        usize,
400        usize,
401        usize,
402        *mut sys::aoclsparse_int,
403        *mut sys::aoclsparse_int,
404        *mut Self,
405    )>;
406
407    /// ILU(0) smoother: applies one ILU step in place to `x`.
408    fn ilu_smoother(
409        op: Trans,
410        a: sys::aoclsparse_matrix,
411        descr: &MatDescr,
412        x: &mut [Self],
413        b: &[Self],
414    ) -> Result<()>;
415
416    /// Initialise an iterative-solver handle for this scalar type.
417    fn itsol_init(handle: &mut sys::aoclsparse_itsol_handle) -> Result<()>;
418
419    /// Run the iterative solver's forward (direct) interface.
420    #[allow(clippy::too_many_arguments)]
421    fn itsol_solve(
422        handle: sys::aoclsparse_itsol_handle,
423        n: usize,
424        mat: sys::aoclsparse_matrix,
425        descr: &MatDescr,
426        b: &[Self],
427        x: &mut [Self],
428        rinfo: &mut [Self; 100],
429    ) -> Result<()>;
430
431    /// Sparse-sparse matrix product producing a new CSR matrix. The output
432    /// pointer is written through `*out`.
433    ///
434    /// # Safety
435    /// Caller is responsible for the validity of all `aoclsparse_*` handle
436    /// arguments and for adopting `*out` (e.g. via
437    /// [`SparseMatrix::from_library_owned`]) on success.
438    #[allow(clippy::too_many_arguments)]
439    unsafe fn csr2m_ffi(
440        op_a: sys::aoclsparse_operation,
441        descr_a: sys::aoclsparse_mat_descr,
442        a: sys::aoclsparse_matrix,
443        op_b: sys::aoclsparse_operation,
444        descr_b: sys::aoclsparse_mat_descr,
445        b: sys::aoclsparse_matrix,
446        request: sys::aoclsparse_request,
447        out: *mut sys::aoclsparse_matrix,
448    ) -> sys::aoclsparse_status;
449
450    /// `C := α · op(A) · B + β · C` where `A` is sparse (CSR via the
451    /// matrix handle) and `B`, `C` are dense.
452    #[allow(clippy::too_many_arguments)]
453    fn csrmm(
454        op: Trans,
455        alpha: Self,
456        a: sys::aoclsparse_matrix,
457        descr: &MatDescr,
458        order: Order,
459        b: &[Self],
460        n: usize,
461        ldb: usize,
462        beta: Self,
463        c: &mut [Self],
464        ldc: usize,
465    ) -> Result<()>;
466
467    /// `C := op(A) · B` where both `A` and `B` are sparse and `C` is
468    /// dense.
469    #[allow(clippy::too_many_arguments)]
470    fn spmmd(
471        op: Trans,
472        a: sys::aoclsparse_matrix,
473        b: sys::aoclsparse_matrix,
474        layout: Order,
475        c: &mut [Self],
476        ldc: usize,
477    ) -> Result<()>;
478
479    /// `C := α · op_A(A) · op_B(B) + β · C` where both `A` and `B` are
480    /// sparse and `C` is dense.
481    #[allow(clippy::too_many_arguments)]
482    fn sp2md(
483        op_a: Trans,
484        descr_a: &MatDescr,
485        a: sys::aoclsparse_matrix,
486        op_b: Trans,
487        descr_b: &MatDescr,
488        b: sys::aoclsparse_matrix,
489        alpha: Self,
490        beta: Self,
491        c: &mut [Self],
492        layout: Order,
493        ldc: usize,
494    ) -> Result<()>;
495
496    /// Sparse-sparse `C := α · op(A) + B` returning a fresh CSR handle.
497    ///
498    /// # Safety
499    /// Caller must adopt the resulting `*out` (e.g. via
500    /// [`SparseMatrix::from_library_owned`]).
501    unsafe fn add_ffi(
502        op: sys::aoclsparse_operation,
503        a: sys::aoclsparse_matrix,
504        alpha: Self,
505        b: sys::aoclsparse_matrix,
506        out: *mut sys::aoclsparse_matrix,
507    ) -> sys::aoclsparse_status;
508
509    /// One step of the (S)SOR / forward / backward Gauss-Seidel
510    /// relaxation: `x_new := (1 − ω) x + ω · A⁻¹ (α · b)`.
511    #[allow(clippy::too_many_arguments)]
512    fn sorv(
513        sor_type: SorType,
514        descr: &MatDescr,
515        a: sys::aoclsparse_matrix,
516        omega: Self,
517        alpha: Self,
518        x: &mut [Self],
519        b: &[Self],
520    ) -> Result<()>;
521}
522
523/// Sweep direction for [`sorv`].
524#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
525pub enum SorType {
526    /// Forward Gauss-Seidel sweep.
527    Forward,
528    /// Backward Gauss-Seidel sweep.
529    Backward,
530    /// Symmetric (forward + backward) sweep.
531    Symmetric,
532}
533
534impl SorType {
535    pub(crate) fn raw(self) -> sys::aoclsparse_sor_type {
536        match self {
537            SorType::Forward => sys::aoclsparse_sor_type__aoclsparse_sor_forward,
538            SorType::Backward => sys::aoclsparse_sor_type__aoclsparse_sor_backward,
539            SorType::Symmetric => sys::aoclsparse_sor_type__aoclsparse_sor_symmetric,
540        }
541    }
542}
543
544macro_rules! impl_scalar {
545    (
546        $t:ty,
547        csrmv = $csrmv:ident,
548        axpyi = $axpyi:ident,
549        gthr = $gthr:ident,
550        sctr = $sctr:ident,
551        csrsv = $csrsv:ident,
552        csr2dense = $csr2dense:ident,
553        csr2csc = $csr2csc:ident,
554        ellmv = $ellmv:ident,
555        bsrmv = $bsrmv:ident,
556        create_csr = $create_csr:ident,
557        export_csr = $export_csr:ident,
558        ilu_smoother = $ilu_smoother:ident,
559        itsol_init = $itsol_init:ident,
560        itsol_solve = $itsol_solve:ident,
561        csr2m = $csr2m:ident,
562        csrmm = $csrmm:ident,
563        spmmd = $spmmd:ident,
564        sp2md = $sp2md:ident,
565        add = $add:ident,
566        sorv = $sorv:ident
567    ) => {
568        impl Scalar for $t {
569            fn csrmv(
570                op: Trans,
571                alpha: Self,
572                m: usize,
573                n: usize,
574                csr_val: &[Self],
575                csr_col_ind: &[sys::aoclsparse_int],
576                csr_row_ptr: &[sys::aoclsparse_int],
577                descr: &MatDescr,
578                x: &[Self],
579                beta: Self,
580                y: &mut [Self],
581            ) -> Result<()> {
582                if csr_row_ptr.len() != m + 1 {
583                    return Err(Error::InvalidArgument(format!(
584                        "csrmv: csr_row_ptr length {} != m+1 = {}",
585                        csr_row_ptr.len(),
586                        m + 1
587                    )));
588                }
589                let nnz = csr_val.len();
590                if csr_col_ind.len() != nnz {
591                    return Err(Error::InvalidArgument(format!(
592                        "csrmv: csr_col_ind length {} != csr_val length {}",
593                        csr_col_ind.len(),
594                        nnz
595                    )));
596                }
597                let (x_len, y_len) = match op {
598                    Trans::No => (n, m),
599                    Trans::T | Trans::C => (m, n),
600                };
601                if x.len() < x_len {
602                    return Err(Error::InvalidArgument(format!(
603                        "csrmv: x length {} < expected {x_len}",
604                        x.len()
605                    )));
606                }
607                if y.len() < y_len {
608                    return Err(Error::InvalidArgument(format!(
609                        "csrmv: y length {} < expected {y_len}",
610                        y.len()
611                    )));
612                }
613
614                let status = unsafe {
615                    sys::$csrmv(
616                        trans_raw(op),
617                        &alpha,
618                        m as sys::aoclsparse_int,
619                        n as sys::aoclsparse_int,
620                        nnz as sys::aoclsparse_int,
621                        csr_val.as_ptr(),
622                        csr_col_ind.as_ptr(),
623                        csr_row_ptr.as_ptr(),
624                        descr.as_raw(),
625                        x.as_ptr(),
626                        &beta,
627                        y.as_mut_ptr(),
628                    )
629                };
630                check_status("sparse", status)
631            }
632
633            fn axpyi(
634                alpha: Self,
635                x: &[Self],
636                indx: &[sys::aoclsparse_int],
637                y: &mut [Self],
638            ) -> Result<()> {
639                let status = unsafe {
640                    sys::$axpyi(
641                        x.len() as sys::aoclsparse_int,
642                        alpha,
643                        x.as_ptr(),
644                        indx.as_ptr(),
645                        y.as_mut_ptr(),
646                    )
647                };
648                check_status("sparse", status)
649            }
650
651            fn gthr(y: &[Self], indx: &[sys::aoclsparse_int], x: &mut [Self]) -> Result<()> {
652                let status = unsafe {
653                    sys::$gthr(
654                        x.len() as sys::aoclsparse_int,
655                        y.as_ptr(),
656                        x.as_mut_ptr(),
657                        indx.as_ptr(),
658                    )
659                };
660                check_status("sparse", status)
661            }
662
663            fn sctr(x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()> {
664                let status = unsafe {
665                    sys::$sctr(
666                        x.len() as sys::aoclsparse_int,
667                        x.as_ptr(),
668                        indx.as_ptr(),
669                        y.as_mut_ptr(),
670                    )
671                };
672                check_status("sparse", status)
673            }
674
675            #[allow(clippy::too_many_arguments)]
676            fn csrsv(
677                op: Trans,
678                alpha: Self,
679                m: usize,
680                csr_val: &[Self],
681                csr_col_ind: &[sys::aoclsparse_int],
682                csr_row_ptr: &[sys::aoclsparse_int],
683                descr: &MatDescr,
684                x: &[Self],
685                y: &mut [Self],
686            ) -> Result<()> {
687                if csr_row_ptr.len() != m + 1 {
688                    return Err(Error::InvalidArgument(format!(
689                        "csrsv: csr_row_ptr length {} != m+1 = {}",
690                        csr_row_ptr.len(),
691                        m + 1
692                    )));
693                }
694                if x.len() < m || y.len() < m {
695                    return Err(Error::InvalidArgument(format!(
696                        "csrsv: x.len()={}, y.len()={}, m={m}",
697                        x.len(),
698                        y.len()
699                    )));
700                }
701                let status = unsafe {
702                    sys::$csrsv(
703                        trans_raw(op),
704                        &alpha,
705                        m as sys::aoclsparse_int,
706                        csr_val.as_ptr(),
707                        csr_col_ind.as_ptr(),
708                        csr_row_ptr.as_ptr(),
709                        descr.as_raw(),
710                        x.as_ptr(),
711                        y.as_mut_ptr(),
712                    )
713                };
714                check_status("sparse", status)
715            }
716
717            #[allow(clippy::too_many_arguments)]
718            fn csr_to_dense(
719                m: usize,
720                n: usize,
721                descr: &MatDescr,
722                csr_val: &[Self],
723                csr_row_ptr: &[sys::aoclsparse_int],
724                csr_col_ind: &[sys::aoclsparse_int],
725                a: &mut [Self],
726                ld: usize,
727                order: Order,
728            ) -> Result<()> {
729                let status = unsafe {
730                    sys::$csr2dense(
731                        m as sys::aoclsparse_int,
732                        n as sys::aoclsparse_int,
733                        descr.as_raw(),
734                        csr_val.as_ptr(),
735                        csr_row_ptr.as_ptr(),
736                        csr_col_ind.as_ptr(),
737                        a.as_mut_ptr(),
738                        ld as sys::aoclsparse_int,
739                        order.raw(),
740                    )
741                };
742                check_status("sparse", status)
743            }
744
745            #[allow(clippy::too_many_arguments)]
746            fn csr_to_csc(
747                m: usize,
748                n: usize,
749                descr: &MatDescr,
750                base_csc: IndexBase,
751                csr_row_ptr: &[sys::aoclsparse_int],
752                csr_col_ind: &[sys::aoclsparse_int],
753                csr_val: &[Self],
754                csc_row_ind: &mut [sys::aoclsparse_int],
755                csc_col_ptr: &mut [sys::aoclsparse_int],
756                csc_val: &mut [Self],
757            ) -> Result<()> {
758                let nnz = csr_val.len();
759                let status = unsafe {
760                    sys::$csr2csc(
761                        m as sys::aoclsparse_int,
762                        n as sys::aoclsparse_int,
763                        nnz as sys::aoclsparse_int,
764                        descr.as_raw(),
765                        base_csc.raw(),
766                        csr_row_ptr.as_ptr(),
767                        csr_col_ind.as_ptr(),
768                        csr_val.as_ptr(),
769                        csc_row_ind.as_mut_ptr(),
770                        csc_col_ptr.as_mut_ptr(),
771                        csc_val.as_mut_ptr(),
772                    )
773                };
774                check_status("sparse", status)
775            }
776
777            #[allow(clippy::too_many_arguments)]
778            fn ellmv(
779                op: Trans,
780                alpha: Self,
781                m: usize,
782                n: usize,
783                ell_val: &[Self],
784                ell_col_ind: &[sys::aoclsparse_int],
785                ell_width: usize,
786                descr: &MatDescr,
787                x: &[Self],
788                beta: Self,
789                y: &mut [Self],
790            ) -> Result<()> {
791                let nnz = ell_val.len();
792                if ell_col_ind.len() != nnz {
793                    return Err(Error::InvalidArgument(format!(
794                        "ellmv: ell_col_ind length {} != ell_val length {nnz}",
795                        ell_col_ind.len()
796                    )));
797                }
798                let needed = m.checked_mul(ell_width).ok_or_else(|| {
799                    Error::InvalidArgument("ellmv: m * ell_width overflows".into())
800                })?;
801                if nnz < needed {
802                    return Err(Error::InvalidArgument(format!(
803                        "ellmv: ell_val length {nnz} < m*ell_width = {needed}"
804                    )));
805                }
806                let (x_len, y_len) = match op {
807                    Trans::No => (n, m),
808                    Trans::T | Trans::C => (m, n),
809                };
810                if x.len() < x_len || y.len() < y_len {
811                    return Err(Error::InvalidArgument(format!(
812                        "ellmv: x.len()={}, y.len()={}, expected ({x_len}, {y_len})",
813                        x.len(),
814                        y.len()
815                    )));
816                }
817                let status = unsafe {
818                    sys::$ellmv(
819                        trans_raw(op),
820                        &alpha,
821                        m as sys::aoclsparse_int,
822                        n as sys::aoclsparse_int,
823                        nnz as sys::aoclsparse_int,
824                        ell_val.as_ptr(),
825                        ell_col_ind.as_ptr(),
826                        ell_width as sys::aoclsparse_int,
827                        descr.as_raw(),
828                        x.as_ptr(),
829                        &beta,
830                        y.as_mut_ptr(),
831                    )
832                };
833                check_status("sparse", status)
834            }
835
836            #[allow(clippy::too_many_arguments)]
837            fn bsrmv(
838                op: Trans,
839                alpha: Self,
840                mb: usize,
841                nb: usize,
842                bsr_dim: usize,
843                bsr_val: &[Self],
844                bsr_col_ind: &[sys::aoclsparse_int],
845                bsr_row_ptr: &[sys::aoclsparse_int],
846                descr: &MatDescr,
847                x: &[Self],
848                beta: Self,
849                y: &mut [Self],
850            ) -> Result<()> {
851                if bsr_row_ptr.len() != mb + 1 {
852                    return Err(Error::InvalidArgument(format!(
853                        "bsrmv: bsr_row_ptr length {} != mb+1 = {}",
854                        bsr_row_ptr.len(),
855                        mb + 1
856                    )));
857                }
858                let block_area = bsr_dim.checked_mul(bsr_dim).ok_or_else(|| {
859                    Error::InvalidArgument("bsrmv: bsr_dim*bsr_dim overflows".into())
860                })?;
861                let nnzb = bsr_col_ind.len();
862                if bsr_val.len() < nnzb * block_area {
863                    return Err(Error::InvalidArgument(format!(
864                        "bsrmv: bsr_val length {} < nnzb*bsr_dim^2 = {}",
865                        bsr_val.len(),
866                        nnzb * block_area
867                    )));
868                }
869                let (x_len, y_len) = match op {
870                    Trans::No => (nb * bsr_dim, mb * bsr_dim),
871                    Trans::T | Trans::C => (mb * bsr_dim, nb * bsr_dim),
872                };
873                if x.len() < x_len || y.len() < y_len {
874                    return Err(Error::InvalidArgument(format!(
875                        "bsrmv: x.len()={}, y.len()={}, expected ({x_len}, {y_len})",
876                        x.len(),
877                        y.len()
878                    )));
879                }
880                let status = unsafe {
881                    sys::$bsrmv(
882                        trans_raw(op),
883                        &alpha,
884                        mb as sys::aoclsparse_int,
885                        nb as sys::aoclsparse_int,
886                        bsr_dim as sys::aoclsparse_int,
887                        bsr_val.as_ptr(),
888                        bsr_col_ind.as_ptr(),
889                        bsr_row_ptr.as_ptr(),
890                        descr.as_raw(),
891                        x.as_ptr(),
892                        &beta,
893                        y.as_mut_ptr(),
894                    )
895                };
896                check_status("sparse", status)
897            }
898
899            fn create_csr(
900                base: IndexBase,
901                m: usize,
902                n: usize,
903                nnz: usize,
904                row_ptr: *mut sys::aoclsparse_int,
905                col_idx: *mut sys::aoclsparse_int,
906                val: *mut Self,
907            ) -> Result<sys::aoclsparse_matrix> {
908                let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
909                let status = unsafe {
910                    sys::$create_csr(
911                        &mut raw,
912                        base.raw(),
913                        m as sys::aoclsparse_int,
914                        n as sys::aoclsparse_int,
915                        nnz as sys::aoclsparse_int,
916                        row_ptr,
917                        col_idx,
918                        val,
919                    )
920                };
921                check_status("sparse", status)?;
922                if raw.is_null() {
923                    return Err(Error::AllocationFailed("sparse"));
924                }
925                Ok(raw)
926            }
927
928            fn export_csr(
929                mat: sys::aoclsparse_matrix,
930            ) -> Result<(
931                IndexBase,
932                usize,
933                usize,
934                usize,
935                *mut sys::aoclsparse_int,
936                *mut sys::aoclsparse_int,
937                *mut Self,
938            )> {
939                let mut base: sys::aoclsparse_index_base = 0;
940                let mut m: sys::aoclsparse_int = 0;
941                let mut n: sys::aoclsparse_int = 0;
942                let mut nnz: sys::aoclsparse_int = 0;
943                let mut row_ptr: *mut sys::aoclsparse_int = std::ptr::null_mut();
944                let mut col_ind: *mut sys::aoclsparse_int = std::ptr::null_mut();
945                let mut val: *mut Self = std::ptr::null_mut();
946                let status = unsafe {
947                    sys::$export_csr(
948                        mat,
949                        &mut base,
950                        &mut m,
951                        &mut n,
952                        &mut nnz,
953                        &mut row_ptr,
954                        &mut col_ind,
955                        &mut val,
956                    )
957                };
958                check_status("sparse", status)?;
959                let base_e = if base == sys::aoclsparse_index_base__aoclsparse_index_base_one {
960                    IndexBase::One
961                } else {
962                    IndexBase::Zero
963                };
964                Ok((
965                    base_e,
966                    m as usize,
967                    n as usize,
968                    nnz as usize,
969                    row_ptr,
970                    col_ind,
971                    val,
972                ))
973            }
974
975            fn ilu_smoother(
976                op: Trans,
977                a: sys::aoclsparse_matrix,
978                descr: &MatDescr,
979                x: &mut [Self],
980                b: &[Self],
981            ) -> Result<()> {
982                let mut precond_csr_val: *mut Self = std::ptr::null_mut();
983                let status = unsafe {
984                    sys::$ilu_smoother(
985                        trans_raw(op),
986                        a,
987                        descr.as_raw(),
988                        &mut precond_csr_val,
989                        std::ptr::null(),
990                        x.as_mut_ptr(),
991                        b.as_ptr(),
992                    )
993                };
994                check_status("sparse", status)
995            }
996
997            fn itsol_init(handle: &mut sys::aoclsparse_itsol_handle) -> Result<()> {
998                let status = unsafe { sys::$itsol_init(handle) };
999                check_status("sparse", status)
1000            }
1001
1002            fn itsol_solve(
1003                handle: sys::aoclsparse_itsol_handle,
1004                n: usize,
1005                mat: sys::aoclsparse_matrix,
1006                descr: &MatDescr,
1007                b: &[Self],
1008                x: &mut [Self],
1009                rinfo: &mut [Self; 100],
1010            ) -> Result<()> {
1011                if b.len() < n || x.len() < n {
1012                    return Err(Error::InvalidArgument(format!(
1013                        "itsol_solve: b.len()={}, x.len()={}, n={n}",
1014                        b.len(),
1015                        x.len()
1016                    )));
1017                }
1018                let status = unsafe {
1019                    sys::$itsol_solve(
1020                        handle,
1021                        n as sys::aoclsparse_int,
1022                        mat,
1023                        descr.as_raw(),
1024                        b.as_ptr(),
1025                        x.as_mut_ptr(),
1026                        rinfo.as_mut_ptr(),
1027                        None,
1028                        None,
1029                        std::ptr::null_mut(),
1030                    )
1031                };
1032                check_status("sparse", status)
1033            }
1034
1035            unsafe fn csr2m_ffi(
1036                op_a: sys::aoclsparse_operation,
1037                descr_a: sys::aoclsparse_mat_descr,
1038                a: sys::aoclsparse_matrix,
1039                op_b: sys::aoclsparse_operation,
1040                descr_b: sys::aoclsparse_mat_descr,
1041                b: sys::aoclsparse_matrix,
1042                request: sys::aoclsparse_request,
1043                out: *mut sys::aoclsparse_matrix,
1044            ) -> sys::aoclsparse_status {
1045                sys::$csr2m(op_a, descr_a, a, op_b, descr_b, b, request, out)
1046            }
1047
1048            #[allow(clippy::too_many_arguments)]
1049            fn csrmm(
1050                op: Trans,
1051                alpha: Self,
1052                a: sys::aoclsparse_matrix,
1053                descr: &MatDescr,
1054                order: Order,
1055                b: &[Self],
1056                n: usize,
1057                ldb: usize,
1058                beta: Self,
1059                c: &mut [Self],
1060                ldc: usize,
1061            ) -> Result<()> {
1062                let status = unsafe {
1063                    sys::$csrmm(
1064                        trans_raw(op),
1065                        alpha,
1066                        a,
1067                        descr.as_raw(),
1068                        order.raw(),
1069                        b.as_ptr(),
1070                        n as sys::aoclsparse_int,
1071                        ldb as sys::aoclsparse_int,
1072                        beta,
1073                        c.as_mut_ptr(),
1074                        ldc as sys::aoclsparse_int,
1075                    )
1076                };
1077                check_status("sparse", status)
1078            }
1079
1080            fn spmmd(
1081                op: Trans,
1082                a: sys::aoclsparse_matrix,
1083                b: sys::aoclsparse_matrix,
1084                layout: Order,
1085                c: &mut [Self],
1086                ldc: usize,
1087            ) -> Result<()> {
1088                let status = unsafe {
1089                    sys::$spmmd(
1090                        trans_raw(op),
1091                        a,
1092                        b,
1093                        layout.raw(),
1094                        c.as_mut_ptr(),
1095                        ldc as sys::aoclsparse_int,
1096                    )
1097                };
1098                check_status("sparse", status)
1099            }
1100
1101            #[allow(clippy::too_many_arguments)]
1102            fn sp2md(
1103                op_a: Trans,
1104                descr_a: &MatDescr,
1105                a: sys::aoclsparse_matrix,
1106                op_b: Trans,
1107                descr_b: &MatDescr,
1108                b: sys::aoclsparse_matrix,
1109                alpha: Self,
1110                beta: Self,
1111                c: &mut [Self],
1112                layout: Order,
1113                ldc: usize,
1114            ) -> Result<()> {
1115                let status = unsafe {
1116                    sys::$sp2md(
1117                        trans_raw(op_a),
1118                        descr_a.as_raw(),
1119                        a,
1120                        trans_raw(op_b),
1121                        descr_b.as_raw(),
1122                        b,
1123                        alpha,
1124                        beta,
1125                        c.as_mut_ptr(),
1126                        layout.raw(),
1127                        ldc as sys::aoclsparse_int,
1128                    )
1129                };
1130                check_status("sparse", status)
1131            }
1132
1133            unsafe fn add_ffi(
1134                op: sys::aoclsparse_operation,
1135                a: sys::aoclsparse_matrix,
1136                alpha: Self,
1137                b: sys::aoclsparse_matrix,
1138                out: *mut sys::aoclsparse_matrix,
1139            ) -> sys::aoclsparse_status {
1140                sys::$add(op, a, alpha, b, out)
1141            }
1142
1143            #[allow(clippy::too_many_arguments)]
1144            fn sorv(
1145                sor_type: SorType,
1146                descr: &MatDescr,
1147                a: sys::aoclsparse_matrix,
1148                omega: Self,
1149                alpha: Self,
1150                x: &mut [Self],
1151                b: &[Self],
1152            ) -> Result<()> {
1153                let status = unsafe {
1154                    sys::$sorv(
1155                        sor_type.raw(),
1156                        descr.as_raw(),
1157                        a,
1158                        omega,
1159                        alpha,
1160                        x.as_mut_ptr(),
1161                        b.as_ptr(),
1162                    )
1163                };
1164                check_status("sparse", status)
1165            }
1166        }
1167    };
1168}
1169
1170impl_scalar!(
1171    f32,
1172    csrmv = aoclsparse_scsrmv,
1173    axpyi = aoclsparse_saxpyi,
1174    gthr = aoclsparse_sgthr,
1175    sctr = aoclsparse_ssctr,
1176    csrsv = aoclsparse_scsrsv,
1177    csr2dense = aoclsparse_scsr2dense,
1178    csr2csc = aoclsparse_scsr2csc,
1179    ellmv = aoclsparse_sellmv,
1180    bsrmv = aoclsparse_sbsrmv,
1181    create_csr = aoclsparse_create_scsr,
1182    export_csr = aoclsparse_export_scsr,
1183    ilu_smoother = aoclsparse_silu_smoother,
1184    itsol_init = aoclsparse_itsol_s_init,
1185    itsol_solve = aoclsparse_itsol_s_solve,
1186    csr2m = aoclsparse_scsr2m,
1187    csrmm = aoclsparse_scsrmm,
1188    spmmd = aoclsparse_sspmmd,
1189    sp2md = aoclsparse_ssp2md,
1190    add = aoclsparse_sadd,
1191    sorv = aoclsparse_ssorv
1192);
1193impl_scalar!(
1194    f64,
1195    csrmv = aoclsparse_dcsrmv,
1196    axpyi = aoclsparse_daxpyi,
1197    gthr = aoclsparse_dgthr,
1198    sctr = aoclsparse_dsctr,
1199    csrsv = aoclsparse_dcsrsv,
1200    csr2dense = aoclsparse_dcsr2dense,
1201    csr2csc = aoclsparse_dcsr2csc,
1202    ellmv = aoclsparse_dellmv,
1203    bsrmv = aoclsparse_dbsrmv,
1204    create_csr = aoclsparse_create_dcsr,
1205    export_csr = aoclsparse_export_dcsr,
1206    ilu_smoother = aoclsparse_dilu_smoother,
1207    itsol_init = aoclsparse_itsol_d_init,
1208    itsol_solve = aoclsparse_itsol_d_solve,
1209    csr2m = aoclsparse_dcsr2m,
1210    csrmm = aoclsparse_dcsrmm,
1211    spmmd = aoclsparse_dspmmd,
1212    sp2md = aoclsparse_dsp2md,
1213    add = aoclsparse_dadd,
1214    sorv = aoclsparse_dsorv
1215);
1216
1217/// Compute `y := α · A · x + β · y` for a CSR matrix `A`.
1218#[allow(clippy::too_many_arguments)]
1219pub fn csrmv<T: Scalar>(
1220    alpha: T,
1221    m: usize,
1222    n: usize,
1223    csr_val: &[T],
1224    csr_col_ind: &[sys::aoclsparse_int],
1225    csr_row_ptr: &[sys::aoclsparse_int],
1226    descr: &MatDescr,
1227    x: &[T],
1228    beta: T,
1229    y: &mut [T],
1230) -> Result<()> {
1231    T::csrmv(
1232        Trans::No,
1233        alpha,
1234        m,
1235        n,
1236        csr_val,
1237        csr_col_ind,
1238        csr_row_ptr,
1239        descr,
1240        x,
1241        beta,
1242        y,
1243    )
1244}
1245
1246// =========================================================================
1247//   Sparse vector operations (axpyi, gather/scatter)
1248// =========================================================================
1249
1250/// Sparse `y[indx] := α·x + y[indx]` (sparse vector AXPY).
1251///
1252/// `x` and `indx` must have equal length (`nnz`); each `indx[i]` indexes
1253/// into `y`.
1254pub fn axpyi<T: Scalar>(
1255    alpha: T,
1256    x: &[T],
1257    indx: &[sys::aoclsparse_int],
1258    y: &mut [T],
1259) -> Result<()> {
1260    if x.len() != indx.len() {
1261        return Err(Error::InvalidArgument(format!(
1262            "axpyi: x.len()={}, indx.len()={}",
1263            x.len(),
1264            indx.len()
1265        )));
1266    }
1267    T::axpyi(alpha, x, indx, y)
1268}
1269
1270/// Sparse gather: `x[i] := y[indx[i]]` for `i ∈ [0, nnz)`.
1271pub fn gthr<T: Scalar>(y: &[T], indx: &[sys::aoclsparse_int], x: &mut [T]) -> Result<()> {
1272    if x.len() != indx.len() {
1273        return Err(Error::InvalidArgument(format!(
1274            "gthr: x.len()={}, indx.len()={}",
1275            x.len(),
1276            indx.len()
1277        )));
1278    }
1279    T::gthr(y, indx, x)
1280}
1281
1282/// Sparse scatter: `y[indx[i]] := x[i]` for `i ∈ [0, nnz)`.
1283pub fn sctr<T: Scalar>(x: &[T], indx: &[sys::aoclsparse_int], y: &mut [T]) -> Result<()> {
1284    if x.len() != indx.len() {
1285        return Err(Error::InvalidArgument(format!(
1286            "sctr: x.len()={}, indx.len()={}",
1287            x.len(),
1288            indx.len()
1289        )));
1290    }
1291    T::sctr(x, indx, y)
1292}
1293
1294// =========================================================================
1295//   Sparse triangular solve (csrsv)
1296// =========================================================================
1297
1298/// Solve `op(A) · y = α · x` where `A` is sparse triangular in CSR
1299/// format. The triangle is determined by the `descr`'s fill mode (set
1300/// via the `aoclsparse_set_mat_*` C-API; defaults to upper, non-unit).
1301#[allow(clippy::too_many_arguments)]
1302pub fn csrsv<T: Scalar>(
1303    op: Trans,
1304    alpha: T,
1305    m: usize,
1306    csr_val: &[T],
1307    csr_col_ind: &[sys::aoclsparse_int],
1308    csr_row_ptr: &[sys::aoclsparse_int],
1309    descr: &MatDescr,
1310    x: &[T],
1311    y: &mut [T],
1312) -> Result<()> {
1313    T::csrsv(op, alpha, m, csr_val, csr_col_ind, csr_row_ptr, descr, x, y)
1314}
1315
1316// =========================================================================
1317//   Format conversion: csr → dense, csr → csc
1318// =========================================================================
1319
1320/// Storage order used when converting from CSR to a dense matrix.
1321#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1322pub enum Order {
1323    RowMajor,
1324    ColMajor,
1325}
1326
1327impl Order {
1328    pub(crate) fn raw(self) -> sys::aoclsparse_order {
1329        match self {
1330            Order::RowMajor => sys::aoclsparse_order__aoclsparse_order_row,
1331            Order::ColMajor => sys::aoclsparse_order__aoclsparse_order_column,
1332        }
1333    }
1334}
1335
1336/// Convert a CSR sparse matrix to a dense `m × n` matrix in `out`.
1337#[allow(clippy::too_many_arguments)]
1338pub fn csr_to_dense<T: Scalar>(
1339    m: usize,
1340    n: usize,
1341    descr: &MatDescr,
1342    csr_val: &[T],
1343    csr_row_ptr: &[sys::aoclsparse_int],
1344    csr_col_ind: &[sys::aoclsparse_int],
1345    a: &mut [T],
1346    ld: usize,
1347    order: Order,
1348) -> Result<()> {
1349    if csr_row_ptr.len() != m + 1 {
1350        return Err(Error::InvalidArgument(format!(
1351            "csr_to_dense: csr_row_ptr length {} != m+1 = {}",
1352            csr_row_ptr.len(),
1353            m + 1
1354        )));
1355    }
1356    let needed = match order {
1357        Order::RowMajor => m.saturating_sub(1) * ld + n,
1358        Order::ColMajor => n.saturating_sub(1) * ld + m,
1359    };
1360    if a.len() < needed {
1361        return Err(Error::InvalidArgument(format!(
1362            "csr_to_dense: A length {} < needed {needed}",
1363            a.len()
1364        )));
1365    }
1366    T::csr_to_dense(m, n, descr, csr_val, csr_row_ptr, csr_col_ind, a, ld, order)
1367}
1368
1369/// Convert a CSR matrix to CSC. Output arrays must be pre-sized.
1370#[allow(clippy::too_many_arguments)]
1371pub fn csr_to_csc<T: Scalar>(
1372    m: usize,
1373    n: usize,
1374    descr: &MatDescr,
1375    base_csc: IndexBase,
1376    csr_row_ptr: &[sys::aoclsparse_int],
1377    csr_col_ind: &[sys::aoclsparse_int],
1378    csr_val: &[T],
1379    csc_row_ind: &mut [sys::aoclsparse_int],
1380    csc_col_ptr: &mut [sys::aoclsparse_int],
1381    csc_val: &mut [T],
1382) -> Result<()> {
1383    let nnz = csr_val.len();
1384    if csr_col_ind.len() != nnz || csc_row_ind.len() < nnz || csc_val.len() < nnz {
1385        return Err(Error::InvalidArgument(format!(
1386            "csr_to_csc: nnz mismatch (csr_val={}, csr_col_ind={}, csc_row_ind={}, csc_val={})",
1387            nnz,
1388            csr_col_ind.len(),
1389            csc_row_ind.len(),
1390            csc_val.len()
1391        )));
1392    }
1393    if csr_row_ptr.len() != m + 1 || csc_col_ptr.len() != n + 1 {
1394        return Err(Error::InvalidArgument(format!(
1395            "csr_to_csc: row_ptr length {} != m+1 = {} or col_ptr length {} != n+1 = {}",
1396            csr_row_ptr.len(),
1397            m + 1,
1398            csc_col_ptr.len(),
1399            n + 1
1400        )));
1401    }
1402    T::csr_to_csc(
1403        m,
1404        n,
1405        descr,
1406        base_csc,
1407        csr_row_ptr,
1408        csr_col_ind,
1409        csr_val,
1410        csc_row_ind,
1411        csc_col_ptr,
1412        csc_val,
1413    )
1414}
1415
1416/// Index base for sparse-format index arrays.
1417#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1418pub enum IndexBase {
1419    /// 0-based indexing (C convention).
1420    Zero,
1421    /// 1-based indexing (Fortran convention).
1422    One,
1423}
1424
1425impl IndexBase {
1426    fn raw(self) -> sys::aoclsparse_index_base {
1427        match self {
1428            IndexBase::Zero => sys::aoclsparse_index_base__aoclsparse_index_base_zero,
1429            IndexBase::One => sys::aoclsparse_index_base__aoclsparse_index_base_one,
1430        }
1431    }
1432}
1433
1434// =========================================================================
1435//   ELLPACK and BSR mat-vec
1436// =========================================================================
1437
1438/// Compute `y := α · op(A) · x + β · y` for an ELLPACK-format `A`.
1439///
1440/// `ell_val` and `ell_col_ind` each have length `m * ell_width`, where rows
1441/// shorter than `ell_width` are padded.
1442///
1443/// AOCL only implements `op = Trans::No` for ELLPACK at present.
1444#[allow(clippy::too_many_arguments)]
1445pub fn ellmv<T: Scalar>(
1446    op: Trans,
1447    alpha: T,
1448    m: usize,
1449    n: usize,
1450    ell_val: &[T],
1451    ell_col_ind: &[sys::aoclsparse_int],
1452    ell_width: usize,
1453    descr: &MatDescr,
1454    x: &[T],
1455    beta: T,
1456    y: &mut [T],
1457) -> Result<()> {
1458    T::ellmv(
1459        op,
1460        alpha,
1461        m,
1462        n,
1463        ell_val,
1464        ell_col_ind,
1465        ell_width,
1466        descr,
1467        x,
1468        beta,
1469        y,
1470    )
1471}
1472
1473/// Compute `y := α · op(A) · x + β · y` for a BSR-format `A`.
1474///
1475/// `mb` and `nb` count blocks (so `A` is `mb·bsr_dim × nb·bsr_dim`).
1476/// `bsr_val` is laid out as `nnzb` consecutive `bsr_dim × bsr_dim` blocks.
1477///
1478/// AOCL only implements `op = Trans::No` for BSR at present.
1479#[allow(clippy::too_many_arguments)]
1480pub fn bsrmv<T: Scalar>(
1481    op: Trans,
1482    alpha: T,
1483    mb: usize,
1484    nb: usize,
1485    bsr_dim: usize,
1486    bsr_val: &[T],
1487    bsr_col_ind: &[sys::aoclsparse_int],
1488    bsr_row_ptr: &[sys::aoclsparse_int],
1489    descr: &MatDescr,
1490    x: &[T],
1491    beta: T,
1492    y: &mut [T],
1493) -> Result<()> {
1494    T::bsrmv(
1495        op,
1496        alpha,
1497        mb,
1498        nb,
1499        bsr_dim,
1500        bsr_val,
1501        bsr_col_ind,
1502        bsr_row_ptr,
1503        descr,
1504        x,
1505        beta,
1506        y,
1507    )
1508}
1509
1510// =========================================================================
1511//   High-level matrix handle (aoclsparse_matrix)
1512// =========================================================================
1513
1514/// Stage of a multi-pass sparse-sparse matrix product (see [`csr2m`]).
1515#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1516pub enum Stage {
1517    /// Analyse only; reserve nnz count without writing values.
1518    NnzCount,
1519    /// Compute values, given a prior `NnzCount` call.
1520    Finalize,
1521    /// Single-shot full computation.
1522    FullComputation,
1523}
1524
1525impl Stage {
1526    fn raw(self) -> sys::aoclsparse_request {
1527        match self {
1528            Stage::NnzCount => sys::aoclsparse_request__aoclsparse_stage_nnz_count,
1529            Stage::Finalize => sys::aoclsparse_request__aoclsparse_stage_finalize,
1530            Stage::FullComputation => sys::aoclsparse_request__aoclsparse_stage_full_computation,
1531        }
1532    }
1533}
1534
1535enum CsrStorage<T: Scalar> {
1536    /// Arrays are owned by the Rust side (we keep them alive until drop).
1537    Owned {
1538        _row_ptr: Vec<sys::aoclsparse_int>,
1539        _col_ind: Vec<sys::aoclsparse_int>,
1540        _val: Vec<T>,
1541    },
1542    /// Arrays are owned by the library; only the handle needs destroying.
1543    LibraryOwned,
1544}
1545
1546/// RAII wrapper around an `aoclsparse_matrix` handle holding a CSR matrix.
1547///
1548/// Construct from raw CSR vectors with [`SparseMatrix::from_csr`] (the
1549/// values are copied into the wrapper) or as the result of an operation
1550/// like [`csr2m`].
1551pub struct SparseMatrix<T: Scalar> {
1552    raw: sys::aoclsparse_matrix,
1553    #[allow(dead_code)] // kept alive so the library can keep its pointers
1554    storage: CsrStorage<T>,
1555    base: IndexBase,
1556    m: usize,
1557    n: usize,
1558    nnz: usize,
1559}
1560
1561impl<T: Scalar> SparseMatrix<T> {
1562    /// Build a new matrix handle from CSR arrays. The arrays are copied
1563    /// into the wrapper; the caller's slices are not retained.
1564    pub fn from_csr(
1565        base: IndexBase,
1566        m: usize,
1567        n: usize,
1568        row_ptr: &[sys::aoclsparse_int],
1569        col_ind: &[sys::aoclsparse_int],
1570        val: &[T],
1571    ) -> Result<Self> {
1572        if row_ptr.len() != m + 1 {
1573            return Err(Error::InvalidArgument(format!(
1574                "from_csr: row_ptr length {} != m+1 = {}",
1575                row_ptr.len(),
1576                m + 1
1577            )));
1578        }
1579        let nnz = val.len();
1580        if col_ind.len() != nnz {
1581            return Err(Error::InvalidArgument(format!(
1582                "from_csr: col_ind length {} != val length {nnz}",
1583                col_ind.len()
1584            )));
1585        }
1586        let mut row_ptr = row_ptr.to_vec();
1587        let mut col_ind = col_ind.to_vec();
1588        let mut val = val.to_vec();
1589        let raw = T::create_csr(
1590            base,
1591            m,
1592            n,
1593            nnz,
1594            row_ptr.as_mut_ptr(),
1595            col_ind.as_mut_ptr(),
1596            val.as_mut_ptr(),
1597        )?;
1598        Ok(Self {
1599            raw,
1600            storage: CsrStorage::Owned {
1601                _row_ptr: row_ptr,
1602                _col_ind: col_ind,
1603                _val: val,
1604            },
1605            base,
1606            m,
1607            n,
1608            nnz,
1609        })
1610    }
1611
1612    /// Adopt a handle returned by an AOCL routine that allocates its own
1613    /// CSR storage (e.g. `aoclsparse_dcsr2m`). The library will free the
1614    /// arrays when the handle is destroyed.
1615    ///
1616    /// # Safety
1617    /// `raw` must be a valid `aoclsparse_matrix` whose internal storage is
1618    /// owned by the AOCL library and whose precision matches `T`.
1619    pub unsafe fn from_library_owned(raw: sys::aoclsparse_matrix) -> Result<Self> {
1620        if raw.is_null() {
1621            return Err(Error::AllocationFailed("sparse"));
1622        }
1623        let (base, m, n, nnz, _, _, _) = T::export_csr(raw)?;
1624        Ok(Self {
1625            raw,
1626            storage: CsrStorage::LibraryOwned,
1627            base,
1628            m,
1629            n,
1630            nnz,
1631        })
1632    }
1633
1634    /// `(m, n)` dimensions of the matrix.
1635    pub fn dims(&self) -> (usize, usize) {
1636        (self.m, self.n)
1637    }
1638
1639    /// Number of explicitly stored non-zeros.
1640    pub fn nnz(&self) -> usize {
1641        self.nnz
1642    }
1643
1644    /// Index base used by this matrix's row-pointer and column-index arrays.
1645    pub fn base(&self) -> IndexBase {
1646        self.base
1647    }
1648
1649    /// Borrow the raw underlying handle for raw FFI calls.
1650    ///
1651    /// # Safety
1652    /// The returned pointer is valid only for the lifetime of `self`.
1653    /// Do not call `aoclsparse_destroy` on it.
1654    pub fn as_raw(&self) -> sys::aoclsparse_matrix {
1655        self.raw
1656    }
1657
1658    /// Read out the CSR contents as freshly allocated `Vec`s.
1659    pub fn export_csr(
1660        &self,
1661    ) -> Result<(
1662        IndexBase,
1663        Vec<sys::aoclsparse_int>,
1664        Vec<sys::aoclsparse_int>,
1665        Vec<T>,
1666    )> {
1667        let (base, m, _, nnz, row_ptr, col_ind, val) = T::export_csr(self.raw)?;
1668        let row_ptr = unsafe { std::slice::from_raw_parts(row_ptr, m + 1).to_vec() };
1669        let col_ind = unsafe { std::slice::from_raw_parts(col_ind, nnz).to_vec() };
1670        let val = unsafe { std::slice::from_raw_parts(val, nnz).to_vec() };
1671        Ok((base, row_ptr, col_ind, val))
1672    }
1673
1674    // --- Hints: tell the analysis pass how the matrix will be used.
1675    // Pair with `optimize()` to actually run the analysis.
1676
1677    /// Hint that mat-vec (`mv`) will be called `expected_calls` times in
1678    /// the given orientation `op`.
1679    pub fn set_mv_hint(
1680        &mut self,
1681        op: Trans,
1682        descr: &MatDescr,
1683        expected_calls: usize,
1684    ) -> Result<()> {
1685        let status = unsafe {
1686            sys::aoclsparse_set_mv_hint(
1687                self.raw,
1688                trans_raw(op),
1689                descr.as_raw(),
1690                expected_calls as sys::aoclsparse_int,
1691            )
1692        };
1693        check_status("sparse", status)
1694    }
1695
1696    /// Hint for triangular solve (`trsv`).
1697    pub fn set_sv_hint(
1698        &mut self,
1699        op: Trans,
1700        descr: &MatDescr,
1701        expected_calls: usize,
1702    ) -> Result<()> {
1703        let status = unsafe {
1704            sys::aoclsparse_set_sv_hint(
1705                self.raw,
1706                trans_raw(op),
1707                descr.as_raw(),
1708                expected_calls as sys::aoclsparse_int,
1709            )
1710        };
1711        check_status("sparse", status)
1712    }
1713
1714    /// Hint for sparse-dense matrix multiply (`csrmm`).
1715    pub fn set_mm_hint(
1716        &mut self,
1717        op: Trans,
1718        descr: &MatDescr,
1719        expected_calls: usize,
1720    ) -> Result<()> {
1721        let status = unsafe {
1722            sys::aoclsparse_set_mm_hint(
1723                self.raw,
1724                trans_raw(op),
1725                descr.as_raw(),
1726                expected_calls as sys::aoclsparse_int,
1727            )
1728        };
1729        check_status("sparse", status)
1730    }
1731
1732    /// Hint for sparse-sparse matrix multiply (`csr2m`).
1733    pub fn set_2m_hint(
1734        &mut self,
1735        op: Trans,
1736        descr: &MatDescr,
1737        expected_calls: usize,
1738    ) -> Result<()> {
1739        let status = unsafe {
1740            sys::aoclsparse_set_2m_hint(
1741                self.raw,
1742                trans_raw(op),
1743                descr.as_raw(),
1744                expected_calls as sys::aoclsparse_int,
1745            )
1746        };
1747        check_status("sparse", status)
1748    }
1749
1750    /// Hint for sparse triangular solve with multiple right-hand sides
1751    /// (`trsm`). `order` is the layout of the dense `B`/`X` matrices.
1752    pub fn set_sm_hint(
1753        &mut self,
1754        op: Trans,
1755        descr: &MatDescr,
1756        order: Order,
1757        expected_calls: usize,
1758    ) -> Result<()> {
1759        let status = unsafe {
1760            sys::aoclsparse_set_sm_hint(
1761                self.raw,
1762                trans_raw(op),
1763                descr.as_raw(),
1764                order.raw(),
1765                expected_calls as sys::aoclsparse_int,
1766            )
1767        };
1768        check_status("sparse", status)
1769    }
1770
1771    /// Hint for ILU smoothing.
1772    pub fn set_lu_smoother_hint(
1773        &mut self,
1774        op: Trans,
1775        descr: &MatDescr,
1776        expected_calls: usize,
1777    ) -> Result<()> {
1778        let status = unsafe {
1779            sys::aoclsparse_set_lu_smoother_hint(
1780                self.raw,
1781                trans_raw(op),
1782                descr.as_raw(),
1783                expected_calls as sys::aoclsparse_int,
1784            )
1785        };
1786        check_status("sparse", status)
1787    }
1788
1789    /// Hint for symmetric Gauss-Seidel.
1790    pub fn set_symgs_hint(
1791        &mut self,
1792        op: Trans,
1793        descr: &MatDescr,
1794        expected_calls: usize,
1795    ) -> Result<()> {
1796        let status = unsafe {
1797            sys::aoclsparse_set_symgs_hint(
1798                self.raw,
1799                trans_raw(op),
1800                descr.as_raw(),
1801                expected_calls as sys::aoclsparse_int,
1802            )
1803        };
1804        check_status("sparse", status)
1805    }
1806
1807    /// Hint for the fused dot-mat-vec routine.
1808    pub fn set_dotmv_hint(
1809        &mut self,
1810        op: Trans,
1811        descr: &MatDescr,
1812        expected_calls: usize,
1813    ) -> Result<()> {
1814        let status = unsafe {
1815            sys::aoclsparse_set_dotmv_hint(
1816                self.raw,
1817                trans_raw(op),
1818                descr.as_raw(),
1819                expected_calls as sys::aoclsparse_int,
1820            )
1821        };
1822        check_status("sparse", status)
1823    }
1824
1825    /// Hint for SOR / Gauss-Seidel sweeps.
1826    pub fn set_sorv_hint(
1827        &mut self,
1828        sor_type: SorType,
1829        descr: &MatDescr,
1830        expected_calls: usize,
1831    ) -> Result<()> {
1832        let status = unsafe {
1833            sys::aoclsparse_set_sorv_hint(
1834                self.raw,
1835                descr.as_raw(),
1836                sor_type.raw(),
1837                expected_calls as sys::aoclsparse_int,
1838            )
1839        };
1840        check_status("sparse", status)
1841    }
1842}
1843
1844impl<T: Scalar> Drop for SparseMatrix<T> {
1845    fn drop(&mut self) {
1846        if !self.raw.is_null() {
1847            unsafe {
1848                let _ = sys::aoclsparse_destroy(&mut self.raw);
1849            }
1850            self.raw = std::ptr::null_mut();
1851        }
1852    }
1853}
1854
1855impl<T: Scalar> std::fmt::Debug for SparseMatrix<T> {
1856    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1857        f.debug_struct("SparseMatrix")
1858            .field("m", &self.m)
1859            .field("n", &self.n)
1860            .field("nnz", &self.nnz)
1861            .field("base", &self.base)
1862            .finish()
1863    }
1864}
1865
1866// =========================================================================
1867//   Sparse-sparse matrix product (csr2m)
1868// =========================================================================
1869
1870/// Compute `C := op_A(A) · op_B(B)` between two CSR matrices, returning a
1871/// freshly allocated CSR result. Both inputs must have matching scalar
1872/// precision; complex precisions are not yet exposed by these wrappers.
1873///
1874/// Use `Stage::FullComputation` for a one-shot call. The two-stage form
1875/// (`NnzCount` followed by `Finalize`) is for repeated multiplies sharing
1876/// the same sparsity pattern.
1877#[allow(clippy::too_many_arguments)]
1878pub fn csr2m<T: Scalar>(
1879    op_a: Trans,
1880    descr_a: &MatDescr,
1881    a: &SparseMatrix<T>,
1882    op_b: Trans,
1883    descr_b: &MatDescr,
1884    b: &SparseMatrix<T>,
1885    stage: Stage,
1886) -> Result<SparseMatrix<T>> {
1887    let mut c_raw: sys::aoclsparse_matrix = std::ptr::null_mut();
1888    let status = unsafe {
1889        // We need a per-precision dispatch. Encode by calling the matching
1890        // FFI symbol via a small helper trait.
1891        T::csr2m_ffi(
1892            trans_raw(op_a),
1893            descr_a.as_raw(),
1894            a.raw,
1895            trans_raw(op_b),
1896            descr_b.as_raw(),
1897            b.raw,
1898            stage.raw(),
1899            &mut c_raw,
1900        )
1901    };
1902    check_status("sparse", status)?;
1903    unsafe { SparseMatrix::from_library_owned(c_raw) }
1904}
1905
1906// =========================================================================
1907//   Sparse-dense matrix products (csrmm, spmmd, sp2md)
1908// =========================================================================
1909
1910/// Compute `C := α · op(A) · B + β · C` where `A` is sparse (CSR via the
1911/// matrix handle) and `B`, `C` are dense matrices laid out per `order`.
1912///
1913/// `n` is the number of columns of `B` (and `C`). `ldb` and `ldc` are
1914/// the leading dimensions of `B` and `C` respectively.
1915#[allow(clippy::too_many_arguments)]
1916pub fn csrmm<T: Scalar>(
1917    op: Trans,
1918    alpha: T,
1919    a: &SparseMatrix<T>,
1920    descr: &MatDescr,
1921    order: Order,
1922    b: &[T],
1923    n: usize,
1924    ldb: usize,
1925    beta: T,
1926    c: &mut [T],
1927    ldc: usize,
1928) -> Result<()> {
1929    T::csrmm(op, alpha, a.as_raw(), descr, order, b, n, ldb, beta, c, ldc)
1930}
1931
1932/// Compute `C := op(A) · B` where both `A` and `B` are sparse and `C` is
1933/// a dense matrix. `ldc` is the leading dimension of `C` per `layout`.
1934pub fn spmmd<T: Scalar>(
1935    op: Trans,
1936    a: &SparseMatrix<T>,
1937    b: &SparseMatrix<T>,
1938    layout: Order,
1939    c: &mut [T],
1940    ldc: usize,
1941) -> Result<()> {
1942    T::spmmd(op, a.as_raw(), b.as_raw(), layout, c, ldc)
1943}
1944
1945/// Compute `C := α · op_A(A) · op_B(B) + β · C` where both `A` and `B`
1946/// are sparse and `C` is dense.
1947#[allow(clippy::too_many_arguments)]
1948pub fn sp2md<T: Scalar>(
1949    op_a: Trans,
1950    descr_a: &MatDescr,
1951    a: &SparseMatrix<T>,
1952    op_b: Trans,
1953    descr_b: &MatDescr,
1954    b: &SparseMatrix<T>,
1955    alpha: T,
1956    beta: T,
1957    c: &mut [T],
1958    layout: Order,
1959    ldc: usize,
1960) -> Result<()> {
1961    T::sp2md(
1962        op_a,
1963        descr_a,
1964        a.as_raw(),
1965        op_b,
1966        descr_b,
1967        b.as_raw(),
1968        alpha,
1969        beta,
1970        c,
1971        layout,
1972        ldc,
1973    )
1974}
1975
1976// =========================================================================
1977//   Sparse-sparse addition and SOR sweep
1978// =========================================================================
1979
1980/// Compute `C := α · op(A) + B` returning a fresh sparse CSR matrix.
1981pub fn add<T: Scalar>(
1982    op: Trans,
1983    a: &SparseMatrix<T>,
1984    alpha: T,
1985    b: &SparseMatrix<T>,
1986) -> Result<SparseMatrix<T>> {
1987    let mut c_raw: sys::aoclsparse_matrix = std::ptr::null_mut();
1988    let status = unsafe { T::add_ffi(trans_raw(op), a.as_raw(), alpha, b.as_raw(), &mut c_raw) };
1989    check_status("sparse", status)?;
1990    unsafe { SparseMatrix::from_library_owned(c_raw) }
1991}
1992
1993/// One step of (S)SOR / forward / backward Gauss-Seidel relaxation:
1994/// `x_new := (1 − ω) x + ω · A⁻¹ (α · b)`. Useful as a smoother
1995/// inside iterative-solver loops.
1996pub fn sorv<T: Scalar>(
1997    sor_type: SorType,
1998    descr: &MatDescr,
1999    a: &SparseMatrix<T>,
2000    omega: T,
2001    alpha: T,
2002    x: &mut [T],
2003    b: &[T],
2004) -> Result<()> {
2005    if x.len() < a.dims().1 || b.len() < a.dims().0 {
2006        return Err(Error::InvalidArgument(format!(
2007            "sorv: x.len()={}, b.len()={}, dims=({}, {})",
2008            x.len(),
2009            b.len(),
2010            a.dims().0,
2011            a.dims().1
2012        )));
2013    }
2014    T::sorv(sor_type, descr, a.as_raw(), omega, alpha, x, b)
2015}
2016
2017// =========================================================================
2018//   ILU smoother
2019// =========================================================================
2020
2021// =========================================================================
2022//   Real high-level mat-vec / triangular solve / triangular multi-RHS solve
2023// =========================================================================
2024
2025/// Real high-level mat-vec `y := α · op(A) · x + β · y` against a
2026/// `SparseMatrix<f64>` handle (per-precision; complex variants live in
2027/// the [`complex`] submodule).
2028#[allow(clippy::too_many_arguments)]
2029pub fn mv_f64(
2030    op: Trans,
2031    alpha: f64,
2032    a: &SparseMatrix<f64>,
2033    descr: &MatDescr,
2034    x: &[f64],
2035    beta: f64,
2036    y: &mut [f64],
2037) -> Result<()> {
2038    let status = unsafe {
2039        sys::aoclsparse_dmv(
2040            trans_raw(op),
2041            &alpha,
2042            a.as_raw(),
2043            descr.as_raw(),
2044            x.as_ptr(),
2045            &beta,
2046            y.as_mut_ptr(),
2047        )
2048    };
2049    check_status("sparse", status)
2050}
2051
2052/// `f32` mat-vec via the matrix-handle interface. See [`mv_f64`].
2053#[allow(clippy::too_many_arguments)]
2054pub fn mv_f32(
2055    op: Trans,
2056    alpha: f32,
2057    a: &SparseMatrix<f32>,
2058    descr: &MatDescr,
2059    x: &[f32],
2060    beta: f32,
2061    y: &mut [f32],
2062) -> Result<()> {
2063    let status = unsafe {
2064        sys::aoclsparse_smv(
2065            trans_raw(op),
2066            &alpha,
2067            a.as_raw(),
2068            descr.as_raw(),
2069            x.as_ptr(),
2070            &beta,
2071            y.as_mut_ptr(),
2072        )
2073    };
2074    check_status("sparse", status)
2075}
2076
2077/// Solve `op(A) · x = α · b` for sparse triangular `A` against a
2078/// `SparseMatrix<f64>` handle. The fill mode and unit/non-unit diag
2079/// must be set on `descr` first.
2080pub fn trsv_f64(
2081    op: Trans,
2082    alpha: f64,
2083    a: &SparseMatrix<f64>,
2084    descr: &MatDescr,
2085    b: &[f64],
2086    x: &mut [f64],
2087) -> Result<()> {
2088    let status = unsafe {
2089        sys::aoclsparse_dtrsv(
2090            trans_raw(op),
2091            alpha,
2092            a.as_raw(),
2093            descr.as_raw(),
2094            b.as_ptr(),
2095            x.as_mut_ptr(),
2096        )
2097    };
2098    check_status("sparse", status)
2099}
2100
2101/// `f32` triangular solve. See [`trsv_f64`].
2102pub fn trsv_f32(
2103    op: Trans,
2104    alpha: f32,
2105    a: &SparseMatrix<f32>,
2106    descr: &MatDescr,
2107    b: &[f32],
2108    x: &mut [f32],
2109) -> Result<()> {
2110    let status = unsafe {
2111        sys::aoclsparse_strsv(
2112            trans_raw(op),
2113            alpha,
2114            a.as_raw(),
2115            descr.as_raw(),
2116            b.as_ptr(),
2117            x.as_mut_ptr(),
2118        )
2119    };
2120    check_status("sparse", status)
2121}
2122
2123/// Multi-RHS triangular solve: `op(A) · X = α · B` for `n_rhs`
2124/// right-hand sides, `B`/`X` dense `m × n_rhs`.
2125#[allow(clippy::too_many_arguments)]
2126pub fn trsm_f64(
2127    op: Trans,
2128    alpha: f64,
2129    a: &SparseMatrix<f64>,
2130    descr: &MatDescr,
2131    order: Order,
2132    b: &[f64],
2133    n_rhs: usize,
2134    ldb: usize,
2135    x: &mut [f64],
2136    ldx: usize,
2137) -> Result<()> {
2138    let status = unsafe {
2139        sys::aoclsparse_dtrsm(
2140            trans_raw(op),
2141            alpha,
2142            a.as_raw(),
2143            descr.as_raw(),
2144            order.raw(),
2145            b.as_ptr(),
2146            n_rhs as sys::aoclsparse_int,
2147            ldb as sys::aoclsparse_int,
2148            x.as_mut_ptr(),
2149            ldx as sys::aoclsparse_int,
2150        )
2151    };
2152    check_status("sparse", status)
2153}
2154
2155/// `f32` multi-RHS triangular solve. See [`trsm_f64`].
2156#[allow(clippy::too_many_arguments)]
2157pub fn trsm_f32(
2158    op: Trans,
2159    alpha: f32,
2160    a: &SparseMatrix<f32>,
2161    descr: &MatDescr,
2162    order: Order,
2163    b: &[f32],
2164    n_rhs: usize,
2165    ldb: usize,
2166    x: &mut [f32],
2167    ldx: usize,
2168) -> Result<()> {
2169    let status = unsafe {
2170        sys::aoclsparse_strsm(
2171            trans_raw(op),
2172            alpha,
2173            a.as_raw(),
2174            descr.as_raw(),
2175            order.raw(),
2176            b.as_ptr(),
2177            n_rhs as sys::aoclsparse_int,
2178            ldb as sys::aoclsparse_int,
2179            x.as_mut_ptr(),
2180            ldx as sys::aoclsparse_int,
2181        )
2182    };
2183    check_status("sparse", status)
2184}
2185
2186/// Sparse vector dot product: `Σᵢ x[i] · y[indx[i]]`. `x` and `indx`
2187/// must have the same length (the nnz count).
2188pub fn doti_f64(x: &[f64], indx: &[sys::aoclsparse_int], y: &[f64]) -> Result<f64> {
2189    if x.len() != indx.len() {
2190        return Err(Error::InvalidArgument(format!(
2191            "doti: x.len()={} != indx.len()={}",
2192            x.len(),
2193            indx.len()
2194        )));
2195    }
2196    let r = unsafe {
2197        sys::aoclsparse_ddoti(
2198            x.len() as sys::aoclsparse_int,
2199            x.as_ptr(),
2200            indx.as_ptr(),
2201            y.as_ptr(),
2202        )
2203    };
2204    Ok(r)
2205}
2206
2207/// `f32` sparse dot. See [`doti_f64`].
2208pub fn doti_f32(x: &[f32], indx: &[sys::aoclsparse_int], y: &[f32]) -> Result<f32> {
2209    if x.len() != indx.len() {
2210        return Err(Error::InvalidArgument(format!(
2211            "doti: x.len()={} != indx.len()={}",
2212            x.len(),
2213            indx.len()
2214        )));
2215    }
2216    let r = unsafe {
2217        sys::aoclsparse_sdoti(
2218            x.len() as sys::aoclsparse_int,
2219            x.as_ptr(),
2220            indx.as_ptr(),
2221            y.as_ptr(),
2222        )
2223    };
2224    Ok(r)
2225}
2226
2227// =========================================================================
2228//   CSR ↔ ELL / DIA / BSR conversions, COO / CSC creators, value mutators,
2229//   block-CSR mat-vec, symmetric Gauss-Seidel
2230// =========================================================================
2231
2232/// Convert CSR → ELLPACK. `ell_width` is the chosen padded row length;
2233/// caller must allocate `ell_col_ind` and `ell_val` of size
2234/// `m * ell_width`. (`f64`)
2235#[allow(clippy::too_many_arguments)]
2236pub fn csr2ell_f64(
2237    m: usize,
2238    descr: &MatDescr,
2239    csr_row_ptr: &[sys::aoclsparse_int],
2240    csr_col_ind: &[sys::aoclsparse_int],
2241    csr_val: &[f64],
2242    ell_col_ind: &mut [sys::aoclsparse_int],
2243    ell_val: &mut [f64],
2244    ell_width: usize,
2245) -> Result<()> {
2246    let status = unsafe {
2247        sys::aoclsparse_dcsr2ell(
2248            m as sys::aoclsparse_int,
2249            descr.as_raw(),
2250            csr_row_ptr.as_ptr(),
2251            csr_col_ind.as_ptr(),
2252            csr_val.as_ptr(),
2253            ell_col_ind.as_mut_ptr(),
2254            ell_val.as_mut_ptr(),
2255            ell_width as sys::aoclsparse_int,
2256        )
2257    };
2258    check_status("sparse", status)
2259}
2260
2261/// `f32` CSR → ELLPACK. See [`csr2ell_f64`].
2262#[allow(clippy::too_many_arguments)]
2263pub fn csr2ell_f32(
2264    m: usize,
2265    descr: &MatDescr,
2266    csr_row_ptr: &[sys::aoclsparse_int],
2267    csr_col_ind: &[sys::aoclsparse_int],
2268    csr_val: &[f32],
2269    ell_col_ind: &mut [sys::aoclsparse_int],
2270    ell_val: &mut [f32],
2271    ell_width: usize,
2272) -> Result<()> {
2273    let status = unsafe {
2274        sys::aoclsparse_scsr2ell(
2275            m as sys::aoclsparse_int,
2276            descr.as_raw(),
2277            csr_row_ptr.as_ptr(),
2278            csr_col_ind.as_ptr(),
2279            csr_val.as_ptr(),
2280            ell_col_ind.as_mut_ptr(),
2281            ell_val.as_mut_ptr(),
2282            ell_width as sys::aoclsparse_int,
2283        )
2284    };
2285    check_status("sparse", status)
2286}
2287
2288/// Convert CSR → DIA (diagonal storage). `dia_offset` and `dia_val`
2289/// must be pre-allocated; `dia_num_diag` is the number of distinct
2290/// diagonals. (`f64`)
2291#[allow(clippy::too_many_arguments)]
2292pub fn csr2dia_f64(
2293    m: usize,
2294    n: usize,
2295    descr: &MatDescr,
2296    csr_row_ptr: &[sys::aoclsparse_int],
2297    csr_col_ind: &[sys::aoclsparse_int],
2298    csr_val: &[f64],
2299    dia_num_diag: usize,
2300    dia_offset: &mut [sys::aoclsparse_int],
2301    dia_val: &mut [f64],
2302) -> Result<()> {
2303    let status = unsafe {
2304        sys::aoclsparse_dcsr2dia(
2305            m as sys::aoclsparse_int,
2306            n as sys::aoclsparse_int,
2307            descr.as_raw(),
2308            csr_row_ptr.as_ptr(),
2309            csr_col_ind.as_ptr(),
2310            csr_val.as_ptr(),
2311            dia_num_diag as sys::aoclsparse_int,
2312            dia_offset.as_mut_ptr(),
2313            dia_val.as_mut_ptr(),
2314        )
2315    };
2316    check_status("sparse", status)
2317}
2318
2319/// `f32` CSR → DIA. See [`csr2dia_f64`].
2320#[allow(clippy::too_many_arguments)]
2321pub fn csr2dia_f32(
2322    m: usize,
2323    n: usize,
2324    descr: &MatDescr,
2325    csr_row_ptr: &[sys::aoclsparse_int],
2326    csr_col_ind: &[sys::aoclsparse_int],
2327    csr_val: &[f32],
2328    dia_num_diag: usize,
2329    dia_offset: &mut [sys::aoclsparse_int],
2330    dia_val: &mut [f32],
2331) -> Result<()> {
2332    let status = unsafe {
2333        sys::aoclsparse_scsr2dia(
2334            m as sys::aoclsparse_int,
2335            n as sys::aoclsparse_int,
2336            descr.as_raw(),
2337            csr_row_ptr.as_ptr(),
2338            csr_col_ind.as_ptr(),
2339            csr_val.as_ptr(),
2340            dia_num_diag as sys::aoclsparse_int,
2341            dia_offset.as_mut_ptr(),
2342            dia_val.as_mut_ptr(),
2343        )
2344    };
2345    check_status("sparse", status)
2346}
2347
2348/// Compute the number of nonzero blocks (`bsr_nnz`) and `bsr_row_ptr`
2349/// for a CSR → BSR conversion with the given `block_dim`.
2350pub fn csr2bsr_nnz(
2351    m: usize,
2352    n: usize,
2353    descr: &MatDescr,
2354    csr_row_ptr: &[sys::aoclsparse_int],
2355    csr_col_ind: &[sys::aoclsparse_int],
2356    block_dim: usize,
2357    bsr_row_ptr: &mut [sys::aoclsparse_int],
2358) -> Result<usize> {
2359    let mut bsr_nnz: sys::aoclsparse_int = 0;
2360    let status = unsafe {
2361        sys::aoclsparse_csr2bsr_nnz(
2362            m as sys::aoclsparse_int,
2363            n as sys::aoclsparse_int,
2364            descr.as_raw(),
2365            csr_row_ptr.as_ptr(),
2366            csr_col_ind.as_ptr(),
2367            block_dim as sys::aoclsparse_int,
2368            bsr_row_ptr.as_mut_ptr(),
2369            &mut bsr_nnz,
2370        )
2371    };
2372    check_status("sparse", status)?;
2373    Ok(bsr_nnz as usize)
2374}
2375
2376/// Convert CSR → BSR. Run [`csr2bsr_nnz`] first to size the output. (`f64`)
2377#[allow(clippy::too_many_arguments)]
2378pub fn csr2bsr_f64(
2379    m: usize,
2380    n: usize,
2381    descr: &MatDescr,
2382    csr_val: &[f64],
2383    csr_row_ptr: &[sys::aoclsparse_int],
2384    csr_col_ind: &[sys::aoclsparse_int],
2385    block_dim: usize,
2386    bsr_val: &mut [f64],
2387    bsr_row_ptr: &mut [sys::aoclsparse_int],
2388    bsr_col_ind: &mut [sys::aoclsparse_int],
2389) -> Result<()> {
2390    let status = unsafe {
2391        sys::aoclsparse_dcsr2bsr(
2392            m as sys::aoclsparse_int,
2393            n as sys::aoclsparse_int,
2394            descr.as_raw(),
2395            csr_val.as_ptr(),
2396            csr_row_ptr.as_ptr(),
2397            csr_col_ind.as_ptr(),
2398            block_dim as sys::aoclsparse_int,
2399            bsr_val.as_mut_ptr(),
2400            bsr_row_ptr.as_mut_ptr(),
2401            bsr_col_ind.as_mut_ptr(),
2402        )
2403    };
2404    check_status("sparse", status)
2405}
2406
2407/// `f32` CSR → BSR. See [`csr2bsr_f64`].
2408#[allow(clippy::too_many_arguments)]
2409pub fn csr2bsr_f32(
2410    m: usize,
2411    n: usize,
2412    descr: &MatDescr,
2413    csr_val: &[f32],
2414    csr_row_ptr: &[sys::aoclsparse_int],
2415    csr_col_ind: &[sys::aoclsparse_int],
2416    block_dim: usize,
2417    bsr_val: &mut [f32],
2418    bsr_row_ptr: &mut [sys::aoclsparse_int],
2419    bsr_col_ind: &mut [sys::aoclsparse_int],
2420) -> Result<()> {
2421    let status = unsafe {
2422        sys::aoclsparse_scsr2bsr(
2423            m as sys::aoclsparse_int,
2424            n as sys::aoclsparse_int,
2425            descr.as_raw(),
2426            csr_val.as_ptr(),
2427            csr_row_ptr.as_ptr(),
2428            csr_col_ind.as_ptr(),
2429            block_dim as sys::aoclsparse_int,
2430            bsr_val.as_mut_ptr(),
2431            bsr_row_ptr.as_mut_ptr(),
2432            bsr_col_ind.as_mut_ptr(),
2433        )
2434    };
2435    check_status("sparse", status)
2436}
2437
2438/// Block-CSR mat-vec with a per-block bitmask. `masks[i]` is a u8 bit
2439/// pattern indicating which rows of block `i` participate. `n_rows_blk`
2440/// is the block height. (`f64`)
2441#[allow(clippy::too_many_arguments)]
2442pub fn blkcsrmv_f64(
2443    op: Trans,
2444    alpha: f64,
2445    m: usize,
2446    n: usize,
2447    masks: &[u8],
2448    csr_val: &[f64],
2449    csr_col_ind: &[sys::aoclsparse_int],
2450    csr_row_ptr: &[sys::aoclsparse_int],
2451    descr: &MatDescr,
2452    x: &[f64],
2453    beta: f64,
2454    y: &mut [f64],
2455    n_rows_blk: usize,
2456) -> Result<()> {
2457    let nnz = csr_val.len();
2458    let status = unsafe {
2459        sys::aoclsparse_dblkcsrmv(
2460            trans_raw(op),
2461            &alpha,
2462            m as sys::aoclsparse_int,
2463            n as sys::aoclsparse_int,
2464            nnz as sys::aoclsparse_int,
2465            masks.as_ptr(),
2466            csr_val.as_ptr(),
2467            csr_col_ind.as_ptr(),
2468            csr_row_ptr.as_ptr(),
2469            descr.as_raw(),
2470            x.as_ptr(),
2471            &beta,
2472            y.as_mut_ptr(),
2473            n_rows_blk as sys::aoclsparse_int,
2474        )
2475    };
2476    check_status("sparse", status)
2477}
2478
2479/// Symmetric Gauss-Seidel sweep `x := x + α · D⁻¹ (b − A·x)` against a
2480/// symmetric matrix handle. (`f64`)
2481pub fn symgs_f64(
2482    op: Trans,
2483    a: &SparseMatrix<f64>,
2484    descr: &MatDescr,
2485    alpha: f64,
2486    b: &[f64],
2487    x: &mut [f64],
2488) -> Result<()> {
2489    let status = unsafe {
2490        sys::aoclsparse_dsymgs(
2491            trans_raw(op),
2492            a.as_raw(),
2493            descr.as_raw(),
2494            alpha,
2495            b.as_ptr(),
2496            x.as_mut_ptr(),
2497        )
2498    };
2499    check_status("sparse", status)
2500}
2501
2502/// `f32` symmetric Gauss-Seidel. See [`symgs_f64`].
2503pub fn symgs_f32(
2504    op: Trans,
2505    a: &SparseMatrix<f32>,
2506    descr: &MatDescr,
2507    alpha: f32,
2508    b: &[f32],
2509    x: &mut [f32],
2510) -> Result<()> {
2511    let status = unsafe {
2512        sys::aoclsparse_ssymgs(
2513            trans_raw(op),
2514            a.as_raw(),
2515            descr.as_raw(),
2516            alpha,
2517            b.as_ptr(),
2518            x.as_mut_ptr(),
2519        )
2520    };
2521    check_status("sparse", status)
2522}
2523
2524/// Fused Gauss-Seidel + matrix-vector: produces both the updated `x`
2525/// and `y := A·x` in one pass. (`f64`)
2526#[allow(clippy::too_many_arguments)]
2527pub fn symgs_mv_f64(
2528    op: Trans,
2529    a: &SparseMatrix<f64>,
2530    descr: &MatDescr,
2531    alpha: f64,
2532    b: &[f64],
2533    x: &mut [f64],
2534    y: &mut [f64],
2535) -> Result<()> {
2536    let status = unsafe {
2537        sys::aoclsparse_dsymgs_mv(
2538            trans_raw(op),
2539            a.as_raw(),
2540            descr.as_raw(),
2541            alpha,
2542            b.as_ptr(),
2543            x.as_mut_ptr(),
2544            y.as_mut_ptr(),
2545        )
2546    };
2547    check_status("sparse", status)
2548}
2549
2550/// `f32` fused symmetric Gauss-Seidel + mat-vec. See [`symgs_mv_f64`].
2551#[allow(clippy::too_many_arguments)]
2552pub fn symgs_mv_f32(
2553    op: Trans,
2554    a: &SparseMatrix<f32>,
2555    descr: &MatDescr,
2556    alpha: f32,
2557    b: &[f32],
2558    x: &mut [f32],
2559    y: &mut [f32],
2560) -> Result<()> {
2561    let status = unsafe {
2562        sys::aoclsparse_ssymgs_mv(
2563            trans_raw(op),
2564            a.as_raw(),
2565            descr.as_raw(),
2566            alpha,
2567            b.as_ptr(),
2568            x.as_mut_ptr(),
2569            y.as_mut_ptr(),
2570        )
2571    };
2572    check_status("sparse", status)
2573}
2574
2575/// Set a single value at `(row_idx, col_idx)` in a CSR matrix handle.
2576/// Element must already exist in the sparsity pattern. (`f64`)
2577pub fn set_value_f64(
2578    a: &mut SparseMatrix<f64>,
2579    row_idx: i32,
2580    col_idx: i32,
2581    val: f64,
2582) -> Result<()> {
2583    let status = unsafe {
2584        sys::aoclsparse_dset_value(
2585            a.as_raw(),
2586            row_idx as sys::aoclsparse_int,
2587            col_idx as sys::aoclsparse_int,
2588            val,
2589        )
2590    };
2591    check_status("sparse", status)
2592}
2593
2594/// `f32` set_value. See [`set_value_f64`].
2595pub fn set_value_f32(
2596    a: &mut SparseMatrix<f32>,
2597    row_idx: i32,
2598    col_idx: i32,
2599    val: f32,
2600) -> Result<()> {
2601    let status = unsafe {
2602        sys::aoclsparse_sset_value(
2603            a.as_raw(),
2604            row_idx as sys::aoclsparse_int,
2605            col_idx as sys::aoclsparse_int,
2606            val,
2607        )
2608    };
2609    check_status("sparse", status)
2610}
2611
2612/// Update the values array of a CSR handle without rebuilding.
2613/// `len = nnz`; the new values must respect the existing sparsity
2614/// pattern. (`f64`)
2615pub fn update_values_f64(a: &mut SparseMatrix<f64>, val: &mut [f64]) -> Result<()> {
2616    let status = unsafe {
2617        sys::aoclsparse_dupdate_values(
2618            a.as_raw(),
2619            val.len() as sys::aoclsparse_int,
2620            val.as_mut_ptr(),
2621        )
2622    };
2623    check_status("sparse", status)
2624}
2625
2626/// `f32` update_values. See [`update_values_f64`].
2627pub fn update_values_f32(a: &mut SparseMatrix<f32>, val: &mut [f32]) -> Result<()> {
2628    let status = unsafe {
2629        sys::aoclsparse_supdate_values(
2630            a.as_raw(),
2631            val.len() as sys::aoclsparse_int,
2632            val.as_mut_ptr(),
2633        )
2634    };
2635    check_status("sparse", status)
2636}
2637
2638// =========================================================================
2639//   Fused dot + mat-vec, sparse rotations, strided gather/scatter,
2640//   symmetric rank-k, symmetric triple product
2641// =========================================================================
2642
2643/// Fused dot product and mat-vec: `y := α · op(A) · x + β · y` and
2644/// `d := xᵀ · y` (or `d := xᴴ · y` for complex). Saves one pass over
2645/// the matrix versus calling `mv` then `dot` separately. (`f64`)
2646#[allow(clippy::too_many_arguments)]
2647pub fn dotmv_f64(
2648    op: Trans,
2649    alpha: f64,
2650    a: &SparseMatrix<f64>,
2651    descr: &MatDescr,
2652    x: &[f64],
2653    beta: f64,
2654    y: &mut [f64],
2655    d: &mut f64,
2656) -> Result<()> {
2657    let status = unsafe {
2658        sys::aoclsparse_ddotmv(
2659            trans_raw(op),
2660            alpha,
2661            a.as_raw(),
2662            descr.as_raw(),
2663            x.as_ptr(),
2664            beta,
2665            y.as_mut_ptr(),
2666            d,
2667        )
2668    };
2669    check_status("sparse", status)
2670}
2671
2672/// `f32` fused dot + mat-vec. See [`dotmv_f64`].
2673#[allow(clippy::too_many_arguments)]
2674pub fn dotmv_f32(
2675    op: Trans,
2676    alpha: f32,
2677    a: &SparseMatrix<f32>,
2678    descr: &MatDescr,
2679    x: &[f32],
2680    beta: f32,
2681    y: &mut [f32],
2682    d: &mut f32,
2683) -> Result<()> {
2684    let status = unsafe {
2685        sys::aoclsparse_sdotmv(
2686            trans_raw(op),
2687            alpha,
2688            a.as_raw(),
2689            descr.as_raw(),
2690            x.as_ptr(),
2691            beta,
2692            y.as_mut_ptr(),
2693            d,
2694        )
2695    };
2696    check_status("sparse", status)
2697}
2698
2699/// Symmetric rank-k update from a sparse matrix:
2700/// `C := α · op(A) · op(A)ᵀ + β · C` (real) /
2701/// `C := α · op(A) · op(A)ᴴ + β · C` (complex). `C` is dense. (`f64`)
2702#[allow(clippy::too_many_arguments)]
2703pub fn syrkd_f64(
2704    op_a: Trans,
2705    a: &SparseMatrix<f64>,
2706    alpha: f64,
2707    beta: f64,
2708    c: &mut [f64],
2709    order_c: Order,
2710    ldc: usize,
2711) -> Result<()> {
2712    let status = unsafe {
2713        sys::aoclsparse_dsyrkd(
2714            trans_raw(op_a),
2715            a.as_raw(),
2716            alpha,
2717            beta,
2718            c.as_mut_ptr(),
2719            order_c.raw(),
2720            ldc as sys::aoclsparse_int,
2721        )
2722    };
2723    check_status("sparse", status)
2724}
2725
2726/// `f32` symmetric rank-k update from sparse `A`. See [`syrkd_f64`].
2727#[allow(clippy::too_many_arguments)]
2728pub fn syrkd_f32(
2729    op_a: Trans,
2730    a: &SparseMatrix<f32>,
2731    alpha: f32,
2732    beta: f32,
2733    c: &mut [f32],
2734    order_c: Order,
2735    ldc: usize,
2736) -> Result<()> {
2737    let status = unsafe {
2738        sys::aoclsparse_ssyrkd(
2739            trans_raw(op_a),
2740            a.as_raw(),
2741            alpha,
2742            beta,
2743            c.as_mut_ptr(),
2744            order_c.raw(),
2745            ldc as sys::aoclsparse_int,
2746        )
2747    };
2748    check_status("sparse", status)
2749}
2750
2751/// Symmetric triple product from a sparse `A` and dense `B`:
2752/// `C := α · op(A) · B · op(A)ᵀ + β · C`. (`f64`)
2753#[allow(clippy::too_many_arguments)]
2754pub fn syprd_f64(
2755    op_a: Trans,
2756    a: &SparseMatrix<f64>,
2757    b: &[f64],
2758    order_b: Order,
2759    ldb: usize,
2760    alpha: f64,
2761    beta: f64,
2762    c: &mut [f64],
2763    order_c: Order,
2764    ldc: usize,
2765) -> Result<()> {
2766    let status = unsafe {
2767        sys::aoclsparse_dsyprd(
2768            trans_raw(op_a),
2769            a.as_raw(),
2770            b.as_ptr(),
2771            order_b.raw(),
2772            ldb as sys::aoclsparse_int,
2773            alpha,
2774            beta,
2775            c.as_mut_ptr(),
2776            order_c.raw(),
2777            ldc as sys::aoclsparse_int,
2778        )
2779    };
2780    check_status("sparse", status)
2781}
2782
2783/// `f32` symmetric triple product. See [`syprd_f64`].
2784#[allow(clippy::too_many_arguments)]
2785pub fn syprd_f32(
2786    op_a: Trans,
2787    a: &SparseMatrix<f32>,
2788    b: &[f32],
2789    order_b: Order,
2790    ldb: usize,
2791    alpha: f32,
2792    beta: f32,
2793    c: &mut [f32],
2794    order_c: Order,
2795    ldc: usize,
2796) -> Result<()> {
2797    let status = unsafe {
2798        sys::aoclsparse_ssyprd(
2799            trans_raw(op_a),
2800            a.as_raw(),
2801            b.as_ptr(),
2802            order_b.raw(),
2803            ldb as sys::aoclsparse_int,
2804            alpha,
2805            beta,
2806            c.as_mut_ptr(),
2807            order_c.raw(),
2808            ldc as sys::aoclsparse_int,
2809        )
2810    };
2811    check_status("sparse", status)
2812}
2813
2814/// Sparse rotation: simultaneously rotate `(x, y)` at indices `indx`
2815/// by `(c, s)`. `x` is sparse with `indx` indices into the dense `y`. (`f64`)
2816pub fn roti_f64(
2817    x: &mut [f64],
2818    indx: &[sys::aoclsparse_int],
2819    y: &mut [f64],
2820    c: f64,
2821    s: f64,
2822) -> Result<()> {
2823    let status = unsafe {
2824        sys::aoclsparse_droti(
2825            x.len() as sys::aoclsparse_int,
2826            x.as_mut_ptr(),
2827            indx.as_ptr(),
2828            y.as_mut_ptr(),
2829            c,
2830            s,
2831        )
2832    };
2833    check_status("sparse", status)
2834}
2835
2836/// `f32` sparse rotation. See [`roti_f64`].
2837pub fn roti_f32(
2838    x: &mut [f32],
2839    indx: &[sys::aoclsparse_int],
2840    y: &mut [f32],
2841    c: f32,
2842    s: f32,
2843) -> Result<()> {
2844    let status = unsafe {
2845        sys::aoclsparse_sroti(
2846            x.len() as sys::aoclsparse_int,
2847            x.as_mut_ptr(),
2848            indx.as_ptr(),
2849            y.as_mut_ptr(),
2850            c,
2851            s,
2852        )
2853    };
2854    check_status("sparse", status)
2855}
2856
2857/// Strided sparse gather: `x[i] := y[i · stride]` for `i ∈ [0, nnz)`. (`f64`)
2858pub fn gthrs_f64(y: &[f64], x: &mut [f64], stride: i32) -> Result<()> {
2859    let status = unsafe {
2860        sys::aoclsparse_dgthrs(
2861            x.len() as sys::aoclsparse_int,
2862            y.as_ptr(),
2863            x.as_mut_ptr(),
2864            stride as sys::aoclsparse_int,
2865        )
2866    };
2867    check_status("sparse", status)
2868}
2869/// `f32` strided gather. See [`gthrs_f64`].
2870pub fn gthrs_f32(y: &[f32], x: &mut [f32], stride: i32) -> Result<()> {
2871    let status = unsafe {
2872        sys::aoclsparse_sgthrs(
2873            x.len() as sys::aoclsparse_int,
2874            y.as_ptr(),
2875            x.as_mut_ptr(),
2876            stride as sys::aoclsparse_int,
2877        )
2878    };
2879    check_status("sparse", status)
2880}
2881
2882/// Strided sparse scatter: `y[i · stride] := x[i]`. (`f64`)
2883pub fn sctrs_f64(x: &[f64], y: &mut [f64], stride: i32) -> Result<()> {
2884    let status = unsafe {
2885        sys::aoclsparse_dsctrs(
2886            x.len() as sys::aoclsparse_int,
2887            x.as_ptr(),
2888            stride as sys::aoclsparse_int,
2889            y.as_mut_ptr(),
2890        )
2891    };
2892    check_status("sparse", status)
2893}
2894/// `f32` strided scatter. See [`sctrs_f64`].
2895pub fn sctrs_f32(x: &[f32], y: &mut [f32], stride: i32) -> Result<()> {
2896    let status = unsafe {
2897        sys::aoclsparse_ssctrs(
2898            x.len() as sys::aoclsparse_int,
2899            x.as_ptr(),
2900            stride as sys::aoclsparse_int,
2901            y.as_mut_ptr(),
2902        )
2903    };
2904    check_status("sparse", status)
2905}
2906
2907/// Gather-and-zero: copy `y[indx[i]]` into `x[i]` and zero the source.
2908/// (`f64`)
2909pub fn gthrz_f64(y: &mut [f64], indx: &[sys::aoclsparse_int], x: &mut [f64]) -> Result<()> {
2910    let status = unsafe {
2911        sys::aoclsparse_dgthrz(
2912            x.len() as sys::aoclsparse_int,
2913            y.as_mut_ptr(),
2914            x.as_mut_ptr(),
2915            indx.as_ptr(),
2916        )
2917    };
2918    check_status("sparse", status)
2919}
2920/// `f32` gather-and-zero. See [`gthrz_f64`].
2921pub fn gthrz_f32(y: &mut [f32], indx: &[sys::aoclsparse_int], x: &mut [f32]) -> Result<()> {
2922    let status = unsafe {
2923        sys::aoclsparse_sgthrz(
2924            x.len() as sys::aoclsparse_int,
2925            y.as_mut_ptr(),
2926            x.as_mut_ptr(),
2927            indx.as_ptr(),
2928        )
2929    };
2930    check_status("sparse", status)
2931}
2932
2933/// `f32` SOR / forward / backward Gauss-Seidel sweep. The existing
2934/// generic [`sorv`] dispatches to this via the Scalar trait; this is
2935/// a per-precision direct alias.
2936#[allow(clippy::too_many_arguments)]
2937pub fn sorv_f32(
2938    sor_type: SorType,
2939    descr: &MatDescr,
2940    a: &SparseMatrix<f32>,
2941    omega: f32,
2942    alpha: f32,
2943    x: &mut [f32],
2944    b: &[f32],
2945) -> Result<()> {
2946    let status = unsafe {
2947        sys::aoclsparse_ssorv(
2948            sor_type.raw(),
2949            descr.as_raw(),
2950            a.as_raw(),
2951            omega,
2952            alpha,
2953            x.as_mut_ptr(),
2954            b.as_ptr(),
2955        )
2956    };
2957    check_status("sparse", status)
2958}
2959
2960// =========================================================================
2961//   TCSR (triangular CSR): two parallel CSRs, one for L, one for U
2962// =========================================================================
2963
2964/// Build a TCSR matrix handle from two CSR sub-arrays describing the
2965/// strict-lower (`*_L`) and strict-upper (`*_U`) parts. (`f64`)
2966#[allow(clippy::too_many_arguments)]
2967pub fn create_tcsr_f64(
2968    base: IndexBase,
2969    m: usize,
2970    n: usize,
2971    nnz: usize,
2972    row_ptr_l: &mut [sys::aoclsparse_int],
2973    row_ptr_u: &mut [sys::aoclsparse_int],
2974    col_idx_l: &mut [sys::aoclsparse_int],
2975    col_idx_u: &mut [sys::aoclsparse_int],
2976    val_l: &mut [f64],
2977    val_u: &mut [f64],
2978) -> Result<sys::aoclsparse_matrix> {
2979    let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
2980    let status = unsafe {
2981        sys::aoclsparse_create_dtcsr(
2982            &mut raw,
2983            base.raw(),
2984            m as sys::aoclsparse_int,
2985            n as sys::aoclsparse_int,
2986            nnz as sys::aoclsparse_int,
2987            row_ptr_l.as_mut_ptr(),
2988            row_ptr_u.as_mut_ptr(),
2989            col_idx_l.as_mut_ptr(),
2990            col_idx_u.as_mut_ptr(),
2991            val_l.as_mut_ptr(),
2992            val_u.as_mut_ptr(),
2993        )
2994    };
2995    check_status("sparse", status)?;
2996    if raw.is_null() {
2997        return Err(Error::AllocationFailed("sparse"));
2998    }
2999    Ok(raw)
3000}
3001
3002/// `f32` TCSR creator. See [`create_tcsr_f64`].
3003#[allow(clippy::too_many_arguments)]
3004pub fn create_tcsr_f32(
3005    base: IndexBase,
3006    m: usize,
3007    n: usize,
3008    nnz: usize,
3009    row_ptr_l: &mut [sys::aoclsparse_int],
3010    row_ptr_u: &mut [sys::aoclsparse_int],
3011    col_idx_l: &mut [sys::aoclsparse_int],
3012    col_idx_u: &mut [sys::aoclsparse_int],
3013    val_l: &mut [f32],
3014    val_u: &mut [f32],
3015) -> Result<sys::aoclsparse_matrix> {
3016    let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
3017    let status = unsafe {
3018        sys::aoclsparse_create_stcsr(
3019            &mut raw,
3020            base.raw(),
3021            m as sys::aoclsparse_int,
3022            n as sys::aoclsparse_int,
3023            nnz as sys::aoclsparse_int,
3024            row_ptr_l.as_mut_ptr(),
3025            row_ptr_u.as_mut_ptr(),
3026            col_idx_l.as_mut_ptr(),
3027            col_idx_u.as_mut_ptr(),
3028            val_l.as_mut_ptr(),
3029            val_u.as_mut_ptr(),
3030        )
3031    };
3032    check_status("sparse", status)?;
3033    if raw.is_null() {
3034        return Err(Error::AllocationFailed("sparse"));
3035    }
3036    Ok(raw)
3037}
3038
3039// =========================================================================
3040//   CSC creators (real)
3041// =========================================================================
3042
3043/// Build a CSC matrix handle. Column-major analogue of CSR. (`f64`)
3044#[allow(clippy::too_many_arguments)]
3045pub fn create_csc_f64(
3046    base: IndexBase,
3047    m: usize,
3048    n: usize,
3049    nnz: usize,
3050    col_ptr: &mut [sys::aoclsparse_int],
3051    row_idx: &mut [sys::aoclsparse_int],
3052    val: &mut [f64],
3053) -> Result<sys::aoclsparse_matrix> {
3054    let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
3055    let status = unsafe {
3056        sys::aoclsparse_create_dcsc(
3057            &mut raw,
3058            base.raw(),
3059            m as sys::aoclsparse_int,
3060            n as sys::aoclsparse_int,
3061            nnz as sys::aoclsparse_int,
3062            col_ptr.as_mut_ptr(),
3063            row_idx.as_mut_ptr(),
3064            val.as_mut_ptr(),
3065        )
3066    };
3067    check_status("sparse", status)?;
3068    if raw.is_null() {
3069        return Err(Error::AllocationFailed("sparse"));
3070    }
3071    Ok(raw)
3072}
3073/// `f32` CSC creator. See [`create_csc_f64`].
3074#[allow(clippy::too_many_arguments)]
3075pub fn create_csc_f32(
3076    base: IndexBase,
3077    m: usize,
3078    n: usize,
3079    nnz: usize,
3080    col_ptr: &mut [sys::aoclsparse_int],
3081    row_idx: &mut [sys::aoclsparse_int],
3082    val: &mut [f32],
3083) -> Result<sys::aoclsparse_matrix> {
3084    let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
3085    let status = unsafe {
3086        sys::aoclsparse_create_scsc(
3087            &mut raw,
3088            base.raw(),
3089            m as sys::aoclsparse_int,
3090            n as sys::aoclsparse_int,
3091            nnz as sys::aoclsparse_int,
3092            col_ptr.as_mut_ptr(),
3093            row_idx.as_mut_ptr(),
3094            val.as_mut_ptr(),
3095        )
3096    };
3097    check_status("sparse", status)?;
3098    if raw.is_null() {
3099        return Err(Error::AllocationFailed("sparse"));
3100    }
3101    Ok(raw)
3102}
3103
3104/// Memory-usage hint to give the analysis pass.
3105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3106pub enum MemoryUsage {
3107    /// Minimise extra memory; favour smaller working sets.
3108    Minimal,
3109    /// Allow the library to use additional memory if it speeds things up.
3110    Unrestricted,
3111}
3112
3113impl MemoryUsage {
3114    fn raw(self) -> sys::aoclsparse_memory_usage {
3115        match self {
3116            MemoryUsage::Minimal => sys::aoclsparse_memory_usage__aoclsparse_memory_usage_minimal,
3117            MemoryUsage::Unrestricted => {
3118                sys::aoclsparse_memory_usage__aoclsparse_memory_usage_unrestricted
3119            }
3120        }
3121    }
3122}
3123
3124/// Tell the library how much extra memory it may consume on this
3125/// matrix. Pair with [`optimize`].
3126pub fn set_memory_hint<T: Scalar>(mat: &mut SparseMatrix<T>, policy: MemoryUsage) -> Result<()> {
3127    let status = unsafe { sys::aoclsparse_set_memory_hint(mat.as_raw(), policy.raw()) };
3128    check_status("sparse", status)
3129}
3130
3131/// Apply one ILU(0) smoothing step in-place to `x`, with right-hand side `b`.
3132///
3133/// On the first call against a matrix this also factorizes; subsequent
3134/// calls re-use the cached factors stored on the matrix handle.
3135pub fn ilu_smoother<T: Scalar>(
3136    op: Trans,
3137    a: &SparseMatrix<T>,
3138    descr: &MatDescr,
3139    x: &mut [T],
3140    b: &[T],
3141) -> Result<()> {
3142    if x.len() < a.n || b.len() < a.m {
3143        return Err(Error::InvalidArgument(format!(
3144            "ilu_smoother: x.len()={}, b.len()={}, dims=({}, {})",
3145            x.len(),
3146            b.len(),
3147            a.m,
3148            a.n
3149        )));
3150    }
3151    T::ilu_smoother(op, a.raw, descr, x, b)
3152}
3153
3154// =========================================================================
3155//   Iterative solver (CG / GMRES) — direct interface
3156// =========================================================================
3157
3158/// RAII handle for the AOCL-Sparse iterative-solver suite.
3159///
3160/// Configure the solver type and tolerances with [`IterSolver::set_option`]
3161/// (e.g. `set_option("iterative method", "cg")` or `"gmres"`), then call
3162/// [`IterSolver::solve`] with the system matrix and right-hand side.
3163pub struct IterSolver<T: Scalar> {
3164    handle: sys::aoclsparse_itsol_handle,
3165    _marker: PhantomData<T>,
3166}
3167
3168impl<T: Scalar> IterSolver<T> {
3169    /// Initialise a new iterative-solver handle for this scalar type.
3170    pub fn new() -> Result<Self> {
3171        let mut handle: sys::aoclsparse_itsol_handle = std::ptr::null_mut();
3172        T::itsol_init(&mut handle)?;
3173        if handle.is_null() {
3174            return Err(Error::AllocationFailed("sparse"));
3175        }
3176        Ok(Self {
3177            handle,
3178            _marker: PhantomData,
3179        })
3180    }
3181
3182    /// Set a string-keyed solver option. See AOCL-Sparse's
3183    /// `aoclsparse_itsol_option_set` for the full option list. Common
3184    /// keys: `"iterative method"` (`"cg"`/`"gmres"`/`"pcg"`), `"cg
3185    /// iteration limit"`, `"cg rel tolerance"`, `"gmres preconditioner"`.
3186    pub fn set_option(&mut self, name: &str, value: &str) -> Result<()> {
3187        let c_name = CString::new(name)
3188            .map_err(|_| Error::InvalidArgument("set_option: name has interior NUL".into()))?;
3189        let c_value = CString::new(value)
3190            .map_err(|_| Error::InvalidArgument("set_option: value has interior NUL".into()))?;
3191        let status = unsafe {
3192            sys::aoclsparse_itsol_option_set(self.handle, c_name.as_ptr(), c_value.as_ptr())
3193        };
3194        check_status("sparse", status)
3195    }
3196
3197    /// Solve `A · x = b`. On entry `x` should hold an initial guess (zero
3198    /// is fine if you have nothing better); on success it contains the
3199    /// approximate solution. Returns the solver's `rinfo[100]` array of
3200    /// statistics (iteration counts, residual norms, etc.).
3201    pub fn solve(
3202        &mut self,
3203        mat: &SparseMatrix<T>,
3204        descr: &MatDescr,
3205        b: &[T],
3206        x: &mut [T],
3207    ) -> Result<Box<[T; 100]>>
3208    where
3209        T: Default,
3210    {
3211        let n = mat.n;
3212        if mat.m != mat.n {
3213            return Err(Error::InvalidArgument(format!(
3214                "iterative solve requires square matrix; got ({}, {})",
3215                mat.m, mat.n
3216            )));
3217        }
3218        let mut rinfo: Box<[T; 100]> = Box::new([T::default(); 100]);
3219        T::itsol_solve(self.handle, n, mat.raw, descr, b, x, &mut rinfo)?;
3220        Ok(rinfo)
3221    }
3222}
3223
3224impl<T: Scalar> Drop for IterSolver<T> {
3225    fn drop(&mut self) {
3226        if !self.handle.is_null() {
3227            unsafe {
3228                sys::aoclsparse_itsol_destroy(&mut self.handle);
3229            }
3230            self.handle = std::ptr::null_mut();
3231        }
3232    }
3233}
3234
3235impl<T: Scalar> std::fmt::Debug for IterSolver<T> {
3236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3237        f.debug_struct("IterSolver").finish_non_exhaustive()
3238    }
3239}
3240
3241#[cfg(test)]
3242mod tests {
3243    use super::*;
3244
3245    #[test]
3246    fn csrmv_2x2_identity_f64() {
3247        let val = [1.0_f64, 1.0];
3248        let col: [sys::aoclsparse_int; 2] = [0, 1];
3249        let rowptr: [sys::aoclsparse_int; 3] = [0, 1, 2];
3250        let x = [3.0_f64, 4.0];
3251        let mut y = [0.0_f64; 2];
3252        let descr = MatDescr::new().unwrap();
3253        csrmv(1.0_f64, 2, 2, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap();
3254        assert!((y[0] - 3.0).abs() < 1e-12);
3255        assert!((y[1] - 4.0).abs() < 1e-12);
3256    }
3257
3258    #[test]
3259    fn csrmv_simple_2x3() {
3260        let val = [1.0_f64, 2.0, 3.0];
3261        let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3262        let rowptr: [sys::aoclsparse_int; 3] = [0, 2, 3];
3263        let x = [1.0_f64; 3];
3264        let mut y = [0.0_f64; 2];
3265        let descr = MatDescr::new().unwrap();
3266        csrmv(1.0_f64, 2, 3, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap();
3267        assert!((y[0] - 3.0).abs() < 1e-12, "got {}", y[0]);
3268        assert!((y[1] - 3.0).abs() < 1e-12, "got {}", y[1]);
3269    }
3270
3271    #[test]
3272    fn dim_mismatch_is_error() {
3273        let val = [1.0_f64];
3274        let col: [sys::aoclsparse_int; 1] = [0];
3275        let rowptr: [sys::aoclsparse_int; 2] = [0, 1];
3276        let x = [1.0_f64];
3277        let mut y = [0.0_f64; 2];
3278        let descr = MatDescr::new().unwrap();
3279        let err = csrmv(1.0_f64, 2, 1, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap_err();
3280        matches!(err, Error::InvalidArgument(_));
3281    }
3282
3283    #[test]
3284    fn axpyi_scatter() {
3285        // y = [10, 20, 30, 40], x = [1, 2], indx = [0, 2], α = 3
3286        // → y[0] += 3·1 = 13; y[2] += 3·2 = 36
3287        let mut y = [10.0_f64, 20.0, 30.0, 40.0];
3288        let x = [1.0_f64, 2.0];
3289        let indx: [sys::aoclsparse_int; 2] = [0, 2];
3290        axpyi(3.0_f64, &x, &indx, &mut y).unwrap();
3291        assert_eq!(y, [13.0, 20.0, 36.0, 40.0]);
3292    }
3293
3294    #[test]
3295    fn gthr_scatter_round_trip() {
3296        // Gather from y at indx → x; scatter x at indx → into a fresh y2.
3297        let y = [10.0_f64, 20.0, 30.0, 40.0];
3298        let indx: [sys::aoclsparse_int; 2] = [1, 3];
3299        let mut x = [0.0_f64; 2];
3300        gthr(&y, &indx, &mut x).unwrap();
3301        assert_eq!(x, [20.0, 40.0]);
3302
3303        let mut y2 = [0.0_f64; 4];
3304        sctr(&x, &indx, &mut y2).unwrap();
3305        assert_eq!(y2, [0.0, 20.0, 0.0, 40.0]);
3306    }
3307
3308    #[test]
3309    fn add_identity_plus_identity_is_2_diag() {
3310        // A = B = 2x2 identity; C = 1·A + B = 2·I.
3311        let val = [1.0_f64, 1.0];
3312        let col: [sys::aoclsparse_int; 2] = [0, 1];
3313        let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
3314        let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
3315        let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
3316        let c = add(Trans::No, &a, 1.0, &b).unwrap();
3317        let (_, _, _, val_c) = c.export_csr().unwrap();
3318        // Both diagonal entries should equal 2.
3319        assert_eq!(val_c.len(), 2);
3320        for v in &val_c {
3321            assert!((v - 2.0).abs() < 1e-12, "got {v}, want 2.0");
3322        }
3323    }
3324
3325    #[test]
3326    fn csrmm_2x2_identity_against_2x3_dense() {
3327        // 2×2 identity sparse A; 2×3 dense B; C = A · B should equal B.
3328        // Row-major: ldb = ldc = 3.
3329        let val = [1.0_f64, 1.0];
3330        let col: [sys::aoclsparse_int; 2] = [0, 1];
3331        let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
3332        let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
3333        let descr = MatDescr::new().unwrap();
3334        let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
3335        let mut c = [0.0_f64; 6];
3336        csrmm(
3337            Trans::No,
3338            1.0,
3339            &a,
3340            &descr,
3341            Order::RowMajor,
3342            &b,
3343            3,
3344            3,
3345            0.0,
3346            &mut c,
3347            3,
3348        )
3349        .unwrap();
3350        for (got, want) in c.iter().zip(b.iter()) {
3351            assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
3352        }
3353    }
3354
3355    #[test]
3356    fn spmmd_identity_squared_yields_identity_dense() {
3357        // 3×3 identity * 3×3 identity = 3×3 identity dense.
3358        let val = [1.0_f64; 3];
3359        let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3360        let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
3361        let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3362        let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3363        let mut c = [0.0_f64; 9];
3364        spmmd(Trans::No, &a, &b, Order::RowMajor, &mut c, 3).unwrap();
3365        // Row-major identity matrix.
3366        let expected = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
3367        for (got, want) in c.iter().zip(expected.iter()) {
3368            assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
3369        }
3370    }
3371
3372    #[test]
3373    fn ellmv_2x3_f64() {
3374        // 2×3: [[1, 0, 2], [3, 4, 0]]; ell_width = 2 (max nnz/row).
3375        // Padding indices/values for the "missing" slot use index 0, value 0.
3376        let val: [f64; 4] = [1.0, 2.0, 3.0, 4.0];
3377        let col: [sys::aoclsparse_int; 4] = [0, 2, 0, 1];
3378        let descr = MatDescr::new().unwrap();
3379        let x = [10.0_f64, 20.0, 30.0];
3380        let mut y = [0.0_f64; 2];
3381        ellmv(
3382            Trans::No,
3383            1.0_f64,
3384            2,
3385            3,
3386            &val,
3387            &col,
3388            2,
3389            &descr,
3390            &x,
3391            0.0,
3392            &mut y,
3393        )
3394        .unwrap();
3395        // y[0] = 1*10 + 2*30 = 70; y[1] = 3*10 + 4*20 = 110
3396        assert!((y[0] - 70.0).abs() < 1e-12, "got {}", y[0]);
3397        assert!((y[1] - 110.0).abs() < 1e-12, "got {}", y[1]);
3398    }
3399
3400    #[test]
3401    fn bsrmv_2x2_blocks_f64() {
3402        // 4×4 matrix laid out as 2×2 blocks of size 2×2:
3403        // block (0,0) = [[1,0],[0,1]] (identity), block (1,1) = [[2,0],[0,2]]
3404        // bsr_dim = 2, mb = 2, nb = 2, nnzb = 2
3405        let val: [f64; 8] = [1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
3406        let col: [sys::aoclsparse_int; 2] = [0, 1];
3407        let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
3408        let descr = MatDescr::new().unwrap();
3409        let x = [1.0_f64, 2.0, 3.0, 4.0];
3410        let mut y = [0.0_f64; 4];
3411        bsrmv(
3412            Trans::No,
3413            1.0_f64,
3414            2,
3415            2,
3416            2,
3417            &val,
3418            &col,
3419            &rp,
3420            &descr,
3421            &x,
3422            0.0,
3423            &mut y,
3424        )
3425        .unwrap();
3426        // y = diag([1,1,2,2]) * x = [1, 2, 6, 8]
3427        assert!((y[0] - 1.0).abs() < 1e-12);
3428        assert!((y[1] - 2.0).abs() < 1e-12);
3429        assert!((y[2] - 6.0).abs() < 1e-12);
3430        assert!((y[3] - 8.0).abs() < 1e-12);
3431    }
3432
3433    #[test]
3434    fn sparse_matrix_round_trip() {
3435        // 2×3 matrix [[1,0,2],[0,3,0]]
3436        let val = [1.0_f64, 2.0, 3.0];
3437        let col: [sys::aoclsparse_int; 3] = [0, 2, 1];
3438        let rp: [sys::aoclsparse_int; 3] = [0, 2, 3];
3439        let mat = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 3, &rp, &col, &val).unwrap();
3440        assert_eq!(mat.dims(), (2, 3));
3441        assert_eq!(mat.nnz(), 3);
3442        assert_eq!(mat.base(), IndexBase::Zero);
3443        let (base, rp2, col2, val2) = mat.export_csr().unwrap();
3444        assert_eq!(base, IndexBase::Zero);
3445        assert_eq!(rp2, [0, 2, 3]);
3446        assert_eq!(col2, [0, 2, 1]);
3447        assert_eq!(val2, [1.0, 2.0, 3.0]);
3448    }
3449
3450    #[test]
3451    fn csr2m_identity_squared_is_identity() {
3452        // 3×3 identity in CSR.
3453        let val = [1.0_f64; 3];
3454        let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3455        let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
3456        let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3457        let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3458        let descr = MatDescr::new().unwrap();
3459        let c = csr2m(
3460            Trans::No,
3461            &descr,
3462            &a,
3463            Trans::No,
3464            &descr,
3465            &b,
3466            Stage::FullComputation,
3467        )
3468        .unwrap();
3469        assert_eq!(c.dims(), (3, 3));
3470        let (_, rp_c, col_c, val_c) = c.export_csr().unwrap();
3471        assert_eq!(rp_c, [0, 1, 2, 3]);
3472        assert_eq!(col_c, [0, 1, 2]);
3473        for v in &val_c {
3474            assert!((v - 1.0).abs() < 1e-12);
3475        }
3476    }
3477
3478    #[test]
3479    fn iter_solver_cg_diagonal_3x3() {
3480        // Solve diag(2,2,2) · x = b, with b = [4, 6, 10]; expected x = [2, 3, 5].
3481        let val = [2.0_f64, 2.0, 2.0];
3482        let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
3483        let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
3484        let mat = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
3485
3486        let descr = MatDescr::new().unwrap();
3487        unsafe {
3488            sys::aoclsparse_set_mat_type(
3489                descr.as_raw(),
3490                sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric,
3491            );
3492        }
3493
3494        let b = [4.0_f64, 6.0, 10.0];
3495        let mut x = [0.0_f64; 3];
3496        let mut solver = IterSolver::<f64>::new().unwrap();
3497        solver.set_option("iterative method", "cg").unwrap();
3498        solver.set_option("cg rel tolerance", "1e-10").unwrap();
3499        solver.set_option("cg iteration limit", "200").unwrap();
3500        solver.solve(&mat, &descr, &b, &mut x).unwrap();
3501        assert!((x[0] - 2.0).abs() < 1e-6, "x[0] = {}", x[0]);
3502        assert!((x[1] - 3.0).abs() < 1e-6, "x[1] = {}", x[1]);
3503        assert!((x[2] - 5.0).abs() < 1e-6, "x[2] = {}", x[2]);
3504    }
3505
3506    #[test]
3507    fn csr_to_dense_round_trip() {
3508        // 2×3 CSR: [[1,0,2],[0,3,0]] → val=[1,2,3], col=[0,2,1], rp=[0,2,3]
3509        let val = [1.0_f64, 2.0, 3.0];
3510        let col: [sys::aoclsparse_int; 3] = [0, 2, 1];
3511        let rp: [sys::aoclsparse_int; 3] = [0, 2, 3];
3512        let descr = MatDescr::new().unwrap();
3513        let mut dense = [0.0_f64; 6];
3514        csr_to_dense::<f64>(
3515            2,
3516            3,
3517            &descr,
3518            &val,
3519            &rp,
3520            &col,
3521            &mut dense,
3522            3,
3523            Order::RowMajor,
3524        )
3525        .unwrap();
3526        assert_eq!(dense, [1.0, 0.0, 2.0, 0.0, 3.0, 0.0]);
3527    }
3528}