burn_backend/backend/ops/
activation.rs

1use crate::element::ElementConversion;
2use crate::tensor::{FloatElem, FloatTensor};
3use crate::{Backend, TensorMetadata};
4use core::f64::consts::SQRT_2;
5
6/// Activation function operations.
7///
8/// This trait let backend implementations override activation functions for better performance.
9pub trait ActivationOps<B: Backend> {
10    /// Applies the LeakyReLU activation function.
11    ///
12    /// # Arguments
13    ///
14    /// * `tensor` - The tensor.
15    /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
16    ///
17    /// # Returns
18    ///
19    /// The output tensor.
20    fn leaky_relu(tensor: FloatTensor<B>, negative_slope: FloatElem<B>) -> FloatTensor<B> {
21        let mask = B::float_lower_elem(tensor.clone(), 0.elem());
22        let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope.elem());
23
24        // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
25        B::float_mask_where(tensor, mask, scaled_tensor)
26    }
27
28    /// Applies the ReLU activation function.
29    ///
30    /// # Arguments
31    ///
32    /// * `tensor` - The tensor.
33    ///
34    /// # Returns
35    ///
36    /// The output tensor.
37    fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
38        let mask = B::float_lower_equal_elem(tensor.clone(), 0.elem());
39
40        B::float_mask_fill(tensor, mask, 0.elem())
41    }
42
43    /// Applies the ReLU activation function backward.
44    ///
45    /// # Arguments
46    ///
47    /// * `output` - The output tensor.
48    ///
49    /// # Returns
50    ///
51    /// The gradient.
52    fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
53        let mask = B::float_lower_equal_elem(output, 0.elem());
54
55        B::float_mask_fill(grad, mask, 0.elem())
56    }
57
58    /// Applies the Gelu activation function.
59    ///
60    /// # Arguments
61    ///
62    /// * `tensor` - The tensor.
63    ///
64    /// # Returns
65    ///
66    /// The output tensor.
67    fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
68        let x = B::float_div_scalar(tensor.clone(), SQRT_2.elem());
69        let x = B::float_erf(x);
70        let x = B::float_add_scalar(x, 1i32.elem());
71        let x = B::float_mul(tensor, x);
72
73        B::float_div_scalar(x, 2i32.elem())
74    }
75    /// Applies the PReLu activation function.
76    /// # Arguments
77    /// * `tensor` - The input tensor
78    /// * `alpha` - The weight tensor
79    fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
80        let mask = B::float_lower_elem(tensor.clone(), 0.elem());
81        let scaled_tensor = B::float_mul(tensor.clone(), alpha);
82        B::float_mask_where(tensor, mask, scaled_tensor)
83    }
84
85    /// Applies the Gelu activation function backward.
86    ///
87    /// # Arguments
88    ///
89    /// * `x` - The tensor.
90    /// * `grad` - The gradient.
91    ///
92    /// # Returns
93    ///
94    /// The output tensor.
95    fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
96        // Derivative of the approximate gelu implementation based on tanh.
97
98        let constant_1 = 0.0356774;
99        let constant_2 = 0.797885;
100        let constant_3 = 0.0535161;
101        let constant_4 = 0.398942;
102
103        let x3 = B::float_powi_scalar(x.clone(), 3.elem());
104
105        let c1 = B::float_mul_scalar(x3.clone(), constant_1.elem());
106        let c2 = B::float_mul_scalar(x.clone(), constant_2.elem());
107        let c3 = B::float_mul_scalar(x3, constant_3.elem());
108        let c4 = B::float_mul_scalar(x, constant_4.elem());
109
110        let inner1 = B::float_add(c1, c2);
111        let inner2 = B::float_add(c3, c4);
112
113        let tanh = B::float_tanh(inner1);
114
115        let sech = B::float_powi_scalar(tanh.clone(), 2.elem());
116        let sech = B::float_neg(sech);
117        let sech = B::float_add_scalar(sech, 1.elem());
118
119        let y1 = B::float_mul_scalar(tanh, 0.5.elem());
120        let y2 = B::float_mul(inner2, sech);
121        let y2 = B::float_add_scalar(y2, 0.5.elem());
122        let y = B::float_add(y1, y2);
123
124        B::float_mul(y, grad)
125    }
126
127    /// Applies the Sigmoid activation function.
128    ///
129    /// # Arguments
130    ///
131    /// * `tensor` - The tensor.
132    ///
133    /// # Returns
134    ///
135    /// The output tensor.
136    fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
137        let dtype = tensor.dtype();
138        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
139        let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
140            B::float_exp(B::float_neg(tensor_full)),
141            1.0.elem(),
142        ))));
143
144        B::float_cast(tensor_tmp, dtype.into())
145    }
146
147    /// Applies the Sigmoid activation function backward.
148    ///
149    /// # Arguments
150    ///
151    /// * `output` - The output tensor of the sigmoid function.
152    /// * `grad` - The gradient.
153    ///
154    /// # Returns
155    ///
156    /// The output tensor.
157    fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
158        let value = B::float_mul(
159            output.clone(),
160            B::float_add_scalar(B::float_neg(output), 1.0.elem()),
161        );
162        B::float_mul(value, grad)
163    }
164
165    /// Applies the hard Sigmoid activation function.
166    ///
167    /// # Arguments
168    ///
169    /// * `tensor` - The tensor.
170    /// * `alpha` - The alpha value that the tensor is multiplied with.
171    /// * `beta` - The beta value that is added to the tensor
172    ///
173    /// # Returns
174    ///
175    /// The output tensor.
176    fn hard_sigmoid(
177        tensor: FloatTensor<B>,
178        alpha: FloatElem<B>,
179        beta: FloatElem<B>,
180    ) -> FloatTensor<B> {
181        let dtype = tensor.dtype();
182        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
183
184        let tensor_tmp = B::float_clamp(
185            B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha.elem()), beta.elem()),
186            0.0.elem(),
187            1.0.elem(),
188        );
189
190        B::float_cast(tensor_tmp, dtype.into())
191    }
192
193    /// Applies the LogSigmoid activation function.
194    ///
195    /// # Arguments
196    ///
197    /// * `tensor` - The tensor.
198    ///
199    /// # Returns
200    ///
201    /// The output tensor.
202    fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
203        // To avoid overflow, we use the log-sum-exp trick.
204        //
205        // ```ignore
206        // log(sigmoid(x)) = log(1/(1 + exp(-x)))
207        //                 = log(1) - log(1 + exp(-x))
208        //                 = -log(1 + exp(-x))
209        //                 = -log(exp(0) + exp(-x))
210        // ```
211        // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
212        // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
213        // following equivalence:
214        // ```ignore
215        // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
216        // ```
217        //
218        // This extends the range of values for which we obtain accurate results.
219
220        // max(-x, 0)
221        let tensor_neg = B::float_neg(tensor);
222        let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem());
223        let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem());
224        let max_elem_neg = B::float_neg(max_elem.clone());
225
226        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
227        let z = B::float_add(
228            B::float_exp(max_elem_neg.clone()),
229            B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
230        );
231
232        // -max(-x, 0) - log(-z)
233        B::float_sub(max_elem_neg, B::float_log(z))
234    }
235
236    /// Applies the LogSigmoid activation function backward.
237    ///
238    /// # Arguments
239    ///
240    /// * `x` - The input tensor.
241    /// * `grad` - The gradient.
242    ///
243    /// # Returns
244    ///
245    /// The output gradient.
246    fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
247        // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
248        // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
249        // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
250        //
251        // This simplifies to:
252        // -max_derive - (z-1)/z if x is >= 0
253        // -max_derive + (z-1)/z if x is < 0
254
255        let shape = x.shape();
256        let dtype = x.dtype();
257        let device = B::float_device(&x);
258
259        // max(-x, 0)
260        let x_neg = B::float_neg(x);
261        let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0
262        let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem());
263
264        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
265        let z = B::float_add(
266            B::float_exp(B::float_neg(max_elem.clone())),
267            B::float_exp(B::float_sub(x_neg, max_elem)),
268        );
269
270        // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
271        let ones = B::float_ones(shape, &device, dtype.into());
272        let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem());
273        let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem());
274
275        // grad * (max_derive - sign * (1 - (1 / z)))
276        B::float_mul(
277            grad,
278            B::float_sub(
279                max_derive,
280                B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
281            ),
282        )
283    }
284}