apple_accelerate/
sparse.rs1use crate::bridge;
2use crate::error::{Error, Result};
3
4pub type SparseIndex = i64;
5
6fn validate_sparse_entries(values: &[f32], indices: &[SparseIndex]) -> Result<()> {
7 if values.len() != indices.len() {
8 return Err(Error::InvalidLength {
9 expected: values.len(),
10 actual: indices.len(),
11 });
12 }
13 for window in indices.windows(2) {
14 if window[0] >= window[1] {
15 return Err(Error::InvalidValue(
16 "sparse indices must be strictly increasing and unique",
17 ));
18 }
19 }
20 Ok(())
21}
22
23fn validate_dense_coverage(indices: &[SparseIndex], dense_len: usize) -> Result<()> {
24 if let Some(&max_index) = indices.last() {
25 let max_index = usize::try_from(max_index)
26 .map_err(|_| Error::InvalidValue("sparse indices must be non-negative"))?;
27 if max_index >= dense_len {
28 return Err(Error::InvalidLength {
29 expected: max_index + 1,
30 actual: dense_len,
31 });
32 }
33 }
34 Ok(())
35}
36
37pub fn dot_dense_f32(values: &[f32], indices: &[SparseIndex], dense: &[f32]) -> Result<f32> {
38 validate_sparse_entries(values, indices)?;
39 validate_dense_coverage(indices, dense.len())?;
40 if values.is_empty() {
41 return Ok(0.0);
42 }
43
44 Ok(unsafe {
46 bridge::acc_sparse_dot_dense_f32(
47 values.len() as u64,
48 values.as_ptr(),
49 indices.as_ptr(),
50 dense.as_ptr(),
51 )
52 })
53}
54
55pub fn dot_sparse_f32(
56 lhs_values: &[f32],
57 lhs_indices: &[SparseIndex],
58 rhs_values: &[f32],
59 rhs_indices: &[SparseIndex],
60) -> Result<f32> {
61 validate_sparse_entries(lhs_values, lhs_indices)?;
62 validate_sparse_entries(rhs_values, rhs_indices)?;
63 if lhs_values.is_empty() || rhs_values.is_empty() {
64 return Ok(0.0);
65 }
66
67 Ok(unsafe {
69 bridge::acc_sparse_dot_sparse_f32(
70 lhs_values.len() as u64,
71 lhs_values.as_ptr(),
72 lhs_indices.as_ptr(),
73 rhs_values.len() as u64,
74 rhs_values.as_ptr(),
75 rhs_indices.as_ptr(),
76 )
77 })
78}
79
80pub fn add_to_dense_f32(
81 values: &[f32],
82 indices: &[SparseIndex],
83 alpha: f32,
84 dense: &mut [f32],
85) -> Result<()> {
86 validate_sparse_entries(values, indices)?;
87 validate_dense_coverage(indices, dense.len())?;
88 if values.is_empty() {
89 return Ok(());
90 }
91
92 let ok = unsafe {
94 bridge::acc_sparse_add_to_dense_f32(
95 values.len() as u64,
96 alpha,
97 values.as_ptr(),
98 indices.as_ptr(),
99 dense.as_mut_ptr(),
100 )
101 };
102 if ok {
103 Ok(())
104 } else {
105 Err(Error::SparseStatus(-1))
106 }
107}