onnx_helpers/builder/
model.rs

1//! Model builder.
2
3use onnx_pb::{GraphProto, ModelProto, OperatorSetIdProto, StringStringEntryProto, Version};
4
5const DEFAULT_OPSET_ID_VERSION: i64 = 11;
6
7/// Model builder.
8#[derive(Default, Clone)]
9pub struct Model {
10    graph: GraphProto,
11    domain: Option<String>,
12    model_version: Option<i64>,
13    producer_name: Option<String>,
14    producer_version: Option<String>,
15    doc_string: Option<String>,
16    metadata: Vec<(String, String)>,
17    opset_imports: Option<Vec<OperatorSetIdProto>>,
18}
19
20impl Model {
21    /// Creates a new builder.
22    #[inline]
23    pub fn new<G: Into<GraphProto>>(graph: G) -> Self {
24        Model {
25            graph: graph.into(),
26            ..Model::default()
27        }
28    }
29
30    /// Sets model doc_string.
31    #[inline]
32    pub fn domain<S: Into<String>>(mut self, domain: S) -> Self {
33        self.domain = Some(domain.into());
34        self
35    }
36
37    /// Sets model doc_string.
38    #[inline]
39    pub fn model_version(mut self, model_version: i64) -> Self {
40        self.model_version = Some(model_version);
41        self
42    }
43
44    /// Sets model doc_string.
45    #[inline]
46    pub fn producer_name<S: Into<String>>(mut self, producer_name: S) -> Self {
47        self.producer_name = Some(producer_name.into());
48        self
49    }
50
51    /// Sets model doc_string.
52    #[inline]
53    pub fn producer_version<S: Into<String>>(mut self, producer_version: S) -> Self {
54        self.producer_version = Some(producer_version.into());
55        self
56    }
57
58    /// Sets model doc_string.
59    #[inline]
60    pub fn doc_string<S: Into<String>>(mut self, doc_string: S) -> Self {
61        self.doc_string = Some(doc_string.into());
62        self
63    }
64
65    /// Inserts model metadata.
66    #[inline]
67    pub fn metadata<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
68        self.metadata.push((key.into(), value.into()));
69        self
70    }
71
72    /// Inserts operator set import.
73    #[inline]
74    pub fn opset_import(mut self, opset: OperatorSetIdProto) -> Self {
75        if let Some(opset_imports) = self.opset_imports.as_mut() {
76            opset_imports.push(opset);
77        } else {
78            self.opset_imports = Some(vec![opset]);
79        }
80        self
81    }
82
83    /// Builds the model.
84    #[inline]
85    pub fn build(self) -> ModelProto {
86        let opset_import = self.opset_imports.unwrap_or_else(|| {
87            vec![OperatorSetIdProto {
88                version: DEFAULT_OPSET_ID_VERSION,
89                ..OperatorSetIdProto::default()
90            }]
91        });
92        let metadata_props = self
93            .metadata
94            .into_iter()
95            .map(|(k, v)| StringStringEntryProto {
96                key: k.into(),
97                value: v.into(),
98            })
99            .collect();
100        ModelProto {
101            ir_version: Version::IrVersion as i64,
102            graph: Some(self.graph),
103            domain: self.domain.unwrap_or_default(),
104            doc_string: self.doc_string.unwrap_or_default(),
105            producer_name: self.producer_name.unwrap_or_default(),
106            producer_version: self.producer_version.unwrap_or_default(),
107            model_version: self.model_version.unwrap_or_default(),
108            opset_import,
109            metadata_props,
110            ..ModelProto::default()
111        }
112    }
113}
114
115impl Into<ModelProto> for Model {
116    fn into(self) -> ModelProto {
117        self.build()
118    }
119}