1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use onnx_pb::{GraphProto, ModelProto, OperatorSetIdProto, StringStringEntryProto, Version};
const DEFAULT_OPSET_ID_VERSION: i64 = 11;
#[derive(Default, Clone)]
pub struct Model {
graph: GraphProto,
domain: Option<String>,
model_version: Option<i64>,
producer_name: Option<String>,
producer_version: Option<String>,
doc_string: Option<String>,
metadata: Vec<(String, String)>,
opset_imports: Option<Vec<OperatorSetIdProto>>,
}
impl Model {
#[inline]
pub fn new<G: Into<GraphProto>>(graph: G) -> Self {
Model {
graph: graph.into(),
..Model::default()
}
}
#[inline]
pub fn domain<S: Into<String>>(mut self, domain: S) -> Self {
self.domain = Some(domain.into());
self
}
#[inline]
pub fn model_version(mut self, model_version: i64) -> Self {
self.model_version = Some(model_version);
self
}
#[inline]
pub fn producer_name<S: Into<String>>(mut self, producer_name: S) -> Self {
self.producer_name = Some(producer_name.into());
self
}
#[inline]
pub fn producer_version<S: Into<String>>(mut self, producer_version: S) -> Self {
self.producer_version = Some(producer_version.into());
self
}
#[inline]
pub fn doc_string<S: Into<String>>(mut self, doc_string: S) -> Self {
self.doc_string = Some(doc_string.into());
self
}
#[inline]
pub fn metadata<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.metadata.push((key.into(), value.into()));
self
}
#[inline]
pub fn opset_import(mut self, opset: OperatorSetIdProto) -> Self {
if let Some(opset_imports) = self.opset_imports.as_mut() {
opset_imports.push(opset);
} else {
self.opset_imports = Some(vec![opset]);
}
self
}
#[inline]
pub fn build(self) -> ModelProto {
let opset_import = self.opset_imports.unwrap_or_else(|| {
vec![OperatorSetIdProto {
version: DEFAULT_OPSET_ID_VERSION,
..OperatorSetIdProto::default()
}]
});
let metadata_props = self
.metadata
.into_iter()
.map(|(k, v)| StringStringEntryProto {
key: k.into(),
value: v.into(),
})
.collect();
ModelProto {
ir_version: Version::IrVersion as i64,
graph: Some(self.graph),
domain: self.domain.unwrap_or_default(),
doc_string: self.doc_string.unwrap_or_default(),
producer_name: self.producer_name.unwrap_or_default(),
producer_version: self.producer_version.unwrap_or_default(),
model_version: self.model_version.unwrap_or_default(),
opset_import,
metadata_props,
..ModelProto::default()
}
}
}
impl Into<ModelProto> for Model {
fn into(self) -> ModelProto {
self.build()
}
}