compiled_nn/
compiled_nn.rs1use std::{
2 ffi::CString,
3 path::Path,
4 slice::{from_raw_parts, from_raw_parts_mut},
5};
6
7use compiled_nn_bindings;
8
9pub struct CompiledNN {
10 core: compiled_nn_bindings::CompiledNN,
11}
12
13unsafe impl Send for CompiledNN {}
14
15impl Default for CompiledNN {
16 fn default() -> Self {
17 Self {
18 core: unsafe { compiled_nn_bindings::CompiledNN::new() },
19 }
20 }
21}
22
23impl Drop for CompiledNN {
24 fn drop(&mut self) {
25 unsafe { self.core.destruct() }
26 }
27}
28
29impl CompiledNN {
30 pub fn compile(&mut self, filename: impl AsRef<Path>) {
31 let filename =
32 CString::new(filename.as_ref().to_str().unwrap()).expect("CString::new failed");
33 unsafe { self.core.compile(filename.as_ptr()) }
34 }
35
36 pub fn input(&self, index: usize) -> Tensor {
37 unsafe {
38 let input = self.core.input(index as u64);
39 Tensor {
40 data: from_raw_parts(input.data, input.data_size as usize),
41 dimensions: from_raw_parts(input.dimensions, input.dimensions_size as usize),
42 }
43 }
44 }
45
46 pub fn input_mut(&mut self, index: usize) -> TensorMut {
47 unsafe {
48 let input = self.core.input_mut(index as u64);
49 TensorMut {
50 data: from_raw_parts_mut(input.data, input.data_size as usize),
51 dimensions: from_raw_parts(input.dimensions, input.dimensions_size as usize),
52 }
53 }
54 }
55
56 pub fn output(&self, index: usize) -> Tensor {
57 unsafe {
58 let output = self.core.output(index as u64);
59 Tensor {
60 data: from_raw_parts(output.data, output.data_size as usize),
61 dimensions: from_raw_parts(output.dimensions, output.dimensions_size as usize),
62 }
63 }
64 }
65
66 pub fn output_mut(&mut self, index: usize) -> TensorMut {
67 unsafe {
68 let output = self.core.output_mut(index as u64);
69 TensorMut {
70 data: from_raw_parts_mut(output.data, output.data_size as usize),
71 dimensions: from_raw_parts(output.dimensions, output.dimensions_size as usize),
72 }
73 }
74 }
75
76 pub fn apply(&mut self) {
77 unsafe { self.core.apply() }
78 }
79}
80
81#[derive(Debug)]
82pub struct Tensor<'a> {
83 pub data: &'a [f32],
84 pub dimensions: &'a [u32],
85}
86
87#[derive(Debug)]
88pub struct TensorMut<'a> {
89 pub data: &'a mut [f32],
90 pub dimensions: &'a [u32],
91}