auto_diff/op/
element.rs

1#![allow(clippy::redundant_closure_call)]
2use tensor_rs::tensor::Tensor;
3use super::{OpTrait, OpHandle};
4use super::macros::new_element_op;
5
6#[cfg(feature = "use-serde")]
7use serde::{Serialize, Deserialize};
8#[cfg(feature = "use-serde")]
9use std::any::Any;
10
11
12new_element_op!(Abs,
13                "Abs",
14                abs,
15                (|input: &[Tensor],
16                 output_grad: &[Tensor],
17                 input_grad: &[Tensor]| {
18                     input_grad[0].swap(
19                         &input[0].conditional_select(
20                             &input[0].ones_like(),
21                             &input[0].ones_like().neg())
22                             .mul(&output_grad[0]));
23                 }));
24
25new_element_op!(Acos,
26                "Acos",
27                acos,
28                (|input: &[Tensor],
29                 output_grad: &[Tensor],
30                 input_grad: &[Tensor]| {
31                     let ret = input[0].ones_like().sub(&input[0].mul(&input[0])).sqrt().reciprocal().neg();
32		     input_grad[0].swap(&ret.mul(&output_grad[0]));
33                 }));
34
35new_element_op!(Asin,
36                "Asin",
37                asin,
38                (|input: &[Tensor],
39                 output_grad: &[Tensor],
40                 input_grad: &[Tensor]| {
41                     let ret = input[0].ones_like().sub(&input[0].mul(&input[0])).sqrt().reciprocal();
42		     input_grad[0].swap(&ret.mul(&output_grad[0]));
43                 }));
44
45new_element_op!(Atan,
46                "Atan",
47                atan,
48                (|input: &[Tensor],
49                 output_grad: &[Tensor],
50                 input_grad: &[Tensor]| {
51                     let ret = input[0].ones_like().add(&input[0].mul(&input[0])).reciprocal();
52		     input_grad[0].swap(&ret.mul(&output_grad[0]));
53                 }));
54
55new_element_op!(Ceil,
56                "Ceil",
57                ceil,
58                (|input: &[Tensor],
59                 output_grad: &[Tensor],
60                 input_grad: &[Tensor]| {
61		     input_grad[0].swap(&input[0].zeros_like());
62                 }));
63
64new_element_op!(Cos,
65                "Cos",
66                cos,
67                (|input: &[Tensor],
68                 output_grad: &[Tensor],
69                 input_grad: &[Tensor]| {
70		     let ret = input[0].sin().neg();
71		     input_grad[0].swap(&ret.mul(&output_grad[0]));
72                 }));
73
74new_element_op!(Cosh,
75                "Cosh",
76                cosh,
77                (|input: &[Tensor],
78                 output_grad: &[Tensor],
79                 input_grad: &[Tensor]| {
80		     let ret = input[0].sinh();
81		     input_grad[0].swap(&ret.mul(&output_grad[0]));
82                 }));
83
84new_element_op!(Exp,
85                "Exp",
86                exp,
87                (|input: &[Tensor],
88                 output_grad: &[Tensor],
89                 input_grad: &[Tensor]| {
90		     let ret = input[0].exp();
91		     input_grad[0].swap(&ret.mul(&output_grad[0]));
92                 }));
93
94
95new_element_op!(Expm1,
96                "Expm1",
97                expm1,
98                (|input: &[Tensor],
99                 output_grad: &[Tensor],
100                 input_grad: &[Tensor]| {
101		     let ret = input[0].exp();
102		     input_grad[0].swap(&ret.mul(&output_grad[0]));
103                 }));
104
105new_element_op!(Floor,
106                "Floor",
107                floor,
108                (|input: &[Tensor],
109                 output_grad: &[Tensor],
110                 input_grad: &[Tensor]| {
111                     input_grad[0].swap(&input[0].zeros_like());
112                 }));
113
114new_element_op!(Frac,
115                "Frac",
116                frac,
117                (|input: &[Tensor],
118                 output_grad: &[Tensor],
119                 input_grad: &[Tensor]| {
120                     input_grad[0].swap(&input[0].ones_like());
121                 }));
122
123new_element_op!(Log,
124                "Log",
125                log,
126                (|input: &[Tensor],
127                 output_grad: &[Tensor],
128                 input_grad: &[Tensor]| {
129		     let ret = input[0].reciprocal();
130		     input_grad[0].swap(&ret.mul(&output_grad[0]));
131                 }));
132
133new_element_op!(Log10,
134                "Log10",
135                log10,
136                (|input: &[Tensor],
137                 output_grad: &[Tensor],
138                 input_grad: &[Tensor]| {
139		     let ret = input[0].reciprocal().div(&input[0].log10_like());
140		     input_grad[0].swap(&ret.mul(&output_grad[0]));
141                 }));
142
143new_element_op!(Log1p,
144                "Log1p",
145                log1p,
146                (|input: &[Tensor],
147                 output_grad: &[Tensor],
148                 input_grad: &[Tensor]| {
149		     let ret = input[0].add(&input[0].ones_like()).reciprocal();
150		     input_grad[0].swap(&ret.mul(&output_grad[0]));
151                 }));
152
153new_element_op!(Log1pexp,
154                "Log1pexp",
155                log1pexp,
156                (|input: &[Tensor],
157                 output_grad: &[Tensor],
158                 input_grad: &[Tensor]| {
159		     let ret = input[0].neg().exp().add(&input[0].ones_like()).reciprocal();
160		     input_grad[0].swap(&ret.mul(&output_grad[0]));
161                 }));
162
163new_element_op!(Log2,
164                "Log2",
165                log2,
166                (|input: &[Tensor],
167                 output_grad: &[Tensor],
168                 input_grad: &[Tensor]| {
169		     let ret = input[0].reciprocal().div(&input[0].log2_like());
170		     input_grad[0].swap(&ret.mul(&output_grad[0]));
171                 }));
172
173new_element_op!(Neg,
174                "Neg",
175                neg,
176                (|input: &[Tensor],
177                 output_grad: &[Tensor],
178                 input_grad: &[Tensor]| {
179		     let ret = input[0].ones_like().neg();
180		     input_grad[0].swap(&ret.mul(&output_grad[0]));
181                 }));
182
183new_element_op!(Reciprocal,
184                "Reciprocal",
185                reciprocal,
186                (|input: &[Tensor],
187                 output_grad: &[Tensor],
188                 input_grad: &[Tensor]| {
189		     let ret = input[0].square().reciprocal().neg();
190		     input_grad[0].swap(&ret.mul(&output_grad[0]));
191                 }));
192
193new_element_op!(Round,
194                "Round",
195                round,
196                (|input: &[Tensor],
197                 output_grad: &[Tensor],
198                 input_grad: &[Tensor]| {
199		     let ret = input[0].zeros_like();
200		     input_grad[0].swap(&ret.mul(&output_grad[0]));
201                 }));
202
203new_element_op!(Rsqrt,
204                "Rsqrt",
205                rsqrt,
206                (|input: &[Tensor],
207                 output_grad: &[Tensor],
208                 input_grad: &[Tensor]| {
209		     let ret = input[0].sqrt().reciprocal().
210                         div(&input[0]).neg().div(
211			 &input[0].ones_like().add(&input[0].ones_like()));
212		     input_grad[0].swap(&ret.mul(&output_grad[0]));
213                 }));
214
215new_element_op!(Sigmoid,
216                "Sigmoid",
217                sigmoid,
218                (|input: &[Tensor],
219                 output_grad: &[Tensor],
220                 input_grad: &[Tensor]| {
221                     let ret = input[0].sigmoid().mul(&input[0].sigmoid().neg().add(&input[0].ones_like()));
222		     input_grad[0].swap(&ret.mul(&output_grad[0]));
223                 }));
224
225new_element_op!(Sign,
226                "Sign",
227                sign,
228                (|input: &[Tensor],
229                 output_grad: &[Tensor],
230                 input_grad: &[Tensor]| {
231		     let ret = input[0].zeros_like();
232		     input_grad[0].swap(&ret.mul(&output_grad[0]));
233                 }));
234
235new_element_op!(Sin,
236                "Sin",
237                sin,
238                (|input: &[Tensor],
239                 output_grad: &[Tensor],
240                 input_grad: &[Tensor]| {
241                     let ret = input[0].cos();
242		     input_grad[0].swap(&ret.mul(&output_grad[0]));
243                 }));
244
245new_element_op!(Sinh,
246                "Sinh",
247                sinh,
248                (|input: &[Tensor],
249                 output_grad: &[Tensor],
250                 input_grad: &[Tensor]| {
251		     let ret = input[0].cosh();
252		     input_grad[0].swap(&ret.mul(&output_grad[0]));
253                 }));
254
255new_element_op!(Sqrt,
256                "Sqrt",
257                sqrt,
258                (|input: &[Tensor],
259                 output_grad: &[Tensor],
260                 input_grad: &[Tensor]| {
261		     let ret = input[0].sqrt().reciprocal().div(
262			 &input[0].ones_like().add(&input[0].ones_like()));
263		     input_grad[0].swap(&ret.mul(&output_grad[0]));
264                 }));
265
266new_element_op!(Tan,
267                "Tan",
268                tan,
269                (|input: &[Tensor],
270                 output_grad: &[Tensor],
271                 input_grad: &[Tensor]| {
272		     let ret = input[0].tan().square().add(&input[0].ones_like());
273		     input_grad[0].swap(&ret.mul(&output_grad[0]));
274                 }));
275
276new_element_op!(Tanh,
277                "Tanh",
278                tanh,
279                (|input: &[Tensor],
280                 output_grad: &[Tensor],
281                 input_grad: &[Tensor]| {
282		     let ret = input[0].tanh().square().neg().add(&input[0].ones_like());
283		     input_grad[0].swap(&ret.mul(&output_grad[0]));
284                 }));
285
286new_element_op!(Trunc,
287                "Trunc",
288                trunc,
289                (|input: &[Tensor],
290                 output_grad: &[Tensor],
291                 input_grad: &[Tensor]| {
292		     let ret = input[0].zeros_like();
293		     input_grad[0].swap(&ret.mul(&output_grad[0]));
294                 }));
295
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::op::_gradient_checker;
301
302    fn test_range_data(op: &mut dyn OpTrait) {
303        for i in 0..10 {
304            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 - 0.51)], &vec![1]);
305            let good_grad = _gradient_checker(op, &[zero], None, None, None);
306            assert_eq!(good_grad, true);                        
307        }
308    }
309
310    #[test]
311    fn abs() {
312        let mut op = Abs::new();
313        test_range_data(&mut op);
314    }
315
316    #[test]
317    fn acos() {
318        let mut op = Acos::new();
319        test_range_data(&mut op);
320    }
321
322    #[test]
323    fn asin() {
324        let mut op = Asin::new();
325        test_range_data(&mut op);
326    }
327
328    #[test]
329    fn atan() {
330        let mut op = Atan::new();
331        test_range_data(&mut op);
332    }
333
334    #[test]
335    fn ceil() {
336        let mut op = Ceil::new();
337        test_range_data(&mut op);
338    }
339
340    #[test]
341    fn cos() {
342        let mut op = Cos::new();
343        test_range_data(&mut op);
344    }
345
346    #[test]
347    fn cosh() {
348        let mut op = Cosh::new();
349        test_range_data(&mut op);
350    }
351
352    #[test]
353    fn exp() {
354        let mut op = Exp::new();
355        test_range_data(&mut op);
356    }
357
358    #[test]
359    fn expm1() {
360        let mut op = Expm1::new();
361        test_range_data(&mut op);
362    }
363
364    #[test]
365    fn floor() {
366        let mut op = Floor::new();
367        test_range_data(&mut op);
368    }
369
370    #[test]
371    fn frac() {
372        let mut op = Frac::new();
373        test_range_data(&mut op);
374    }
375
376    #[test]
377    fn log() {
378        let mut op = Log::new();
379        for i in 0..10 {
380            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
381            let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
382            assert_eq!(good_grad, true);                        
383        }
384    }
385
386    #[test]
387    fn log10() {
388        let mut op = Log10::new();
389        for i in 0..10 {
390            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
391            let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
392            assert_eq!(good_grad, true);                        
393        }
394    }
395
396    #[test]
397    fn log1p() {
398        let mut op = Log1p::new();
399        for i in 0..10 {
400            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 - 0.51)], &vec![1]);
401            let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
402            assert_eq!(good_grad, true);                        
403        }
404    }
405
406    #[test]
407    fn log1pexp() {
408        let mut op = Log1pexp::new();
409        for i in 0..10 {
410            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 - 0.51)], &vec![1]);
411            let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
412            assert_eq!(good_grad, true);                        
413        }
414    }
415
416    #[test]
417    fn log2() {
418        let mut op = Log2::new();
419        for i in 0..10 {
420            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
421            let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
422            assert_eq!(good_grad, true);                        
423        }
424    }
425
426    #[test]
427    fn neg() {
428        let mut op = Neg::new();
429        test_range_data(&mut op);
430    }
431
432    #[test]
433    fn reciprocal() {
434        let mut op = Reciprocal::new();
435        for i in 0..10 {
436            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
437            let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
438            assert_eq!(good_grad, true);                        
439        }
440    }
441
442    #[test]
443    fn round() {
444        let mut op = Round::new();
445        test_range_data(&mut op);
446    }
447
448    #[test]
449    fn rsqrt() {
450        let mut op = Rsqrt::new();
451        for i in 0..10 {
452            let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
453            let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
454            assert_eq!(good_grad, true);                        
455        }
456    }
457
458    #[test]
459    fn sigmoid() {
460        let mut op = Sigmoid::new();
461        test_range_data(&mut op);
462    }
463
464    #[test]
465    fn sign() {
466        let mut op = Sign::new();
467        test_range_data(&mut op);
468    }
469
470    #[test]
471    fn sinh() {
472        let mut op = Sinh::new();
473        test_range_data(&mut op);
474    }
475
476    #[test]
477    fn sqrt() {
478        let mut op = Sqrt::new();
479        test_range_data(&mut op);
480    }
481
482    #[test]
483    fn tan() {
484        let mut op = Tan::new();
485        test_range_data(&mut op);
486    }
487
488    #[test]
489    fn tanh() {
490        let mut op = Tanh::new();
491        test_range_data(&mut op);
492    }
493
494    #[test]
495    fn trunc() {
496        let mut op = Trunc::new();
497        test_range_data(&mut op);
498    }
499}