menoh/
builder.rs

1use crate::Dtype;
2use crate::Error;
3use crate::Model;
4use crate::ModelBuilder;
5use crate::ModelData;
6use crate::VariableProfileTableBuilder;
7use std::path::Path;
8
9/// Helper to build `Model`.
10pub struct Builder {
11    model_data: ModelData,
12    vpt_builder: VariableProfileTableBuilder,
13}
14
15impl Builder {
16    /// Create a builder from a ONNX file.
17    ///
18    /// ```
19    /// # use menoh::*;
20    /// # fn main() -> Result<(), Error> {
21    /// let builder = Builder::from_onnx("MLP.onnx")?;
22    /// # Ok(())
23    /// # }
24    /// ```
25    pub fn from_onnx<P>(path: P) -> Result<Self, Error>
26    where
27        P: AsRef<Path>,
28    {
29        Ok(Self {
30            model_data: ModelData::from_onnx(path)?,
31            vpt_builder: VariableProfileTableBuilder::new()?,
32        })
33    }
34
35    /// Create a builder from a ONNX data.
36    ///
37    /// ```
38    /// # use menoh::*;
39    /// # fn main() -> Result<(), Error> {
40    /// # let onnx_data = include_bytes!("../MLP.onnx");
41    /// let builder = Builder::from_onnx_bytes(onnx_data)?;
42    /// # Ok(())
43    /// # }
44    /// ```
45    pub fn from_onnx_bytes(data: &[u8]) -> Result<Self, Error> {
46        Ok(Self {
47            model_data: ModelData::from_onnx_bytes(data)?,
48            vpt_builder: VariableProfileTableBuilder::new()?,
49        })
50    }
51
52    /// Register a variable as input.
53    ///
54    /// ```
55    /// # use menoh::*;
56    /// # fn main() -> Result<(), Error> {
57    /// # let builder = Builder::from_onnx("MLP.onnx")?;
58    /// let builder = builder.add_input::<f32>("input", &[2, 3])?;
59    /// # Ok(())
60    /// # }
61    /// ```
62    pub fn add_input<T>(mut self, name: &str, dims: &[usize]) -> Result<Self, Error>
63    where
64        T: Dtype,
65    {
66        self.vpt_builder.add_input::<T>(name, dims)?;
67        Ok(self)
68    }
69
70    /// Register a variable as output.
71    ///
72    /// ```
73    /// # use menoh::*;
74    /// # fn main() -> Result<(), Error> {
75    /// # let builder = Builder::from_onnx("MLP.onnx")?;
76    /// let builder = builder.add_output("fc2")?;
77    /// # Ok(())
78    /// # }
79    /// ```
80    pub fn add_output(mut self, name: &str) -> Result<Self, Error> {
81        self.vpt_builder.add_output(name)?;
82        Ok(self)
83    }
84
85    /// Build a `Model`.
86    ///
87    /// ```
88    /// # use menoh::*;
89    /// # fn main() -> Result<(), Error> {
90    /// # let builder = Builder::from_onnx("MLP.onnx")?
91    /// #     .add_input::<f32>("input", &[2, 3])?
92    /// #     .add_output("fc2")?;
93    /// let model = builder.build("mkldnn", "")?;
94    /// # Ok(())
95    /// # }
96    /// ```
97    pub fn build(mut self, backend: &str, backend_config: &str) -> Result<Model, Error> {
98        let vpt = self.vpt_builder.build(&self.model_data)?;
99        self.model_data.optimize(&vpt)?;
100        let model_builder = ModelBuilder::new(&vpt)?;
101        Ok(model_builder.build(self.model_data, backend, backend_config)?)
102    }
103}