Skip to main content

apple_accelerate/
bnns.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3use crate::raw_ffi;
4use core::ffi::c_void;
5use core::ptr;
6
7/// `BNNSGraphOptimizationPreference` constants.
8pub mod graph_optimization_preference {
9    /// `BNNSGraphOptimizationPreference` value that favors runtime performance.
10    pub const PERFORMANCE: u32 = 0;
11    /// `BNNSGraphOptimizationPreference` value that favors smaller BNNS IR size.
12    pub const IR_SIZE: u32 = 1;
13}
14
15/// Thin owner for the deprecated-but-still-available BNNS filter APIs.
16pub struct Filter {
17    ptr: raw_ffi::BNNSFilter,
18}
19
20unsafe impl Send for Filter {}
21unsafe impl Sync for Filter {}
22
23impl Drop for Filter {
24    fn drop(&mut self) {
25        if !self.ptr.is_null() {
26            // SAFETY: `ptr` was returned by a BNNS filter constructor and is owned by this wrapper.
27            unsafe { raw_ffi::BNNSFilterDestroy(self.ptr) };
28            self.ptr = ptr::null_mut();
29        }
30    }
31}
32
33impl Filter {
34    /// Create a BNNS convolution filter from caller-owned raw layer/filter parameter structs.
35    ///
36    /// # Safety
37    ///
38    /// The pointers must refer to valid BNNS parameter structs for the duration of the call.
39    #[must_use]
40    pub unsafe fn from_convolution(
41        layer_params: *const c_void,
42        filter_params: *const c_void,
43    ) -> Option<Self> {
44        // SAFETY: Caller has verified that `layer_params` and `filter_params` are valid BNNS structs.
45        let ptr = unsafe { raw_ffi::BNNSFilterCreateLayerConvolution(layer_params, filter_params) };
46        if ptr.is_null() {
47            None
48        } else {
49            Some(Self { ptr })
50        }
51    }
52
53    /// Create a BNNS fully connected filter from caller-owned raw layer/filter parameter structs.
54    ///
55    /// # Safety
56    ///
57    /// The pointers must refer to valid BNNS parameter structs for the duration of the call.
58    #[must_use]
59    pub unsafe fn from_fully_connected(
60        layer_params: *const c_void,
61        filter_params: *const c_void,
62    ) -> Option<Self> {
63        // SAFETY: Caller has verified that `layer_params` and `filter_params` are valid BNNS structs.
64        let ptr =
65            unsafe { raw_ffi::BNNSFilterCreateLayerFullyConnected(layer_params, filter_params) };
66        if ptr.is_null() {
67            None
68        } else {
69            Some(Self { ptr })
70        }
71    }
72
73    /// Returns the raw `BNNSFilter` pointer wrapped by this owner.
74    #[must_use]
75    pub const fn as_ptr(&self) -> *mut c_void {
76        self.ptr
77    }
78
79    /// Apply the filter to caller-owned raw buffers.
80    ///
81    /// # Safety
82    ///
83    /// `input` and `output` must match the layout and lengths described when the filter was created.
84    pub unsafe fn apply(&self, input: *const c_void, output: *mut c_void) -> i32 {
85        // SAFETY: Caller has verified that `input` and `output` satisfy the filter's documented preconditions.
86        unsafe { raw_ffi::BNNSFilterApply(self.ptr, input, output) }
87    }
88}
89
90/// Owned BNNS Graph compile-options handle backed by the Swift bridge.
91pub struct GraphCompileOptions {
92    ptr: *mut c_void,
93}
94
95unsafe impl Send for GraphCompileOptions {}
96unsafe impl Sync for GraphCompileOptions {}
97
98impl Drop for GraphCompileOptions {
99    fn drop(&mut self) {
100        if !self.ptr.is_null() {
101            // SAFETY: `ptr` is an opaque Swift object retained by the bridge.
102            unsafe { bridge::acc_release_handle(self.ptr) };
103            self.ptr = ptr::null_mut();
104        }
105    }
106}
107
108fn activation_result(status: i32) -> Result<()> {
109    if status == 0 {
110        Ok(())
111    } else {
112        Err(Error::BnnsStatus(status))
113    }
114}
115
116fn graph_result(ok: bool) -> Result<()> {
117    if ok {
118        Ok(())
119    } else {
120        Err(Error::OperationFailed(
121            "BNNS Graph compile options are unavailable on this macOS version",
122        ))
123    }
124}
125
126fn apply_activation(
127    values: &[f32],
128    f: unsafe extern "C" fn(*const f32, *mut f32, usize) -> i32,
129) -> Result<Vec<f32>> {
130    let mut out = vec![0.0_f32; values.len()];
131    if values.is_empty() {
132        return Ok(out);
133    }
134
135    // SAFETY: The buffers are valid for `values.len()` contiguous `f32` elements.
136    let status = unsafe { f(values.as_ptr(), out.as_mut_ptr(), values.len()) };
137    activation_result(status)?;
138    Ok(out)
139}
140
141impl GraphCompileOptions {
142    /// Creates BNNS graph compile options backed by the BNNS graph compile-options API.
143    #[must_use]
144    pub fn new() -> Option<Self> {
145        // SAFETY: Pure constructor over the current runtime environment.
146        let ptr = unsafe { bridge::acc_bnns_graph_compile_options_create() };
147        if ptr.is_null() {
148            None
149        } else {
150            Some(Self { ptr })
151        }
152    }
153
154    /// Sets the BNNS graph compile option that targets single-thread execution.
155    pub fn set_target_single_thread(&mut self, value: bool) -> Result<()> {
156        // SAFETY: `self.ptr` is a live bridge handle.
157        graph_result(unsafe {
158            bridge::acc_bnns_graph_compile_options_set_target_single_thread(self.ptr, value)
159        })
160    }
161
162    /// Returns whether BNNS graph compilation targets single-thread execution.
163    #[must_use]
164    pub fn target_single_thread(&self) -> bool {
165        // SAFETY: `self.ptr` is a live bridge handle.
166        unsafe { bridge::acc_bnns_graph_compile_options_get_target_single_thread(self.ptr) }
167    }
168
169    /// Sets the BNNS graph compile option that emits debug information.
170    pub fn set_generate_debug_info(&mut self, value: bool) -> Result<()> {
171        // SAFETY: `self.ptr` is a live bridge handle.
172        graph_result(unsafe {
173            bridge::acc_bnns_graph_compile_options_set_generate_debug_info(self.ptr, value)
174        })
175    }
176
177    /// Returns whether the BNNS graph compile option emits debug information.
178    #[must_use]
179    pub fn generate_debug_info(&self) -> bool {
180        // SAFETY: `self.ptr` is a live bridge handle.
181        unsafe { bridge::acc_bnns_graph_compile_options_get_generate_debug_info(self.ptr) }
182    }
183
184    /// Sets the `BNNSGraphOptimizationPreference` used by BNNS graph compilation.
185    pub fn set_optimization_preference(&mut self, preference: u32) -> Result<()> {
186        // SAFETY: `self.ptr` is a live bridge handle.
187        graph_result(unsafe {
188            bridge::acc_bnns_graph_compile_options_set_optimization_preference(self.ptr, preference)
189        })
190    }
191
192    /// Returns the `BNNSGraphOptimizationPreference` used by BNNS graph compilation.
193    #[must_use]
194    pub fn optimization_preference(&self) -> u32 {
195        // SAFETY: `self.ptr` is a live bridge handle.
196        unsafe { bridge::acc_bnns_graph_compile_options_get_optimization_preference(self.ptr) }
197    }
198}
199
200/// Apply BNNS `ReLU` activation to a vector.
201pub fn relu_f32(values: &[f32]) -> Result<Vec<f32>> {
202    apply_activation(values, bridge::acc_bnns_relu_f32)
203}
204
205/// Apply BNNS sigmoid activation to a vector.
206pub fn sigmoid_f32(values: &[f32]) -> Result<Vec<f32>> {
207    apply_activation(values, bridge::acc_bnns_sigmoid_f32)
208}