burn_backend/backend/ops/activation.rs
1use crate::element::ElementConversion;
2use crate::tensor::{FloatElem, FloatTensor};
3use crate::{Backend, TensorMetadata};
4use core::f64::consts::SQRT_2;
5
6/// Activation function operations.
7///
8/// This trait let backend implementations override activation functions for better performance.
9pub trait ActivationOps<B: Backend> {
10 /// Applies the LeakyReLU activation function.
11 ///
12 /// # Arguments
13 ///
14 /// * `tensor` - The tensor.
15 /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
16 ///
17 /// # Returns
18 ///
19 /// The output tensor.
20 fn leaky_relu(tensor: FloatTensor<B>, negative_slope: FloatElem<B>) -> FloatTensor<B> {
21 let mask = B::float_lower_elem(tensor.clone(), 0.elem());
22 let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope.elem());
23
24 // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
25 B::float_mask_where(tensor, mask, scaled_tensor)
26 }
27
28 /// Applies the ReLU activation function.
29 ///
30 /// # Arguments
31 ///
32 /// * `tensor` - The tensor.
33 ///
34 /// # Returns
35 ///
36 /// The output tensor.
37 fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
38 let mask = B::float_lower_equal_elem(tensor.clone(), 0.elem());
39
40 B::float_mask_fill(tensor, mask, 0.elem())
41 }
42
43 /// Applies the ReLU activation function backward.
44 ///
45 /// # Arguments
46 ///
47 /// * `output` - The output tensor.
48 ///
49 /// # Returns
50 ///
51 /// The gradient.
52 fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
53 let mask = B::float_lower_equal_elem(output, 0.elem());
54
55 B::float_mask_fill(grad, mask, 0.elem())
56 }
57
58 /// Applies the Gelu activation function.
59 ///
60 /// # Arguments
61 ///
62 /// * `tensor` - The tensor.
63 ///
64 /// # Returns
65 ///
66 /// The output tensor.
67 fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
68 let x = B::float_div_scalar(tensor.clone(), SQRT_2.elem());
69 let x = B::float_erf(x);
70 let x = B::float_add_scalar(x, 1i32.elem());
71 let x = B::float_mul(tensor, x);
72
73 B::float_div_scalar(x, 2i32.elem())
74 }
75 /// Applies the PReLu activation function.
76 /// # Arguments
77 /// * `tensor` - The input tensor
78 /// * `alpha` - The weight tensor
79 fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
80 let mask = B::float_lower_elem(tensor.clone(), 0.elem());
81 let scaled_tensor = B::float_mul(tensor.clone(), alpha);
82 B::float_mask_where(tensor, mask, scaled_tensor)
83 }
84
85 /// Applies the Gelu activation function backward.
86 ///
87 /// # Arguments
88 ///
89 /// * `x` - The tensor.
90 /// * `grad` - The gradient.
91 ///
92 /// # Returns
93 ///
94 /// The output tensor.
95 fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
96 // Derivative of the approximate gelu implementation based on tanh.
97
98 let constant_1 = 0.0356774;
99 let constant_2 = 0.797885;
100 let constant_3 = 0.0535161;
101 let constant_4 = 0.398942;
102
103 let x3 = B::float_powi_scalar(x.clone(), 3.elem());
104
105 let c1 = B::float_mul_scalar(x3.clone(), constant_1.elem());
106 let c2 = B::float_mul_scalar(x.clone(), constant_2.elem());
107 let c3 = B::float_mul_scalar(x3, constant_3.elem());
108 let c4 = B::float_mul_scalar(x, constant_4.elem());
109
110 let inner1 = B::float_add(c1, c2);
111 let inner2 = B::float_add(c3, c4);
112
113 let tanh = B::float_tanh(inner1);
114
115 let sech = B::float_powi_scalar(tanh.clone(), 2.elem());
116 let sech = B::float_neg(sech);
117 let sech = B::float_add_scalar(sech, 1.elem());
118
119 let y1 = B::float_mul_scalar(tanh, 0.5.elem());
120 let y2 = B::float_mul(inner2, sech);
121 let y2 = B::float_add_scalar(y2, 0.5.elem());
122 let y = B::float_add(y1, y2);
123
124 B::float_mul(y, grad)
125 }
126
127 /// Applies the Sigmoid activation function.
128 ///
129 /// # Arguments
130 ///
131 /// * `tensor` - The tensor.
132 ///
133 /// # Returns
134 ///
135 /// The output tensor.
136 fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
137 let dtype = tensor.dtype();
138 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
139 let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
140 B::float_exp(B::float_neg(tensor_full)),
141 1.0.elem(),
142 ))));
143
144 B::float_cast(tensor_tmp, dtype.into())
145 }
146
147 /// Applies the Sigmoid activation function backward.
148 ///
149 /// # Arguments
150 ///
151 /// * `output` - The output tensor of the sigmoid function.
152 /// * `grad` - The gradient.
153 ///
154 /// # Returns
155 ///
156 /// The output tensor.
157 fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
158 let value = B::float_mul(
159 output.clone(),
160 B::float_add_scalar(B::float_neg(output), 1.0.elem()),
161 );
162 B::float_mul(value, grad)
163 }
164
165 /// Applies the hard Sigmoid activation function.
166 ///
167 /// # Arguments
168 ///
169 /// * `tensor` - The tensor.
170 /// * `alpha` - The alpha value that the tensor is multiplied with.
171 /// * `beta` - The beta value that is added to the tensor
172 ///
173 /// # Returns
174 ///
175 /// The output tensor.
176 fn hard_sigmoid(
177 tensor: FloatTensor<B>,
178 alpha: FloatElem<B>,
179 beta: FloatElem<B>,
180 ) -> FloatTensor<B> {
181 let dtype = tensor.dtype();
182 let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);
183
184 let tensor_tmp = B::float_clamp(
185 B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha.elem()), beta.elem()),
186 0.0.elem(),
187 1.0.elem(),
188 );
189
190 B::float_cast(tensor_tmp, dtype.into())
191 }
192
193 /// Applies the LogSigmoid activation function.
194 ///
195 /// # Arguments
196 ///
197 /// * `tensor` - The tensor.
198 ///
199 /// # Returns
200 ///
201 /// The output tensor.
202 fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
203 // To avoid overflow, we use the log-sum-exp trick.
204 //
205 // ```ignore
206 // log(sigmoid(x)) = log(1/(1 + exp(-x)))
207 // = log(1) - log(1 + exp(-x))
208 // = -log(1 + exp(-x))
209 // = -log(exp(0) + exp(-x))
210 // ```
211 // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
212 // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
213 // following equivalence:
214 // ```ignore
215 // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
216 // ```
217 //
218 // This extends the range of values for which we obtain accurate results.
219
220 // max(-x, 0)
221 let tensor_neg = B::float_neg(tensor);
222 let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem());
223 let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem());
224 let max_elem_neg = B::float_neg(max_elem.clone());
225
226 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
227 let z = B::float_add(
228 B::float_exp(max_elem_neg.clone()),
229 B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
230 );
231
232 // -max(-x, 0) - log(-z)
233 B::float_sub(max_elem_neg, B::float_log(z))
234 }
235
236 /// Applies the LogSigmoid activation function backward.
237 ///
238 /// # Arguments
239 ///
240 /// * `x` - The input tensor.
241 /// * `grad` - The gradient.
242 ///
243 /// # Returns
244 ///
245 /// The output gradient.
246 fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
247 // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
248 // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
249 // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
250 //
251 // This simplifies to:
252 // -max_derive - (z-1)/z if x is >= 0
253 // -max_derive + (z-1)/z if x is < 0
254
255 let shape = x.shape();
256 let dtype = x.dtype();
257 let device = B::float_device(&x);
258
259 // max(-x, 0)
260 let x_neg = B::float_neg(x);
261 let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0
262 let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem());
263
264 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
265 let z = B::float_add(
266 B::float_exp(B::float_neg(max_elem.clone())),
267 B::float_exp(B::float_sub(x_neg, max_elem)),
268 );
269
270 // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
271 let ones = B::float_ones(shape, &device, dtype.into());
272 let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem());
273 let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem());
274
275 // grad * (max_derive - sign * (1 - (1 / z)))
276 B::float_mul(
277 grad,
278 B::float_sub(
279 max_derive,
280 B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
281 ),
282 )
283 }
284}