1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use std::path;

use Dtype;
use Error;
use Model;
use ModelBuilder;
use ModelData;
use VariableProfileTableBuilder;

/// Helper to build `Model`.
pub struct Builder {
    model_data: ModelData,
    vpt_builder: VariableProfileTableBuilder,
}

impl Builder {
    /// Create a builder from a ONNX file.
    ///
    /// ```
    /// # use menoh::*;
    /// # fn main() -> Result<(), Error> {
    /// let builder = Builder::from_onnx("MLP.onnx")?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn from_onnx<P>(path: P) -> Result<Self, Error>
        where P: AsRef<path::Path>
    {
        Ok(Self {
               model_data: ModelData::from_onnx(path)?,
               vpt_builder: VariableProfileTableBuilder::new()?,
           })
    }

    /// Register a variable as input.
    ///
    /// ```
    /// # use menoh::*;
    /// # fn main() -> Result<(), Error> {
    /// # let builder = Builder::from_onnx("MLP.onnx")?;
    /// let builder = builder.add_input::<f32>("input", &[2, 3])?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn add_input<T>(mut self, name: &str, dims: &[usize]) -> Result<Self, Error>
        where T: Dtype
    {
        self.vpt_builder.add_input::<T>(name, dims)?;
        Ok(self)
    }

    /// Register a variable as output.
    ///
    /// ```
    /// # use menoh::*;
    /// # fn main() -> Result<(), Error> {
    /// # let builder = Builder::from_onnx("MLP.onnx")?;
    /// let builder = builder.add_output::<f32>("fc2")?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn add_output<T>(mut self, name: &str) -> Result<Self, Error>
        where T: Dtype
    {
        self.vpt_builder.add_output::<T>(name)?;
        Ok(self)
    }

    /// Build a `Model`.
    ///
    /// ```
    /// # use menoh::*;
    /// # fn main() -> Result<(), Error> {
    /// # let builder = Builder::from_onnx("MLP.onnx")?
    /// #     .add_input::<f32>("input", &[2, 3])?
    /// #     .add_output::<f32>("fc2")?;
    /// let model = builder.build("mkldnn", "")?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn build(mut self, backend: &str, backend_config: &str) -> Result<Model, Error> {
        let vpt = self.vpt_builder.build(&self.model_data)?;
        self.model_data.optimize(&vpt)?;
        let model_builder = ModelBuilder::new(&vpt)?;
        Ok(model_builder
               .build(self.model_data, backend, backend_config)?)
    }
}