onnx_helpers/builder/
node.rs1use onnx_pb::{make_attribute, Attribute, NodeProto};
4
5use crate::{builder::Bag, nodes};
6
7#[derive(Default, Clone)]
9pub struct Node {
10 op_type: String,
11 inputs: Vec<String>,
12 outputs: Vec<String>,
13 name: Option<String>,
14 doc_string: Option<String>,
15 domain: Option<String>,
16 attributes: Vec<(String, Attribute)>,
17 pub(crate) bag: Option<Bag>,
18}
19
20impl Node {
21 #[inline]
23 pub fn new<S: Into<String>>(op_type: S) -> Self {
24 Node {
25 op_type: op_type.into(),
26 ..Node::default()
27 }
28 }
29
30 #[inline]
32 pub fn named<S: Into<String>>(name: S) -> Self {
33 Node {
34 name: Some(name.into()),
35 ..Node::default()
36 }
37 }
38
39 #[inline]
41 pub fn name<S: Into<String>>(mut self, name: S) -> Self {
42 self.name = Some(name.into());
43 self
44 }
45
46 #[inline]
48 pub fn op<S: Into<String>>(mut self, op: S) -> Self {
49 self.op_type = op.into();
50 self
51 }
52
53 #[inline]
55 pub fn doc_string<S: Into<String>>(mut self, doc_string: S) -> Self {
56 self.doc_string = Some(doc_string.into());
57 self
58 }
59
60 #[inline]
62 pub fn domain<S: Into<String>>(mut self, domain: S) -> Self {
63 self.domain = Some(domain.into());
64 self
65 }
66
67 #[inline]
69 pub fn input<S: Into<String>>(mut self, input: S) -> Self {
70 self.inputs.push(input.into());
71 self
72 }
73
74 #[inline]
76 pub fn output<S: Into<String>>(mut self, output: S) -> Self {
77 self.outputs.push(output.into());
78 self
79 }
80
81 #[inline]
83 pub fn inputs<I>(mut self, inputs: I) -> Self
84 where
85 I: IntoIterator,
86 I::Item: Into<String>,
87 {
88 for input in inputs {
89 self.inputs.push(input.into());
90 }
91 self
92 }
93
94 #[inline]
96 pub fn outputs<I>(mut self, outputs: I) -> Self
97 where
98 I: IntoIterator,
99 I::Item: Into<String>,
100 {
101 for output in outputs {
102 self.outputs.push(output.into());
103 }
104 self
105 }
106
107 #[inline]
109 pub fn attribute<S: Into<String>, A: Into<Attribute>>(mut self, name: S, attribute: A) -> Self {
110 self.attributes.push((name.into(), attribute.into()));
111 self
112 }
113
114 #[inline]
116 pub fn build(self) -> nodes::Node {
117 let name = if let Some(name) = self.name {
118 name
119 } else {
120 let attrs = self
121 .attributes
122 .iter()
123 .map(|(name, attr)| format!("{}_{}", name, attr))
124 .collect::<Vec<String>>()
125 .join("_");
126 if self.inputs.len() == 2 {
127 format!(
128 "{}_{}_{}_{}",
129 self.inputs.get(0).unwrap(),
130 self.op_type,
131 self.inputs.get(1).unwrap(),
132 attrs
133 )
134 } else {
135 format!(
136 "S{}_{}_{}_{}E",
137 self.op_type,
138 self.inputs.join("_"),
139 self.op_type,
140 attrs
141 )
142 }
143 };
144 let output = if self.outputs.len() > 0 {
145 self.outputs
146 } else {
147 vec![format!("{}O", name)]
148 };
149 let attributes = self
150 .attributes
151 .into_iter()
152 .map(|(name, attr)| make_attribute(name, attr))
153 .collect();
154 let proto = NodeProto {
155 name,
156 domain: self.domain.unwrap_or_default(),
157 op_type: self.op_type,
158 doc_string: self.doc_string.unwrap_or_default(),
159 input: self.inputs,
160 output: output,
161 attribute: attributes,
162 };
163 let mut node = nodes::Node::from_proto(proto);
164 nodes::maybe_bag_node(self.bag.clone(), &mut node);
165 node
166 }
167}
168
169impl Into<nodes::Node> for Node {
170 fn into(self) -> nodes::Node {
171 self.build()
172 }
173}
174
175impl Into<NodeProto> for Node {
176 fn into(self) -> NodeProto {
177 self.build().into()
178 }
179}