Skip to main content

apple_accelerate/
sparse.rs

1use crate::blas::{blas_order, blas_transpose};
2use crate::bridge;
3use crate::error::{Error, Result};
4use core::ffi::c_void;
5use core::ptr;
6
7/// Sparse index type used by the `sparse_*_float` routines.
8pub type SparseIndex = i64;
9
10/// `sparse_matrix_property` constants.
11pub mod sparse_matrix_property {
12    /// `sparse_matrix_property` flag for upper-triangular matrices.
13    pub const UPPER_TRIANGULAR: i32 = 1;
14    /// `sparse_matrix_property` flag for lower-triangular matrices.
15    pub const LOWER_TRIANGULAR: i32 = 2;
16    /// `sparse_matrix_property` flag for upper-symmetric matrices.
17    pub const UPPER_SYMMETRIC: i32 = 4;
18    /// `sparse_matrix_property` flag for lower-symmetric matrices.
19    pub const LOWER_SYMMETRIC: i32 = 8;
20}
21
22/// `sparse_status` constants.
23pub mod sparse_status {
24    /// `sparse_status` value returned for successful sparse operations.
25    pub const SUCCESS: i32 = 0;
26    /// `sparse_status` value returned when a sparse argument is invalid.
27    pub const ILLEGAL_PARAMETER: i32 = -1000;
28    /// `sparse_status` value returned when `sparse_set_matrix_property` rejects a property.
29    pub const CANNOT_SET_PROPERTY: i32 = -1001;
30    /// `sparse_status` value returned when the sparse runtime reports a system failure.
31    pub const SYSTEM_ERROR: i32 = -1002;
32}
33
34fn u64_len(value: usize) -> Result<u64> {
35    u64::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension overflowed"))
36}
37
38fn i64_index(value: usize) -> Result<i64> {
39    i64::try_from(value).map_err(|_| Error::OperationFailed("sparse index overflowed"))
40}
41
42fn usize_dimension(value: u64) -> Result<usize> {
43    usize::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension exceeds usize"))
44}
45
46fn usize_count(value: i64) -> Result<usize> {
47    if value < 0 {
48        return Err(Error::SparseStatus(
49            i32::try_from(value).unwrap_or(sparse_status::SYSTEM_ERROR),
50        ));
51    }
52    usize::try_from(value).map_err(|_| Error::OperationFailed("sparse count exceeds usize"))
53}
54
55fn sparse_result(status: i32) -> Result<()> {
56    if status == sparse_status::SUCCESS {
57        Ok(())
58    } else {
59        Err(Error::SparseStatus(status))
60    }
61}
62
63fn validate_sparse_entries(values: &[f32], indices: &[SparseIndex]) -> Result<()> {
64    if values.len() != indices.len() {
65        return Err(Error::InvalidLength {
66            expected: values.len(),
67            actual: indices.len(),
68        });
69    }
70    for window in indices.windows(2) {
71        if window[0] >= window[1] {
72            return Err(Error::InvalidValue(
73                "sparse indices must be strictly increasing and unique",
74            ));
75        }
76    }
77    Ok(())
78}
79
80fn validate_dense_coverage(indices: &[SparseIndex], dense_len: usize) -> Result<()> {
81    if let Some(&max_index) = indices.last() {
82        let max_index = usize::try_from(max_index)
83            .map_err(|_| Error::InvalidValue("sparse indices must be non-negative"))?;
84        if max_index >= dense_len {
85            return Err(Error::InvalidLength {
86                expected: max_index + 1,
87                actual: dense_len,
88            });
89        }
90    }
91    Ok(())
92}
93
94/// Owned sparse single-precision matrix handle backed by the Swift bridge.
95pub struct SparseMatrixF32 {
96    ptr: *mut c_void,
97}
98
99unsafe impl Send for SparseMatrixF32 {}
100unsafe impl Sync for SparseMatrixF32 {}
101
102impl Drop for SparseMatrixF32 {
103    fn drop(&mut self) {
104        if !self.ptr.is_null() {
105            // SAFETY: `ptr` is an opaque Swift object retained by the bridge.
106            unsafe { bridge::acc_release_handle(self.ptr) };
107            self.ptr = ptr::null_mut();
108        }
109    }
110}
111
112impl SparseMatrixF32 {
113    /// Creates a sparse single-precision matrix with `sparse_matrix_create_float`.
114    #[must_use]
115    pub fn new(rows: usize, columns: usize) -> Option<Self> {
116        if rows == 0 || columns == 0 {
117            return None;
118        }
119        let rows = u64::try_from(rows).ok()?;
120        let columns = u64::try_from(columns).ok()?;
121
122        // SAFETY: Pure constructor over scalar dimensions.
123        let ptr = unsafe { bridge::acc_sparse_matrix_f32_create(rows, columns) };
124        if ptr.is_null() {
125            None
126        } else {
127            Some(Self { ptr })
128        }
129    }
130
131    /// Sets a `sparse_matrix_property` on the matrix with `sparse_set_matrix_property`.
132    pub fn set_property(&mut self, property: i32) -> Result<()> {
133        // SAFETY: `self.ptr` is a live bridge handle.
134        sparse_result(unsafe { bridge::acc_sparse_matrix_f32_set_property(self.ptr, property) })
135    }
136
137    /// Inserts one matrix entry with `sparse_insert_entry_float`.
138    pub fn insert_entry(&mut self, row: usize, column: usize, value: f32) -> Result<()> {
139        let rows = self.rows()?;
140        let columns = self.columns()?;
141        if row >= rows || column >= columns {
142            return Err(Error::InvalidValue(
143                "sparse entry coordinates must be within matrix bounds",
144            ));
145        }
146
147        let row = i64_index(row)?;
148        let column = i64_index(column)?;
149        // SAFETY: Bounds were validated and `self.ptr` is a live bridge handle.
150        sparse_result(unsafe {
151            bridge::acc_sparse_matrix_f32_insert_entry(self.ptr, value, row, column)
152        })
153    }
154
155    /// Finalizes pending sparse edits with `sparse_commit`.
156    pub fn commit(&mut self) -> Result<()> {
157        // SAFETY: `self.ptr` is a live bridge handle.
158        sparse_result(unsafe { bridge::acc_sparse_matrix_f32_commit(self.ptr) })
159    }
160
161    /// Returns the row count reported by `sparse_get_matrix_number_of_rows`.
162    pub fn rows(&self) -> Result<usize> {
163        // SAFETY: `self.ptr` is a live bridge handle.
164        usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_rows(self.ptr) })
165    }
166
167    /// Returns the column count reported by `sparse_get_matrix_number_of_columns`.
168    pub fn columns(&self) -> Result<usize> {
169        // SAFETY: `self.ptr` is a live bridge handle.
170        usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_columns(self.ptr) })
171    }
172
173    /// Returns the nonzero count reported by `sparse_get_matrix_nonzero_count`.
174    pub fn nonzero_count(&self) -> Result<usize> {
175        // SAFETY: `self.ptr` is a live bridge handle.
176        usize_count(unsafe { bridge::acc_sparse_matrix_f32_nonzero_count(self.ptr) })
177    }
178
179    /// Solves a sparse triangular system against a dense vector with `sparse_vector_triangular_solve_dense_float`.
180    pub fn triangular_solve_vector(
181        &self,
182        transpose: i32,
183        alpha: f32,
184        values: &mut [f32],
185    ) -> Result<()> {
186        let rows = self.rows()?;
187        let columns = self.columns()?;
188        if rows != columns {
189            return Err(Error::InvalidValue(
190                "sparse triangular solve requires a square matrix",
191            ));
192        }
193        if values.len() != rows {
194            return Err(Error::InvalidLength {
195                expected: rows,
196                actual: values.len(),
197            });
198        }
199
200        let len = u64_len(values.len())?;
201        // SAFETY: The matrix and dense vector satisfy the API preconditions.
202        sparse_result(unsafe {
203            bridge::acc_sparse_matrix_f32_triangular_solve_vector(
204                self.ptr,
205                transpose,
206                alpha,
207                values.as_mut_ptr(),
208                len,
209            )
210        })
211    }
212
213    /// Solves a sparse triangular system against a row-major dense matrix with `sparse_matrix_triangular_solve_dense_float`.
214    pub fn triangular_solve_matrix_row_major(
215        &self,
216        transpose: i32,
217        rhs_columns: usize,
218        alpha: f32,
219        values: &mut [f32],
220    ) -> Result<()> {
221        let rows = self.rows()?;
222        let columns = self.columns()?;
223        if rows != columns {
224            return Err(Error::InvalidValue(
225                "sparse triangular solve requires a square matrix",
226            ));
227        }
228        let expected = rows
229            .checked_mul(rhs_columns)
230            .ok_or(Error::OperationFailed("sparse rhs dimensions overflowed"))?;
231        if values.len() != expected {
232            return Err(Error::InvalidLength {
233                expected,
234                actual: values.len(),
235            });
236        }
237        if rhs_columns == 0 {
238            return Ok(());
239        }
240
241        let rhs_count = u64_len(rhs_columns)?;
242        let ldb = u64_len(rhs_columns)?;
243        // SAFETY: The matrix and dense matrix satisfy the API preconditions.
244        sparse_result(unsafe {
245            bridge::acc_sparse_matrix_f32_triangular_solve_matrix(
246                self.ptr,
247                blas_order::ROW_MAJOR,
248                transpose,
249                rhs_count,
250                alpha,
251                values.as_mut_ptr(),
252                ldb,
253            )
254        })
255    }
256}
257
258/// Wraps `sparse_inner_product_dense_float`.
259pub fn dot_dense_f32(values: &[f32], indices: &[SparseIndex], dense: &[f32]) -> Result<f32> {
260    validate_sparse_entries(values, indices)?;
261    validate_dense_coverage(indices, dense.len())?;
262    if values.is_empty() {
263        return Ok(0.0);
264    }
265
266    let nz = u64_len(values.len())?;
267    // SAFETY: Inputs are validated for length, ordering, and dense coverage.
268    Ok(unsafe {
269        bridge::acc_sparse_dot_dense_f32(nz, values.as_ptr(), indices.as_ptr(), dense.as_ptr())
270    })
271}
272
273/// Wraps `sparse_inner_product_sparse_float`.
274pub fn dot_sparse_f32(
275    lhs_values: &[f32],
276    lhs_indices: &[SparseIndex],
277    rhs_values: &[f32],
278    rhs_indices: &[SparseIndex],
279) -> Result<f32> {
280    validate_sparse_entries(lhs_values, lhs_indices)?;
281    validate_sparse_entries(rhs_values, rhs_indices)?;
282    if lhs_values.is_empty() || rhs_values.is_empty() {
283        return Ok(0.0);
284    }
285
286    let lhs_count = u64_len(lhs_values.len())?;
287    let rhs_count = u64_len(rhs_values.len())?;
288    // SAFETY: Inputs are validated for length and monotonic indices.
289    Ok(unsafe {
290        bridge::acc_sparse_dot_sparse_f32(
291            lhs_count,
292            lhs_values.as_ptr(),
293            lhs_indices.as_ptr(),
294            rhs_count,
295            rhs_values.as_ptr(),
296            rhs_indices.as_ptr(),
297        )
298    })
299}
300
301/// Wraps `sparse_vector_add_with_scale_dense_float`.
302pub fn add_to_dense_f32(
303    values: &[f32],
304    indices: &[SparseIndex],
305    alpha: f32,
306    dense: &mut [f32],
307) -> Result<()> {
308    validate_sparse_entries(values, indices)?;
309    validate_dense_coverage(indices, dense.len())?;
310    if values.is_empty() {
311        return Ok(());
312    }
313
314    let nz = u64_len(values.len())?;
315    // SAFETY: Inputs are validated for length, ordering, and dense coverage.
316    let ok = unsafe {
317        bridge::acc_sparse_add_to_dense_f32(
318            nz,
319            alpha,
320            values.as_ptr(),
321            indices.as_ptr(),
322            dense.as_mut_ptr(),
323        )
324    };
325    if ok {
326        Ok(())
327    } else {
328        Err(Error::SparseStatus(-1))
329    }
330}
331
332#[allow(dead_code)]
333const _: i32 = blas_transpose::NO_TRANS;