1use crate::bridge;
2use crate::error::{Error, Result};
3use crate::raw_ffi;
4use core::ffi::c_void;
5use core::ptr;
6
7pub struct Filter {
9 ptr: raw_ffi::BNNSFilter,
10}
11
12unsafe impl Send for Filter {}
13unsafe impl Sync for Filter {}
14
15impl Drop for Filter {
16 fn drop(&mut self) {
17 if !self.ptr.is_null() {
18 unsafe { raw_ffi::BNNSFilterDestroy(self.ptr) };
20 self.ptr = ptr::null_mut();
21 }
22 }
23}
24
25impl Filter {
26 #[must_use]
32 pub unsafe fn from_convolution(
33 layer_params: *const c_void,
34 filter_params: *const c_void,
35 ) -> Option<Self> {
36 let ptr = unsafe { raw_ffi::BNNSFilterCreateLayerConvolution(layer_params, filter_params) };
37 if ptr.is_null() {
38 None
39 } else {
40 Some(Self { ptr })
41 }
42 }
43
44 #[must_use]
50 pub unsafe fn from_fully_connected(
51 layer_params: *const c_void,
52 filter_params: *const c_void,
53 ) -> Option<Self> {
54 let ptr =
55 unsafe { raw_ffi::BNNSFilterCreateLayerFullyConnected(layer_params, filter_params) };
56 if ptr.is_null() {
57 None
58 } else {
59 Some(Self { ptr })
60 }
61 }
62
63 #[must_use]
64 pub const fn as_ptr(&self) -> *mut c_void {
65 self.ptr
66 }
67
68 pub unsafe fn apply(&self, input: *const c_void, output: *mut c_void) -> i32 {
74 unsafe { raw_ffi::BNNSFilterApply(self.ptr, input, output) }
75 }
76}
77
78fn activation_result(status: i32) -> Result<()> {
79 if status == 0 {
80 Ok(())
81 } else {
82 Err(Error::BnnsStatus(status))
83 }
84}
85
86fn apply_activation(
87 values: &[f32],
88 f: unsafe extern "C" fn(*const f32, *mut f32, usize) -> i32,
89) -> Result<Vec<f32>> {
90 let mut out = vec![0.0_f32; values.len()];
91 if values.is_empty() {
92 return Ok(out);
93 }
94
95 let status = unsafe { f(values.as_ptr(), out.as_mut_ptr(), values.len()) };
97 activation_result(status)?;
98 Ok(out)
99}
100
101pub fn relu_f32(values: &[f32]) -> Result<Vec<f32>> {
103 apply_activation(values, bridge::acc_bnns_relu_f32)
104}
105
106pub fn sigmoid_f32(values: &[f32]) -> Result<Vec<f32>> {
108 apply_activation(values, bridge::acc_bnns_sigmoid_f32)
109}