Skip to main content

apple_accelerate/
bnns.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3use crate::raw_ffi;
4use core::ffi::c_void;
5use core::ptr;
6
7/// Thin owner for the deprecated-but-still-available BNNS filter APIs.
8pub struct Filter {
9    ptr: raw_ffi::BNNSFilter,
10}
11
12unsafe impl Send for Filter {}
13unsafe impl Sync for Filter {}
14
15impl Drop for Filter {
16    fn drop(&mut self) {
17        if !self.ptr.is_null() {
18            // SAFETY: `ptr` was returned by a BNNS filter constructor and is owned by this wrapper.
19            unsafe { raw_ffi::BNNSFilterDestroy(self.ptr) };
20            self.ptr = ptr::null_mut();
21        }
22    }
23}
24
25impl Filter {
26    /// Create a BNNS convolution filter from caller-owned raw layer/filter parameter structs.
27    ///
28    /// # Safety
29    ///
30    /// The pointers must refer to valid BNNS parameter structs for the duration of the call.
31    #[must_use]
32    pub unsafe fn from_convolution(
33        layer_params: *const c_void,
34        filter_params: *const c_void,
35    ) -> Option<Self> {
36        let ptr = unsafe { raw_ffi::BNNSFilterCreateLayerConvolution(layer_params, filter_params) };
37        if ptr.is_null() {
38            None
39        } else {
40            Some(Self { ptr })
41        }
42    }
43
44    /// Create a BNNS fully connected filter from caller-owned raw layer/filter parameter structs.
45    ///
46    /// # Safety
47    ///
48    /// The pointers must refer to valid BNNS parameter structs for the duration of the call.
49    #[must_use]
50    pub unsafe fn from_fully_connected(
51        layer_params: *const c_void,
52        filter_params: *const c_void,
53    ) -> Option<Self> {
54        let ptr =
55            unsafe { raw_ffi::BNNSFilterCreateLayerFullyConnected(layer_params, filter_params) };
56        if ptr.is_null() {
57            None
58        } else {
59            Some(Self { ptr })
60        }
61    }
62
63    #[must_use]
64    pub const fn as_ptr(&self) -> *mut c_void {
65        self.ptr
66    }
67
68    /// Apply the filter to caller-owned raw buffers.
69    ///
70    /// # Safety
71    ///
72    /// `input` and `output` must match the layout and lengths described when the filter was created.
73    pub unsafe fn apply(&self, input: *const c_void, output: *mut c_void) -> i32 {
74        unsafe { raw_ffi::BNNSFilterApply(self.ptr, input, output) }
75    }
76}
77
78fn activation_result(status: i32) -> Result<()> {
79    if status == 0 {
80        Ok(())
81    } else {
82        Err(Error::BnnsStatus(status))
83    }
84}
85
86fn apply_activation(
87    values: &[f32],
88    f: unsafe extern "C" fn(*const f32, *mut f32, usize) -> i32,
89) -> Result<Vec<f32>> {
90    let mut out = vec![0.0_f32; values.len()];
91    if values.is_empty() {
92        return Ok(out);
93    }
94
95    // SAFETY: The buffers are valid for `values.len()` contiguous `f32` elements.
96    let status = unsafe { f(values.as_ptr(), out.as_mut_ptr(), values.len()) };
97    activation_result(status)?;
98    Ok(out)
99}
100
101/// Apply BNNS `ReLU` activation to a vector.
102pub fn relu_f32(values: &[f32]) -> Result<Vec<f32>> {
103    apply_activation(values, bridge::acc_bnns_relu_f32)
104}
105
106/// Apply BNNS sigmoid activation to a vector.
107pub fn sigmoid_f32(values: &[f32]) -> Result<Vec<f32>> {
108    apply_activation(values, bridge::acc_bnns_sigmoid_f32)
109}