1#![allow(clippy::redundant_closure_call)]
2
3macro_rules! one_to_1_op_with_paras {
4 ($a:ident, $b:expr, $is:expr,$os:expr, $c:ident, $d: tt, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
5 #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
6 pub struct $a {
7 #[cfg_attr(feature = "use-serde", serde(skip))]
8 handle: OpHandle,
9 $( $arg_name : $ArgTy ),*
10 }
11 impl $a {
12 pub fn new($( $arg_name : $ArgTy ),*) -> $a{
13 $a{
14 handle: OpHandle::new(),
15 $( $arg_name ),*
16 }
17 }
18 fn get_handle(&self) -> &OpHandle {
19 &self.handle
20 }
21 fn get_handle_mut(&mut self) -> &mut OpHandle {
22 &mut self.handle
23 }
24 }
25 impl OpCall for $a {
26 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
27 let new_one = $a {
28 handle: OpHandle::new(),
29 $( $arg_name : self.$arg_name ),*
30 };
31
32 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
33
34 inputs[0].called_with(op, &inputs[1..inputs.len()])
35 }
36 }
37 impl OpTrait for $a {
38
39 fn get_name(&self) -> &'static str {
40 ($b)
41 }
42 fn get_input_size(&self) -> usize {
43 $is
44 }
45 fn get_output_size(&self) -> usize {
46 $os
47 }
48 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
49 output[0].swap(&input[0].$c($( self.$arg_name ),*))
50 }
51 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
52 $d(input, output_grad, input_grad)
53 }
54 fn get_values(&self) -> Vec<Tensor> {
55 Vec::new()
56 }
57 fn get_grads(&self) -> Vec<Tensor> {
58 Vec::new()
59 }
60 fn set_values(&self, _v: &[Tensor]) {
61 }
62 #[cfg(feature = "use-serde")]
63 fn as_any(&self) -> &dyn Any {
64 self
65 }
66 }
67 }
68}
69
70macro_rules! many_to_1_op_with_paras {
71 ($a:ident, $b:expr, $is:expr,$os:expr, $c:ident, $d: tt, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
72 #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
73 pub struct $a {
74 #[cfg_attr(feature = "use-serde", serde(skip))]
75 handle: OpHandle,
76 $( $arg_name : $ArgTy ),*
77 }
78 impl $a {
79 pub fn new($( $arg_name : $ArgTy ),*) -> $a{
80 $a{
81 handle: OpHandle::new(),
82 $( $arg_name ),*
83 }
84 }
85 fn get_handle(&self) -> &OpHandle {
86 &self.handle
87 }
88 fn get_handle_mut(&mut self) -> &mut OpHandle {
89 &mut self.handle
90 }
91 }
92 impl OpCall for $a {
93 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
94 let new_one = $a {
95 handle: OpHandle::new(),
96 $( $arg_name : self.$arg_name ),*
97 };
98
99 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
100
101 inputs[0].called_with(op, &inputs[1..inputs.len()])
102 }
103 }
104 impl OpTrait for $a {
105
106 fn get_name(&self) -> &'static str {
107 ($b)
108 }
109 fn get_input_size(&self) -> usize {
110 $is
111 }
112 fn get_output_size(&self) -> usize {
113 $os
114 }
115 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
116 output[0].swap(&input[0].$c(&input[1..input.len()], $( self.$arg_name ),*))
117 }
118 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
119 $d(input, output_grad, input_grad)
120 }
121 fn get_values(&self) -> Vec<Tensor> {
122 Vec::new()
123 }
124 fn get_grads(&self) -> Vec<Tensor> {
125 Vec::new()
126 }
127 fn set_values(&self, _v: &[Tensor]) {
128 }
129 #[cfg(feature = "use-serde")]
130 fn as_any(&self) -> &dyn Any {
131 self
132 }
133 }
134 }
135}
136
137macro_rules! one_to_vec_op_with_paras {
138 ($a:ident, $b:expr, $is:expr,$os:expr, $c:ident, $d: tt, $( $arg_name:ident : $ArgTy:ty ),* $(,)?) => {
139 #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
140 pub struct $a {
141 #[cfg_attr(feature = "use-serde", serde(skip))]
142 handle: OpHandle,
143 $( $arg_name : $ArgTy ),*
144 }
145 impl $a {
146 pub fn new($( $arg_name : $ArgTy ),*) -> $a{
147 $a{
148 handle: OpHandle::new(),
149 $( $arg_name ),*
150 }
151 }
152 fn get_handle(&self) -> &OpHandle {
153 &self.handle
154 }
155 fn get_handle_mut(&mut self) -> &mut OpHandle {
156 &mut self.handle
157 }
158 }
159 impl OpCall for $a {
160 fn call(&mut self, inputs: &[&Var]) -> Result<Vec<Var>, AutoDiffError> {
161 let new_one = $a {
162 handle: OpHandle::new(),
163 $( $arg_name : self.$arg_name ),*
164 };
165
166 let op = Op::new(Rc::new(RefCell::new(Box::new(new_one))));
167
168 inputs[0].called_with(op, &inputs[1..inputs.len()])
169 }
170 }
171 impl OpTrait for $a {
172
173 fn get_name(&self) -> &'static str {
174 ($b)
175 }
176 fn get_input_size(&self) -> usize {
177 $is
178 }
179 fn get_output_size(&self) -> usize {
180 $os
181 }
182 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
183 let result = input[0].$c($( self.$arg_name ),*);
184 for (i, j) in output.iter().zip(result.iter()) {
185 i.swap(j);
186 }
187 }
188 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
189 $d(input, output_grad, input_grad)
190 }
191 fn get_values(&self) -> Vec<Tensor> {
192 Vec::new()
193 }
194 fn get_grads(&self) -> Vec<Tensor> {
195 Vec::new()
196 }
197 fn set_values(&self, _v: &[Tensor]) {
198 }
199 #[cfg(feature = "use-serde")]
200 fn as_any(&self) -> &dyn Any {
201 self
202 }
203 }
204 }
205}
206
207macro_rules! new_binary_op {
208 ($a:ident, $b:expr, $c:tt, $d: tt) => {
209 #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
210 pub struct $a {
211 #[cfg_attr(feature = "use-serde", serde(skip))]
212 handle: OpHandle,
213 }
214 impl $a {
215 pub fn new() -> $a{
216 $a{
217 handle: OpHandle::new(),
218 }
219 }
220 fn get_handle(&self) -> &OpHandle {
221 &self.handle
222 }
223 fn get_handle_mut(&mut self) -> &mut OpHandle {
224 &mut self.handle
225 }
226 }
227 impl OpTrait for $a {
228
229 fn get_name(&self) -> &'static str {
230 ($b)
231 }
232 fn get_input_size(&self) -> usize {
233 2
234 }
235 fn get_output_size(&self) -> usize {
236 1
237 }
238 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
239 $c(input, output)
240 }
241 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
242 $d(input, output_grad, input_grad)
243 }
244 fn get_values(&self) -> Vec<Tensor> {
245 Vec::new()
246 }
247 fn get_grads(&self) -> Vec<Tensor> {
248 Vec::new()
249 }
250 fn set_values(&self, _v: &[Tensor]) {
251 }
252 #[cfg(feature = "use-serde")]
253 fn as_any(&self) -> &dyn Any {
254 self
255 }
256 }
257 impl Default for $a {
258 fn default() -> Self {
259 Self::new()
260 }
261 }
262 }
263}
264
265macro_rules! new_element_op {
266 ($a:ident, $b:expr, $c:ident, $d: tt) => {
267 #[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
268 pub struct $a {
269 #[cfg_attr(feature = "use-serde", serde(skip))]
270 handle: OpHandle,
271 }
272 impl $a {
273 pub fn new() -> $a{
274 $a{
275 handle: OpHandle::new(),
276 }
277 }
278 fn get_handle(&self) -> &OpHandle {
279 &self.handle
280 }
281 fn get_handle_mut(&mut self) -> &mut OpHandle {
282 &mut self.handle
283 }
284 }
285 impl OpTrait for $a {
286
287 fn get_name(&self) -> &'static str {
288 ($b)
289 }
290 fn get_input_size(&self) -> usize {
291 2
292 }
293 fn get_output_size(&self) -> usize {
294 1
295 }
296 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
297 output[0].swap(&input[0].$c())
298 }
299 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
300 $d(input, output_grad, input_grad)
301 }
302 fn get_values(&self) -> Vec<Tensor> {
303 Vec::new()
304 }
305 fn get_grads(&self) -> Vec<Tensor> {
306 Vec::new()
307 }
308 fn set_values(&self, _v: &[Tensor]) {
309 }
310 #[cfg(feature = "use-serde")]
311 fn as_any(&self) -> &dyn Any {
312 self
313 }
314 }
315 impl Default for $a {
316 fn default() -> Self {
317 Self::new()
318 }
319 }
320 }
321}
322
323pub(crate) use one_to_1_op_with_paras;
324pub(crate) use many_to_1_op_with_paras;
325pub(crate) use one_to_vec_op_with_paras;
326pub(crate) use new_binary_op;
327pub(crate) use new_element_op;