onnx_helpers/builder/
model.rs1use onnx_pb::{GraphProto, ModelProto, OperatorSetIdProto, StringStringEntryProto, Version};
4
5const DEFAULT_OPSET_ID_VERSION: i64 = 11;
6
7#[derive(Default, Clone)]
9pub struct Model {
10 graph: GraphProto,
11 domain: Option<String>,
12 model_version: Option<i64>,
13 producer_name: Option<String>,
14 producer_version: Option<String>,
15 doc_string: Option<String>,
16 metadata: Vec<(String, String)>,
17 opset_imports: Option<Vec<OperatorSetIdProto>>,
18}
19
20impl Model {
21 #[inline]
23 pub fn new<G: Into<GraphProto>>(graph: G) -> Self {
24 Model {
25 graph: graph.into(),
26 ..Model::default()
27 }
28 }
29
30 #[inline]
32 pub fn domain<S: Into<String>>(mut self, domain: S) -> Self {
33 self.domain = Some(domain.into());
34 self
35 }
36
37 #[inline]
39 pub fn model_version(mut self, model_version: i64) -> Self {
40 self.model_version = Some(model_version);
41 self
42 }
43
44 #[inline]
46 pub fn producer_name<S: Into<String>>(mut self, producer_name: S) -> Self {
47 self.producer_name = Some(producer_name.into());
48 self
49 }
50
51 #[inline]
53 pub fn producer_version<S: Into<String>>(mut self, producer_version: S) -> Self {
54 self.producer_version = Some(producer_version.into());
55 self
56 }
57
58 #[inline]
60 pub fn doc_string<S: Into<String>>(mut self, doc_string: S) -> Self {
61 self.doc_string = Some(doc_string.into());
62 self
63 }
64
65 #[inline]
67 pub fn metadata<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
68 self.metadata.push((key.into(), value.into()));
69 self
70 }
71
72 #[inline]
74 pub fn opset_import(mut self, opset: OperatorSetIdProto) -> Self {
75 if let Some(opset_imports) = self.opset_imports.as_mut() {
76 opset_imports.push(opset);
77 } else {
78 self.opset_imports = Some(vec![opset]);
79 }
80 self
81 }
82
83 #[inline]
85 pub fn build(self) -> ModelProto {
86 let opset_import = self.opset_imports.unwrap_or_else(|| {
87 vec![OperatorSetIdProto {
88 version: DEFAULT_OPSET_ID_VERSION,
89 ..OperatorSetIdProto::default()
90 }]
91 });
92 let metadata_props = self
93 .metadata
94 .into_iter()
95 .map(|(k, v)| StringStringEntryProto {
96 key: k.into(),
97 value: v.into(),
98 })
99 .collect();
100 ModelProto {
101 ir_version: Version::IrVersion as i64,
102 graph: Some(self.graph),
103 domain: self.domain.unwrap_or_default(),
104 doc_string: self.doc_string.unwrap_or_default(),
105 producer_name: self.producer_name.unwrap_or_default(),
106 producer_version: self.producer_version.unwrap_or_default(),
107 model_version: self.model_version.unwrap_or_default(),
108 opset_import,
109 metadata_props,
110 ..ModelProto::default()
111 }
112 }
113}
114
115impl Into<ModelProto> for Model {
116 fn into(self) -> ModelProto {
117 self.build()
118 }
119}