burn_backend/backend/ops/activation.rs
1use crate::tensor::FloatTensor;
2use crate::{Backend, Scalar, TensorMetadata, get_device_settings};
3use core::f64::consts::SQRT_2;
4
5/// Activation function operations.
6///
7/// This trait let backend implementations override activation functions for better performance.
8pub trait ActivationOps<B: Backend> {
9 /// Applies the LeakyReLU activation function.
10 ///
11 /// # Arguments
12 ///
13 /// * `tensor` - The tensor.
14 /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
15 ///
16 /// # Returns
17 ///
18 /// The output tensor.
19 fn leaky_relu(tensor: FloatTensor<B>, negative_slope: Scalar) -> FloatTensor<B> {
20 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
21 let mask = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
22 let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope);
23
24 // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
25 B::float_mask_where(tensor, mask, scaled_tensor)
26 }
27
28 /// Applies the ReLU activation function.
29 ///
30 /// # Arguments
31 ///
32 /// * `tensor` - The tensor.
33 ///
34 /// # Returns
35 ///
36 /// The output tensor.
37 fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
38 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
39 let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into(), bool_dtype);
40
41 B::float_mask_fill(tensor, mask, 0f32.into())
42 }
43
44 /// Applies the ReLU activation function backward.
45 ///
46 /// # Arguments
47 ///
48 /// * `output` - The output tensor.
49 ///
50 /// # Returns
51 ///
52 /// The gradient.
53 fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
54 let bool_dtype = get_device_settings::<B>(&B::float_device(&output)).bool_dtype;
55 let mask = B::float_lower_equal_elem(output, 0f32.into(), bool_dtype);
56
57 B::float_mask_fill(grad, mask, 0.into())
58 }
59
60 /// Applies the Gelu activation function.
61 ///
62 /// # Arguments
63 ///
64 /// * `tensor` - The tensor.
65 ///
66 /// # Returns
67 ///
68 /// The output tensor.
69 fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
70 let x = B::float_div_scalar(tensor.clone(), SQRT_2.into());
71 let x = B::float_erf(x);
72 let x = B::float_add_scalar(x, 1f32.into());
73 let x = B::float_mul(tensor, x);
74
75 B::float_div_scalar(x, 2f32.into())
76 }
77 /// Applies the PReLu activation function.
78 /// # Arguments
79 /// * `tensor` - The input tensor
80 /// * `alpha` - The weight tensor
81 fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
82 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
83 let mask = B::float_lower_elem(tensor.clone(), 0f32.into(), bool_dtype);
84 let scaled_tensor = B::float_mul(tensor.clone(), alpha);
85 B::float_mask_where(tensor, mask, scaled_tensor)
86 }
87
88 /// Applies the Gelu activation function backward.
89 ///
90 /// # Arguments
91 ///
92 /// * `x` - The tensor.
93 /// * `grad` - The gradient.
94 ///
95 /// # Returns
96 ///
97 /// The output tensor.
98 fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
99 // Derivative of the approximate gelu implementation based on tanh.
100
101 let constant_1 = 0.0356774;
102 let constant_2 = 0.797885;
103 let constant_3 = 0.0535161;
104 let constant_4 = 0.398942;
105
106 let x3 = B::float_powi_scalar(x.clone(), 3.into());
107
108 let c1 = B::float_mul_scalar(x3.clone(), constant_1.into());
109 let c2 = B::float_mul_scalar(x.clone(), constant_2.into());
110 let c3 = B::float_mul_scalar(x3, constant_3.into());
111 let c4 = B::float_mul_scalar(x, constant_4.into());
112
113 let inner1 = B::float_add(c1, c2);
114 let inner2 = B::float_add(c3, c4);
115
116 let tanh = B::float_tanh(inner1);
117
118 let sech = B::float_powi_scalar(tanh.clone(), 2.into());
119 let sech = B::float_neg(sech);
120 let sech = B::float_add_scalar(sech, 1.into());
121
122 let y1 = B::float_mul_scalar(tanh, 0.5.into());
123 let y2 = B::float_mul(inner2, sech);
124 let y2 = B::float_add_scalar(y2, 0.5.into());
125 let y = B::float_add(y1, y2);
126
127 B::float_mul(y, grad)
128 }
129
130 /// Applies the Sigmoid activation function.
131 ///
132 /// # Arguments
133 ///
134 /// * `tensor` - The tensor.
135 ///
136 /// # Returns
137 ///
138 /// The output tensor.
139 fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
140 let dtype = tensor.dtype();
141 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
142 let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
143 B::float_exp(B::float_neg(tensor_full)),
144 1.0.into(),
145 ))));
146
147 B::float_cast(tensor_tmp, dtype.into())
148 }
149
150 /// Applies the Sigmoid activation function backward.
151 ///
152 /// # Arguments
153 ///
154 /// * `output` - The output tensor of the sigmoid function.
155 /// * `grad` - The gradient.
156 ///
157 /// # Returns
158 ///
159 /// The output tensor.
160 fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
161 let value = B::float_mul(
162 output.clone(),
163 B::float_add_scalar(B::float_neg(output), 1.0.into()),
164 );
165 B::float_mul(value, grad)
166 }
167
168 /// Applies the hard Sigmoid activation function.
169 ///
170 /// # Arguments
171 ///
172 /// * `tensor` - The tensor.
173 /// * `alpha` - The alpha value that the tensor is multiplied with.
174 /// * `beta` - The beta value that is added to the tensor
175 ///
176 /// # Returns
177 ///
178 /// The output tensor.
179 fn hard_sigmoid(tensor: FloatTensor<B>, alpha: Scalar, beta: Scalar) -> FloatTensor<B> {
180 let dtype = tensor.dtype();
181 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
182
183 let tensor_tmp = B::float_clamp(
184 B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta),
185 0.0.into(),
186 1.0.into(),
187 );
188
189 B::float_cast(tensor_tmp, dtype.into())
190 }
191
192 /// Applies the LogSigmoid activation function.
193 ///
194 /// # Arguments
195 ///
196 /// * `tensor` - The tensor.
197 ///
198 /// # Returns
199 ///
200 /// The output tensor.
201 fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
202 // To avoid overflow, we use the log-sum-exp trick.
203 //
204 // ```ignore
205 // log(sigmoid(x)) = log(1/(1 + exp(-x)))
206 // = log(1) - log(1 + exp(-x))
207 // = -log(1 + exp(-x))
208 // = -log(exp(0) + exp(-x))
209 // ```
210 // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
211 // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
212 // following equivalence:
213 // ```ignore
214 // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
215 // ```
216 //
217 // This extends the range of values for which we obtain accurate results.
218
219 // max(-x, 0)
220 let bool_dtype = get_device_settings::<B>(&B::float_device(&tensor)).bool_dtype;
221 let tensor_neg = B::float_neg(tensor);
222 let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into(), bool_dtype);
223 let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into());
224 let max_elem_neg = B::float_neg(max_elem.clone());
225
226 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
227 let z = B::float_add(
228 B::float_exp(max_elem_neg.clone()),
229 B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
230 );
231
232 // -max(-x, 0) - log(-z)
233 B::float_sub(max_elem_neg, B::float_log(z))
234 }
235
236 /// Applies the softmax function along the given dimension.
237 ///
238 /// Uses the max-shift trick for numerical stability: the per-row `max` is detached
239 /// so no gradient flows back through it (the shift is a numerical-stability
240 /// transformation, not part of the function).
241 ///
242 /// # Arguments
243 ///
244 /// * `tensor` - The tensor.
245 /// * `dim` - The dimension along which softmax is computed.
246 ///
247 /// # Returns
248 ///
249 /// The output tensor.
250 fn softmax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
251 let max = B::float_max_dim(B::float_detach(tensor.clone()), dim);
252 let shifted = B::float_sub(tensor, max);
253 let exp = B::float_exp(shifted);
254 let sum = B::float_sum_dim(exp.clone(), dim);
255 B::float_div(exp, sum)
256 }
257
258 /// Applies the log-softmax function along the given dimension.
259 ///
260 /// Computed via the log-sum-exp trick with a detached max-shift for numerical
261 /// stability.
262 ///
263 /// # Arguments
264 ///
265 /// * `tensor` - The tensor.
266 /// * `dim` - The dimension along which log-softmax is computed.
267 ///
268 /// # Returns
269 ///
270 /// The output tensor.
271 fn log_softmax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
272 let max = B::float_max_dim(B::float_detach(tensor.clone()), dim);
273 let shifted = B::float_sub(tensor, max);
274 let log_sum_exp = B::float_log(B::float_sum_dim(B::float_exp(shifted.clone()), dim));
275 B::float_sub(shifted, log_sum_exp)
276 }
277
278 /// Applies the softmin function along the given dimension.
279 ///
280 /// Equivalent to `softmax(-tensor, dim)`.
281 ///
282 /// # Arguments
283 ///
284 /// * `tensor` - The tensor.
285 /// * `dim` - The dimension along which softmin is computed.
286 ///
287 /// # Returns
288 ///
289 /// The output tensor.
290 fn softmin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {
291 Self::softmax(B::float_neg(tensor), dim)
292 }
293
294 /// Applies the LogSigmoid activation function backward.
295 ///
296 /// # Arguments
297 ///
298 /// * `x` - The input tensor.
299 /// * `grad` - The gradient.
300 ///
301 /// # Returns
302 ///
303 /// The output gradient.
304 fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
305 // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
306 // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
307 // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
308 //
309 // This simplifies to:
310 // -max_derive - (z-1)/z if x is >= 0
311 // -max_derive + (z-1)/z if x is < 0
312
313 let shape = x.shape();
314 let dtype = x.dtype();
315 let device = B::float_device(&x);
316 let bool_dtype = get_device_settings::<B>(&device).bool_dtype;
317
318 // max(-x, 0)
319 let x_neg = B::float_neg(x);
320 let mask = B::float_lower_elem(x_neg.clone(), 0f32.into(), bool_dtype); // -x < 0 or x >= 0
321 let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into());
322
323 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
324 let z = B::float_add(
325 B::float_exp(B::float_neg(max_elem.clone())),
326 B::float_exp(B::float_sub(x_neg, max_elem)),
327 );
328
329 // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
330 let ones = B::float_ones(shape, &device, dtype.into());
331 let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into());
332 let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into());
333
334 // grad * (max_derive - sign * (1 - (1 / z)))
335 B::float_mul(
336 grad,
337 B::float_sub(
338 max_derive,
339 B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
340 ),
341 )
342 }
343}