1use crate::bridge;
2use crate::error::{Error, Result};
3use crate::raw_ffi;
4use core::ffi::c_void;
5use core::ptr;
6
7pub mod graph_optimization_preference {
9 pub const PERFORMANCE: u32 = 0;
10 pub const IR_SIZE: u32 = 1;
11}
12
13pub struct Filter {
15 ptr: raw_ffi::BNNSFilter,
16}
17
18unsafe impl Send for Filter {}
19unsafe impl Sync for Filter {}
20
21impl Drop for Filter {
22 fn drop(&mut self) {
23 if !self.ptr.is_null() {
24 unsafe { raw_ffi::BNNSFilterDestroy(self.ptr) };
26 self.ptr = ptr::null_mut();
27 }
28 }
29}
30
31impl Filter {
32 #[must_use]
38 pub unsafe fn from_convolution(
39 layer_params: *const c_void,
40 filter_params: *const c_void,
41 ) -> Option<Self> {
42 let ptr = unsafe { raw_ffi::BNNSFilterCreateLayerConvolution(layer_params, filter_params) };
43 if ptr.is_null() {
44 None
45 } else {
46 Some(Self { ptr })
47 }
48 }
49
50 #[must_use]
56 pub unsafe fn from_fully_connected(
57 layer_params: *const c_void,
58 filter_params: *const c_void,
59 ) -> Option<Self> {
60 let ptr =
61 unsafe { raw_ffi::BNNSFilterCreateLayerFullyConnected(layer_params, filter_params) };
62 if ptr.is_null() {
63 None
64 } else {
65 Some(Self { ptr })
66 }
67 }
68
69 #[must_use]
70 pub const fn as_ptr(&self) -> *mut c_void {
71 self.ptr
72 }
73
74 pub unsafe fn apply(&self, input: *const c_void, output: *mut c_void) -> i32 {
80 unsafe { raw_ffi::BNNSFilterApply(self.ptr, input, output) }
81 }
82}
83
84pub struct GraphCompileOptions {
86 ptr: *mut c_void,
87}
88
89unsafe impl Send for GraphCompileOptions {}
90unsafe impl Sync for GraphCompileOptions {}
91
92impl Drop for GraphCompileOptions {
93 fn drop(&mut self) {
94 if !self.ptr.is_null() {
95 unsafe { bridge::acc_release_handle(self.ptr) };
97 self.ptr = ptr::null_mut();
98 }
99 }
100}
101
102fn activation_result(status: i32) -> Result<()> {
103 if status == 0 {
104 Ok(())
105 } else {
106 Err(Error::BnnsStatus(status))
107 }
108}
109
110fn graph_result(ok: bool) -> Result<()> {
111 if ok {
112 Ok(())
113 } else {
114 Err(Error::OperationFailed(
115 "BNNS Graph compile options are unavailable on this macOS version",
116 ))
117 }
118}
119
120fn apply_activation(
121 values: &[f32],
122 f: unsafe extern "C" fn(*const f32, *mut f32, usize) -> i32,
123) -> Result<Vec<f32>> {
124 let mut out = vec![0.0_f32; values.len()];
125 if values.is_empty() {
126 return Ok(out);
127 }
128
129 let status = unsafe { f(values.as_ptr(), out.as_mut_ptr(), values.len()) };
131 activation_result(status)?;
132 Ok(out)
133}
134
135impl GraphCompileOptions {
136 #[must_use]
137 pub fn new() -> Option<Self> {
138 let ptr = unsafe { bridge::acc_bnns_graph_compile_options_create() };
140 if ptr.is_null() {
141 None
142 } else {
143 Some(Self { ptr })
144 }
145 }
146
147 pub fn set_target_single_thread(&mut self, value: bool) -> Result<()> {
148 graph_result(unsafe {
150 bridge::acc_bnns_graph_compile_options_set_target_single_thread(self.ptr, value)
151 })
152 }
153
154 #[must_use]
155 pub fn target_single_thread(&self) -> bool {
156 unsafe { bridge::acc_bnns_graph_compile_options_get_target_single_thread(self.ptr) }
158 }
159
160 pub fn set_generate_debug_info(&mut self, value: bool) -> Result<()> {
161 graph_result(unsafe {
163 bridge::acc_bnns_graph_compile_options_set_generate_debug_info(self.ptr, value)
164 })
165 }
166
167 #[must_use]
168 pub fn generate_debug_info(&self) -> bool {
169 unsafe { bridge::acc_bnns_graph_compile_options_get_generate_debug_info(self.ptr) }
171 }
172
173 pub fn set_optimization_preference(&mut self, preference: u32) -> Result<()> {
174 graph_result(unsafe {
176 bridge::acc_bnns_graph_compile_options_set_optimization_preference(self.ptr, preference)
177 })
178 }
179
180 #[must_use]
181 pub fn optimization_preference(&self) -> u32 {
182 unsafe { bridge::acc_bnns_graph_compile_options_get_optimization_preference(self.ptr) }
184 }
185}
186
187pub fn relu_f32(values: &[f32]) -> Result<Vec<f32>> {
189 apply_activation(values, bridge::acc_bnns_relu_f32)
190}
191
192pub fn sigmoid_f32(values: &[f32]) -> Result<Vec<f32>> {
194 apply_activation(values, bridge::acc_bnns_sigmoid_f32)
195}