menoh/builder.rs
1use crate::Dtype;
2use crate::Error;
3use crate::Model;
4use crate::ModelBuilder;
5use crate::ModelData;
6use crate::VariableProfileTableBuilder;
7use std::path::Path;
8
9/// Helper to build `Model`.
10pub struct Builder {
11 model_data: ModelData,
12 vpt_builder: VariableProfileTableBuilder,
13}
14
15impl Builder {
16 /// Create a builder from a ONNX file.
17 ///
18 /// ```
19 /// # use menoh::*;
20 /// # fn main() -> Result<(), Error> {
21 /// let builder = Builder::from_onnx("MLP.onnx")?;
22 /// # Ok(())
23 /// # }
24 /// ```
25 pub fn from_onnx<P>(path: P) -> Result<Self, Error>
26 where
27 P: AsRef<Path>,
28 {
29 Ok(Self {
30 model_data: ModelData::from_onnx(path)?,
31 vpt_builder: VariableProfileTableBuilder::new()?,
32 })
33 }
34
35 /// Create a builder from a ONNX data.
36 ///
37 /// ```
38 /// # use menoh::*;
39 /// # fn main() -> Result<(), Error> {
40 /// # let onnx_data = include_bytes!("../MLP.onnx");
41 /// let builder = Builder::from_onnx_bytes(onnx_data)?;
42 /// # Ok(())
43 /// # }
44 /// ```
45 pub fn from_onnx_bytes(data: &[u8]) -> Result<Self, Error> {
46 Ok(Self {
47 model_data: ModelData::from_onnx_bytes(data)?,
48 vpt_builder: VariableProfileTableBuilder::new()?,
49 })
50 }
51
52 /// Register a variable as input.
53 ///
54 /// ```
55 /// # use menoh::*;
56 /// # fn main() -> Result<(), Error> {
57 /// # let builder = Builder::from_onnx("MLP.onnx")?;
58 /// let builder = builder.add_input::<f32>("input", &[2, 3])?;
59 /// # Ok(())
60 /// # }
61 /// ```
62 pub fn add_input<T>(mut self, name: &str, dims: &[usize]) -> Result<Self, Error>
63 where
64 T: Dtype,
65 {
66 self.vpt_builder.add_input::<T>(name, dims)?;
67 Ok(self)
68 }
69
70 /// Register a variable as output.
71 ///
72 /// ```
73 /// # use menoh::*;
74 /// # fn main() -> Result<(), Error> {
75 /// # let builder = Builder::from_onnx("MLP.onnx")?;
76 /// let builder = builder.add_output("fc2")?;
77 /// # Ok(())
78 /// # }
79 /// ```
80 pub fn add_output(mut self, name: &str) -> Result<Self, Error> {
81 self.vpt_builder.add_output(name)?;
82 Ok(self)
83 }
84
85 /// Build a `Model`.
86 ///
87 /// ```
88 /// # use menoh::*;
89 /// # fn main() -> Result<(), Error> {
90 /// # let builder = Builder::from_onnx("MLP.onnx")?
91 /// # .add_input::<f32>("input", &[2, 3])?
92 /// # .add_output("fc2")?;
93 /// let model = builder.build("mkldnn", "")?;
94 /// # Ok(())
95 /// # }
96 /// ```
97 pub fn build(mut self, backend: &str, backend_config: &str) -> Result<Model, Error> {
98 let vpt = self.vpt_builder.build(&self.model_data)?;
99 self.model_data.optimize(&vpt)?;
100 let model_builder = ModelBuilder::new(&vpt)?;
101 Ok(model_builder.build(self.model_data, backend, backend_config)?)
102 }
103}