auto_diff/op/
macros.rs

1#![allow(clippy::redundant_closure_call)]
2
3macro_rules! one_to_1_op_with_paras {
4    ($a:ident, $b:expr, $is:expr,$os:expr, $c:ident, $d: tt, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
5	#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
6        pub struct $a {
7	    #[cfg_attr(feature = "use-serde", serde(skip))]
8            handle: OpHandle,
9            $( $arg_name : $ArgTy ),*
10        }
11        impl $a {
12            pub fn new($( $arg_name : $ArgTy ),*) -> $a{
13                $a{
14                    handle: OpHandle::new(),
15                    $( $arg_name ),*
16                }
17            }
18            fn get_handle(&self) -> &OpHandle {
19                &self.handle
20            }
21            fn get_handle_mut(&mut self) -> &mut OpHandle {
22                &mut self.handle
23            }
24        }
25        impl OpCall for $a {
26            fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
27                let new_one = $a {
28                    handle: OpHandle::new(),
29                    $( $arg_name : self.$arg_name ),*
30                };
31
32                let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
33
34                inputs[0].called_with(op, &inputs[1..inputs.len()])
35            }
36        }
37        impl OpTrait for $a {
38     
39            fn get_name(&self) -> &'static str {
40                ($b)
41            }
42            fn get_input_size(&self) -> usize {
43                $is
44            }
45            fn get_output_size(&self) -> usize {
46                $os
47            }
48            fn apply(&self, input: &[Tensor], output: &[Tensor]) {
49                output[0].swap(&input[0].$c($( self.$arg_name ),*))
50            }
51            fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
52                $d(input, output_grad, input_grad)
53            }
54            fn get_values(&self) -> Vec<Tensor> {
55                Vec::new()
56            }
57            fn get_grads(&self) -> Vec<Tensor> {
58                Vec::new()
59            }
60            fn set_values(&self, _v: &[Tensor]) {
61            }
62	    #[cfg(feature = "use-serde")]
63	    fn as_any(&self) -> &dyn Any {
64		self
65	    }
66        }
67    }
68}
69
70macro_rules! many_to_1_op_with_paras {
71    ($a:ident, $b:expr, $is:expr,$os:expr, $c:ident, $d: tt, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
72	#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
73        pub struct $a {
74	    #[cfg_attr(feature = "use-serde", serde(skip))]
75            handle: OpHandle,
76            $( $arg_name : $ArgTy ),*
77        }
78        impl $a {
79            pub fn new($( $arg_name : $ArgTy ),*) -> $a{
80                $a{
81                    handle: OpHandle::new(),
82                    $( $arg_name ),*
83                }
84            }
85            fn get_handle(&self) -> &OpHandle {
86                &self.handle
87            }
88            fn get_handle_mut(&mut self) -> &mut OpHandle {
89                &mut self.handle
90            }
91        }
92        impl OpCall for $a {
93            fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
94                let new_one = $a {
95                    handle: OpHandle::new(),
96                    $( $arg_name : self.$arg_name ),*
97                };
98
99                let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
100
101                inputs[0].called_with(op, &inputs[1..inputs.len()])
102            }
103        }
104        impl OpTrait for $a {
105     
106            fn get_name(&self) -> &'static str {
107                ($b)
108            }
109            fn get_input_size(&self) -> usize {
110                $is
111            }
112            fn get_output_size(&self) -> usize {
113                $os
114            }
115            fn apply(&self, input: &[Tensor], output: &[Tensor]) {
116                output[0].swap(&input[0].$c(&input[1..input.len()], $( self.$arg_name ),*))
117            }
118            fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
119                $d(input, output_grad, input_grad)
120            }
121            fn get_values(&self) -> Vec<Tensor> {
122                Vec::new()
123            }
124            fn get_grads(&self) -> Vec<Tensor> {
125                Vec::new()
126            }
127            fn set_values(&self, _v: &[Tensor]) {
128            }
129	    #[cfg(feature = "use-serde")]
130	    fn as_any(&self) -> &dyn Any {
131		self
132	    }
133        }
134    }
135}
136
137macro_rules! one_to_vec_op_with_paras {
138    ($a:ident, $b:expr, $is:expr,$os:expr, $c:ident, $d: tt, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
139	#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
140        pub struct $a {
141	    #[cfg_attr(feature = "use-serde", serde(skip))]
142            handle: OpHandle,
143            $( $arg_name : $ArgTy ),*
144        }
145        impl $a {
146            pub fn new($( $arg_name : $ArgTy ),*) -> $a{
147                $a{
148                    handle: OpHandle::new(),
149                    $( $arg_name ),*
150                }
151            }
152            fn get_handle(&self) -> &OpHandle {
153                &self.handle
154            }
155            fn get_handle_mut(&mut self) -> &mut OpHandle {
156                &mut self.handle
157            }
158        }
159        impl OpCall for $a {
160            fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
161                let new_one = $a {
162                    handle: OpHandle::new(),
163                    $( $arg_name : self.$arg_name ),*
164                };
165
166                let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
167
168                inputs[0].called_with(op, &inputs[1..inputs.len()])
169            }
170        }
171        impl OpTrait for $a {
172     
173            fn get_name(&self) -> &'static str {
174                ($b)
175            }
176            fn get_input_size(&self) -> usize {
177                $is
178            }
179            fn get_output_size(&self) -> usize {
180                $os
181            }
182            fn apply(&self, input: &[Tensor], output: &[Tensor]) {
183                let result = input[0].$c($( self.$arg_name ),*);
184                for (i, j) in output.iter().zip(result.iter()) {
185                    i.swap(j);
186                }
187            }
188            fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
189                $d(input, output_grad, input_grad)
190            }
191            fn get_values(&self) -> Vec<Tensor> {
192                Vec::new()
193            }
194            fn get_grads(&self) -> Vec<Tensor> {
195                Vec::new()
196            }
197            fn set_values(&self, _v: &[Tensor]) {
198            }
199	    #[cfg(feature = "use-serde")]
200	    fn as_any(&self) -> &dyn Any {
201		self
202	    }
203        }
204    }
205}
206
207macro_rules! new_binary_op {
208    ($a:ident, $b:expr, $c:tt, $d: tt) => {
209	#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
210        pub struct $a {
211	    #[cfg_attr(feature = "use-serde", serde(skip))]
212            handle: OpHandle,
213        }
214        impl $a {
215            pub fn new() -> $a{
216                $a{
217                    handle: OpHandle::new(),
218                }
219            }
220            fn get_handle(&self) -> &OpHandle {
221                &self.handle
222            }
223            fn get_handle_mut(&mut self) -> &mut OpHandle {
224                &mut self.handle
225            }
226        }
227        impl OpTrait for $a {
228     
229            fn get_name(&self) -> &'static str {
230                ($b)
231            }
232            fn get_input_size(&self) -> usize {
233                2
234            }
235            fn get_output_size(&self) -> usize {
236                1
237            }
238            fn apply(&self, input: &[Tensor], output: &[Tensor]) {
239                $c(input, output)
240            }
241            fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
242                $d(input, output_grad, input_grad)
243            }
244            fn get_values(&self) -> Vec<Tensor> {
245                Vec::new()
246            }
247            fn get_grads(&self) -> Vec<Tensor> {
248                Vec::new()
249            }
250            fn set_values(&self, _v: &[Tensor]) {
251            }
252	    #[cfg(feature = "use-serde")]
253	    fn as_any(&self) -> &dyn Any {
254		self
255	    }
256        }
257        impl Default for $a {
258            fn default() -> Self {
259                Self::new()
260            }
261        }
262    }
263}
264
265macro_rules! new_element_op {
266    ($a:ident, $b:expr, $c:ident, $d: tt) => {
267	#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
268        pub struct $a {
269	    #[cfg_attr(feature = "use-serde", serde(skip))]
270            handle: OpHandle,
271        }
272        impl $a {
273            pub fn new() -> $a{
274                $a{
275                    handle: OpHandle::new(),
276                }
277            }
278            fn get_handle(&self) -> &OpHandle {
279                &self.handle
280            }
281            fn get_handle_mut(&mut self) -> &mut OpHandle {
282                &mut self.handle
283            }
284        }
285        impl OpTrait for $a {
286     
287            fn get_name(&self) -> &'static str {
288                ($b)
289            }
290            fn get_input_size(&self) -> usize {
291                2
292            }
293            fn get_output_size(&self) -> usize {
294                1
295            }
296            fn apply(&self, input: &[Tensor], output: &[Tensor]) {
297                output[0].swap(&input[0].$c())
298            }
299            fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
300                $d(input, output_grad, input_grad)
301            }
302            fn get_values(&self) -> Vec<Tensor> {
303                Vec::new()
304            }
305            fn get_grads(&self) -> Vec<Tensor> {
306                Vec::new()
307            }
308            fn set_values(&self, _v: &[Tensor]) {
309            }
310	    #[cfg(feature = "use-serde")]
311	    fn as_any(&self) -> &dyn Any {
312		self
313	    }
314        }
315        impl Default for $a {
316            fn default() -> Self {
317                Self::new()
318            }
319        }
320    }
321}
322
323pub(crate) use one_to_1_op_with_paras;
324pub(crate) use many_to_1_op_with_paras;
325pub(crate) use one_to_vec_op_with_paras;
326pub(crate) use new_binary_op;
327pub(crate) use new_element_op;