1#![allow(clippy::redundant_closure_call)]
2use tensor_rs::tensor::Tensor;
3use super::{OpTrait, OpHandle};
4use super::macros::new_element_op;
5
6#[cfg(feature = "use-serde")]
7use serde::{Serialize, Deserialize};
8#[cfg(feature = "use-serde")]
9use std::any::Any;
10
11
12new_element_op!(Abs,
13 "Abs",
14 abs,
15 (|input: &[Tensor],
16 output_grad: &[Tensor],
17 input_grad: &[Tensor]| {
18 input_grad[0].swap(
19 &input[0].conditional_select(
20 &input[0].ones_like(),
21 &input[0].ones_like().neg())
22 .mul(&output_grad[0]));
23 }));
24
25new_element_op!(Acos,
26 "Acos",
27 acos,
28 (|input: &[Tensor],
29 output_grad: &[Tensor],
30 input_grad: &[Tensor]| {
31 let ret = input[0].ones_like().sub(&input[0].mul(&input[0])).sqrt().reciprocal().neg();
32 input_grad[0].swap(&ret.mul(&output_grad[0]));
33 }));
34
35new_element_op!(Asin,
36 "Asin",
37 asin,
38 (|input: &[Tensor],
39 output_grad: &[Tensor],
40 input_grad: &[Tensor]| {
41 let ret = input[0].ones_like().sub(&input[0].mul(&input[0])).sqrt().reciprocal();
42 input_grad[0].swap(&ret.mul(&output_grad[0]));
43 }));
44
45new_element_op!(Atan,
46 "Atan",
47 atan,
48 (|input: &[Tensor],
49 output_grad: &[Tensor],
50 input_grad: &[Tensor]| {
51 let ret = input[0].ones_like().add(&input[0].mul(&input[0])).reciprocal();
52 input_grad[0].swap(&ret.mul(&output_grad[0]));
53 }));
54
55new_element_op!(Ceil,
56 "Ceil",
57 ceil,
58 (|input: &[Tensor],
59 output_grad: &[Tensor],
60 input_grad: &[Tensor]| {
61 input_grad[0].swap(&input[0].zeros_like());
62 }));
63
64new_element_op!(Cos,
65 "Cos",
66 cos,
67 (|input: &[Tensor],
68 output_grad: &[Tensor],
69 input_grad: &[Tensor]| {
70 let ret = input[0].sin().neg();
71 input_grad[0].swap(&ret.mul(&output_grad[0]));
72 }));
73
74new_element_op!(Cosh,
75 "Cosh",
76 cosh,
77 (|input: &[Tensor],
78 output_grad: &[Tensor],
79 input_grad: &[Tensor]| {
80 let ret = input[0].sinh();
81 input_grad[0].swap(&ret.mul(&output_grad[0]));
82 }));
83
84new_element_op!(Exp,
85 "Exp",
86 exp,
87 (|input: &[Tensor],
88 output_grad: &[Tensor],
89 input_grad: &[Tensor]| {
90 let ret = input[0].exp();
91 input_grad[0].swap(&ret.mul(&output_grad[0]));
92 }));
93
94
95new_element_op!(Expm1,
96 "Expm1",
97 expm1,
98 (|input: &[Tensor],
99 output_grad: &[Tensor],
100 input_grad: &[Tensor]| {
101 let ret = input[0].exp();
102 input_grad[0].swap(&ret.mul(&output_grad[0]));
103 }));
104
105new_element_op!(Floor,
106 "Floor",
107 floor,
108 (|input: &[Tensor],
109 output_grad: &[Tensor],
110 input_grad: &[Tensor]| {
111 input_grad[0].swap(&input[0].zeros_like());
112 }));
113
114new_element_op!(Frac,
115 "Frac",
116 frac,
117 (|input: &[Tensor],
118 output_grad: &[Tensor],
119 input_grad: &[Tensor]| {
120 input_grad[0].swap(&input[0].ones_like());
121 }));
122
123new_element_op!(Log,
124 "Log",
125 log,
126 (|input: &[Tensor],
127 output_grad: &[Tensor],
128 input_grad: &[Tensor]| {
129 let ret = input[0].reciprocal();
130 input_grad[0].swap(&ret.mul(&output_grad[0]));
131 }));
132
133new_element_op!(Log10,
134 "Log10",
135 log10,
136 (|input: &[Tensor],
137 output_grad: &[Tensor],
138 input_grad: &[Tensor]| {
139 let ret = input[0].reciprocal().div(&input[0].log10_like());
140 input_grad[0].swap(&ret.mul(&output_grad[0]));
141 }));
142
143new_element_op!(Log1p,
144 "Log1p",
145 log1p,
146 (|input: &[Tensor],
147 output_grad: &[Tensor],
148 input_grad: &[Tensor]| {
149 let ret = input[0].add(&input[0].ones_like()).reciprocal();
150 input_grad[0].swap(&ret.mul(&output_grad[0]));
151 }));
152
153new_element_op!(Log1pexp,
154 "Log1pexp",
155 log1pexp,
156 (|input: &[Tensor],
157 output_grad: &[Tensor],
158 input_grad: &[Tensor]| {
159 let ret = input[0].neg().exp().add(&input[0].ones_like()).reciprocal();
160 input_grad[0].swap(&ret.mul(&output_grad[0]));
161 }));
162
163new_element_op!(Log2,
164 "Log2",
165 log2,
166 (|input: &[Tensor],
167 output_grad: &[Tensor],
168 input_grad: &[Tensor]| {
169 let ret = input[0].reciprocal().div(&input[0].log2_like());
170 input_grad[0].swap(&ret.mul(&output_grad[0]));
171 }));
172
173new_element_op!(Neg,
174 "Neg",
175 neg,
176 (|input: &[Tensor],
177 output_grad: &[Tensor],
178 input_grad: &[Tensor]| {
179 let ret = input[0].ones_like().neg();
180 input_grad[0].swap(&ret.mul(&output_grad[0]));
181 }));
182
183new_element_op!(Reciprocal,
184 "Reciprocal",
185 reciprocal,
186 (|input: &[Tensor],
187 output_grad: &[Tensor],
188 input_grad: &[Tensor]| {
189 let ret = input[0].square().reciprocal().neg();
190 input_grad[0].swap(&ret.mul(&output_grad[0]));
191 }));
192
193new_element_op!(Round,
194 "Round",
195 round,
196 (|input: &[Tensor],
197 output_grad: &[Tensor],
198 input_grad: &[Tensor]| {
199 let ret = input[0].zeros_like();
200 input_grad[0].swap(&ret.mul(&output_grad[0]));
201 }));
202
203new_element_op!(Rsqrt,
204 "Rsqrt",
205 rsqrt,
206 (|input: &[Tensor],
207 output_grad: &[Tensor],
208 input_grad: &[Tensor]| {
209 let ret = input[0].sqrt().reciprocal().
210 div(&input[0]).neg().div(
211 &input[0].ones_like().add(&input[0].ones_like()));
212 input_grad[0].swap(&ret.mul(&output_grad[0]));
213 }));
214
215new_element_op!(Sigmoid,
216 "Sigmoid",
217 sigmoid,
218 (|input: &[Tensor],
219 output_grad: &[Tensor],
220 input_grad: &[Tensor]| {
221 let ret = input[0].sigmoid().mul(&input[0].sigmoid().neg().add(&input[0].ones_like()));
222 input_grad[0].swap(&ret.mul(&output_grad[0]));
223 }));
224
225new_element_op!(Sign,
226 "Sign",
227 sign,
228 (|input: &[Tensor],
229 output_grad: &[Tensor],
230 input_grad: &[Tensor]| {
231 let ret = input[0].zeros_like();
232 input_grad[0].swap(&ret.mul(&output_grad[0]));
233 }));
234
235new_element_op!(Sin,
236 "Sin",
237 sin,
238 (|input: &[Tensor],
239 output_grad: &[Tensor],
240 input_grad: &[Tensor]| {
241 let ret = input[0].cos();
242 input_grad[0].swap(&ret.mul(&output_grad[0]));
243 }));
244
245new_element_op!(Sinh,
246 "Sinh",
247 sinh,
248 (|input: &[Tensor],
249 output_grad: &[Tensor],
250 input_grad: &[Tensor]| {
251 let ret = input[0].cosh();
252 input_grad[0].swap(&ret.mul(&output_grad[0]));
253 }));
254
255new_element_op!(Sqrt,
256 "Sqrt",
257 sqrt,
258 (|input: &[Tensor],
259 output_grad: &[Tensor],
260 input_grad: &[Tensor]| {
261 let ret = input[0].sqrt().reciprocal().div(
262 &input[0].ones_like().add(&input[0].ones_like()));
263 input_grad[0].swap(&ret.mul(&output_grad[0]));
264 }));
265
266new_element_op!(Tan,
267 "Tan",
268 tan,
269 (|input: &[Tensor],
270 output_grad: &[Tensor],
271 input_grad: &[Tensor]| {
272 let ret = input[0].tan().square().add(&input[0].ones_like());
273 input_grad[0].swap(&ret.mul(&output_grad[0]));
274 }));
275
276new_element_op!(Tanh,
277 "Tanh",
278 tanh,
279 (|input: &[Tensor],
280 output_grad: &[Tensor],
281 input_grad: &[Tensor]| {
282 let ret = input[0].tanh().square().neg().add(&input[0].ones_like());
283 input_grad[0].swap(&ret.mul(&output_grad[0]));
284 }));
285
286new_element_op!(Trunc,
287 "Trunc",
288 trunc,
289 (|input: &[Tensor],
290 output_grad: &[Tensor],
291 input_grad: &[Tensor]| {
292 let ret = input[0].zeros_like();
293 input_grad[0].swap(&ret.mul(&output_grad[0]));
294 }));
295
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::op::_gradient_checker;
301
302 fn test_range_data(op: &mut dyn OpTrait) {
303 for i in 0..10 {
304 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 - 0.51)], &vec![1]);
305 let good_grad = _gradient_checker(op, &[zero], None, None, None);
306 assert_eq!(good_grad, true);
307 }
308 }
309
310 #[test]
311 fn abs() {
312 let mut op = Abs::new();
313 test_range_data(&mut op);
314 }
315
316 #[test]
317 fn acos() {
318 let mut op = Acos::new();
319 test_range_data(&mut op);
320 }
321
322 #[test]
323 fn asin() {
324 let mut op = Asin::new();
325 test_range_data(&mut op);
326 }
327
328 #[test]
329 fn atan() {
330 let mut op = Atan::new();
331 test_range_data(&mut op);
332 }
333
334 #[test]
335 fn ceil() {
336 let mut op = Ceil::new();
337 test_range_data(&mut op);
338 }
339
340 #[test]
341 fn cos() {
342 let mut op = Cos::new();
343 test_range_data(&mut op);
344 }
345
346 #[test]
347 fn cosh() {
348 let mut op = Cosh::new();
349 test_range_data(&mut op);
350 }
351
352 #[test]
353 fn exp() {
354 let mut op = Exp::new();
355 test_range_data(&mut op);
356 }
357
358 #[test]
359 fn expm1() {
360 let mut op = Expm1::new();
361 test_range_data(&mut op);
362 }
363
364 #[test]
365 fn floor() {
366 let mut op = Floor::new();
367 test_range_data(&mut op);
368 }
369
370 #[test]
371 fn frac() {
372 let mut op = Frac::new();
373 test_range_data(&mut op);
374 }
375
376 #[test]
377 fn log() {
378 let mut op = Log::new();
379 for i in 0..10 {
380 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
381 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
382 assert_eq!(good_grad, true);
383 }
384 }
385
386 #[test]
387 fn log10() {
388 let mut op = Log10::new();
389 for i in 0..10 {
390 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
391 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
392 assert_eq!(good_grad, true);
393 }
394 }
395
396 #[test]
397 fn log1p() {
398 let mut op = Log1p::new();
399 for i in 0..10 {
400 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 - 0.51)], &vec![1]);
401 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
402 assert_eq!(good_grad, true);
403 }
404 }
405
406 #[test]
407 fn log1pexp() {
408 let mut op = Log1pexp::new();
409 for i in 0..10 {
410 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 - 0.51)], &vec![1]);
411 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
412 assert_eq!(good_grad, true);
413 }
414 }
415
416 #[test]
417 fn log2() {
418 let mut op = Log2::new();
419 for i in 0..10 {
420 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
421 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
422 assert_eq!(good_grad, true);
423 }
424 }
425
426 #[test]
427 fn neg() {
428 let mut op = Neg::new();
429 test_range_data(&mut op);
430 }
431
432 #[test]
433 fn reciprocal() {
434 let mut op = Reciprocal::new();
435 for i in 0..10 {
436 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
437 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
438 assert_eq!(good_grad, true);
439 }
440 }
441
442 #[test]
443 fn round() {
444 let mut op = Round::new();
445 test_range_data(&mut op);
446 }
447
448 #[test]
449 fn rsqrt() {
450 let mut op = Rsqrt::new();
451 for i in 0..10 {
452 let zero = Tensor::from_vec_f64(&vec![(i as f64 / 10.0 + 0.51)], &vec![1]);
453 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
454 assert_eq!(good_grad, true);
455 }
456 }
457
458 #[test]
459 fn sigmoid() {
460 let mut op = Sigmoid::new();
461 test_range_data(&mut op);
462 }
463
464 #[test]
465 fn sign() {
466 let mut op = Sign::new();
467 test_range_data(&mut op);
468 }
469
470 #[test]
471 fn sinh() {
472 let mut op = Sinh::new();
473 test_range_data(&mut op);
474 }
475
476 #[test]
477 fn sqrt() {
478 let mut op = Sqrt::new();
479 test_range_data(&mut op);
480 }
481
482 #[test]
483 fn tan() {
484 let mut op = Tan::new();
485 test_range_data(&mut op);
486 }
487
488 #[test]
489 fn tanh() {
490 let mut op = Tanh::new();
491 test_range_data(&mut op);
492 }
493
494 #[test]
495 fn trunc() {
496 let mut op = Trunc::new();
497 test_range_data(&mut op);
498 }
499}