Skip to main content

apple_accelerate/
sparse.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3
4pub type SparseIndex = i64;
5
6fn validate_sparse_entries(values: &[f32], indices: &[SparseIndex]) -> Result<()> {
7    if values.len() != indices.len() {
8        return Err(Error::InvalidLength {
9            expected: values.len(),
10            actual: indices.len(),
11        });
12    }
13    for window in indices.windows(2) {
14        if window[0] >= window[1] {
15            return Err(Error::InvalidValue(
16                "sparse indices must be strictly increasing and unique",
17            ));
18        }
19    }
20    Ok(())
21}
22
23fn validate_dense_coverage(indices: &[SparseIndex], dense_len: usize) -> Result<()> {
24    if let Some(&max_index) = indices.last() {
25        let max_index = usize::try_from(max_index)
26            .map_err(|_| Error::InvalidValue("sparse indices must be non-negative"))?;
27        if max_index >= dense_len {
28            return Err(Error::InvalidLength {
29                expected: max_index + 1,
30                actual: dense_len,
31            });
32        }
33    }
34    Ok(())
35}
36
37pub fn dot_dense_f32(values: &[f32], indices: &[SparseIndex], dense: &[f32]) -> Result<f32> {
38    validate_sparse_entries(values, indices)?;
39    validate_dense_coverage(indices, dense.len())?;
40    if values.is_empty() {
41        return Ok(0.0);
42    }
43
44    // SAFETY: Inputs are validated for length, ordering, and dense coverage.
45    Ok(unsafe {
46        bridge::acc_sparse_dot_dense_f32(
47            values.len() as u64,
48            values.as_ptr(),
49            indices.as_ptr(),
50            dense.as_ptr(),
51        )
52    })
53}
54
55pub fn dot_sparse_f32(
56    lhs_values: &[f32],
57    lhs_indices: &[SparseIndex],
58    rhs_values: &[f32],
59    rhs_indices: &[SparseIndex],
60) -> Result<f32> {
61    validate_sparse_entries(lhs_values, lhs_indices)?;
62    validate_sparse_entries(rhs_values, rhs_indices)?;
63    if lhs_values.is_empty() || rhs_values.is_empty() {
64        return Ok(0.0);
65    }
66
67    // SAFETY: Inputs are validated for length and monotonic indices.
68    Ok(unsafe {
69        bridge::acc_sparse_dot_sparse_f32(
70            lhs_values.len() as u64,
71            lhs_values.as_ptr(),
72            lhs_indices.as_ptr(),
73            rhs_values.len() as u64,
74            rhs_values.as_ptr(),
75            rhs_indices.as_ptr(),
76        )
77    })
78}
79
80pub fn add_to_dense_f32(
81    values: &[f32],
82    indices: &[SparseIndex],
83    alpha: f32,
84    dense: &mut [f32],
85) -> Result<()> {
86    validate_sparse_entries(values, indices)?;
87    validate_dense_coverage(indices, dense.len())?;
88    if values.is_empty() {
89        return Ok(());
90    }
91
92    // SAFETY: Inputs are validated for length, ordering, and dense coverage.
93    let ok = unsafe {
94        bridge::acc_sparse_add_to_dense_f32(
95            values.len() as u64,
96            alpha,
97            values.as_ptr(),
98            indices.as_ptr(),
99            dense.as_mut_ptr(),
100        )
101    };
102    if ok {
103        Ok(())
104    } else {
105        Err(Error::SparseStatus(-1))
106    }
107}