burn_tensor/tensor/activation/
base.rs1use crate::backend::Backend;
2use crate::check::TensorCheck;
3use crate::{Tensor, TensorPrimitive, check, s};
4
5#[cfg_attr(doc, doc = "$$\\text{ReLU}\\(x\\) = \\(x\\)^+ = \\max\\(0, x\\)$$")]
9#[cfg_attr(not(doc), doc = "`ReLU(x) = max(0, x)`")]
10pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
11 tensor.relu()
12}
13
14#[cfg_attr(
17 doc,
18 doc = r#"
19$$
20\text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \cdot \min\(0, x\)
21$$
22
23or
24
25$$
26\text{LeakyReLU}(x) =
27 \begin{cases}
28 x & \text{if } x \geq 0 \newline
29 \text{negative\\_slope} \cdot x & \text{otherwise}
30 \end{cases}
31$$
32"#
33)]
34#[cfg_attr(
35 not(doc),
36 doc = "`f(x) =`\n- `x for x >= 0`\n- `negative_slope * x if x < 0`"
37)]
38pub fn leaky_relu<const D: usize, B: Backend>(
39 tensor: Tensor<B, D>,
40 negative_slope: f64,
41) -> Tensor<B, D> {
42 Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(
43 tensor.primitive.tensor(),
44 crate::ElementConversion::elem(negative_slope),
45 )))
46}
47
48#[cfg_attr(
52 doc,
53 doc = r#"
54$$
55\text{GELU}(x)
56= x \cdot \Phi(x)
57= x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
58$$
59
60where $\Phi(x)$ is the cumulative distribution function for the Gaussian distribution.
61"#
62)]
63#[cfg_attr(
64 not(doc),
65 doc = r#"
66`GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))`
67
68where `Φ(x)` is the cumulative distribution function for the Gaussian distribution.
69"#
70)]
71pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
72 Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))
73}
74
75#[cfg_attr(
82 doc,
83 doc = r#"
84$$
85\text{PReLU}\(x\) = \max\(0,x\) + \alpha \cdot \min\(0, x\)
86$$
87
88or
89
90$$
91\text{PReLU}(x) =
92 \begin{cases}
93 x & \text{if } x \geq 0 \newline
94 \alpha x & \text{otherwise}
95 \end{cases}
96$$
97"#
98)]
99#[cfg_attr(not(doc), doc = "`PReLu(x) = max(0,x) + alpha * min(0,x)`")]
100pub fn prelu<const D: usize, B: Backend>(
101 tensor: Tensor<B, D>,
102 alpha: Tensor<B, 1>,
103) -> Tensor<B, D> {
104 check!(TensorCheck::check_prelu_shape::<D>(
105 &tensor.shape(),
106 &alpha.shape()
107 ));
108
109 let weight = if alpha.dims()[0] == 1 {
110 alpha.reshape([1; D])
112 } else {
113 let num_weights = alpha.dims()[0];
116 let mut s = [1; D];
117 s[1] = num_weights;
118 alpha.reshape(s)
120 };
121
122 Tensor::from_primitive(TensorPrimitive::Float(B::prelu(
123 tensor.primitive.tensor(),
124 weight.primitive.tensor(),
125 )))
126}
127
128#[cfg_attr(
131 doc,
132 doc = r#"
133$$
134\text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}
135$$
136"#
137)]
138#[cfg_attr(not(doc), doc = "`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`")]
139pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
146 check!(TensorCheck::dim_ops::<D>("softmax", dim));
147
148 let tensor = tensor.clone() - tensor.detach().max_dim(dim);
149 let tensor = tensor.exp();
150 let tensor_tmp = tensor.clone().sum_dim(dim);
151
152 tensor.div(tensor_tmp)
153}
154
155#[cfg_attr(
158 doc,
159 doc = r#"
160$$
161\text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)}
162$$
163"#
164)]
165#[cfg_attr(not(doc), doc = "`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`")]
166pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
173 check!(TensorCheck::dim_ops::<D>("softmin", dim));
174 softmax(tensor.neg(), dim)
175}
176
177#[cfg_attr(
180 doc,
181 doc = r#"
182$$
183\text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\)
184$$
185"#
186)]
187#[cfg_attr(not(doc), doc = "`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`")]
188pub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {
191 let tensor = (tensor.mul_scalar(beta).exp() + 1).log();
192 tensor.div_scalar(beta)
193}
194
195#[cfg_attr(
201 doc,
202 doc = r#"
203$$
204\text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)}
205$$
206"#
207)]
208#[cfg_attr(
209 not(doc),
210 doc = "`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`"
211)]
212pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
219 check!(TensorCheck::dim_ops::<D>("softmax", dim));
220
221 let tensor = tensor.clone() - tensor.detach().max_dim(dim);
222 let tensor = tensor.exp();
223 let tensor_tmp = tensor.clone().sum_dim(dim);
224
225 tensor.div(tensor_tmp + 1)
226}
227
228#[cfg_attr(
231 doc,
232 doc = r#"
233$$
234\text{log\\_softmax}\(x_i\)
235= \log\left(\text{softmax}\(x_i\)\right)
236= \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right)
237$$
238"#
239)]
240#[cfg_attr(
241 not(doc),
242 doc = "`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`"
243)]
244pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
251 check!(TensorCheck::dim_ops::<D>("log softmax", dim));
252
253 let tensor = tensor.clone() - tensor.detach().max_dim(dim);
254 let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();
255
256 tensor.sub(tensor_tmp)
257}
258
259#[cfg_attr(
262 doc,
263 doc = r#"
264$$
265\text{sigmoid}\(x\)
266= \sigma(x)
267= \frac{1}{1 + \exp(-x)}
268$$
269"#
270)]
271#[cfg_attr(not(doc), doc = "`sigmoid(x) = 1 / (1 + exp(-x))`")]
272pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
273 Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(
274 tensor.primitive.tensor(),
275 )))
276}
277
278#[cfg_attr(
281 doc,
282 doc = r#"
283$$
284\text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha \cdot x + \beta))
285$$
286"#
287)]
288#[cfg_attr(not(doc), doc = "`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`")]
289pub fn hard_sigmoid<const D: usize, B: Backend>(
290 tensor: Tensor<B, D>,
291 alpha: f64,
292 beta: f64,
293) -> Tensor<B, D> {
294 Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
295 tensor.primitive.tensor(),
296 crate::ElementConversion::elem(alpha),
297 crate::ElementConversion::elem(beta),
298 )))
299}
300
301#[cfg_attr(
304 doc,
305 doc = r#"
306$$
307\text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right)
308$$
309"#
310)]
311#[cfg_attr(not(doc), doc = "`log_sigmoid(x) = log(1 / (1 + exp(-x)))`")]
312pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
313 Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(
314 tensor.primitive.tensor(),
315 )))
316}
317
318#[cfg_attr(
321 doc,
322 doc = r#"
323$$
324\text{SiLU}\(x\) = x \cdot \sigma(x) = \frac{x}{1 + \exp(-x)}
325$$
326"#
327)]
328#[cfg_attr(not(doc), doc = "`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`")]
329pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
330 tensor.clone().mul(sigmoid(tensor))
331}
332
333#[cfg_attr(
337 doc,
338 doc = r#"
339$$
340\text{Mish}\(x\)
341= x \cdot \tanh(\text{Softplus}(x))
342= \tanh\left(\log\(1 + \exp\(x\)\)\right)
343$$
344"#
345)]
346#[cfg_attr(
347 not(doc),
348 doc = "`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`"
349)]
350pub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
351 tensor.clone().mul(softplus(tensor, 1.0).tanh())
352}
353
354pub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
356 tensor.tanh()
357}
358
359pub fn glu<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
372 assert!(
375 tensor.dims()[dim].is_multiple_of(2),
376 "Input tensor along dimension {dim} must have an even size. N is divisible by 2."
377 );
378 let new_len = tensor.dims()[dim] / 2;
379 let a = tensor.clone().slice(s![dim, 0..new_len]);
382 let b = tensor.slice(s![dim, new_len..new_len * 2]);
383
384 a.mul(sigmoid(b))
385}