compiled_nn/
compiled_nn.rs

1use std::{
2    ffi::CString,
3    path::Path,
4    slice::{from_raw_parts, from_raw_parts_mut},
5};
6
7use compiled_nn_bindings;
8
9pub struct CompiledNN {
10    core: compiled_nn_bindings::CompiledNN,
11}
12
13unsafe impl Send for CompiledNN {}
14
15impl Default for CompiledNN {
16    fn default() -> Self {
17        Self {
18            core: unsafe { compiled_nn_bindings::CompiledNN::new() },
19        }
20    }
21}
22
23impl Drop for CompiledNN {
24    fn drop(&mut self) {
25        unsafe { self.core.destruct() }
26    }
27}
28
29impl CompiledNN {
30    pub fn compile(&mut self, filename: impl AsRef<Path>) {
31        let filename =
32            CString::new(filename.as_ref().to_str().unwrap()).expect("CString::new failed");
33        unsafe { self.core.compile(filename.as_ptr()) }
34    }
35
36    pub fn input(&self, index: usize) -> Tensor {
37        unsafe {
38            let input = self.core.input(index as u64);
39            Tensor {
40                data: from_raw_parts(input.data, input.data_size as usize),
41                dimensions: from_raw_parts(input.dimensions, input.dimensions_size as usize),
42            }
43        }
44    }
45
46    pub fn input_mut(&mut self, index: usize) -> TensorMut {
47        unsafe {
48            let input = self.core.input_mut(index as u64);
49            TensorMut {
50                data: from_raw_parts_mut(input.data, input.data_size as usize),
51                dimensions: from_raw_parts(input.dimensions, input.dimensions_size as usize),
52            }
53        }
54    }
55
56    pub fn output(&self, index: usize) -> Tensor {
57        unsafe {
58            let output = self.core.output(index as u64);
59            Tensor {
60                data: from_raw_parts(output.data, output.data_size as usize),
61                dimensions: from_raw_parts(output.dimensions, output.dimensions_size as usize),
62            }
63        }
64    }
65
66    pub fn output_mut(&mut self, index: usize) -> TensorMut {
67        unsafe {
68            let output = self.core.output_mut(index as u64);
69            TensorMut {
70                data: from_raw_parts_mut(output.data, output.data_size as usize),
71                dimensions: from_raw_parts(output.dimensions, output.dimensions_size as usize),
72            }
73        }
74    }
75
76    pub fn apply(&mut self) {
77        unsafe { self.core.apply() }
78    }
79}
80
81#[derive(Debug)]
82pub struct Tensor<'a> {
83    pub data: &'a [f32],
84    pub dimensions: &'a [u32],
85}
86
87#[derive(Debug)]
88pub struct TensorMut<'a> {
89    pub data: &'a mut [f32],
90    pub dimensions: &'a [u32],
91}