Skip to main content

burn_tensor/tensor/activation/
base.rs

1use crate::backend::Backend;
2use crate::check::TensorCheck;
3use crate::{Tensor, TensorPrimitive, check, s};
4
5/// Applies the rectified linear unit function element-wise
6/// as described in the paper [Deep Learning using Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375).
7///
8#[cfg_attr(doc, doc = "$$\\text{ReLU}\\(x\\) = \\(x\\)^+ = \\max\\(0, x\\)$$")]
9#[cfg_attr(not(doc), doc = "`ReLU(x) = max(0, x)`")]
10pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
11    tensor.relu()
12}
13
14/// Applies the leaky rectified linear unit function element-wise.
15///
16#[cfg_attr(
17    doc,
18    doc = r#"
19$$
20\text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \cdot \min\(0, x\)
21$$
22
23or
24
25$$
26\text{LeakyReLU}(x) =
27 \begin{cases}
28     x & \text{if } x \geq 0 \newline
29     \text{negative\\_slope} \cdot x & \text{otherwise}
30 \end{cases}
31$$
32"#
33)]
34#[cfg_attr(
35    not(doc),
36    doc = "`f(x) =`\n- `x for x >= 0`\n- `negative_slope * x if x < 0`"
37)]
38pub fn leaky_relu<const D: usize, B: Backend>(
39    tensor: Tensor<B, D>,
40    negative_slope: f64,
41) -> Tensor<B, D> {
42    Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(
43        tensor.primitive.tensor(),
44        negative_slope.into(),
45    )))
46}
47
48/// Applies the Gaussian Error Linear Units function as described in the paper
49/// [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
50///
51#[cfg_attr(
52    doc,
53    doc = r#"
54$$
55\text{GELU}(x)
56= x \cdot \Phi(x)
57= x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
58$$
59
60where $\Phi(x)$ is the cumulative distribution function for the Gaussian distribution.
61"#
62)]
63#[cfg_attr(
64    not(doc),
65    doc = r#"
66`GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))`
67
68where `Φ(x)` is the cumulative distribution function for the Gaussian distribution.
69"#
70)]
71pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
72    Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))
73}
74
75/// Applies the tanh-based approximate GELU function element-wise.
76///
77#[cfg_attr(
78    doc,
79    doc = r#"
80$$
81\text{GELU\_approx}(x)
82= \frac{x}{2}\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right)
83$$
84"#
85)]
86#[cfg_attr(
87    not(doc),
88    doc = "`GELU_approx(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`"
89)]
90pub fn gelu_approximate<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
91    /// sqrt(2/π) precomputed as FRAC_2_SQRT_PI * FRAC_1_SQRT_2
92    const SQRT_2_OVER_PI: f64 =
93        core::f64::consts::FRAC_2_SQRT_PI * core::f64::consts::FRAC_1_SQRT_2;
94
95    let x = tensor;
96    let inner = x.clone() + x.clone().powf_scalar(3.0) * 0.044715;
97    let inner = inner * SQRT_2_OVER_PI;
98    (x.clone() * (inner.tanh() + 1)) * 0.5
99}
100
101/// Applies Parametric ReLu activation function as described in the paper
102/// [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).
103///
104/// - The tensor is assumed to be of shape `[batch_size, channels, ...]`.
105/// - `alpha` is assumed to be of shape `[channels]` or `[1]`.
106///
107#[cfg_attr(
108    doc,
109    doc = r#"
110$$
111\text{PReLU}\(x\) = \max\(0,x\) + \alpha \cdot \min\(0, x\)
112$$
113
114or
115
116$$
117\text{PReLU}(x) =
118 \begin{cases}
119     x & \text{if } x \geq 0 \newline
120     \alpha x & \text{otherwise}
121 \end{cases}
122$$
123"#
124)]
125#[cfg_attr(not(doc), doc = "`PReLu(x) = max(0,x) + alpha * min(0,x)`")]
126pub fn prelu<const D: usize, B: Backend>(
127    tensor: Tensor<B, D>,
128    alpha: Tensor<B, 1>,
129) -> Tensor<B, D> {
130    check!(TensorCheck::check_prelu_shape::<D>(
131        &tensor.shape(),
132        &alpha.shape()
133    ));
134
135    let weight = if alpha.dims()[0] == 1 {
136        // if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D
137        alpha.reshape([1; D])
138    } else {
139        // D>=2 because the case where D==1 and num_weights >1 is handled by check function
140        // there is more than 1 weight and rank is more than 2
141        let num_weights = alpha.dims()[0];
142        let mut s = [1; D];
143        s[1] = num_weights;
144        // reshape the weights to (1, channels,1 ...)
145        alpha.reshape(s)
146    };
147
148    Tensor::from_primitive(TensorPrimitive::Float(B::prelu(
149        tensor.primitive.tensor(),
150        weight.primitive.tensor(),
151    )))
152}
153
154/// Applies the softmax function on the input tensor along the given dimension.
155///
156#[cfg_attr(
157    doc,
158    doc = r#"
159$$
160\text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}
161$$
162"#
163)]
164#[cfg_attr(not(doc), doc = "`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`")]
165///
166/// # Arguments
167/// - `dim`: the dimension along which Softmax will be computed.
168///
169/// # Panics
170/// - If `dim` is outside [0, D)
171pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
172    check!(TensorCheck::dim_ops::<D>("softmax", dim));
173
174    Tensor::from_primitive(TensorPrimitive::Float(B::softmax(
175        tensor.primitive.tensor(),
176        dim,
177    )))
178}
179
180/// Applies the softmin function on the input tensor along the given dimension.
181///
182#[cfg_attr(
183    doc,
184    doc = r#"
185$$
186\text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)}
187$$
188"#
189)]
190#[cfg_attr(not(doc), doc = "`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`")]
191///
192/// # Arguments
193/// - `dim`: the dimension along which Softmax will be computed.
194///
195/// # Panics
196/// - If `dim` is outside [0, D)
197pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
198    check!(TensorCheck::dim_ops::<D>("softmin", dim));
199
200    Tensor::from_primitive(TensorPrimitive::Float(B::softmin(
201        tensor.primitive.tensor(),
202        dim,
203    )))
204}
205
206/// Applies the SoftPlus function element-wise.
207///
208#[cfg_attr(
209    doc,
210    doc = r#"
211$$
212\text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\)
213$$
214"#
215)]
216#[cfg_attr(not(doc), doc = "`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`")]
217///
218/// The SoftPlus function is a smooth approximation of the ReLU function.
219pub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {
220    let tensor = (tensor.mul_scalar(beta).exp() + 1).log();
221    tensor.div_scalar(beta)
222}
223
224/// Applies the "quiet softmax" function on the input tensor along the given dimension.
225///
226/// Also referred to as [`softmax1`](https://www.evanmiller.org/attention-is-off-by-one.html).
227///
228/// This function is similar to the softmax function, but it allows for "no selection" when
229/// all the outputs are close to zero.
230///
231#[cfg_attr(
232    doc,
233    doc = r#"
234$$
235\text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)}
236$$
237"#
238)]
239#[cfg_attr(
240    not(doc),
241    doc = "`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`"
242)]
243///
244/// # Arguments
245/// - `dim`: the dimension along which Softmax will be computed.
246///
247/// # Panics
248/// - If `dim` is outside [0, D)
249pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
250    check!(TensorCheck::dim_ops::<D>("softmax", dim));
251
252    let max_vals = tensor.clone().detach().max_dim(dim);
253    let exp_x = (tensor - max_vals.clone()).exp();
254    let sum_exp = exp_x.clone().sum_dim(dim);
255
256    exp_x.div(sum_exp + max_vals.neg().exp())
257}
258
259/// Applies the log softmax function on the input tensor along the given dimension.
260///
261#[cfg_attr(
262    doc,
263    doc = r#"
264$$
265\text{log\\_softmax}\(x_i\)
266= \log\left(\text{softmax}\(x_i\)\right)
267= \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right)
268$$
269"#
270)]
271#[cfg_attr(
272    not(doc),
273    doc = "`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`"
274)]
275///
276/// # Arguments
277/// - `dim`: the dimension along which Softmax will be computed.
278///
279/// # Panics
280/// - If `dim` is outside [0, D)
281pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
282    check!(TensorCheck::dim_ops::<D>("log softmax", dim));
283
284    Tensor::from_primitive(TensorPrimitive::Float(B::log_softmax(
285        tensor.primitive.tensor(),
286        dim,
287    )))
288}
289
290/// Applies the sigmoid function element-wise.
291///
292#[cfg_attr(
293    doc,
294    doc = r#"
295$$
296\text{sigmoid}\(x\)
297= \sigma(x)
298= \frac{1}{1 + \exp(-x)}
299$$
300"#
301)]
302#[cfg_attr(not(doc), doc = "`sigmoid(x) = 1 / (1 + exp(-x))`")]
303pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
304    Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(
305        tensor.primitive.tensor(),
306    )))
307}
308
309/// Applies the hard sigmoid function element-wise.
310///
311#[cfg_attr(
312    doc,
313    doc = r#"
314$$
315\text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha \cdot x + \beta))
316$$
317"#
318)]
319#[cfg_attr(not(doc), doc = "`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`")]
320pub fn hard_sigmoid<const D: usize, B: Backend>(
321    tensor: Tensor<B, D>,
322    alpha: f64,
323    beta: f64,
324) -> Tensor<B, D> {
325    Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
326        tensor.primitive.tensor(),
327        alpha.into(),
328        beta.into(),
329    )))
330}
331
332/// Applies the log sigmoid function element-wise.
333///
334#[cfg_attr(
335    doc,
336    doc = r#"
337$$
338\text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right)
339$$
340"#
341)]
342#[cfg_attr(not(doc), doc = "`log_sigmoid(x) = log(1 / (1 + exp(-x)))`")]
343pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
344    Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(
345        tensor.primitive.tensor(),
346    )))
347}
348
349/// Applies the SiLU function (also known as the swish function) element-wise.
350///
351#[cfg_attr(
352    doc,
353    doc = r#"
354$$
355\text{SiLU}\(x\) = x \cdot \sigma(x) = \frac{x}{1 + \exp(-x)}
356$$
357"#
358)]
359#[cfg_attr(not(doc), doc = "`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`")]
360pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
361    tensor.clone().mul(sigmoid(tensor))
362}
363
364/// Applies the hard swish function element-wise.
365///
366#[cfg_attr(
367    doc,
368    doc = r#"
369$$
370\text{hard\_swish}\(x\) = x \cdot \text{hard\_sigmoid}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5))
371$$
372"#
373)]
374#[cfg_attr(
375    not(doc),
376    doc = "`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`"
377)]
378pub fn hard_swish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
379    tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5))
380}
381
382/// Applies the Mish function as described in the paper in
383/// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
384///
385#[cfg_attr(
386    doc,
387    doc = r#"
388$$
389\text{Mish}\(x\)
390= x \cdot \tanh(\text{Softplus}(x))
391= \tanh\left(\log\(1 + \exp\(x\)\)\right)
392$$
393"#
394)]
395#[cfg_attr(
396    not(doc),
397    doc = "`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`"
398)]
399pub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
400    tensor.clone().mul(softplus(tensor, 1.0).tanh())
401}
402
403/// Applies the tanh function element-wise.
404pub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
405    tensor.tanh()
406}
407
408/// Applies the Exponential Linear Unit function element-wise.
409///
410#[cfg_attr(
411    doc,
412    doc = r#"
413$$
414\text{ELU}\(x\) =
415 \begin{cases}
416     x & \text{if } x > 0 \newline
417     \alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
418 \end{cases}
419$$
420"#
421)]
422#[cfg_attr(
423    not(doc),
424    doc = "`f(x) =`\n- `x for x > 0`\n- `alpha * (exp(x) - 1) for x <= 0`"
425)]
426pub fn elu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
427    let mask = tensor.clone().lower_equal_elem(0);
428    let scaled = tensor.clone().exp().sub_scalar(1).mul_scalar(alpha);
429    tensor.mask_where(mask, scaled)
430}
431
432/// Applies the Continuously Differentiable Exponential Linear Unit function element-wise.
433///
434#[cfg_attr(
435    doc,
436    doc = r#"
437$$
438\text{CELU}(x) =
439 \begin{cases}
440     x & \text{if } x \geq 0 \newline
441     \alpha \cdot \left(\exp\left(\frac{x}{\alpha}\right) - 1\right) & \text{otherwise}
442 \end{cases}
443$$
444"#
445)]
446#[cfg_attr(
447    not(doc),
448    doc = "`celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`"
449)]
450///
451/// See also [CELU](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html)
452///
453/// # Arguments
454/// - `alpha`: scaling parameter for the negative part.
455pub fn celu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
456    let mask = tensor.clone().lower_equal_elem(0);
457    let scaled = tensor
458        .clone()
459        .div_scalar(alpha)
460        .exp()
461        .sub_scalar(1)
462        .mul_scalar(alpha);
463    tensor.mask_where(mask, scaled)
464}
465
466/// Applies the Scaled Exponential Linear Unit function element-wise
467/// as described in the paper [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
468///
469#[cfg_attr(
470    doc,
471    doc = r#"
472$$
473\text{SELU}\(x\) = \gamma \cdot
474 \begin{cases}
475     x & \text{if } x > 0 \newline
476     \alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
477 \end{cases}
478$$
479
480where $\alpha \approx 1.6733$ and $\gamma \approx 1.0507$.
481"#
482)]
483#[cfg_attr(
484    not(doc),
485    doc = "`selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0`"
486)]
487pub fn selu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
488    // Constants from the SELU paper / ONNX spec
489    const ALPHA: f64 = 1.6732632423543772848170429916717_f64;
490    const GAMMA: f64 = 1.0507009873554804934193349852946_f64;
491
492    let mask = tensor.clone().greater_equal_elem(0.0);
493    let positive = tensor.clone().mul_scalar(GAMMA);
494    let negative = tensor.exp().sub_scalar(1.0).mul_scalar(ALPHA * GAMMA);
495
496    negative.mask_where(mask, positive)
497}
498
499/// Applies the thresholded rectified linear unit function element-wise.
500///
501#[cfg_attr(
502    doc,
503    doc = r#"
504$$
505\text{ThresholdedReLU}(x) =
506 \begin{cases}
507     x & \text{if } x > \alpha \newline
508     0 & \text{otherwise}
509 \end{cases}
510$$
511"#
512)]
513#[cfg_attr(not(doc), doc = "`f(x) =`\n- `x if x > alpha`\n- `0 otherwise`")]
514///
515/// # Arguments
516/// - `alpha`: threshold value (default in ONNX is 1.0).
517pub fn thresholded_relu<const D: usize, B: Backend>(
518    tensor: Tensor<B, D>,
519    alpha: f64,
520) -> Tensor<B, D> {
521    let mask = tensor.clone().lower_equal_elem(alpha);
522    tensor.mask_fill(mask, 0)
523}
524
525/// Applies the gated linear unit function.
526///
527/// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.
528///
529/// **Note**:
530/// * The size of the input tensor along `dim` must be divisible by 2.
531///
532/// ### Arguments
533/// * `tensor` - The input tensor.
534///
535/// ### Returns
536/// * A tensor with the same shape as the input, except the size along `dim` is halved.
537pub fn glu<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
538    // TODO: Handle negative indices with AsIndex for compatibility with Pytorch nn.GLU.
539
540    assert!(
541        tensor.dims()[dim].is_multiple_of(2),
542        "Input tensor along dimension {dim} must have an even size. N is divisible by 2."
543    );
544    let new_len = tensor.dims()[dim] / 2;
545
546    let a = tensor.clone().slice_dim(dim, s![0..new_len]);
547    let b = tensor.slice_dim(dim, s![new_len..new_len * 2]);
548
549    a.mul(sigmoid(b))
550}
551
552/// Applies the Softsign function element-wise.
553///
554#[cfg_attr(
555    doc,
556    doc = r#"
557$$
558\text{softsign}(x) = \frac{x}{1 + |x|}
559$$
560"#
561)]
562#[cfg_attr(not(doc), doc = "`softsign(x_i) = x_i / (1 + |x_i|)`")]
563pub fn softsign<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
564    tensor.clone().div(tensor.abs() + 1)
565}
566
567/// Applies the HardShrink function element-wise.
568///
569#[cfg_attr(
570    doc,
571    doc = r#"
572$$
573\text{hard\_shrink}(x) =
574 \begin{cases}
575     x & \text{if } x > \lambda \newline
576     x & \text{if } x < -\lambda \newline
577     0 & \text{otherwise}
578 \end{cases}
579$$
580"#
581)]
582#[cfg_attr(
583    not(doc),
584    doc = "`hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise`"
585)]
586/// # Arguments
587/// - `lambda`: the lambda value for the Hard Shrink formulation. Default is 0.5.
588pub fn hard_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
589    let mask = tensor.clone().abs().lower_equal_elem(lambda);
590    tensor.mask_fill(mask, 0)
591}
592
593/// Applies the SoftShrink function element-wise.
594///
595#[cfg_attr(
596    doc,
597    doc = r#"
598$$
599\text{soft\_shrink}(x) =
600 \begin{cases}
601     x - \lambda & \text{if } x > \lambda \newline
602     x + \lambda & \text{if } x < -\lambda \newline
603     0 & \text{otherwise}
604 \end{cases}
605$$
606"#
607)]
608#[cfg_attr(
609    not(doc),
610    doc = "`soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`"
611)]
612/// # Arguments
613/// - `lambda`: the lambda value for the Soft Shrink formulation. Default is 0.5.
614pub fn soft_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
615    shrink(tensor, lambda, lambda)
616}
617
618/// Applies the Shrink function element-wise.
619///
620#[cfg_attr(
621    doc,
622    doc = r#"
623$$
624\text{shrink}(x) =
625 \begin{cases}
626     x - \text{bias} & \text{if } x > \lambda \newline
627     x + \text{bias} & \text{if } x < -\lambda \newline
628     0 & \text{otherwise}
629 \end{cases}
630$$
631"#
632)]
633#[cfg_attr(
634    not(doc),
635    doc = "`shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`"
636)]
637/// # Arguments
638/// - `lambda`: the lambda value for the Shrink formulation.
639/// - `bias`: the bias value for the Shrink formulation.
640pub fn shrink<const D: usize, B: Backend>(
641    tensor: Tensor<B, D>,
642    lambda: f64,
643    bias: f64,
644) -> Tensor<B, D> {
645    let abs_tensor = tensor.clone().abs();
646    let sign = tensor.clone().sign();
647    let shrunk = tensor.sub(sign.mul_scalar(bias));
648    let mask = abs_tensor.lower_equal_elem(lambda);
649    shrunk.mask_fill(mask, 0)
650}