Skip to main content

apple_accelerate/
bnns.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3use crate::raw_ffi;
4use core::ffi::c_void;
5use core::ptr;
6
7/// `BNNSGraphOptimizationPreference` constants.
8pub mod graph_optimization_preference {
9    pub const PERFORMANCE: u32 = 0;
10    pub const IR_SIZE: u32 = 1;
11}
12
13/// Thin owner for the deprecated-but-still-available BNNS filter APIs.
14pub struct Filter {
15    ptr: raw_ffi::BNNSFilter,
16}
17
18unsafe impl Send for Filter {}
19unsafe impl Sync for Filter {}
20
21impl Drop for Filter {
22    fn drop(&mut self) {
23        if !self.ptr.is_null() {
24            // SAFETY: `ptr` was returned by a BNNS filter constructor and is owned by this wrapper.
25            unsafe { raw_ffi::BNNSFilterDestroy(self.ptr) };
26            self.ptr = ptr::null_mut();
27        }
28    }
29}
30
31impl Filter {
32    /// Create a BNNS convolution filter from caller-owned raw layer/filter parameter structs.
33    ///
34    /// # Safety
35    ///
36    /// The pointers must refer to valid BNNS parameter structs for the duration of the call.
37    #[must_use]
38    pub unsafe fn from_convolution(
39        layer_params: *const c_void,
40        filter_params: *const c_void,
41    ) -> Option<Self> {
42        let ptr = unsafe { raw_ffi::BNNSFilterCreateLayerConvolution(layer_params, filter_params) };
43        if ptr.is_null() {
44            None
45        } else {
46            Some(Self { ptr })
47        }
48    }
49
50    /// Create a BNNS fully connected filter from caller-owned raw layer/filter parameter structs.
51    ///
52    /// # Safety
53    ///
54    /// The pointers must refer to valid BNNS parameter structs for the duration of the call.
55    #[must_use]
56    pub unsafe fn from_fully_connected(
57        layer_params: *const c_void,
58        filter_params: *const c_void,
59    ) -> Option<Self> {
60        let ptr =
61            unsafe { raw_ffi::BNNSFilterCreateLayerFullyConnected(layer_params, filter_params) };
62        if ptr.is_null() {
63            None
64        } else {
65            Some(Self { ptr })
66        }
67    }
68
69    #[must_use]
70    pub const fn as_ptr(&self) -> *mut c_void {
71        self.ptr
72    }
73
74    /// Apply the filter to caller-owned raw buffers.
75    ///
76    /// # Safety
77    ///
78    /// `input` and `output` must match the layout and lengths described when the filter was created.
79    pub unsafe fn apply(&self, input: *const c_void, output: *mut c_void) -> i32 {
80        unsafe { raw_ffi::BNNSFilterApply(self.ptr, input, output) }
81    }
82}
83
84/// Owned BNNS Graph compile-options handle backed by the Swift bridge.
85pub struct GraphCompileOptions {
86    ptr: *mut c_void,
87}
88
89unsafe impl Send for GraphCompileOptions {}
90unsafe impl Sync for GraphCompileOptions {}
91
92impl Drop for GraphCompileOptions {
93    fn drop(&mut self) {
94        if !self.ptr.is_null() {
95            // SAFETY: `ptr` is an opaque Swift object retained by the bridge.
96            unsafe { bridge::acc_release_handle(self.ptr) };
97            self.ptr = ptr::null_mut();
98        }
99    }
100}
101
102fn activation_result(status: i32) -> Result<()> {
103    if status == 0 {
104        Ok(())
105    } else {
106        Err(Error::BnnsStatus(status))
107    }
108}
109
110fn graph_result(ok: bool) -> Result<()> {
111    if ok {
112        Ok(())
113    } else {
114        Err(Error::OperationFailed(
115            "BNNS Graph compile options are unavailable on this macOS version",
116        ))
117    }
118}
119
120fn apply_activation(
121    values: &[f32],
122    f: unsafe extern "C" fn(*const f32, *mut f32, usize) -> i32,
123) -> Result<Vec<f32>> {
124    let mut out = vec![0.0_f32; values.len()];
125    if values.is_empty() {
126        return Ok(out);
127    }
128
129    // SAFETY: The buffers are valid for `values.len()` contiguous `f32` elements.
130    let status = unsafe { f(values.as_ptr(), out.as_mut_ptr(), values.len()) };
131    activation_result(status)?;
132    Ok(out)
133}
134
135impl GraphCompileOptions {
136    #[must_use]
137    pub fn new() -> Option<Self> {
138        // SAFETY: Pure constructor over the current runtime environment.
139        let ptr = unsafe { bridge::acc_bnns_graph_compile_options_create() };
140        if ptr.is_null() {
141            None
142        } else {
143            Some(Self { ptr })
144        }
145    }
146
147    pub fn set_target_single_thread(&mut self, value: bool) -> Result<()> {
148        // SAFETY: `self.ptr` is a live bridge handle.
149        graph_result(unsafe {
150            bridge::acc_bnns_graph_compile_options_set_target_single_thread(self.ptr, value)
151        })
152    }
153
154    #[must_use]
155    pub fn target_single_thread(&self) -> bool {
156        // SAFETY: `self.ptr` is a live bridge handle.
157        unsafe { bridge::acc_bnns_graph_compile_options_get_target_single_thread(self.ptr) }
158    }
159
160    pub fn set_generate_debug_info(&mut self, value: bool) -> Result<()> {
161        // SAFETY: `self.ptr` is a live bridge handle.
162        graph_result(unsafe {
163            bridge::acc_bnns_graph_compile_options_set_generate_debug_info(self.ptr, value)
164        })
165    }
166
167    #[must_use]
168    pub fn generate_debug_info(&self) -> bool {
169        // SAFETY: `self.ptr` is a live bridge handle.
170        unsafe { bridge::acc_bnns_graph_compile_options_get_generate_debug_info(self.ptr) }
171    }
172
173    pub fn set_optimization_preference(&mut self, preference: u32) -> Result<()> {
174        // SAFETY: `self.ptr` is a live bridge handle.
175        graph_result(unsafe {
176            bridge::acc_bnns_graph_compile_options_set_optimization_preference(self.ptr, preference)
177        })
178    }
179
180    #[must_use]
181    pub fn optimization_preference(&self) -> u32 {
182        // SAFETY: `self.ptr` is a live bridge handle.
183        unsafe { bridge::acc_bnns_graph_compile_options_get_optimization_preference(self.ptr) }
184    }
185}
186
187/// Apply BNNS `ReLU` activation to a vector.
188pub fn relu_f32(values: &[f32]) -> Result<Vec<f32>> {
189    apply_activation(values, bridge::acc_bnns_relu_f32)
190}
191
192/// Apply BNNS sigmoid activation to a vector.
193pub fn sigmoid_f32(values: &[f32]) -> Result<Vec<f32>> {
194    apply_activation(values, bridge::acc_bnns_sigmoid_f32)
195}