Skip to main content

burn_tensor/tensor/activation/
base.rs

1use crate::backend::Backend;
2use crate::check::TensorCheck;
3use crate::{Tensor, TensorPrimitive, check, s};
4
5/// Applies the rectified linear unit function element-wise
6/// as described in the paper [Deep Learning using Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375).
7///
8#[cfg_attr(doc, doc = "$$\\text{ReLU}\\(x\\) = \\(x\\)^+ = \\max\\(0, x\\)$$")]
9#[cfg_attr(not(doc), doc = "`ReLU(x) = max(0, x)`")]
10pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
11    tensor.relu()
12}
13
14/// Applies the leaky rectified linear unit function element-wise.
15///
16#[cfg_attr(
17    doc,
18    doc = r#"
19$$
20\text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \cdot \min\(0, x\)
21$$
22
23or
24
25$$
26\text{LeakyReLU}(x) =
27 \begin{cases}
28     x & \text{if } x \geq 0 \newline
29     \text{negative\\_slope} \cdot x & \text{otherwise}
30 \end{cases}
31$$
32"#
33)]
34#[cfg_attr(
35    not(doc),
36    doc = "`f(x) =`\n- `x for x >= 0`\n- `negative_slope * x if x < 0`"
37)]
38pub fn leaky_relu<const D: usize, B: Backend>(
39    tensor: Tensor<B, D>,
40    negative_slope: f64,
41) -> Tensor<B, D> {
42    Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(
43        tensor.primitive.tensor(),
44        crate::ElementConversion::elem(negative_slope),
45    )))
46}
47
48/// Applies the Gaussian Error Linear Units function as described in the paper
49/// [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
50///
51#[cfg_attr(
52    doc,
53    doc = r#"
54$$
55\text{GELU}(x)
56= x \cdot \Phi(x)
57= x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
58$$
59
60where $\Phi(x)$ is the cumulative distribution function for the Gaussian distribution.
61"#
62)]
63#[cfg_attr(
64    not(doc),
65    doc = r#"
66`GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))`
67
68where `Φ(x)` is the cumulative distribution function for the Gaussian distribution.
69"#
70)]
71pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
72    Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))
73}
74
75/// Applies Parametric ReLu activation function as described in the paper
76/// [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).
77///
78/// - The tensor is assumed to be of shape `[batch_size, channels, ...]`.
79/// - `alpha` is assumed to be of shape `[channels]` or `[1]`.
80///
81#[cfg_attr(
82    doc,
83    doc = r#"
84$$
85\text{PReLU}\(x\) = \max\(0,x\) + \alpha \cdot \min\(0, x\)
86$$
87
88or
89
90$$
91\text{PReLU}(x) =
92 \begin{cases}
93     x & \text{if } x \geq 0 \newline
94     \alpha x & \text{otherwise}
95 \end{cases}
96$$
97"#
98)]
99#[cfg_attr(not(doc), doc = "`PReLu(x) = max(0,x) + alpha * min(0,x)`")]
100pub fn prelu<const D: usize, B: Backend>(
101    tensor: Tensor<B, D>,
102    alpha: Tensor<B, 1>,
103) -> Tensor<B, D> {
104    check!(TensorCheck::check_prelu_shape::<D>(
105        &tensor.shape(),
106        &alpha.shape()
107    ));
108
109    let weight = if alpha.dims()[0] == 1 {
110        // if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D
111        alpha.reshape([1; D])
112    } else {
113        // D>=2 because the case where D==1 and num_weights >1 is handled by check function
114        // there is more than 1 weight and rank is more than 2
115        let num_weights = alpha.dims()[0];
116        let mut s = [1; D];
117        s[1] = num_weights;
118        // reshape the weights to (1, channels,1 ...)
119        alpha.reshape(s)
120    };
121
122    Tensor::from_primitive(TensorPrimitive::Float(B::prelu(
123        tensor.primitive.tensor(),
124        weight.primitive.tensor(),
125    )))
126}
127
128/// Applies the softmax function on the input tensor along the given dimension.
129///
130#[cfg_attr(
131    doc,
132    doc = r#"
133$$
134\text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}
135$$
136"#
137)]
138#[cfg_attr(not(doc), doc = "`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`")]
139///
140/// # Arguments
141/// - `dim`: the dimension along which Softmax will be computed.
142///
143/// # Panics
144/// - If `dim` is outside [0, D)
145pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
146    check!(TensorCheck::dim_ops::<D>("softmax", dim));
147
148    let tensor = tensor.clone() - tensor.detach().max_dim(dim);
149    let tensor = tensor.exp();
150    let tensor_tmp = tensor.clone().sum_dim(dim);
151
152    tensor.div(tensor_tmp)
153}
154
155/// Applies the softmin function on the input tensor along the given dimension.
156///
157#[cfg_attr(
158    doc,
159    doc = r#"
160$$
161\text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)}
162$$
163"#
164)]
165#[cfg_attr(not(doc), doc = "`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`")]
166///
167/// # Arguments
168/// - `dim`: the dimension along which Softmax will be computed.
169///
170/// # Panics
171/// - If `dim` is outside [0, D)
172pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
173    check!(TensorCheck::dim_ops::<D>("softmin", dim));
174    softmax(tensor.neg(), dim)
175}
176
177/// Applies the SoftPlus function element-wise.
178///
179#[cfg_attr(
180    doc,
181    doc = r#"
182$$
183\text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\)
184$$
185"#
186)]
187#[cfg_attr(not(doc), doc = "`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`")]
188///
189/// The SoftPlus function is a smooth approximation of the ReLU function.
190pub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {
191    let tensor = (tensor.mul_scalar(beta).exp() + 1).log();
192    tensor.div_scalar(beta)
193}
194
195/// Applies the "quiet softmax" function on the input tensor along the given dimension.
196///
197/// Also referred to as [`softmax1`](https://www.evanmiller.org/attention-is-off-by-one.html).
198///
199/// This function is similar to the softmax function, but it allows for "no selection" when
200/// all the outputs are close to zero.
201///
202#[cfg_attr(
203    doc,
204    doc = r#"
205$$
206\text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)}
207$$
208"#
209)]
210#[cfg_attr(
211    not(doc),
212    doc = "`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`"
213)]
214///
215/// # Arguments
216/// - `dim`: the dimension along which Softmax will be computed.
217///
218/// # Panics
219/// - If `dim` is outside [0, D)
220pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
221    check!(TensorCheck::dim_ops::<D>("softmax", dim));
222
223    let max_vals = tensor.clone().detach().max_dim(dim);
224    let exp_x = (tensor - max_vals.clone()).exp();
225    let sum_exp = exp_x.clone().sum_dim(dim);
226
227    exp_x.div(sum_exp + max_vals.neg().exp())
228}
229
230/// Applies the log softmax function on the input tensor along the given dimension.
231///
232#[cfg_attr(
233    doc,
234    doc = r#"
235$$
236\text{log\\_softmax}\(x_i\)
237= \log\left(\text{softmax}\(x_i\)\right)
238= \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right)
239$$
240"#
241)]
242#[cfg_attr(
243    not(doc),
244    doc = "`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`"
245)]
246///
247/// # Arguments
248/// - `dim`: the dimension along which Softmax will be computed.
249///
250/// # Panics
251/// - If `dim` is outside [0, D)
252pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
253    check!(TensorCheck::dim_ops::<D>("log softmax", dim));
254
255    let tensor = tensor.clone() - tensor.detach().max_dim(dim);
256    let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();
257
258    tensor.sub(tensor_tmp)
259}
260
261/// Applies the sigmoid function element-wise.
262///
263#[cfg_attr(
264    doc,
265    doc = r#"
266$$
267\text{sigmoid}\(x\)
268= \sigma(x)
269= \frac{1}{1 + \exp(-x)}
270$$
271"#
272)]
273#[cfg_attr(not(doc), doc = "`sigmoid(x) = 1 / (1 + exp(-x))`")]
274pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
275    Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(
276        tensor.primitive.tensor(),
277    )))
278}
279
280/// Applies the hard sigmoid function element-wise.
281///
282#[cfg_attr(
283    doc,
284    doc = r#"
285$$
286\text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha \cdot x + \beta))
287$$
288"#
289)]
290#[cfg_attr(not(doc), doc = "`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`")]
291pub fn hard_sigmoid<const D: usize, B: Backend>(
292    tensor: Tensor<B, D>,
293    alpha: f64,
294    beta: f64,
295) -> Tensor<B, D> {
296    Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
297        tensor.primitive.tensor(),
298        crate::ElementConversion::elem(alpha),
299        crate::ElementConversion::elem(beta),
300    )))
301}
302
303/// Applies the log sigmoid function element-wise.
304///
305#[cfg_attr(
306    doc,
307    doc = r#"
308$$
309\text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right)
310$$
311"#
312)]
313#[cfg_attr(not(doc), doc = "`log_sigmoid(x) = log(1 / (1 + exp(-x)))`")]
314pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
315    Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(
316        tensor.primitive.tensor(),
317    )))
318}
319
320/// Applies the SiLU function (also known as the swish function) element-wise.
321///
322#[cfg_attr(
323    doc,
324    doc = r#"
325$$
326\text{SiLU}\(x\) = x \cdot \sigma(x) = \frac{x}{1 + \exp(-x)}
327$$
328"#
329)]
330#[cfg_attr(not(doc), doc = "`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`")]
331pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
332    tensor.clone().mul(sigmoid(tensor))
333}
334
335/// Applies the hard swish function element-wise.
336///
337#[cfg_attr(
338    doc,
339    doc = r#"
340$$
341\text{hard\_swish}\(x\) = x \cdot \text{hard\_sigmoid}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5))
342$$
343"#
344)]
345#[cfg_attr(
346    not(doc),
347    doc = "`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`"
348)]
349pub fn hard_swish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
350    tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5))
351}
352
353/// Applies the Mish function as described in the paper in
354/// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
355///
356#[cfg_attr(
357    doc,
358    doc = r#"
359$$
360\text{Mish}\(x\)
361= x \cdot \tanh(\text{Softplus}(x))
362= \tanh\left(\log\(1 + \exp\(x\)\)\right)
363$$
364"#
365)]
366#[cfg_attr(
367    not(doc),
368    doc = "`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`"
369)]
370pub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
371    tensor.clone().mul(softplus(tensor, 1.0).tanh())
372}
373
374/// Applies the tanh function element-wise.
375pub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
376    tensor.tanh()
377}
378
379/// Applies the gated linear unit function.
380///
381/// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
382///
383/// **Note**:
384/// * The size of the input tensor along `dim` must be divisible by 2.
385///
386/// ### Arguments
387/// * `tensor` - The input tensor.
388///
389/// ### Returns
390/// * A tensor with the same shape as the input, except the size along `dim` is halved.
391pub fn glu<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
392    // TODO: Handle negative indices with AsIndex for compatibility with Pytorch nn.GLU.
393
394    assert!(
395        tensor.dims()[dim].is_multiple_of(2),
396        "Input tensor along dimension {dim} must have an even size. N is divisible by 2."
397    );
398    let new_len = tensor.dims()[dim] / 2;
399
400    let a = tensor.clone().slice_dim(dim, s![0..new_len]);
401    let b = tensor.slice_dim(dim, s![new_len..new_len * 2]);
402
403    a.mul(sigmoid(b))
404}