onnx_helpers/builder/
node.rs

1//! Node builder.
2
3use onnx_pb::{make_attribute, Attribute, NodeProto};
4
5use crate::{builder::Bag, nodes};
6
7/// Node builder.
8#[derive(Default, Clone)]
9pub struct Node {
10    op_type: String,
11    inputs: Vec<String>,
12    outputs: Vec<String>,
13    name: Option<String>,
14    doc_string: Option<String>,
15    domain: Option<String>,
16    attributes: Vec<(String, Attribute)>,
17    pub(crate) bag: Option<Bag>,
18}
19
20impl Node {
21    /// Creates a new builder.
22    #[inline]
23    pub fn new<S: Into<String>>(op_type: S) -> Self {
24        Node {
25            op_type: op_type.into(),
26            ..Node::default()
27        }
28    }
29
30    /// Creates a new builder.
31    #[inline]
32    pub fn named<S: Into<String>>(name: S) -> Self {
33        Node {
34            name: Some(name.into()),
35            ..Node::default()
36        }
37    }
38
39    /// Sets node name.
40    #[inline]
41    pub fn name<S: Into<String>>(mut self, name: S) -> Self {
42        self.name = Some(name.into());
43        self
44    }
45
46    /// Sets node op type.
47    #[inline]
48    pub fn op<S: Into<String>>(mut self, op: S) -> Self {
49        self.op_type = op.into();
50        self
51    }
52
53    /// Sets node doc_string.
54    #[inline]
55    pub fn doc_string<S: Into<String>>(mut self, doc_string: S) -> Self {
56        self.doc_string = Some(doc_string.into());
57        self
58    }
59
60    /// Sets node domain.
61    #[inline]
62    pub fn domain<S: Into<String>>(mut self, domain: S) -> Self {
63        self.domain = Some(domain.into());
64        self
65    }
66
67    /// Inserts node input.
68    #[inline]
69    pub fn input<S: Into<String>>(mut self, input: S) -> Self {
70        self.inputs.push(input.into());
71        self
72    }
73
74    /// Inserts node output.
75    #[inline]
76    pub fn output<S: Into<String>>(mut self, output: S) -> Self {
77        self.outputs.push(output.into());
78        self
79    }
80
81    /// Inserts node inputs.
82    #[inline]
83    pub fn inputs<I>(mut self, inputs: I) -> Self
84    where
85        I: IntoIterator,
86        I::Item: Into<String>,
87    {
88        for input in inputs {
89            self.inputs.push(input.into());
90        }
91        self
92    }
93
94    /// Inserts node outputs.
95    #[inline]
96    pub fn outputs<I>(mut self, outputs: I) -> Self
97    where
98        I: IntoIterator,
99        I::Item: Into<String>,
100    {
101        for output in outputs {
102            self.outputs.push(output.into());
103        }
104        self
105    }
106
107    /// Inserts node attributes.
108    #[inline]
109    pub fn attribute<S: Into<String>, A: Into<Attribute>>(mut self, name: S, attribute: A) -> Self {
110        self.attributes.push((name.into(), attribute.into()));
111        self
112    }
113
114    /// Builds the node.
115    #[inline]
116    pub fn build(self) -> nodes::Node {
117        let name = if let Some(name) = self.name {
118            name
119        } else {
120            let attrs = self
121                .attributes
122                .iter()
123                .map(|(name, attr)| format!("{}_{}", name, attr))
124                .collect::<Vec<String>>()
125                .join("_");
126            if self.inputs.len() == 2 {
127                format!(
128                    "{}_{}_{}_{}",
129                    self.inputs.get(0).unwrap(),
130                    self.op_type,
131                    self.inputs.get(1).unwrap(),
132                    attrs
133                )
134            } else {
135                format!(
136                    "S{}_{}_{}_{}E",
137                    self.op_type,
138                    self.inputs.join("_"),
139                    self.op_type,
140                    attrs
141                )
142            }
143        };
144        let output = if self.outputs.len() > 0 {
145            self.outputs
146        } else {
147            vec![format!("{}O", name)]
148        };
149        let attributes = self
150            .attributes
151            .into_iter()
152            .map(|(name, attr)| make_attribute(name, attr))
153            .collect();
154        let proto = NodeProto {
155            name,
156            domain: self.domain.unwrap_or_default(),
157            op_type: self.op_type,
158            doc_string: self.doc_string.unwrap_or_default(),
159            input: self.inputs,
160            output: output,
161            attribute: attributes,
162        };
163        let mut node = nodes::Node::from_proto(proto);
164        nodes::maybe_bag_node(self.bag.clone(), &mut node);
165        node
166    }
167}
168
169impl Into<nodes::Node> for Node {
170    fn into(self) -> nodes::Node {
171        self.build()
172    }
173}
174
175impl Into<NodeProto> for Node {
176    fn into(self) -> NodeProto {
177        self.build().into()
178    }
179}