apple_accelerate/
sparse.rs1use crate::blas::{blas_order, blas_transpose};
2use crate::bridge;
3use crate::error::{Error, Result};
4use core::ffi::c_void;
5use core::ptr;
6
7pub type SparseIndex = i64;
9
10pub mod sparse_matrix_property {
12 pub const UPPER_TRIANGULAR: i32 = 1;
14 pub const LOWER_TRIANGULAR: i32 = 2;
16 pub const UPPER_SYMMETRIC: i32 = 4;
18 pub const LOWER_SYMMETRIC: i32 = 8;
20}
21
22pub mod sparse_status {
24 pub const SUCCESS: i32 = 0;
26 pub const ILLEGAL_PARAMETER: i32 = -1000;
28 pub const CANNOT_SET_PROPERTY: i32 = -1001;
30 pub const SYSTEM_ERROR: i32 = -1002;
32}
33
34fn u64_len(value: usize) -> Result<u64> {
35 u64::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension overflowed"))
36}
37
38fn i64_index(value: usize) -> Result<i64> {
39 i64::try_from(value).map_err(|_| Error::OperationFailed("sparse index overflowed"))
40}
41
42fn usize_dimension(value: u64) -> Result<usize> {
43 usize::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension exceeds usize"))
44}
45
46fn usize_count(value: i64) -> Result<usize> {
47 if value < 0 {
48 return Err(Error::SparseStatus(
49 i32::try_from(value).unwrap_or(sparse_status::SYSTEM_ERROR),
50 ));
51 }
52 usize::try_from(value).map_err(|_| Error::OperationFailed("sparse count exceeds usize"))
53}
54
55fn sparse_result(status: i32) -> Result<()> {
56 if status == sparse_status::SUCCESS {
57 Ok(())
58 } else {
59 Err(Error::SparseStatus(status))
60 }
61}
62
63fn validate_sparse_entries(values: &[f32], indices: &[SparseIndex]) -> Result<()> {
64 if values.len() != indices.len() {
65 return Err(Error::InvalidLength {
66 expected: values.len(),
67 actual: indices.len(),
68 });
69 }
70 for window in indices.windows(2) {
71 if window[0] >= window[1] {
72 return Err(Error::InvalidValue(
73 "sparse indices must be strictly increasing and unique",
74 ));
75 }
76 }
77 Ok(())
78}
79
80fn validate_dense_coverage(indices: &[SparseIndex], dense_len: usize) -> Result<()> {
81 if let Some(&max_index) = indices.last() {
82 let max_index = usize::try_from(max_index)
83 .map_err(|_| Error::InvalidValue("sparse indices must be non-negative"))?;
84 if max_index >= dense_len {
85 return Err(Error::InvalidLength {
86 expected: max_index + 1,
87 actual: dense_len,
88 });
89 }
90 }
91 Ok(())
92}
93
94pub struct SparseMatrixF32 {
96 ptr: *mut c_void,
97}
98
99unsafe impl Send for SparseMatrixF32 {}
100unsafe impl Sync for SparseMatrixF32 {}
101
102impl Drop for SparseMatrixF32 {
103 fn drop(&mut self) {
104 if !self.ptr.is_null() {
105 unsafe { bridge::acc_release_handle(self.ptr) };
107 self.ptr = ptr::null_mut();
108 }
109 }
110}
111
112impl SparseMatrixF32 {
113 #[must_use]
115 pub fn new(rows: usize, columns: usize) -> Option<Self> {
116 if rows == 0 || columns == 0 {
117 return None;
118 }
119 let rows = u64::try_from(rows).ok()?;
120 let columns = u64::try_from(columns).ok()?;
121
122 let ptr = unsafe { bridge::acc_sparse_matrix_f32_create(rows, columns) };
124 if ptr.is_null() {
125 None
126 } else {
127 Some(Self { ptr })
128 }
129 }
130
131 pub fn set_property(&mut self, property: i32) -> Result<()> {
133 sparse_result(unsafe { bridge::acc_sparse_matrix_f32_set_property(self.ptr, property) })
135 }
136
137 pub fn insert_entry(&mut self, row: usize, column: usize, value: f32) -> Result<()> {
139 let rows = self.rows()?;
140 let columns = self.columns()?;
141 if row >= rows || column >= columns {
142 return Err(Error::InvalidValue(
143 "sparse entry coordinates must be within matrix bounds",
144 ));
145 }
146
147 let row = i64_index(row)?;
148 let column = i64_index(column)?;
149 sparse_result(unsafe {
151 bridge::acc_sparse_matrix_f32_insert_entry(self.ptr, value, row, column)
152 })
153 }
154
155 pub fn commit(&mut self) -> Result<()> {
157 sparse_result(unsafe { bridge::acc_sparse_matrix_f32_commit(self.ptr) })
159 }
160
161 pub fn rows(&self) -> Result<usize> {
163 usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_rows(self.ptr) })
165 }
166
167 pub fn columns(&self) -> Result<usize> {
169 usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_columns(self.ptr) })
171 }
172
173 pub fn nonzero_count(&self) -> Result<usize> {
175 usize_count(unsafe { bridge::acc_sparse_matrix_f32_nonzero_count(self.ptr) })
177 }
178
179 pub fn triangular_solve_vector(
181 &self,
182 transpose: i32,
183 alpha: f32,
184 values: &mut [f32],
185 ) -> Result<()> {
186 let rows = self.rows()?;
187 let columns = self.columns()?;
188 if rows != columns {
189 return Err(Error::InvalidValue(
190 "sparse triangular solve requires a square matrix",
191 ));
192 }
193 if values.len() != rows {
194 return Err(Error::InvalidLength {
195 expected: rows,
196 actual: values.len(),
197 });
198 }
199
200 let len = u64_len(values.len())?;
201 sparse_result(unsafe {
203 bridge::acc_sparse_matrix_f32_triangular_solve_vector(
204 self.ptr,
205 transpose,
206 alpha,
207 values.as_mut_ptr(),
208 len,
209 )
210 })
211 }
212
213 pub fn triangular_solve_matrix_row_major(
215 &self,
216 transpose: i32,
217 rhs_columns: usize,
218 alpha: f32,
219 values: &mut [f32],
220 ) -> Result<()> {
221 let rows = self.rows()?;
222 let columns = self.columns()?;
223 if rows != columns {
224 return Err(Error::InvalidValue(
225 "sparse triangular solve requires a square matrix",
226 ));
227 }
228 let expected = rows
229 .checked_mul(rhs_columns)
230 .ok_or(Error::OperationFailed("sparse rhs dimensions overflowed"))?;
231 if values.len() != expected {
232 return Err(Error::InvalidLength {
233 expected,
234 actual: values.len(),
235 });
236 }
237 if rhs_columns == 0 {
238 return Ok(());
239 }
240
241 let rhs_count = u64_len(rhs_columns)?;
242 let ldb = u64_len(rhs_columns)?;
243 sparse_result(unsafe {
245 bridge::acc_sparse_matrix_f32_triangular_solve_matrix(
246 self.ptr,
247 blas_order::ROW_MAJOR,
248 transpose,
249 rhs_count,
250 alpha,
251 values.as_mut_ptr(),
252 ldb,
253 )
254 })
255 }
256}
257
258pub fn dot_dense_f32(values: &[f32], indices: &[SparseIndex], dense: &[f32]) -> Result<f32> {
260 validate_sparse_entries(values, indices)?;
261 validate_dense_coverage(indices, dense.len())?;
262 if values.is_empty() {
263 return Ok(0.0);
264 }
265
266 let nz = u64_len(values.len())?;
267 Ok(unsafe {
269 bridge::acc_sparse_dot_dense_f32(nz, values.as_ptr(), indices.as_ptr(), dense.as_ptr())
270 })
271}
272
273pub fn dot_sparse_f32(
275 lhs_values: &[f32],
276 lhs_indices: &[SparseIndex],
277 rhs_values: &[f32],
278 rhs_indices: &[SparseIndex],
279) -> Result<f32> {
280 validate_sparse_entries(lhs_values, lhs_indices)?;
281 validate_sparse_entries(rhs_values, rhs_indices)?;
282 if lhs_values.is_empty() || rhs_values.is_empty() {
283 return Ok(0.0);
284 }
285
286 let lhs_count = u64_len(lhs_values.len())?;
287 let rhs_count = u64_len(rhs_values.len())?;
288 Ok(unsafe {
290 bridge::acc_sparse_dot_sparse_f32(
291 lhs_count,
292 lhs_values.as_ptr(),
293 lhs_indices.as_ptr(),
294 rhs_count,
295 rhs_values.as_ptr(),
296 rhs_indices.as_ptr(),
297 )
298 })
299}
300
301pub fn add_to_dense_f32(
303 values: &[f32],
304 indices: &[SparseIndex],
305 alpha: f32,
306 dense: &mut [f32],
307) -> Result<()> {
308 validate_sparse_entries(values, indices)?;
309 validate_dense_coverage(indices, dense.len())?;
310 if values.is_empty() {
311 return Ok(());
312 }
313
314 let nz = u64_len(values.len())?;
315 let ok = unsafe {
317 bridge::acc_sparse_add_to_dense_f32(
318 nz,
319 alpha,
320 values.as_ptr(),
321 indices.as_ptr(),
322 dense.as_mut_ptr(),
323 )
324 };
325 if ok {
326 Ok(())
327 } else {
328 Err(Error::SparseStatus(-1))
329 }
330}
331
332#[allow(dead_code)]
333const _: i32 = blas_transpose::NO_TRANS;