Skip to main content

burn_backend/backend/ops/
activation.rs

1use crate::tensor::FloatTensor;
2use crate::{Backend, Scalar, TensorMetadata, get_device_settings};
3use core::f64::consts::SQRT_2;
4
5/// Activation function operations.
6///
7/// This trait let backend implementations override activation functions for better performance.
8pub trait ActivationOps<B: Backend> {
9    /// Applies the LeakyReLU activation function.
10    ///
11    /// # Arguments
12    ///
13    /// * `tensor` - The tensor.
14    /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
15    ///
16    /// # Returns
17    ///
18    /// The output tensor.
19    fn leaky_relu(tensor: FloatTensor<B>, negative_slope: Scalar) -> FloatTensor<B> {
20        let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
21        let mask = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
22        let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope);
23
24        // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
25        B::float_mask_where(tensor, mask, scaled_tensor)
26    }
27
28    /// Applies the ReLU activation function.
29    ///
30    /// # Arguments
31    ///
32    /// * `tensor` - The tensor.
33    ///
34    /// # Returns
35    ///
36    /// The output tensor.
37    fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
38        let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
39        let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into(), bool_dtype);
40
41        B::float_mask_fill(tensor, mask, 0f32.into())
42    }
43
44    /// Applies the ReLU activation function backward.
45    ///
46    /// # Arguments
47    ///
48    /// * `output` - The output tensor.
49    ///
50    /// # Returns
51    ///
52    /// The gradient.
53    fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
54        let bool_dtype = get_device_settings::<B>(&B::float_device(&output)).bool_dtype;
55        let mask = B::float_lower_equal_elem(output, 0f32.into(), bool_dtype);
56
57        B::float_mask_fill(grad, mask, 0.into())
58    }
59
60    /// Applies the Gelu activation function.
61    ///
62    /// # Arguments
63    ///
64    /// * `tensor` - The tensor.
65    ///
66    /// # Returns
67    ///
68    /// The output tensor.
69    fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
70        let x = B::float_div_scalar(tensor.clone(), SQRT_2.into());
71        let x = B::float_erf(x);
72        let x = B::float_add_scalar(x, 1f32.into());
73        let x = B::float_mul(tensor, x);
74
75        B::float_div_scalar(x, 2f32.into())
76    }
77    /// Applies the PReLu activation function.
78    /// # Arguments
79    /// * `tensor` - The input tensor
80    /// * `alpha` - The weight tensor
81    fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
82        let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
83        let mask = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
84        let scaled_tensor = B::float_mul(tensor.clone(), alpha);
85        B::float_mask_where(tensor, mask, scaled_tensor)
86    }
87
88    /// Applies the Gelu activation function backward.
89    ///
90    /// # Arguments
91    ///
92    /// * `x` - The tensor.
93    /// * `grad` - The gradient.
94    ///
95    /// # Returns
96    ///
97    /// The output tensor.
98    fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
99        // Derivative of the approximate gelu implementation based on tanh.
100
101        let constant_1 = 0.0356774;
102        let constant_2 = 0.797885;
103        let constant_3 = 0.0535161;
104        let constant_4 = 0.398942;
105
106        let x3 = B::float_powi_scalar(x.clone(), 3.into());
107
108        let c1 = B::float_mul_scalar(x3.clone(), constant_1.into());
109        let c2 = B::float_mul_scalar(x.clone(), constant_2.into());
110        let c3 = B::float_mul_scalar(x3, constant_3.into());
111        let c4 = B::float_mul_scalar(x, constant_4.into());
112
113        let inner1 = B::float_add(c1, c2);
114        let inner2 = B::float_add(c3, c4);
115
116        let tanh = B::float_tanh(inner1);
117
118        let sech = B::float_powi_scalar(tanh.clone(), 2.into());
119        let sech = B::float_neg(sech);
120        let sech = B::float_add_scalar(sech, 1.into());
121
122        let y1 = B::float_mul_scalar(tanh, 0.5.into());
123        let y2 = B::float_mul(inner2, sech);
124        let y2 = B::float_add_scalar(y2, 0.5.into());
125        let y = B::float_add(y1, y2);
126
127        B::float_mul(y, grad)
128    }
129
130    /// Applies the Sigmoid activation function.
131    ///
132    /// # Arguments
133    ///
134    /// * `tensor` - The tensor.
135    ///
136    /// # Returns
137    ///
138    /// The output tensor.
139    fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
140        let dtype = tensor.dtype();
141        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
142        let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
143            B::float_exp(B::float_neg(tensor_full)),
144            1.0.into(),
145        ))));
146
147        B::float_cast(tensor_tmp, dtype.into())
148    }
149
150    /// Applies the Sigmoid activation function backward.
151    ///
152    /// # Arguments
153    ///
154    /// * `output` - The output tensor of the sigmoid function.
155    /// * `grad` - The gradient.
156    ///
157    /// # Returns
158    ///
159    /// The output tensor.
160    fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
161        let value = B::float_mul(
162            output.clone(),
163            B::float_add_scalar(B::float_neg(output), 1.0.into()),
164        );
165        B::float_mul(value, grad)
166    }
167
168    /// Applies the hard Sigmoid activation function.
169    ///
170    /// # Arguments
171    ///
172    /// * `tensor` - The tensor.
173    /// * `alpha` - The alpha value that the tensor is multiplied with.
174    /// * `beta` - The beta value that is added to the tensor
175    ///
176    /// # Returns
177    ///
178    /// The output tensor.
179    fn hard_sigmoid(tensor: FloatTensor<B>, alpha: Scalar, beta: Scalar) -> FloatTensor<B> {
180        let dtype = tensor.dtype();
181        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
182
183        let tensor_tmp = B::float_clamp(
184            B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta),
185            0.0.into(),
186            1.0.into(),
187        );
188
189        B::float_cast(tensor_tmp, dtype.into())
190    }
191
192    /// Applies the LogSigmoid activation function.
193    ///
194    /// # Arguments
195    ///
196    /// * `tensor` - The tensor.
197    ///
198    /// # Returns
199    ///
200    /// The output tensor.
201    fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
202        // To avoid overflow, we use the log-sum-exp trick.
203        //
204        // ```ignore
205        // log(sigmoid(x)) = log(1/(1 + exp(-x)))
206        //                 = log(1) - log(1 + exp(-x))
207        //                 = -log(1 + exp(-x))
208        //                 = -log(exp(0) + exp(-x))
209        // ```
210        // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
211        // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
212        // following equivalence:
213        // ```ignore
214        // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
215        // ```
216        //
217        // This extends the range of values for which we obtain accurate results.
218
219        // max(-x, 0)
220        let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
221        let tensor_neg = B::float_neg(tensor);
222        let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into(), bool_dtype);
223        let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into());
224        let max_elem_neg = B::float_neg(max_elem.clone());
225
226        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
227        let z = B::float_add(
228            B::float_exp(max_elem_neg.clone()),
229            B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
230        );
231
232        // -max(-x, 0) - log(-z)
233        B::float_sub(max_elem_neg, B::float_log(z))
234    }
235
236    /// Applies the LogSigmoid activation function backward.
237    ///
238    /// # Arguments
239    ///
240    /// * `x` - The input tensor.
241    /// * `grad` - The gradient.
242    ///
243    /// # Returns
244    ///
245    /// The output gradient.
246    fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
247        // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
248        // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
249        // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
250        //
251        // This simplifies to:
252        // -max_derive - (z-1)/z if x is >= 0
253        // -max_derive + (z-1)/z if x is < 0
254
255        let shape = x.shape();
256        let dtype = x.dtype();
257        let device = B::float_device(&x);
258        let bool_dtype = get_device_settings::<B>(&device).bool_dtype;
259
260        // max(-x, 0)
261        let x_neg = B::float_neg(x);
262        let mask = B::float_lower_elem(x_neg.clone(), 0f32.into(), bool_dtype); // -x < 0 or x >= 0
263        let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into());
264
265        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
266        let z = B::float_add(
267            B::float_exp(B::float_neg(max_elem.clone())),
268            B::float_exp(B::float_sub(x_neg, max_elem)),
269        );
270
271        // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
272        let ones = B::float_ones(shape, &device, dtype.into());
273        let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into());
274        let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into());
275
276        // grad * (max_derive - sign * (1 - (1 / z)))
277        B::float_mul(
278            grad,
279            B::float_sub(
280                max_derive,
281                B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
282            ),
283        )
284    }
285}