Skip to main content

apple_accelerate/
simd.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3
4pub type Float4 = [f32; 4];
5
6fn unary_simd_op(
7    input: Float4,
8    f: unsafe extern "C" fn(*const f32, *mut f32) -> bool,
9) -> Result<Float4> {
10    let mut output = [0.0_f32; 4];
11    // SAFETY: Both arrays are valid for exactly four `f32` values.
12    let ok = unsafe { f(input.as_ptr(), output.as_mut_ptr()) };
13    if ok {
14        Ok(output)
15    } else {
16        Err(Error::OperationFailed("SIMD bridge operation failed"))
17    }
18}
19
20pub fn add_f32x4(lhs: Float4, rhs: Float4) -> Result<Float4> {
21    let mut output = [0.0_f32; 4];
22    // SAFETY: All arrays are valid for exactly four `f32` values.
23    let ok = unsafe { bridge::acc_simd_add_f32x4(lhs.as_ptr(), rhs.as_ptr(), output.as_mut_ptr()) };
24    if ok {
25        Ok(output)
26    } else {
27        Err(Error::OperationFailed("SIMD add failed"))
28    }
29}
30
31pub fn dot_f32x4(lhs: Float4, rhs: Float4) -> Result<f32> {
32    let mut output = 0.0_f32;
33    // SAFETY: All arrays are valid for exactly four `f32` values.
34    let ok = unsafe { bridge::acc_simd_dot_f32x4(lhs.as_ptr(), rhs.as_ptr(), &mut output) };
35    if ok {
36        Ok(output)
37    } else {
38        Err(Error::OperationFailed("SIMD dot product failed"))
39    }
40}
41
42pub fn length_f32x4(input: Float4) -> Result<f32> {
43    let mut output = 0.0_f32;
44    // SAFETY: Both arrays are valid for exactly four `f32` values.
45    let ok = unsafe { bridge::acc_simd_length_f32x4(input.as_ptr(), &mut output) };
46    if ok {
47        Ok(output)
48    } else {
49        Err(Error::OperationFailed("SIMD length failed"))
50    }
51}
52
53pub fn normalize_f32x4(input: Float4) -> Result<Float4> {
54    if input.iter().all(|value| *value == 0.0) {
55        return Err(Error::InvalidValue("cannot normalize the zero vector"));
56    }
57    unary_simd_op(input, bridge::acc_simd_normalize_f32x4)
58}