apple_accelerate/
sparse.rs1use crate::blas::{blas_order, blas_transpose};
2use crate::bridge;
3use crate::error::{Error, Result};
4use core::ffi::c_void;
5use core::ptr;
6
7pub type SparseIndex = i64;
8
9pub mod sparse_matrix_property {
11 pub const UPPER_TRIANGULAR: i32 = 1;
12 pub const LOWER_TRIANGULAR: i32 = 2;
13 pub const UPPER_SYMMETRIC: i32 = 4;
14 pub const LOWER_SYMMETRIC: i32 = 8;
15}
16
17pub mod sparse_status {
19 pub const SUCCESS: i32 = 0;
20 pub const ILLEGAL_PARAMETER: i32 = -1000;
21 pub const CANNOT_SET_PROPERTY: i32 = -1001;
22 pub const SYSTEM_ERROR: i32 = -1002;
23}
24
25fn u64_len(value: usize) -> Result<u64> {
26 u64::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension overflowed"))
27}
28
29fn i64_index(value: usize) -> Result<i64> {
30 i64::try_from(value).map_err(|_| Error::OperationFailed("sparse index overflowed"))
31}
32
33fn usize_dimension(value: u64) -> Result<usize> {
34 usize::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension exceeds usize"))
35}
36
37fn usize_count(value: i64) -> Result<usize> {
38 if value < 0 {
39 return Err(Error::SparseStatus(
40 i32::try_from(value).unwrap_or(sparse_status::SYSTEM_ERROR),
41 ));
42 }
43 usize::try_from(value).map_err(|_| Error::OperationFailed("sparse count exceeds usize"))
44}
45
46fn sparse_result(status: i32) -> Result<()> {
47 if status == sparse_status::SUCCESS {
48 Ok(())
49 } else {
50 Err(Error::SparseStatus(status))
51 }
52}
53
54fn validate_sparse_entries(values: &[f32], indices: &[SparseIndex]) -> Result<()> {
55 if values.len() != indices.len() {
56 return Err(Error::InvalidLength {
57 expected: values.len(),
58 actual: indices.len(),
59 });
60 }
61 for window in indices.windows(2) {
62 if window[0] >= window[1] {
63 return Err(Error::InvalidValue(
64 "sparse indices must be strictly increasing and unique",
65 ));
66 }
67 }
68 Ok(())
69}
70
71fn validate_dense_coverage(indices: &[SparseIndex], dense_len: usize) -> Result<()> {
72 if let Some(&max_index) = indices.last() {
73 let max_index = usize::try_from(max_index)
74 .map_err(|_| Error::InvalidValue("sparse indices must be non-negative"))?;
75 if max_index >= dense_len {
76 return Err(Error::InvalidLength {
77 expected: max_index + 1,
78 actual: dense_len,
79 });
80 }
81 }
82 Ok(())
83}
84
85pub struct SparseMatrixF32 {
87 ptr: *mut c_void,
88}
89
90unsafe impl Send for SparseMatrixF32 {}
91unsafe impl Sync for SparseMatrixF32 {}
92
93impl Drop for SparseMatrixF32 {
94 fn drop(&mut self) {
95 if !self.ptr.is_null() {
96 unsafe { bridge::acc_release_handle(self.ptr) };
98 self.ptr = ptr::null_mut();
99 }
100 }
101}
102
103impl SparseMatrixF32 {
104 #[must_use]
105 pub fn new(rows: usize, columns: usize) -> Option<Self> {
106 if rows == 0 || columns == 0 {
107 return None;
108 }
109 let rows = u64::try_from(rows).ok()?;
110 let columns = u64::try_from(columns).ok()?;
111
112 let ptr = unsafe { bridge::acc_sparse_matrix_f32_create(rows, columns) };
114 if ptr.is_null() {
115 None
116 } else {
117 Some(Self { ptr })
118 }
119 }
120
121 pub fn set_property(&mut self, property: i32) -> Result<()> {
122 sparse_result(unsafe { bridge::acc_sparse_matrix_f32_set_property(self.ptr, property) })
124 }
125
126 pub fn insert_entry(&mut self, row: usize, column: usize, value: f32) -> Result<()> {
127 let rows = self.rows()?;
128 let columns = self.columns()?;
129 if row >= rows || column >= columns {
130 return Err(Error::InvalidValue(
131 "sparse entry coordinates must be within matrix bounds",
132 ));
133 }
134
135 let row = i64_index(row)?;
136 let column = i64_index(column)?;
137 sparse_result(unsafe { bridge::acc_sparse_matrix_f32_insert_entry(self.ptr, value, row, column) })
139 }
140
141 pub fn commit(&mut self) -> Result<()> {
142 sparse_result(unsafe { bridge::acc_sparse_matrix_f32_commit(self.ptr) })
144 }
145
146 pub fn rows(&self) -> Result<usize> {
147 usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_rows(self.ptr) })
149 }
150
151 pub fn columns(&self) -> Result<usize> {
152 usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_columns(self.ptr) })
154 }
155
156 pub fn nonzero_count(&self) -> Result<usize> {
157 usize_count(unsafe { bridge::acc_sparse_matrix_f32_nonzero_count(self.ptr) })
159 }
160
161 pub fn triangular_solve_vector(
162 &self,
163 transpose: i32,
164 alpha: f32,
165 values: &mut [f32],
166 ) -> Result<()> {
167 let rows = self.rows()?;
168 let columns = self.columns()?;
169 if rows != columns {
170 return Err(Error::InvalidValue(
171 "sparse triangular solve requires a square matrix",
172 ));
173 }
174 if values.len() != rows {
175 return Err(Error::InvalidLength {
176 expected: rows,
177 actual: values.len(),
178 });
179 }
180
181 let len = u64_len(values.len())?;
182 sparse_result(unsafe {
184 bridge::acc_sparse_matrix_f32_triangular_solve_vector(
185 self.ptr,
186 transpose,
187 alpha,
188 values.as_mut_ptr(),
189 len,
190 )
191 })
192 }
193
194 pub fn triangular_solve_matrix_row_major(
195 &self,
196 transpose: i32,
197 rhs_columns: usize,
198 alpha: f32,
199 values: &mut [f32],
200 ) -> Result<()> {
201 let rows = self.rows()?;
202 let columns = self.columns()?;
203 if rows != columns {
204 return Err(Error::InvalidValue(
205 "sparse triangular solve requires a square matrix",
206 ));
207 }
208 let expected = rows
209 .checked_mul(rhs_columns)
210 .ok_or(Error::OperationFailed("sparse rhs dimensions overflowed"))?;
211 if values.len() != expected {
212 return Err(Error::InvalidLength {
213 expected,
214 actual: values.len(),
215 });
216 }
217 if rhs_columns == 0 {
218 return Ok(());
219 }
220
221 let rhs_count = u64_len(rhs_columns)?;
222 let ldb = u64_len(rhs_columns)?;
223 sparse_result(unsafe {
225 bridge::acc_sparse_matrix_f32_triangular_solve_matrix(
226 self.ptr,
227 blas_order::ROW_MAJOR,
228 transpose,
229 rhs_count,
230 alpha,
231 values.as_mut_ptr(),
232 ldb,
233 )
234 })
235 }
236}
237
238pub fn dot_dense_f32(values: &[f32], indices: &[SparseIndex], dense: &[f32]) -> Result<f32> {
239 validate_sparse_entries(values, indices)?;
240 validate_dense_coverage(indices, dense.len())?;
241 if values.is_empty() {
242 return Ok(0.0);
243 }
244
245 let nz = u64_len(values.len())?;
246 Ok(unsafe { bridge::acc_sparse_dot_dense_f32(nz, values.as_ptr(), indices.as_ptr(), dense.as_ptr()) })
248}
249
250pub fn dot_sparse_f32(
251 lhs_values: &[f32],
252 lhs_indices: &[SparseIndex],
253 rhs_values: &[f32],
254 rhs_indices: &[SparseIndex],
255) -> Result<f32> {
256 validate_sparse_entries(lhs_values, lhs_indices)?;
257 validate_sparse_entries(rhs_values, rhs_indices)?;
258 if lhs_values.is_empty() || rhs_values.is_empty() {
259 return Ok(0.0);
260 }
261
262 let lhs_count = u64_len(lhs_values.len())?;
263 let rhs_count = u64_len(rhs_values.len())?;
264 Ok(unsafe {
266 bridge::acc_sparse_dot_sparse_f32(
267 lhs_count,
268 lhs_values.as_ptr(),
269 lhs_indices.as_ptr(),
270 rhs_count,
271 rhs_values.as_ptr(),
272 rhs_indices.as_ptr(),
273 )
274 })
275}
276
277pub fn add_to_dense_f32(
278 values: &[f32],
279 indices: &[SparseIndex],
280 alpha: f32,
281 dense: &mut [f32],
282) -> Result<()> {
283 validate_sparse_entries(values, indices)?;
284 validate_dense_coverage(indices, dense.len())?;
285 if values.is_empty() {
286 return Ok(());
287 }
288
289 let nz = u64_len(values.len())?;
290 let ok = unsafe {
292 bridge::acc_sparse_add_to_dense_f32(
293 nz,
294 alpha,
295 values.as_ptr(),
296 indices.as_ptr(),
297 dense.as_mut_ptr(),
298 )
299 };
300 if ok {
301 Ok(())
302 } else {
303 Err(Error::SparseStatus(-1))
304 }
305}
306
307#[allow(dead_code)]
308const _: i32 = blas_transpose::NO_TRANS;