onnx_helpers/nodes/
mod.rs

1//! Onnx node helpers.
2
3pub mod ops;
4
5use std::cell::RefCell;
6use std::rc::Rc;
7
8use onnx_pb::{Axes, NodeProto};
9
10use crate::builder::Bag;
11
12/// Node wrapper.
13#[derive(Clone)]
14pub struct Node {
15    pub(crate) inner: Rc<RefCell<NodeProto>>,
16    pub(crate) bag: Option<Bag>,
17}
18
19impl Node {
20    /// Creates new node from proto.
21    pub fn from_proto(inner: NodeProto) -> Self {
22        Node {
23            bag: None,
24            inner: Rc::new(RefCell::new(inner)),
25        }
26    }
27
28    /// Returns node name.
29    pub fn name(&self) -> String {
30        self.inner.borrow().name.clone()
31    }
32
33    /// Renames output names accordingly.
34    pub fn with_name<N: Into<String>>(self, name: N) -> Self {
35        let name = name.into();
36        let mut bag = self.bag.clone();
37        {
38            let mut inner = self.inner.borrow_mut();
39            inner
40                .output
41                .iter_mut()
42                .enumerate()
43                .for_each(|(index, output)| {
44                    let name = format!("{}{}", name, index);
45                    maybe_bag_rename(&mut bag, &output, &name);
46                    *output = name;
47                });
48            maybe_bag_rename(&mut bag, &inner.name, &name);
49            inner.name = name;
50        }
51        self
52    }
53
54    /// Creates new absolute operation.
55    pub fn abs(&self) -> Node {
56        let mut node: Node = ops::Abs::new(self.select_output()).into();
57        maybe_bag_node(self.bag.clone(), &mut node);
58        node
59    }
60
61    /// Creates new square root operation.
62    pub fn sqrt(&self) -> Node {
63        let mut node: Node = ops::Sqrt::new(self.select_output()).into();
64        maybe_bag_node(self.bag.clone(), &mut node);
65        node
66    }
67
68    /// Creates new power operation.
69    pub fn pow<T: Into<String>>(&self, power: T) -> Node {
70        let mut node: Node = ops::Pow::new(self.select_output(), power).into();
71        maybe_bag_node(self.bag.clone(), &mut node);
72        node
73    }
74
75    /// Creates new reduce sum operation.
76    pub fn sum<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
77        let mut node: Node = ops::ReduceSum::new(self.select_output(), axes, keepdims).into();
78        maybe_bag_node(self.bag.clone(), &mut node);
79        node
80    }
81
82    /// Creates new reduce max operation.
83    pub fn max<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
84        let mut node: Node = ops::ReduceMax::new(self.select_output(), axes, keepdims).into();
85        maybe_bag_node(self.bag.clone(), &mut node);
86        node
87    }
88
89    /// Creates new reduce mean operation.
90    pub fn mean<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
91        let mut node: Node = ops::ReduceMean::new(self.select_output(), axes, keepdims).into();
92        maybe_bag_node(self.bag.clone(), &mut node);
93        node
94    }
95
96    /// Creates new reduce min operation.
97    pub fn min<A: Into<Axes>>(&self, axes: A, keepdims: bool) -> Node {
98        let mut node: Node = ops::ReduceMin::new(self.select_output(), axes, keepdims).into();
99        maybe_bag_node(self.bag.clone(), &mut node);
100        node
101    }
102
103    /// Creates new equal comparison operation.
104    pub fn equal<Rhs: Into<String>>(&self, right: Rhs) -> Node {
105        let mut node: Node = ops::Equal::new(self.select_output(), right).into();
106        maybe_bag_node(self.bag.clone(), &mut node);
107        node
108    }
109
110    /// Creates new greater comparison operation.
111    pub fn greater<Rhs: Into<String>>(&self, right: Rhs) -> Node {
112        let mut node: Node = ops::Greater::new(self.select_output(), right).into();
113        maybe_bag_node(self.bag.clone(), &mut node);
114        node
115    }
116
117    /// Creates new less comparison operation.
118    pub fn less<Rhs: Into<String>>(&self, right: Rhs) -> Node {
119        let mut node: Node = ops::Less::new(self.select_output(), right).into();
120        maybe_bag_node(self.bag.clone(), &mut node);
121        node
122    }
123
124    /// Creates new logical and operation.
125    pub fn and<Rhs: Into<String>>(&self, right: Rhs) -> Node {
126        let mut node: Node = ops::And::new(self.select_output(), right).into();
127        maybe_bag_node(self.bag.clone(), &mut node);
128        node
129    }
130
131    /// Creates new logical or operation.
132    pub fn or<Rhs: Into<String>>(&self, right: Rhs) -> Node {
133        let mut node: Node = ops::Or::new(self.select_output(), right).into();
134        maybe_bag_node(self.bag.clone(), &mut node);
135        node
136    }
137
138    /// Creates new relu activation operation.
139    pub fn relu(&self) -> Node {
140        let mut node: Node = ops::Relu::new(self.select_output()).into();
141        maybe_bag_node(self.bag.clone(), &mut node);
142        node
143    }
144
145    /// Creates new tanh activation operation.
146    pub fn tanh(&self) -> Node {
147        let mut node: Node = ops::Tanh::new(self.select_output()).into();
148        maybe_bag_node(self.bag.clone(), &mut node);
149        node
150    }
151
152    /// Creates new size operation.
153    pub fn size(&self) -> Node {
154        let mut node: Node = ops::Size::new(self.select_output()).into();
155        maybe_bag_node(self.bag.clone(), &mut node);
156        node
157    }
158
159    #[inline]
160    fn select_output(&self) -> String {
161        let node = self.inner.borrow();
162        if node.op_type.is_empty() {
163            node.name.clone()
164        } else {
165            node.output.first().unwrap().to_owned()
166        }
167    }
168}
169
170#[macro_export]
171macro_rules! impl_nodes_op {
172    ( $t:ident, $k:ident, $f:ident ) => {
173        impl<Rhs: AsRef<Node>> std::ops::$k<Rhs> for $t {
174            type Output = Node;
175
176            #[inline(always)]
177            fn $f(self, rhs: Rhs) -> Self::Output {
178                let mut node: Node = ops::$k::new(self.select_output(), rhs.as_ref()).into();
179                maybe_bag_node(self.bag.clone(), &mut node);
180                node
181            }
182        }
183
184        impl<Rhs: AsRef<Node>> std::ops::$k<Rhs> for &$t {
185            type Output = Node;
186
187            #[inline(always)]
188            fn $f(self, rhs: Rhs) -> Self::Output {
189                let mut node: Node = ops::$k::new(self.select_output(), rhs.as_ref()).into();
190                maybe_bag_node(self.bag.clone(), &mut node);
191                node
192            }
193        }
194    };
195}
196
197#[macro_export]
198macro_rules! impl_node_op {
199    ( $t:ident, $k:ident, $f:ident ) => {
200        impl std::ops::$k for &$t {
201            type Output = Node;
202
203            #[inline(always)]
204            fn $f(self) -> Self::Output {
205                let mut node: Node = ops::$k::new(self.select_output()).into();
206                maybe_bag_node(self.bag.clone(), &mut node);
207                node
208            }
209        }
210
211        impl std::ops::$k for $t {
212            type Output = Node;
213
214            #[inline(always)]
215            fn $f(self) -> Self::Output {
216                let mut node: Node = ops::$k::new(self.select_output()).into();
217                maybe_bag_node(self.bag.clone(), &mut node);
218                node
219            }
220        }
221    };
222}
223
224impl_nodes_op!(Node, Add, add);
225impl_nodes_op!(Node, Sub, sub);
226impl_nodes_op!(Node, Mul, mul);
227impl_nodes_op!(Node, Div, div);
228impl_node_op!(Node, Neg, neg);
229impl_node_op!(Node, Not, not);
230
231impl From<NodeProto> for Node {
232    fn from(inner: NodeProto) -> Self {
233        Node::from_proto(inner)
234    }
235}
236
237impl Into<NodeProto> for Node {
238    fn into(self) -> NodeProto {
239        self.inner.borrow().clone()
240    }
241}
242
243impl From<Node> for String {
244    fn from(node: Node) -> String {
245        node.select_output()
246    }
247}
248
249impl From<&Node> for String {
250    fn from(node: &Node) -> String {
251        node.select_output()
252    }
253}
254
255impl AsRef<Node> for Node {
256    #[inline(always)]
257    fn as_ref(&self) -> &Node {
258        &self
259    }
260}
261
262#[inline(always)]
263pub(crate) fn maybe_bag_node(bag: Option<Bag>, node: &mut Node) {
264    if let Some(mut bag) = bag {
265        node.bag = Some(bag.clone());
266        bag.node(node.inner.clone());
267    }
268}
269
270#[inline(always)]
271pub(crate) fn maybe_bag_rename(bag: &mut Option<Bag>, name: &str, new_name: &str) {
272    if let Some(bag) = bag.as_mut() {
273        bag.rename(name, new_name);
274    }
275}