Skip to main content

apple_accelerate/
simd.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3
4/// Four-lane `f32` vector used by the SIMD helpers backed by `simd_float4`.
5pub type Float4 = [f32; 4];
6
7fn unary_simd_op(
8    input: Float4,
9    f: unsafe extern "C" fn(*const f32, *mut f32) -> bool,
10) -> Result<Float4> {
11    let mut output = [0.0_f32; 4];
12    // SAFETY: Both arrays are valid for exactly four `f32` values.
13    let ok = unsafe { f(input.as_ptr(), output.as_mut_ptr()) };
14    if ok {
15        Ok(output)
16    } else {
17        Err(Error::OperationFailed("SIMD bridge operation failed"))
18    }
19}
20
21/// Wraps four-lane vector addition backed by `simd_float4`.
22pub fn add_f32x4(lhs: Float4, rhs: Float4) -> Result<Float4> {
23    let mut output = [0.0_f32; 4];
24    // SAFETY: All arrays are valid for exactly four `f32` values.
25    let ok = unsafe { bridge::acc_simd_add_f32x4(lhs.as_ptr(), rhs.as_ptr(), output.as_mut_ptr()) };
26    if ok {
27        Ok(output)
28    } else {
29        Err(Error::OperationFailed("SIMD add failed"))
30    }
31}
32
33/// Wraps four-lane dot-product evaluation backed by `simd_float4`.
34pub fn dot_f32x4(lhs: Float4, rhs: Float4) -> Result<f32> {
35    let mut output = 0.0_f32;
36    // SAFETY: All arrays are valid for exactly four `f32` values.
37    let ok = unsafe { bridge::acc_simd_dot_f32x4(lhs.as_ptr(), rhs.as_ptr(), &mut output) };
38    if ok {
39        Ok(output)
40    } else {
41        Err(Error::OperationFailed("SIMD dot product failed"))
42    }
43}
44
45/// Wraps four-lane vector length computation backed by `simd_float4`.
46pub fn length_f32x4(input: Float4) -> Result<f32> {
47    let mut output = 0.0_f32;
48    // SAFETY: Both arrays are valid for exactly four `f32` values.
49    let ok = unsafe { bridge::acc_simd_length_f32x4(input.as_ptr(), &mut output) };
50    if ok {
51        Ok(output)
52    } else {
53        Err(Error::OperationFailed("SIMD length failed"))
54    }
55}
56
57/// Wraps four-lane vector normalization backed by `simd_float4`.
58pub fn normalize_f32x4(input: Float4) -> Result<Float4> {
59    if input.iter().all(|value| *value == 0.0) {
60        return Err(Error::InvalidValue("cannot normalize the zero vector"));
61    }
62    unary_simd_op(input, bridge::acc_simd_normalize_f32x4)
63}