menoh/
model.rs

1use crate::error::check;
2use crate::Dtype;
3use crate::Error;
4use std::ffi::CString;
5use std::mem;
6use std::ptr;
7use std::slice;
8
9/// Model, which executes computation.
10pub struct Model {
11    pub(crate) handle: menoh_sys::menoh_model_handle,
12}
13
14impl Model {
15    /// Fetch the shape of a variable.
16    ///
17    /// ```
18    /// # use menoh::*;
19    /// # fn main() -> Result<(), Error> {
20    /// # let model = Builder::from_onnx("MLP.onnx")?
21    /// #     .add_input::<f32>("input", &[2, 3])?
22    /// #     .add_output("fc2")?
23    /// #     .build("mkldnn", "")?;
24    /// let dims = model.get_variable_dims("fc2")?;
25    /// # assert_eq!(dims, &[2, 5]);
26    /// # Ok(())
27    /// # }
28    /// ```
29    pub fn get_variable_dims(&self, name: &str) -> Result<Vec<usize>, Error> {
30        let name = CString::new(name)?;
31        unsafe {
32            let mut size = 0;
33            check(menoh_sys::menoh_model_get_variable_dims_size(
34                self.handle,
35                name.as_ptr(),
36                &mut size,
37            ))?;
38            let mut dims = Vec::with_capacity(size as _);
39            for index in 0..size {
40                let mut dim = 0;
41                check(menoh_sys::menoh_model_get_variable_dims_at(
42                    self.handle,
43                    name.as_ptr(),
44                    index,
45                    &mut dim,
46                ))?;
47                dims.push(dim as _);
48            }
49            Ok(dims)
50        }
51    }
52
53    fn get_variable_dtype(&self, name: &str) -> Result<menoh_sys::menoh_dtype, Error> {
54        let name = CString::new(name)?;
55        unsafe {
56            let mut dtype = mem::uninitialized();
57            check(menoh_sys::menoh_model_get_variable_dtype(
58                self.handle,
59                name.as_ptr(),
60                &mut dtype,
61            ))?;
62            Ok(dtype)
63        }
64    }
65
66    /// Fetch the shape and read-only view of a variable.
67    ///
68    /// ```
69    /// # use menoh::*;
70    /// # fn main() -> Result<(), Error> {
71    /// # let model = Builder::from_onnx("MLP.onnx")?
72    /// #     .add_input::<f32>("input", &[2, 3])?
73    /// #     .add_output("fc2")?
74    /// #     .build("mkldnn", "")?;
75    /// let (dims, buf) = model.get_variable::<f32>("fc2")?;
76    /// # assert_eq!(dims, &[2, 5]);
77    /// # Ok(())
78    /// # }
79    /// ```
80    pub fn get_variable<T>(&self, name: &str) -> Result<(Vec<usize>, &[T]), Error>
81    where
82        T: Dtype,
83    {
84        T::check(self.get_variable_dtype(name)?)?;
85        let dims = self.get_variable_dims(name)?;
86
87        let name = CString::new(name)?;
88        let mut buffer = ptr::null_mut();
89        unsafe {
90            check(menoh_sys::menoh_model_get_variable_buffer_handle(
91                self.handle,
92                name.as_ptr(),
93                &mut buffer,
94            ))?;
95            let buffer = slice::from_raw_parts(buffer as _, dims.iter().product());
96            Ok((dims, buffer))
97        }
98    }
99
100    /// Fetch the shape and read/write view of a variable.
101    ///
102    /// ```
103    /// # use menoh::*;
104    /// # fn main() -> Result<(), Error> {
105    /// # let mut model = Builder::from_onnx("MLP.onnx")?
106    /// #     .add_input::<f32>("input", &[2, 3])?
107    /// #     .add_output("fc2")?
108    /// #     .build("mkldnn", "")?;
109    /// let (dims, buf) = model.get_variable_mut::<f32>("fc2")?;
110    /// # assert_eq!(dims, &[2, 5]);
111    /// # Ok(())
112    /// # }
113    /// ```
114    pub fn get_variable_mut<T>(&mut self, name: &str) -> Result<(Vec<usize>, &mut [T]), Error>
115    where
116        T: Dtype,
117    {
118        T::check(self.get_variable_dtype(name)?)?;
119        let dims = self.get_variable_dims(name)?;
120
121        let name = CString::new(name)?;
122        let mut buffer = ptr::null_mut();
123        unsafe {
124            check(menoh_sys::menoh_model_get_variable_buffer_handle(
125                self.handle,
126                name.as_ptr(),
127                &mut buffer,
128            ))?;
129            let buffer = slice::from_raw_parts_mut(buffer as _, dims.iter().product());
130            Ok((dims, buffer))
131        }
132    }
133
134    pub fn run(&mut self) -> Result<(), Error> {
135        unsafe { check(menoh_sys::menoh_model_run(self.handle)) }
136    }
137}
138
139impl Drop for Model {
140    fn drop(&mut self) {
141        unsafe { menoh_sys::menoh_delete_model(self.handle) }
142    }
143}