burn_backend/backend/ops/activation.rs
1use crate::tensor::FloatTensor;
2use crate::{Backend, Scalar, TensorMetadata};
3use core::f64::consts::SQRT_2;
4
5/// Activation function operations.
6///
7/// This trait let backend implementations override activation functions for better performance.
8pub trait ActivationOps<B: Backend> {
9 /// Applies the LeakyReLU activation function.
10 ///
11 /// # Arguments
12 ///
13 /// * `tensor` - The tensor.
14 /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
15 ///
16 /// # Returns
17 ///
18 /// The output tensor.
19 fn leaky_relu(tensor: FloatTensor<B>, negative_slope: Scalar) -> FloatTensor<B> {
20 let mask = B::float_lower_elem(tensor.clone(), 0f32.into());
21 let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope);
22
23 // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
24 B::float_mask_where(tensor, mask, scaled_tensor)
25 }
26
27 /// Applies the ReLU activation function.
28 ///
29 /// # Arguments
30 ///
31 /// * `tensor` - The tensor.
32 ///
33 /// # Returns
34 ///
35 /// The output tensor.
36 fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
37 let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into());
38
39 B::float_mask_fill(tensor, mask, 0f32.into())
40 }
41
42 /// Applies the ReLU activation function backward.
43 ///
44 /// # Arguments
45 ///
46 /// * `output` - The output tensor.
47 ///
48 /// # Returns
49 ///
50 /// The gradient.
51 fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
52 let mask = B::float_lower_equal_elem(output, 0f32.into());
53
54 B::float_mask_fill(grad, mask, 0.into())
55 }
56
57 /// Applies the Gelu activation function.
58 ///
59 /// # Arguments
60 ///
61 /// * `tensor` - The tensor.
62 ///
63 /// # Returns
64 ///
65 /// The output tensor.
66 fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
67 let x = B::float_div_scalar(tensor.clone(), SQRT_2.into());
68 let x = B::float_erf(x);
69 let x = B::float_add_scalar(x, 1f32.into());
70 let x = B::float_mul(tensor, x);
71
72 B::float_div_scalar(x, 2f32.into())
73 }
74 /// Applies the PReLu activation function.
75 /// # Arguments
76 /// * `tensor` - The input tensor
77 /// * `alpha` - The weight tensor
78 fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
79 let mask = B::float_lower_elem(tensor.clone(), 0f32.into());
80 let scaled_tensor = B::float_mul(tensor.clone(), alpha);
81 B::float_mask_where(tensor, mask, scaled_tensor)
82 }
83
84 /// Applies the Gelu activation function backward.
85 ///
86 /// # Arguments
87 ///
88 /// * `x` - The tensor.
89 /// * `grad` - The gradient.
90 ///
91 /// # Returns
92 ///
93 /// The output tensor.
94 fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
95 // Derivative of the approximate gelu implementation based on tanh.
96
97 let constant_1 = 0.0356774;
98 let constant_2 = 0.797885;
99 let constant_3 = 0.0535161;
100 let constant_4 = 0.398942;
101
102 let x3 = B::float_powi_scalar(x.clone(), 3.into());
103
104 let c1 = B::float_mul_scalar(x3.clone(), constant_1.into());
105 let c2 = B::float_mul_scalar(x.clone(), constant_2.into());
106 let c3 = B::float_mul_scalar(x3, constant_3.into());
107 let c4 = B::float_mul_scalar(x, constant_4.into());
108
109 let inner1 = B::float_add(c1, c2);
110 let inner2 = B::float_add(c3, c4);
111
112 let tanh = B::float_tanh(inner1);
113
114 let sech = B::float_powi_scalar(tanh.clone(), 2.into());
115 let sech = B::float_neg(sech);
116 let sech = B::float_add_scalar(sech, 1.into());
117
118 let y1 = B::float_mul_scalar(tanh, 0.5.into());
119 let y2 = B::float_mul(inner2, sech);
120 let y2 = B::float_add_scalar(y2, 0.5.into());
121 let y = B::float_add(y1, y2);
122
123 B::float_mul(y, grad)
124 }
125
126 /// Applies the Sigmoid activation function.
127 ///
128 /// # Arguments
129 ///
130 /// * `tensor` - The tensor.
131 ///
132 /// # Returns
133 ///
134 /// The output tensor.
135 fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
136 let dtype = tensor.dtype();
137 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
138 let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
139 B::float_exp(B::float_neg(tensor_full)),
140 1.0.into(),
141 ))));
142
143 B::float_cast(tensor_tmp, dtype.into())
144 }
145
146 /// Applies the Sigmoid activation function backward.
147 ///
148 /// # Arguments
149 ///
150 /// * `output` - The output tensor of the sigmoid function.
151 /// * `grad` - The gradient.
152 ///
153 /// # Returns
154 ///
155 /// The output tensor.
156 fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
157 let value = B::float_mul(
158 output.clone(),
159 B::float_add_scalar(B::float_neg(output), 1.0.into()),
160 );
161 B::float_mul(value, grad)
162 }
163
164 /// Applies the hard Sigmoid activation function.
165 ///
166 /// # Arguments
167 ///
168 /// * `tensor` - The tensor.
169 /// * `alpha` - The alpha value that the tensor is multiplied with.
170 /// * `beta` - The beta value that is added to the tensor
171 ///
172 /// # Returns
173 ///
174 /// The output tensor.
175 fn hard_sigmoid(tensor: FloatTensor<B>, alpha: Scalar, beta: Scalar) -> FloatTensor<B> {
176 let dtype = tensor.dtype();
177 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
178
179 let tensor_tmp = B::float_clamp(
180 B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta),
181 0.0.into(),
182 1.0.into(),
183 );
184
185 B::float_cast(tensor_tmp, dtype.into())
186 }
187
188 /// Applies the LogSigmoid activation function.
189 ///
190 /// # Arguments
191 ///
192 /// * `tensor` - The tensor.
193 ///
194 /// # Returns
195 ///
196 /// The output tensor.
197 fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
198 // To avoid overflow, we use the log-sum-exp trick.
199 //
200 // ```ignore
201 // log(sigmoid(x)) = log(1/(1 + exp(-x)))
202 // = log(1) - log(1 + exp(-x))
203 // = -log(1 + exp(-x))
204 // = -log(exp(0) + exp(-x))
205 // ```
206 // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
207 // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
208 // following equivalence:
209 // ```ignore
210 // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
211 // ```
212 //
213 // This extends the range of values for which we obtain accurate results.
214
215 // max(-x, 0)
216 let tensor_neg = B::float_neg(tensor);
217 let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into());
218 let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into());
219 let max_elem_neg = B::float_neg(max_elem.clone());
220
221 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
222 let z = B::float_add(
223 B::float_exp(max_elem_neg.clone()),
224 B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
225 );
226
227 // -max(-x, 0) - log(-z)
228 B::float_sub(max_elem_neg, B::float_log(z))
229 }
230
231 /// Applies the LogSigmoid activation function backward.
232 ///
233 /// # Arguments
234 ///
235 /// * `x` - The input tensor.
236 /// * `grad` - The gradient.
237 ///
238 /// # Returns
239 ///
240 /// The output gradient.
241 fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
242 // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
243 // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
244 // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
245 //
246 // This simplifies to:
247 // -max_derive - (z-1)/z if x is >= 0
248 // -max_derive + (z-1)/z if x is < 0
249
250 let shape = x.shape();
251 let dtype = x.dtype();
252 let device = B::float_device(&x);
253
254 // max(-x, 0)
255 let x_neg = B::float_neg(x);
256 let mask = B::float_lower_elem(x_neg.clone(), 0f32.into()); // -x < 0 or x >= 0
257 let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into());
258
259 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
260 let z = B::float_add(
261 B::float_exp(B::float_neg(max_elem.clone())),
262 B::float_exp(B::float_sub(x_neg, max_elem)),
263 );
264
265 // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
266 let ones = B::float_ones(shape, &device, dtype.into());
267 let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into());
268 let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into());
269
270 // grad * (max_derive - sign * (1 - (1 / z)))
271 B::float_mul(
272 grad,
273 B::float_sub(
274 max_derive,
275 B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
276 ),
277 )
278 }
279}