burn_tensor/tensor/ops/activation.rs
1use crate::TensorMetadata;
2use crate::{ElementConversion, backend::Backend};
3use core::f64::consts::SQRT_2;
4
5use super::FloatTensor;
6
7/// Activation function operations.
8///
9/// This trait let backend implementations override activation functions for better performance.
10pub trait ActivationOps<B: Backend> {
11 /// Applies the LeakyReLU activation function.
12 ///
13 /// # Arguments
14 ///
15 /// * `tensor` - The tensor.
16 /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.
17 ///
18 /// # Returns
19 ///
20 /// The output tensor.
21 fn leaky_relu(tensor: FloatTensor<B>, negative_slope: super::FloatElem<B>) -> FloatTensor<B> {
22 let mask = B::float_lower_elem(tensor.clone(), 0.elem());
23 let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope.elem());
24
25 // Update the tensor where the values are `< 0` by `tensor * negative_slope`.
26 B::float_mask_where(tensor, mask, scaled_tensor)
27 }
28
29 /// Applies the ReLU activation function.
30 ///
31 /// # Arguments
32 ///
33 /// * `tensor` - The tensor.
34 ///
35 /// # Returns
36 ///
37 /// The output tensor.
38 fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {
39 let mask = B::float_lower_equal_elem(tensor.clone(), 0.elem());
40
41 B::float_mask_fill(tensor, mask, 0.elem())
42 }
43
44 /// Applies the ReLU activation function backward.
45 ///
46 /// # Arguments
47 ///
48 /// * `output` - The output tensor.
49 ///
50 /// # Returns
51 ///
52 /// The gradient.
53 fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
54 let mask = B::float_lower_equal_elem(output, 0.elem());
55
56 B::float_mask_fill(grad, mask, 0.elem())
57 }
58
59 /// Applies the Gelu activation function.
60 ///
61 /// # Arguments
62 ///
63 /// * `tensor` - The tensor.
64 ///
65 /// # Returns
66 ///
67 /// The output tensor.
68 fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {
69 let x = B::float_div_scalar(tensor.clone(), SQRT_2.elem());
70 let x = B::float_erf(x);
71 let x = B::float_add_scalar(x, 1i32.elem());
72 let x = B::float_mul(tensor, x);
73
74 B::float_div_scalar(x, 2i32.elem())
75 }
76 /// Applies the PReLu activation function.
77 /// # Arguments
78 /// * `tensor` - The input tensor
79 /// * `alpha` - The weight tensor
80 fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {
81 let mask = B::float_lower_elem(tensor.clone(), 0.elem());
82 let scaled_tensor = B::float_mul(tensor.clone(), alpha);
83 B::float_mask_where(tensor, mask, scaled_tensor)
84 }
85
86 /// Applies the Gelu activation function backward.
87 ///
88 /// # Arguments
89 ///
90 /// * `x` - The tensor.
91 /// * `grad` - The gradient.
92 ///
93 /// # Returns
94 ///
95 /// The output tensor.
96 fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
97 // Derivative of the approximate gelu implementation based on tanh.
98
99 let constant_1 = 0.0356774;
100 let constant_2 = 0.797885;
101 let constant_3 = 0.0535161;
102 let constant_4 = 0.398942;
103
104 let x3 = B::float_powi_scalar(x.clone(), 3.elem());
105
106 let c1 = B::float_mul_scalar(x3.clone(), constant_1.elem());
107 let c2 = B::float_mul_scalar(x.clone(), constant_2.elem());
108 let c3 = B::float_mul_scalar(x3, constant_3.elem());
109 let c4 = B::float_mul_scalar(x, constant_4.elem());
110
111 let inner1 = B::float_add(c1, c2);
112 let inner2 = B::float_add(c3, c4);
113
114 let tanh = B::float_tanh(inner1);
115
116 let sech = B::float_powi_scalar(tanh.clone(), 2.elem());
117 let sech = B::float_neg(sech);
118 let sech = B::float_add_scalar(sech, 1.elem());
119
120 let y1 = B::float_mul_scalar(tanh, 0.5.elem());
121 let y2 = B::float_mul(inner2, sech);
122 let y2 = B::float_add_scalar(y2, 0.5.elem());
123 let y = B::float_add(y1, y2);
124
125 B::float_mul(y, grad)
126 }
127
128 /// Applies the Sigmoid activation function.
129 ///
130 /// # Arguments
131 ///
132 /// * `tensor` - The tensor.
133 ///
134 /// # Returns
135 ///
136 /// The output tensor.
137 fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
138 let dtype = tensor.dtype();
139 let tensor_full = B::float_cast(tensor, crate::FloatDType::F32);
140 let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(
141 B::float_exp(B::float_neg(tensor_full)),
142 1.0.elem(),
143 ))));
144
145 B::float_cast(tensor_tmp, dtype.into())
146 }
147
148 /// Applies the Sigmoid activation function backward.
149 ///
150 /// # Arguments
151 ///
152 /// * `output` - The output tensor of the sigmoid function.
153 /// * `grad` - The gradient.
154 ///
155 /// # Returns
156 ///
157 /// The output tensor.
158 fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
159 let value = B::float_mul(
160 output.clone(),
161 B::float_add_scalar(B::float_neg(output), 1.0.elem()),
162 );
163 B::float_mul(value, grad)
164 }
165
166 /// Applies the hard Sigmoid activation function.
167 ///
168 /// # Arguments
169 ///
170 /// * `tensor` - The tensor.
171 /// * `alpha` - The alpha value that the tensor is multiplied with.
172 /// * `beta` - The beta value that is added to the tensor
173 ///
174 /// # Returns
175 ///
176 /// The output tensor.
177 fn hard_sigmoid(
178 tensor: FloatTensor<B>,
179 alpha: super::FloatElem<B>,
180 beta: super::FloatElem<B>,
181 ) -> FloatTensor<B> {
182 let dtype = tensor.dtype();
183 let tensor_full = B::float_cast(tensor, crate::FloatDType::F32);
184
185 let tensor_tmp = B::float_clamp(
186 B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha.elem()), beta.elem()),
187 0.0.elem(),
188 1.0.elem(),
189 );
190
191 B::float_cast(tensor_tmp, dtype.into())
192 }
193
194 /// Applies the LogSigmoid activation function.
195 ///
196 /// # Arguments
197 ///
198 /// * `tensor` - The tensor.
199 ///
200 /// # Returns
201 ///
202 /// The output tensor.
203 fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {
204 // To avoid overflow, we use the log-sum-exp trick.
205 //
206 // ```ignore
207 // log(sigmoid(x)) = log(1/(1 + exp(-x)))
208 // = log(1) - log(1 + exp(-x))
209 // = -log(1 + exp(-x))
210 // = -log(exp(0) + exp(-x))
211 // ```
212 // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we
213 // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the
214 // following equivalence:
215 // ```ignore
216 // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
217 // ```
218 //
219 // This extends the range of values for which we obtain accurate results.
220
221 // max(-x, 0)
222 let tensor_neg = B::float_neg(tensor);
223 let mask = B::float_lower_elem(tensor_neg.clone(), 0.elem());
224 let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0.elem());
225 let max_elem_neg = B::float_neg(max_elem.clone());
226
227 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
228 let z = B::float_add(
229 B::float_exp(max_elem_neg.clone()),
230 B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),
231 );
232
233 // -max(-x, 0) - log(-z)
234 B::float_sub(max_elem_neg, B::float_log(z))
235 }
236
237 /// Applies the LogSigmoid activation function backward.
238 ///
239 /// # Arguments
240 ///
241 /// * `x` - The input tensor.
242 /// * `grad` - The gradient.
243 ///
244 /// # Returns
245 ///
246 /// The output gradient.
247 fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
248 // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is
249 // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z
250 // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
251 //
252 // This simplifies to:
253 // -max_derive - (z-1)/z if x is >= 0
254 // -max_derive + (z-1)/z if x is < 0
255
256 let shape = x.shape();
257 let dtype = x.dtype();
258 let device = B::float_device(&x);
259
260 // max(-x, 0)
261 let x_neg = B::float_neg(x);
262 let mask = B::float_lower_elem(x_neg.clone(), 0.elem()); // -x < 0 or x >= 0
263 let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0.elem());
264
265 // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))
266 let z = B::float_add(
267 B::float_exp(B::float_neg(max_elem.clone())),
268 B::float_exp(B::float_sub(x_neg, max_elem)),
269 );
270
271 // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0
272 let ones = B::float_ones(shape, &device, dtype.into());
273 let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0.elem());
274 let sign = B::float_mask_fill(ones.clone(), mask, (-1).elem());
275
276 // grad * (max_derive - sign * (1 - (1 / z)))
277 B::float_mul(
278 grad,
279 B::float_sub(
280 max_derive,
281 B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),
282 ),
283 )
284 }
285}