1use crate::bridge;
2use crate::error::{Error, Result};
3use crate::raw_ffi;
4use core::ffi::c_void;
5use core::ptr;
6
7pub mod graph_optimization_preference {
9 pub const PERFORMANCE: u32 = 0;
11 pub const IR_SIZE: u32 = 1;
13}
14
15pub struct Filter {
17 ptr: raw_ffi::BNNSFilter,
18}
19
20unsafe impl Send for Filter {}
21unsafe impl Sync for Filter {}
22
23impl Drop for Filter {
24 fn drop(&mut self) {
25 if !self.ptr.is_null() {
26 unsafe { raw_ffi::BNNSFilterDestroy(self.ptr) };
28 self.ptr = ptr::null_mut();
29 }
30 }
31}
32
33impl Filter {
34 #[must_use]
40 pub unsafe fn from_convolution(
41 layer_params: *const c_void,
42 filter_params: *const c_void,
43 ) -> Option<Self> {
44 let ptr = unsafe { raw_ffi::BNNSFilterCreateLayerConvolution(layer_params, filter_params) };
46 if ptr.is_null() {
47 None
48 } else {
49 Some(Self { ptr })
50 }
51 }
52
53 #[must_use]
59 pub unsafe fn from_fully_connected(
60 layer_params: *const c_void,
61 filter_params: *const c_void,
62 ) -> Option<Self> {
63 let ptr =
65 unsafe { raw_ffi::BNNSFilterCreateLayerFullyConnected(layer_params, filter_params) };
66 if ptr.is_null() {
67 None
68 } else {
69 Some(Self { ptr })
70 }
71 }
72
73 #[must_use]
75 pub const fn as_ptr(&self) -> *mut c_void {
76 self.ptr
77 }
78
79 pub unsafe fn apply(&self, input: *const c_void, output: *mut c_void) -> i32 {
85 unsafe { raw_ffi::BNNSFilterApply(self.ptr, input, output) }
87 }
88}
89
90pub struct GraphCompileOptions {
92 ptr: *mut c_void,
93}
94
95unsafe impl Send for GraphCompileOptions {}
96unsafe impl Sync for GraphCompileOptions {}
97
98impl Drop for GraphCompileOptions {
99 fn drop(&mut self) {
100 if !self.ptr.is_null() {
101 unsafe { bridge::acc_release_handle(self.ptr) };
103 self.ptr = ptr::null_mut();
104 }
105 }
106}
107
108fn activation_result(status: i32) -> Result<()> {
109 if status == 0 {
110 Ok(())
111 } else {
112 Err(Error::BnnsStatus(status))
113 }
114}
115
116fn graph_result(ok: bool) -> Result<()> {
117 if ok {
118 Ok(())
119 } else {
120 Err(Error::OperationFailed(
121 "BNNS Graph compile options are unavailable on this macOS version",
122 ))
123 }
124}
125
126fn apply_activation(
127 values: &[f32],
128 f: unsafe extern "C" fn(*const f32, *mut f32, usize) -> i32,
129) -> Result<Vec<f32>> {
130 let mut out = vec![0.0_f32; values.len()];
131 if values.is_empty() {
132 return Ok(out);
133 }
134
135 let status = unsafe { f(values.as_ptr(), out.as_mut_ptr(), values.len()) };
137 activation_result(status)?;
138 Ok(out)
139}
140
141impl GraphCompileOptions {
142 #[must_use]
144 pub fn new() -> Option<Self> {
145 let ptr = unsafe { bridge::acc_bnns_graph_compile_options_create() };
147 if ptr.is_null() {
148 None
149 } else {
150 Some(Self { ptr })
151 }
152 }
153
154 pub fn set_target_single_thread(&mut self, value: bool) -> Result<()> {
156 graph_result(unsafe {
158 bridge::acc_bnns_graph_compile_options_set_target_single_thread(self.ptr, value)
159 })
160 }
161
162 #[must_use]
164 pub fn target_single_thread(&self) -> bool {
165 unsafe { bridge::acc_bnns_graph_compile_options_get_target_single_thread(self.ptr) }
167 }
168
169 pub fn set_generate_debug_info(&mut self, value: bool) -> Result<()> {
171 graph_result(unsafe {
173 bridge::acc_bnns_graph_compile_options_set_generate_debug_info(self.ptr, value)
174 })
175 }
176
177 #[must_use]
179 pub fn generate_debug_info(&self) -> bool {
180 unsafe { bridge::acc_bnns_graph_compile_options_get_generate_debug_info(self.ptr) }
182 }
183
184 pub fn set_optimization_preference(&mut self, preference: u32) -> Result<()> {
186 graph_result(unsafe {
188 bridge::acc_bnns_graph_compile_options_set_optimization_preference(self.ptr, preference)
189 })
190 }
191
192 #[must_use]
194 pub fn optimization_preference(&self) -> u32 {
195 unsafe { bridge::acc_bnns_graph_compile_options_get_optimization_preference(self.ptr) }
197 }
198}
199
200pub fn relu_f32(values: &[f32]) -> Result<Vec<f32>> {
202 apply_activation(values, bridge::acc_bnns_relu_f32)
203}
204
205pub fn sigmoid_f32(values: &[f32]) -> Result<Vec<f32>> {
207 apply_activation(values, bridge::acc_bnns_sigmoid_f32)
208}