Skip to main content

apple_accelerate/
sparse.rs

1use crate::blas::{blas_order, blas_transpose};
2use crate::bridge;
3use crate::error::{Error, Result};
4use core::ffi::c_void;
5use core::ptr;
6
7pub type SparseIndex = i64;
8
9/// `sparse_matrix_property` constants.
10pub mod sparse_matrix_property {
11    pub const UPPER_TRIANGULAR: i32 = 1;
12    pub const LOWER_TRIANGULAR: i32 = 2;
13    pub const UPPER_SYMMETRIC: i32 = 4;
14    pub const LOWER_SYMMETRIC: i32 = 8;
15}
16
17/// `sparse_status` constants.
18pub mod sparse_status {
19    pub const SUCCESS: i32 = 0;
20    pub const ILLEGAL_PARAMETER: i32 = -1000;
21    pub const CANNOT_SET_PROPERTY: i32 = -1001;
22    pub const SYSTEM_ERROR: i32 = -1002;
23}
24
25fn u64_len(value: usize) -> Result<u64> {
26    u64::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension overflowed"))
27}
28
29fn i64_index(value: usize) -> Result<i64> {
30    i64::try_from(value).map_err(|_| Error::OperationFailed("sparse index overflowed"))
31}
32
33fn usize_dimension(value: u64) -> Result<usize> {
34    usize::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension exceeds usize"))
35}
36
37fn usize_count(value: i64) -> Result<usize> {
38    if value < 0 {
39        return Err(Error::SparseStatus(
40            i32::try_from(value).unwrap_or(sparse_status::SYSTEM_ERROR),
41        ));
42    }
43    usize::try_from(value).map_err(|_| Error::OperationFailed("sparse count exceeds usize"))
44}
45
46fn sparse_result(status: i32) -> Result<()> {
47    if status == sparse_status::SUCCESS {
48        Ok(())
49    } else {
50        Err(Error::SparseStatus(status))
51    }
52}
53
54fn validate_sparse_entries(values: &[f32], indices: &[SparseIndex]) -> Result<()> {
55    if values.len() != indices.len() {
56        return Err(Error::InvalidLength {
57            expected: values.len(),
58            actual: indices.len(),
59        });
60    }
61    for window in indices.windows(2) {
62        if window[0] >= window[1] {
63            return Err(Error::InvalidValue(
64                "sparse indices must be strictly increasing and unique",
65            ));
66        }
67    }
68    Ok(())
69}
70
71fn validate_dense_coverage(indices: &[SparseIndex], dense_len: usize) -> Result<()> {
72    if let Some(&max_index) = indices.last() {
73        let max_index = usize::try_from(max_index)
74            .map_err(|_| Error::InvalidValue("sparse indices must be non-negative"))?;
75        if max_index >= dense_len {
76            return Err(Error::InvalidLength {
77                expected: max_index + 1,
78                actual: dense_len,
79            });
80        }
81    }
82    Ok(())
83}
84
85/// Owned sparse single-precision matrix handle backed by the Swift bridge.
86pub struct SparseMatrixF32 {
87    ptr: *mut c_void,
88}
89
90unsafe impl Send for SparseMatrixF32 {}
91unsafe impl Sync for SparseMatrixF32 {}
92
93impl Drop for SparseMatrixF32 {
94    fn drop(&mut self) {
95        if !self.ptr.is_null() {
96            // SAFETY: `ptr` is an opaque Swift object retained by the bridge.
97            unsafe { bridge::acc_release_handle(self.ptr) };
98            self.ptr = ptr::null_mut();
99        }
100    }
101}
102
103impl SparseMatrixF32 {
104    #[must_use]
105    pub fn new(rows: usize, columns: usize) -> Option<Self> {
106        if rows == 0 || columns == 0 {
107            return None;
108        }
109        let rows = u64::try_from(rows).ok()?;
110        let columns = u64::try_from(columns).ok()?;
111
112        // SAFETY: Pure constructor over scalar dimensions.
113        let ptr = unsafe { bridge::acc_sparse_matrix_f32_create(rows, columns) };
114        if ptr.is_null() {
115            None
116        } else {
117            Some(Self { ptr })
118        }
119    }
120
121    pub fn set_property(&mut self, property: i32) -> Result<()> {
122        // SAFETY: `self.ptr` is a live bridge handle.
123        sparse_result(unsafe { bridge::acc_sparse_matrix_f32_set_property(self.ptr, property) })
124    }
125
126    pub fn insert_entry(&mut self, row: usize, column: usize, value: f32) -> Result<()> {
127        let rows = self.rows()?;
128        let columns = self.columns()?;
129        if row >= rows || column >= columns {
130            return Err(Error::InvalidValue(
131                "sparse entry coordinates must be within matrix bounds",
132            ));
133        }
134
135        let row = i64_index(row)?;
136        let column = i64_index(column)?;
137        // SAFETY: Bounds were validated and `self.ptr` is a live bridge handle.
138        sparse_result(unsafe { bridge::acc_sparse_matrix_f32_insert_entry(self.ptr, value, row, column) })
139    }
140
141    pub fn commit(&mut self) -> Result<()> {
142        // SAFETY: `self.ptr` is a live bridge handle.
143        sparse_result(unsafe { bridge::acc_sparse_matrix_f32_commit(self.ptr) })
144    }
145
146    pub fn rows(&self) -> Result<usize> {
147        // SAFETY: `self.ptr` is a live bridge handle.
148        usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_rows(self.ptr) })
149    }
150
151    pub fn columns(&self) -> Result<usize> {
152        // SAFETY: `self.ptr` is a live bridge handle.
153        usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_columns(self.ptr) })
154    }
155
156    pub fn nonzero_count(&self) -> Result<usize> {
157        // SAFETY: `self.ptr` is a live bridge handle.
158        usize_count(unsafe { bridge::acc_sparse_matrix_f32_nonzero_count(self.ptr) })
159    }
160
161    pub fn triangular_solve_vector(
162        &self,
163        transpose: i32,
164        alpha: f32,
165        values: &mut [f32],
166    ) -> Result<()> {
167        let rows = self.rows()?;
168        let columns = self.columns()?;
169        if rows != columns {
170            return Err(Error::InvalidValue(
171                "sparse triangular solve requires a square matrix",
172            ));
173        }
174        if values.len() != rows {
175            return Err(Error::InvalidLength {
176                expected: rows,
177                actual: values.len(),
178            });
179        }
180
181        let len = u64_len(values.len())?;
182        // SAFETY: The matrix and dense vector satisfy the API preconditions.
183        sparse_result(unsafe {
184            bridge::acc_sparse_matrix_f32_triangular_solve_vector(
185                self.ptr,
186                transpose,
187                alpha,
188                values.as_mut_ptr(),
189                len,
190            )
191        })
192    }
193
194    pub fn triangular_solve_matrix_row_major(
195        &self,
196        transpose: i32,
197        rhs_columns: usize,
198        alpha: f32,
199        values: &mut [f32],
200    ) -> Result<()> {
201        let rows = self.rows()?;
202        let columns = self.columns()?;
203        if rows != columns {
204            return Err(Error::InvalidValue(
205                "sparse triangular solve requires a square matrix",
206            ));
207        }
208        let expected = rows
209            .checked_mul(rhs_columns)
210            .ok_or(Error::OperationFailed("sparse rhs dimensions overflowed"))?;
211        if values.len() != expected {
212            return Err(Error::InvalidLength {
213                expected,
214                actual: values.len(),
215            });
216        }
217        if rhs_columns == 0 {
218            return Ok(());
219        }
220
221        let rhs_count = u64_len(rhs_columns)?;
222        let ldb = u64_len(rhs_columns)?;
223        // SAFETY: The matrix and dense matrix satisfy the API preconditions.
224        sparse_result(unsafe {
225            bridge::acc_sparse_matrix_f32_triangular_solve_matrix(
226                self.ptr,
227                blas_order::ROW_MAJOR,
228                transpose,
229                rhs_count,
230                alpha,
231                values.as_mut_ptr(),
232                ldb,
233            )
234        })
235    }
236}
237
238pub fn dot_dense_f32(values: &[f32], indices: &[SparseIndex], dense: &[f32]) -> Result<f32> {
239    validate_sparse_entries(values, indices)?;
240    validate_dense_coverage(indices, dense.len())?;
241    if values.is_empty() {
242        return Ok(0.0);
243    }
244
245    let nz = u64_len(values.len())?;
246    // SAFETY: Inputs are validated for length, ordering, and dense coverage.
247    Ok(unsafe { bridge::acc_sparse_dot_dense_f32(nz, values.as_ptr(), indices.as_ptr(), dense.as_ptr()) })
248}
249
250pub fn dot_sparse_f32(
251    lhs_values: &[f32],
252    lhs_indices: &[SparseIndex],
253    rhs_values: &[f32],
254    rhs_indices: &[SparseIndex],
255) -> Result<f32> {
256    validate_sparse_entries(lhs_values, lhs_indices)?;
257    validate_sparse_entries(rhs_values, rhs_indices)?;
258    if lhs_values.is_empty() || rhs_values.is_empty() {
259        return Ok(0.0);
260    }
261
262    let lhs_count = u64_len(lhs_values.len())?;
263    let rhs_count = u64_len(rhs_values.len())?;
264    // SAFETY: Inputs are validated for length and monotonic indices.
265    Ok(unsafe {
266        bridge::acc_sparse_dot_sparse_f32(
267            lhs_count,
268            lhs_values.as_ptr(),
269            lhs_indices.as_ptr(),
270            rhs_count,
271            rhs_values.as_ptr(),
272            rhs_indices.as_ptr(),
273        )
274    })
275}
276
277pub fn add_to_dense_f32(
278    values: &[f32],
279    indices: &[SparseIndex],
280    alpha: f32,
281    dense: &mut [f32],
282) -> Result<()> {
283    validate_sparse_entries(values, indices)?;
284    validate_dense_coverage(indices, dense.len())?;
285    if values.is_empty() {
286        return Ok(());
287    }
288
289    let nz = u64_len(values.len())?;
290    // SAFETY: Inputs are validated for length, ordering, and dense coverage.
291    let ok = unsafe {
292        bridge::acc_sparse_add_to_dense_f32(
293            nz,
294            alpha,
295            values.as_ptr(),
296            indices.as_ptr(),
297            dense.as_mut_ptr(),
298        )
299    };
300    if ok {
301        Ok(())
302    } else {
303        Err(Error::SparseStatus(-1))
304    }
305}
306
307#[allow(dead_code)]
308const _: i32 = blas_transpose::NO_TRANS;