Skip to main content

burn_backend/backend/ops/
activation.rs

1use crate::tensor::FloatTensor;
2use crate::{Backend, Scalar, TensorMetadata};
3use core::f64::consts::SQRT_2;
4
5/// Activation function operations.
6///
7/// This trait let backend implementations override activation functions for better performance.
8pub trait ActivationOps<B: Backend> {
9    /// Applies the LeakyReLU activation function.
10    ///
11    /// # Arguments
12    ///
13    /// * `tensor` - The tensor.
14    /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
15    ///
16    /// # Returns
17    ///
18    /// The output tensor.
19    fn leaky_relu(tensor: FloatTensor<B>, negative_slope: Scalar) -> FloatTensor<B> {
20        let mask = B::float_lower_elem(tensor.clone(), 0f32.into());
21        let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope);
22
23        // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
24        B::float_mask_where(tensor, mask, scaled_tensor)
25    }
26
27    /// Applies the ReLU activation function.
28    ///
29    /// # Arguments
30    ///
31    /// * `tensor` - The tensor.
32    ///
33    /// # Returns
34    ///
35    /// The output tensor.
36    fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
37        let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into());
38
39        B::float_mask_fill(tensor, mask, 0f32.into())
40    }
41
42    /// Applies the ReLU activation function backward.
43    ///
44    /// # Arguments
45    ///
46    /// * `output` - The output tensor.
47    ///
48    /// # Returns
49    ///
50    /// The gradient.
51    fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
52        let mask = B::float_lower_equal_elem(output, 0f32.into());
53
54        B::float_mask_fill(grad, mask, 0.into())
55    }
56
57    /// Applies the Gelu activation function.
58    ///
59    /// # Arguments
60    ///
61    /// * `tensor` - The tensor.
62    ///
63    /// # Returns
64    ///
65    /// The output tensor.
66    fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
67        let x = B::float_div_scalar(tensor.clone(), SQRT_2.into());
68        let x = B::float_erf(x);
69        let x = B::float_add_scalar(x, 1f32.into());
70        let x = B::float_mul(tensor, x);
71
72        B::float_div_scalar(x, 2f32.into())
73    }
74    /// Applies the PReLu activation function.
75    /// # Arguments
76    /// * `tensor` - The input tensor
77    /// * `alpha` - The weight tensor
78    fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
79        let mask = B::float_lower_elem(tensor.clone(), 0f32.into());
80        let scaled_tensor = B::float_mul(tensor.clone(), alpha);
81        B::float_mask_where(tensor, mask, scaled_tensor)
82    }
83
84    /// Applies the Gelu activation function backward.
85    ///
86    /// # Arguments
87    ///
88    /// * `x` - The tensor.
89    /// * `grad` - The gradient.
90    ///
91    /// # Returns
92    ///
93    /// The output tensor.
94    fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
95        // Derivative of the approximate gelu implementation based on tanh.
96
97        let constant_1 = 0.0356774;
98        let constant_2 = 0.797885;
99        let constant_3 = 0.0535161;
100        let constant_4 = 0.398942;
101
102        let x3 = B::float_powi_scalar(x.clone(), 3.into());
103
104        let c1 = B::float_mul_scalar(x3.clone(), constant_1.into());
105        let c2 = B::float_mul_scalar(x.clone(), constant_2.into());
106        let c3 = B::float_mul_scalar(x3, constant_3.into());
107        let c4 = B::float_mul_scalar(x, constant_4.into());
108
109        let inner1 = B::float_add(c1, c2);
110        let inner2 = B::float_add(c3, c4);
111
112        let tanh = B::float_tanh(inner1);
113
114        let sech = B::float_powi_scalar(tanh.clone(), 2.into());
115        let sech = B::float_neg(sech);
116        let sech = B::float_add_scalar(sech, 1.into());
117
118        let y1 = B::float_mul_scalar(tanh, 0.5.into());
119        let y2 = B::float_mul(inner2, sech);
120        let y2 = B::float_add_scalar(y2, 0.5.into());
121        let y = B::float_add(y1, y2);
122
123        B::float_mul(y, grad)
124    }
125
126    /// Applies the Sigmoid activation function.
127    ///
128    /// # Arguments
129    ///
130    /// * `tensor` - The tensor.
131    ///
132    /// # Returns
133    ///
134    /// The output tensor.
135    fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
136        let dtype = tensor.dtype();
137        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
138        let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
139            B::float_exp(B::float_neg(tensor_full)),
140            1.0.into(),
141        ))));
142
143        B::float_cast(tensor_tmp, dtype.into())
144    }
145
146    /// Applies the Sigmoid activation function backward.
147    ///
148    /// # Arguments
149    ///
150    /// * `output` - The output tensor of the sigmoid function.
151    /// * `grad` - The gradient.
152    ///
153    /// # Returns
154    ///
155    /// The output tensor.
156    fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
157        let value = B::float_mul(
158            output.clone(),
159            B::float_add_scalar(B::float_neg(output), 1.0.into()),
160        );
161        B::float_mul(value, grad)
162    }
163
164    /// Applies the hard Sigmoid activation function.
165    ///
166    /// # Arguments
167    ///
168    /// * `tensor` - The tensor.
169    /// * `alpha` - The alpha value that the tensor is multiplied with.
170    /// * `beta` - The beta value that is added to the tensor
171    ///
172    /// # Returns
173    ///
174    /// The output tensor.
175    fn hard_sigmoid(tensor: FloatTensor<B>, alpha: Scalar, beta: Scalar) -> FloatTensor<B> {
176        let dtype = tensor.dtype();
177        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
178
179        let tensor_tmp = B::float_clamp(
180            B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta),
181            0.0.into(),
182            1.0.into(),
183        );
184
185        B::float_cast(tensor_tmp, dtype.into())
186    }
187
188    /// Applies the LogSigmoid activation function.
189    ///
190    /// # Arguments
191    ///
192    /// * `tensor` - The tensor.
193    ///
194    /// # Returns
195    ///
196    /// The output tensor.
197    fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
198        // To avoid overflow, we use the log-sum-exp trick.
199        //
200        // ```ignore
201        // log(sigmoid(x)) = log(1/(1 + exp(-x)))
202        //                 = log(1) - log(1 + exp(-x))
203        //                 = -log(1 + exp(-x))
204        //                 = -log(exp(0) + exp(-x))
205        // ```
206        // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
207        // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
208        // following equivalence:
209        // ```ignore
210        // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
211        // ```
212        //
213        // This extends the range of values for which we obtain accurate results.
214
215        // max(-x, 0)
216        let tensor_neg = B::float_neg(tensor);
217        let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into());
218        let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into());
219        let max_elem_neg = B::float_neg(max_elem.clone());
220
221        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
222        let z = B::float_add(
223            B::float_exp(max_elem_neg.clone()),
224            B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
225        );
226
227        // -max(-x, 0) - log(-z)
228        B::float_sub(max_elem_neg, B::float_log(z))
229    }
230
231    /// Applies the LogSigmoid activation function backward.
232    ///
233    /// # Arguments
234    ///
235    /// * `x` - The input tensor.
236    /// * `grad` - The gradient.
237    ///
238    /// # Returns
239    ///
240    /// The output gradient.
241    fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
242        // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
243        // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
244        // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
245        //
246        // This simplifies to:
247        // -max_derive - (z-1)/z if x is >= 0
248        // -max_derive + (z-1)/z if x is < 0
249
250        let shape = x.shape();
251        let dtype = x.dtype();
252        let device = B::float_device(&x);
253
254        // max(-x, 0)
255        let x_neg = B::float_neg(x);
256        let mask = B::float_lower_elem(x_neg.clone(), 0f32.into()); // -x < 0 or x >= 0
257        let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into());
258
259        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
260        let z = B::float_add(
261            B::float_exp(B::float_neg(max_elem.clone())),
262            B::float_exp(B::float_sub(x_neg, max_elem)),
263        );
264
265        // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
266        let ones = B::float_ones(shape, &device, dtype.into());
267        let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into());
268        let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into());
269
270        // grad * (max_derive - sign * (1 - (1 / z)))
271        B::float_mul(
272            grad,
273            B::float_sub(
274                max_derive,
275                B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
276            ),
277        )
278    }
279}