onnx_helpers/nodes/
mod.rs1pub mod ops;
4
5use std::cell::RefCell;
6use std::rc::Rc;
7
8use onnx_pb::{Axes, NodeProto};
9
10use crate::builder::Bag;
11
12#[derive(Clone)]
14pub struct Node {
15 pub(crate) inner: Rc<RefCell<NodeProto>>,
16 pub(crate) bag: Option<Bag>,
17}
18
19impl Node {
20 pub fn from_proto(inner: NodeProto) -> Self {
22 Node {
23 bag: None,
24 inner: Rc::new(RefCell::new(inner)),
25 }
26 }
27
28 pub fn name(&self) -> String {
30 self.inner.borrow().name.clone()
31 }
32
33 pub fn with_name<N: Into<String>>(self, name: N) -> Self {
35 let name = name.into();
36 let mut bag = self.bag.clone();
37 {
38 let mut inner = self.inner.borrow_mut();
39 inner
40 .output
41 .iter_mut()
42 .enumerate()
43 .for_each(|(index, output)| {
44 let name = format!("{}{}", name, index);
45 maybe_bag_rename(&mut bag, &output, &name);
46 *output = name;
47 });
48 maybe_bag_rename(&mut bag, &inner.name, &name);
49 inner.name = name;
50 }
51 self
52 }
53
54 pub fn abs(&self) -> Node {
56 let mut node: Node = ops::Abs::new(self.select_output()).into();
57 maybe_bag_node(self.bag.clone(), &mut node);
58 node
59 }
60
61 pub fn sqrt(&self) -> Node {
63 let mut node: Node = ops::Sqrt::new(self.select_output()).into();
64 maybe_bag_node(self.bag.clone(), &mut node);
65 node
66 }
67
68 pub fn pow<T: Into<String>>(&self, power: T) -> Node {
70 let mut node: Node = ops::Pow::new(self.select_output(), power).into();
71 maybe_bag_node(self.bag.clone(), &mut node);
72 node
73 }
74
75 pub fn sum<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
77 let mut node: Node = ops::ReduceSum::new(self.select_output(), axes, keepdims).into();
78 maybe_bag_node(self.bag.clone(), &mut node);
79 node
80 }
81
82 pub fn max<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
84 let mut node: Node = ops::ReduceMax::new(self.select_output(), axes, keepdims).into();
85 maybe_bag_node(self.bag.clone(), &mut node);
86 node
87 }
88
89 pub fn mean<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
91 let mut node: Node = ops::ReduceMean::new(self.select_output(), axes, keepdims).into();
92 maybe_bag_node(self.bag.clone(), &mut node);
93 node
94 }
95
96 pub fn min<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
98 let mut node: Node = ops::ReduceMin::new(self.select_output(), axes, keepdims).into();
99 maybe_bag_node(self.bag.clone(), &mut node);
100 node
101 }
102
103 pub fn equal<Rhs: Into<String>>(&self, right: Rhs) -> Node {
105 let mut node: Node = ops::Equal::new(self.select_output(), right).into();
106 maybe_bag_node(self.bag.clone(), &mut node);
107 node
108 }
109
110 pub fn greater<Rhs: Into<String>>(&self, right: Rhs) -> Node {
112 let mut node: Node = ops::Greater::new(self.select_output(), right).into();
113 maybe_bag_node(self.bag.clone(), &mut node);
114 node
115 }
116
117 pub fn less<Rhs: Into<String>>(&self, right: Rhs) -> Node {
119 let mut node: Node = ops::Less::new(self.select_output(), right).into();
120 maybe_bag_node(self.bag.clone(), &mut node);
121 node
122 }
123
124 pub fn and<Rhs: Into<String>>(&self, right: Rhs) -> Node {
126 let mut node: Node = ops::And::new(self.select_output(), right).into();
127 maybe_bag_node(self.bag.clone(), &mut node);
128 node
129 }
130
131 pub fn or<Rhs: Into<String>>(&self, right: Rhs) -> Node {
133 let mut node: Node = ops::Or::new(self.select_output(), right).into();
134 maybe_bag_node(self.bag.clone(), &mut node);
135 node
136 }
137
138 pub fn relu(&self) -> Node {
140 let mut node: Node = ops::Relu::new(self.select_output()).into();
141 maybe_bag_node(self.bag.clone(), &mut node);
142 node
143 }
144
145 pub fn tanh(&self) -> Node {
147 let mut node: Node = ops::Tanh::new(self.select_output()).into();
148 maybe_bag_node(self.bag.clone(), &mut node);
149 node
150 }
151
152 pub fn size(&self) -> Node {
154 let mut node: Node = ops::Size::new(self.select_output()).into();
155 maybe_bag_node(self.bag.clone(), &mut node);
156 node
157 }
158
159 #[inline]
160 fn select_output(&self) -> String {
161 let node = self.inner.borrow();
162 if node.op_type.is_empty() {
163 node.name.clone()
164 } else {
165 node.output.first().unwrap().to_owned()
166 }
167 }
168}
169
170#[macro_export]
171macro_rules! impl_nodes_op {
172 ( $t:ident, $k:ident, $f:ident ) => {
173 impl<Rhs: AsRef<Node>> std::ops::$k<Rhs> for $t {
174 type Output = Node;
175
176 #[inline(always)]
177 fn $f(self, rhs: Rhs) -> Self::Output {
178 let mut node: Node = ops::$k::new(self.select_output(), rhs.as_ref()).into();
179 maybe_bag_node(self.bag.clone(), &mut node);
180 node
181 }
182 }
183
184 impl<Rhs: AsRef<Node>> std::ops::$k<Rhs> for &$t {
185 type Output = Node;
186
187 #[inline(always)]
188 fn $f(self, rhs: Rhs) -> Self::Output {
189 let mut node: Node = ops::$k::new(self.select_output(), rhs.as_ref()).into();
190 maybe_bag_node(self.bag.clone(), &mut node);
191 node
192 }
193 }
194 };
195}
196
197#[macro_export]
198macro_rules! impl_node_op {
199 ( $t:ident, $k:ident, $f:ident ) => {
200 impl std::ops::$k for &$t {
201 type Output = Node;
202
203 #[inline(always)]
204 fn $f(self) -> Self::Output {
205 let mut node: Node = ops::$k::new(self.select_output()).into();
206 maybe_bag_node(self.bag.clone(), &mut node);
207 node
208 }
209 }
210
211 impl std::ops::$k for $t {
212 type Output = Node;
213
214 #[inline(always)]
215 fn $f(self) -> Self::Output {
216 let mut node: Node = ops::$k::new(self.select_output()).into();
217 maybe_bag_node(self.bag.clone(), &mut node);
218 node
219 }
220 }
221 };
222}
223
224impl_nodes_op!(Node, Add, add);
225impl_nodes_op!(Node, Sub, sub);
226impl_nodes_op!(Node, Mul, mul);
227impl_nodes_op!(Node, Div, div);
228impl_node_op!(Node, Neg, neg);
229impl_node_op!(Node, Not, not);
230
231impl From<NodeProto> for Node {
232 fn from(inner: NodeProto) -> Self {
233 Node::from_proto(inner)
234 }
235}
236
237impl Into<NodeProto> for Node {
238 fn into(self) -> NodeProto {
239 self.inner.borrow().clone()
240 }
241}
242
243impl From<Node> for String {
244 fn from(node: Node) -> String {
245 node.select_output()
246 }
247}
248
249impl From<&Node> for String {
250 fn from(node: &Node) -> String {
251 node.select_output()
252 }
253}
254
255impl AsRef<Node> for Node {
256 #[inline(always)]
257 fn as_ref(&self) -> &Node {
258 &self
259 }
260}
261
262#[inline(always)]
263pub(crate) fn maybe_bag_node(bag: Option<Bag>, node: &mut Node) {
264 if let Some(mut bag) = bag {
265 node.bag = Some(bag.clone());
266 bag.node(node.inner.clone());
267 }
268}
269
270#[inline(always)]
271pub(crate) fn maybe_bag_rename(bag: &mut Option<Bag>, name: &str, new_name: &str) {
272 if let Some(bag) = bag.as_mut() {
273 bag.rename(name, new_name);
274 }
275}