auto_diff/op/
nonlinear.rs1#![allow(clippy::new_without_default)]
2use tensor_rs::tensor::Tensor;
3use super::{OpTrait, OpHandle};
4
5#[cfg(feature = "use-serde")]
6use serde::{Serialize, Deserialize};
7#[cfg(feature = "use-serde")]
8use std::any::Any;
9
10#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
12pub struct ELU {
13 alpha: Tensor,
14 #[cfg_attr(feature = "use-serde", serde(skip))]
15 handle: OpHandle,
16}
17impl ELU {
18
19 pub fn new(alpha: Tensor) -> ELU {
20 ELU {
21 alpha,
22 handle: OpHandle::new(),
23 }
24 }
25
26
27 handle_method!();
28}
29impl OpTrait for ELU {
30 fn get_name(&self) -> &'static str {
31 "ELU"
32 }
33 fn get_input_size(&self) -> usize {
34 1
35 }
36 fn get_output_size(&self) -> usize {
37 1
38 }
39
40 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
42 let positive = input[0].max_pair(&input[0].zeros_like());
43 let negative = input[0].expm1().mul(&Tensor::fill(&input[0].size(), &self.alpha)).min_pair(&input[0].zeros_like());
44 let ret = positive.add(&negative);
45 output[0].swap(&ret);
46 }
47
48 fn grad(&self, input: &[Tensor],
52 output_grad: &[Tensor],
53 input_grad: &[Tensor]) {
54 let positive = input[0].ge(&input[0].zeros_like());
55 let negative = input[0].lt(&input[0].zeros_like()).mul(&Tensor::fill(&input[0].size(), &self.alpha)).mul(&input[0].exp());
56 let g = positive.add(&negative);
57 input_grad[0].swap(&g.mul(&output_grad[0]));
58 }
59
60 fn get_values(&self) -> Vec<Tensor> {
62 Vec::new()
63 }
64 fn set_values(&self, _v: &[Tensor]) {
65 }
66 fn get_grads(&self) -> Vec<Tensor> {
68 Vec::new()
69 }
70 #[cfg(feature = "use-serde")]
71 fn as_any(&self) -> &dyn Any {
72 self
73 }
74}
75
76#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
85pub struct ReLU {
86 #[cfg_attr(feature = "use-serde", serde(skip))]
87 handle: OpHandle,
88}
89impl ReLU {
90 pub fn new() -> ReLU {
91 ReLU {
92 handle: OpHandle::new(),
93 }
94 }
95
96 handle_method!();
97}
98impl OpTrait for ReLU {
99 fn get_name(&self) -> &'static str {
100 "ReLU"
101 }
102 fn get_input_size(&self) -> usize {
103 1
104 }
105 fn get_output_size(&self) -> usize {
106 1
107 }
108 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
110 let ret = input[0].max_pair(&input[0].zeros_like());
111 output[0].swap(&ret);
112 }
113
114 fn grad(&self, input: &[Tensor],
118 output_grad: &[Tensor],
119 input_grad: &[Tensor]) {
120 let ret = input[0].ge(&input[0].zeros_like()); input_grad[0].swap(&ret.mul(&output_grad[0]));
122 }
123
124 fn get_values(&self) -> Vec<Tensor> {
126 Vec::new()
127 }
128 fn set_values(&self, _v: &[Tensor]) {
129 }
130 fn get_grads(&self) -> Vec<Tensor> {
132 Vec::new()
133 }
134 #[cfg(feature = "use-serde")]
135 fn as_any(&self) -> &dyn Any {
136 self
137 }
138}
139
140
141#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
147pub struct Sigmoid {
148 #[cfg_attr(feature = "use-serde", serde(skip))]
149 handle: OpHandle,
150}
151impl Sigmoid {
152 pub fn new() -> Sigmoid {
153 Sigmoid {
154 handle: OpHandle::new(),
155 }
156 }
157 handle_method!();
158}
159impl OpTrait for Sigmoid {
160
161 fn get_name(&self) -> &'static str {
162 "Sigmoid"
163 }
164 fn get_input_size(&self) -> usize {
165 1
166 }
167 fn get_output_size(&self) -> usize {
168 1
169 }
170 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
172 if input.is_empty() {
173 panic!("{} expect two input, get {}", self.get_name(), input.len());
174 }
175 output[0].swap(&input[0].sigmoid());
176 }
177
178 fn grad(&self, input: &[Tensor],
182 output_grad: &[Tensor],
183 input_grad: &[Tensor]) {
184 let tmp1 = input[0].sigmoid().mul(&input[0].neg().sigmoid());
185 let tmp2 = tmp1.mul(&output_grad[0]);
186 input_grad[0].swap(&tmp2);
187 }
188
189 fn get_values(&self) -> Vec<Tensor> {
191 Vec::new()
192 }
193
194 fn set_values(&self, _v: &[Tensor]) {
195
196 }
197
198 fn get_grads(&self) -> Vec<Tensor> {
200 Vec::new()
201 }
202 #[cfg(feature = "use-serde")]
203 fn as_any(&self) -> &dyn Any {
204 self
205 }
206}
207
208#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
215pub struct Sine {
216 #[cfg_attr(feature = "use-serde", serde(skip))]
217 handle: OpHandle,
218}
219impl Sine {
220 pub fn new() -> Sine {
221 Sine {
222 handle: OpHandle::new(),
223 }
224 }
225 handle_method!();
226}
227impl OpTrait for Sine {
228 fn get_name(&self) -> &'static str {
229 "Sine"
230 }
231 fn get_input_size(&self) -> usize {
232 1
233 }
234 fn get_output_size(&self) -> usize {
235 1
236 }
237 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
239 let ret = input[0].sin();
240 output[0].swap(&ret);
241 }
242
243 fn grad(&self, input: &[Tensor],
247 output_grad: &[Tensor],
248 input_grad: &[Tensor]) {
249 let ret = input[0].cos();
250 input_grad[0].swap(&ret.mul(&output_grad[0]));
251 }
252
253 fn get_values(&self) -> Vec<Tensor> {
255 Vec::new()
256 }
257 fn set_values(&self, _v: &[Tensor]) {
258 }
259 fn get_grads(&self) -> Vec<Tensor> {
261 Vec::new()
262 }
263 #[cfg(feature = "use-serde")]
264 fn as_any(&self) -> &dyn Any {
265 self
266 }
267}
268
269#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::op::_gradient_checker;
281
282 #[test]
283 fn elu() {
284 let mut op = ELU::new(Tensor::from_vec_f64(&[1.], &[1]));
285
286 for i in 0..10 {
287 let zero = Tensor::from_vec_f64(&vec![(i - 5) as f64], &vec![1]);
288 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
289 assert_eq!(good_grad, true);
290 }
291
292
293 }
294
295 #[test]
296 fn relu() {
297 let mut op = ReLU::new();
298
299 for i in 0..10 {
300 let zero = Tensor::from_vec_f64(&vec![(i - 5) as f64], &vec![1]);
301 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
302 assert_eq!(good_grad, true);
303 }
304 }
305
306 #[test]
307 fn sigmoid() {
308 let mut op = Sigmoid::new();
309
310 for i in 0..10 {
311 let zero = Tensor::from_vec_f64(&vec![(i - 5) as f64], &vec![1]);
312 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
313 assert_eq!(good_grad, true);
314 }
315 }
316
317 #[test]
318 fn sine() {
319 let mut op = Sine::new();
320
321 for i in 0..10 {
322 let zero = Tensor::from_vec_f64(&vec![(i - 5) as f64], &vec![1]);
323 let good_grad = _gradient_checker(&mut op, &[zero], None, None, None);
324 assert_eq!(good_grad, true);
325 }
326 }
327}