burn_tensor/tensor/ops/
activation.rs

1use crate::TensorMetadata;
2use crate::{ElementConversion, backend::Backend};
3use core::f64::consts::SQRT_2;
4
5use super::FloatTensor;
6
7/// Activation function operations.
8///
9/// This trait let backend implementations override activation functions for better performance.
10pub trait ActivationOps<B: Backend> {
11    /// Applies the LeakyReLU activation function.
12    ///
13    /// # Arguments
14    ///
15    /// * `tensor` - The tensor.
16    /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
17    ///
18    /// # Returns
19    ///
20    /// The output tensor.
21    fn leaky_relu(tensor: FloatTensor<B>, negative_slope: super::FloatElem<B>) -> FloatTensor<B> {
22        let mask = B::float_lower_elem(tensor.clone(), 0.elem());
23        let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope.elem());
24
25        // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
26        B::float_mask_where(tensor, mask, scaled_tensor)
27    }
28
29    /// Applies the ReLU activation function.
30    ///
31    /// # Arguments
32    ///
33    /// * `tensor` - The tensor.
34    ///
35    /// # Returns
36    ///
37    /// The output tensor.
38    fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
39        let mask = B::float_lower_equal_elem(tensor.clone(), 0.elem());
40
41        B::float_mask_fill(tensor, mask, 0.elem())
42    }
43
44    /// Applies the ReLU activation function backward.
45    ///
46    /// # Arguments
47    ///
48    /// * `output` - The output tensor.
49    ///
50    /// # Returns
51    ///
52    /// The gradient.
53    fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
54        let mask = B::float_lower_equal_elem(output, 0.elem());
55
56        B::float_mask_fill(grad, mask, 0.elem())
57    }
58
59    /// Applies the Gelu activation function.
60    ///
61    /// # Arguments
62    ///
63    /// * `tensor` - The tensor.
64    ///
65    /// # Returns
66    ///
67    /// The output tensor.
68    fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
69        let x = B::float_div_scalar(tensor.clone(), SQRT_2.elem());
70        let x = B::float_erf(x);
71        let x = B::float_add_scalar(x, 1i32.elem());
72        let x = B::float_mul(tensor, x);
73
74        B::float_div_scalar(x, 2i32.elem())
75    }
76    /// Applies the PReLu activation function.
77    /// # Arguments
78    /// * `tensor` - The input tensor
79    /// * `alpha` - The weight tensor
80    fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
81        let mask = B::float_lower_elem(tensor.clone(), 0.elem());
82        let scaled_tensor = B::float_mul(tensor.clone(), alpha);
83        B::float_mask_where(tensor, mask, scaled_tensor)
84    }
85
86    /// Applies the Gelu activation function backward.
87    ///
88    /// # Arguments
89    ///
90    /// * `x` - The tensor.
91    /// * `grad` - The gradient.
92    ///
93    /// # Returns
94    ///
95    /// The output tensor.
96    fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
97        // Derivative of the approximate gelu implementation based on tanh.
98
99        let constant_1 = 0.0356774;
100        let constant_2 = 0.797885;
101        let constant_3 = 0.0535161;
102        let constant_4 = 0.398942;
103
104        let x3 = B::float_powi_scalar(x.clone(), 3.elem());
105
106        let c1 = B::float_mul_scalar(x3.clone(), constant_1.elem());
107        let c2 = B::float_mul_scalar(x.clone(), constant_2.elem());
108        let c3 = B::float_mul_scalar(x3, constant_3.elem());
109        let c4 = B::float_mul_scalar(x, constant_4.elem());
110
111        let inner1 = B::float_add(c1, c2);
112        let inner2 = B::float_add(c3, c4);
113
114        let tanh = B::float_tanh(inner1);
115
116        let sech = B::float_powi_scalar(tanh.clone(), 2.elem());
117        let sech = B::float_neg(sech);
118        let sech = B::float_add_scalar(sech, 1.elem());
119
120        let y1 = B::float_mul_scalar(tanh, 0.5.elem());
121        let y2 = B::float_mul(inner2, sech);
122        let y2 = B::float_add_scalar(y2, 0.5.elem());
123        let y = B::float_add(y1, y2);
124
125        B::float_mul(y, grad)
126    }
127
128    /// Applies the Sigmoid activation function.
129    ///
130    /// # Arguments
131    ///
132    /// * `tensor` - The tensor.
133    ///
134    /// # Returns
135    ///
136    /// The output tensor.
137    fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
138        let dtype = tensor.dtype();
139        let tensor_full = B::float_cast(tensor, crate::FloatDType::F32);
140        let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
141            B::float_exp(B::float_neg(tensor_full)),
142            1.0.elem(),
143        ))));
144
145        B::float_cast(tensor_tmp, dtype.into())
146    }
147
148    /// Applies the Sigmoid activation function backward.
149    ///
150    /// # Arguments
151    ///
152    /// * `output` - The output tensor of the sigmoid function.
153    /// * `grad` - The gradient.
154    ///
155    /// # Returns
156    ///
157    /// The output tensor.
158    fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
159        let value = B::float_mul(
160            output.clone(),
161            B::float_add_scalar(B::float_neg(output), 1.0.elem()),
162        );
163        B::float_mul(value, grad)
164    }
165
166    /// Applies the hard Sigmoid activation function.
167    ///
168    /// # Arguments
169    ///
170    /// * `tensor` - The tensor.
171    /// * `alpha` - The alpha value that the tensor is multiplied with.
172    /// * `beta` - The beta value that is added to the tensor
173    ///
174    /// # Returns
175    ///
176    /// The output tensor.
177    fn hard_sigmoid(
178        tensor: FloatTensor<B>,
179        alpha: super::FloatElem<B>,
180        beta: super::FloatElem<B>,
181    ) -> FloatTensor<B> {
182        let dtype = tensor.dtype();
183        let tensor_full = B::float_cast(tensor, crate::FloatDType::F32);
184
185        let tensor_tmp = B::float_clamp(
186            B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha.elem()), beta.elem()),
187            0.0.elem(),
188            1.0.elem(),
189        );
190
191        B::float_cast(tensor_tmp, dtype.into())
192    }
193
194    /// Applies the LogSigmoid activation function.
195    ///
196    /// # Arguments
197    ///
198    /// * `tensor` - The tensor.
199    ///
200    /// # Returns
201    ///
202    /// The output tensor.
203    fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
204        // To avoid overflow, we use the log-sum-exp trick.
205        //
206        // ```ignore
207        // log(sigmoid(x)) = log(1/(1 + exp(-x)))
208        //                 = log(1) - log(1 + exp(-x))
209        //                 = -log(1 + exp(-x))
210        //                 = -log(exp(0) + exp(-x))
211        // ```
212        // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
213        // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
214        // following equivalence:
215        // ```ignore
216        // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
217        // ```
218        //
219        // This extends the range of values for which we obtain accurate results.
220
221        // max(-x, 0)
222        let tensor_neg = B::float_neg(tensor);
223        let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem());
224        let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem());
225        let max_elem_neg = B::float_neg(max_elem.clone());
226
227        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
228        let z = B::float_add(
229            B::float_exp(max_elem_neg.clone()),
230            B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
231        );
232
233        // -max(-x, 0) - log(-z)
234        B::float_sub(max_elem_neg, B::float_log(z))
235    }
236
237    /// Applies the LogSigmoid activation function backward.
238    ///
239    /// # Arguments
240    ///
241    /// * `x` - The input tensor.
242    /// * `grad` - The gradient.
243    ///
244    /// # Returns
245    ///
246    /// The output gradient.
247    fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
248        // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
249        // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
250        // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
251        //
252        // This simplifies to:
253        // -max_derive - (z-1)/z if x is >= 0
254        // -max_derive + (z-1)/z if x is < 0
255
256        let shape = x.shape();
257        let dtype = x.dtype();
258        let device = B::float_device(&x);
259
260        // max(-x, 0)
261        let x_neg = B::float_neg(x);
262        let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0
263        let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem());
264
265        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
266        let z = B::float_add(
267            B::float_exp(B::float_neg(max_elem.clone())),
268            B::float_exp(B::float_sub(x_neg, max_elem)),
269        );
270
271        // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
272        let ones = B::float_ones(shape, &device, dtype.into());
273        let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem());
274        let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem());
275
276        // grad * (max_derive - sign * (1 - (1 / z)))
277        B::float_mul(
278            grad,
279            B::float_sub(
280                max_derive,
281                B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
282            ),
283        )
284    }
285}