use crate::backend::Backend;
use crate::check::TensorCheck;
use crate::{Tensor, TensorPrimitive, check, s};
#[cfg_attr(doc, doc = "$$\\text{ReLU}\\(x\\) = \\(x\\)^+ = \\max\\(0, x\\)$$")]
#[cfg_attr(not(doc), doc = "`ReLU(x) = max(0, x)`")]
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.relu()
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \cdot \min\(0, x\)
$$
or
$$
\text{LeakyReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \newline
\text{negative\\_slope} \cdot x & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`f(x) =`\n- `x for x >= 0`\n- `negative_slope * x if x < 0`"
)]
pub fn leaky_relu<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
negative_slope: f64,
) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(
tensor.primitive.tensor(),
negative_slope.into(),
)))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{GELU}(x)
= x \cdot \Phi(x)
= x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
$$
where $\Phi(x)$ is the cumulative distribution function for the Gaussian distribution.
"#
)]
#[cfg_attr(
not(doc),
doc = r#"
`GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))`
where `Φ(x)` is the cumulative distribution function for the Gaussian distribution.
"#
)]
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{GELU\_approx}(x)
= \frac{x}{2}\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right)
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`GELU_approx(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`"
)]
pub fn gelu_approximate<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
const SQRT_2_OVER_PI: f64 =
core::f64::consts::FRAC_2_SQRT_PI * core::f64::consts::FRAC_1_SQRT_2;
let x = tensor;
let inner = x.clone() + x.clone().powf_scalar(3.0) * 0.044715;
let inner = inner * SQRT_2_OVER_PI;
(x.clone() * (inner.tanh() + 1)) * 0.5
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{PReLU}\(x\) = \max\(0,x\) + \alpha \cdot \min\(0, x\)
$$
or
$$
\text{PReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \newline
\alpha x & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`PReLu(x) = max(0,x) + alpha * min(0,x)`")]
pub fn prelu<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: Tensor<B, 1>,
) -> Tensor<B, D> {
check!(TensorCheck::check_prelu_shape::<D>(
&tensor.shape(),
&alpha.shape()
));
let weight = if alpha.dims()[0] == 1 {
alpha.reshape([1; D])
} else {
let num_weights = alpha.dims()[0];
let mut s = [1; D];
s[1] = num_weights;
alpha.reshape(s)
};
Tensor::from_primitive(TensorPrimitive::Float(B::prelu(
tensor.primitive.tensor(),
weight.primitive.tensor(),
)))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`")]
pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmax", dim));
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor = tensor.exp();
let tensor_tmp = tensor.clone().sum_dim(dim);
tensor.div(tensor_tmp)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`")]
pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmin", dim));
softmax(tensor.neg(), dim)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\)
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`")]
pub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {
let tensor = (tensor.mul_scalar(beta).exp() + 1).log();
tensor.div_scalar(beta)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`"
)]
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmax", dim));
let max_vals = tensor.clone().detach().max_dim(dim);
let exp_x = (tensor - max_vals.clone()).exp();
let sum_exp = exp_x.clone().sum_dim(dim);
exp_x.div(sum_exp + max_vals.neg().exp())
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{log\\_softmax}\(x_i\)
= \log\left(\text{softmax}\(x_i\)\right)
= \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right)
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`"
)]
pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("log softmax", dim));
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();
tensor.sub(tensor_tmp)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{sigmoid}\(x\)
= \sigma(x)
= \frac{1}{1 + \exp(-x)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`sigmoid(x) = 1 / (1 + exp(-x))`")]
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(
tensor.primitive.tensor(),
)))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha \cdot x + \beta))
$$
"#
)]
#[cfg_attr(not(doc), doc = "`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`")]
pub fn hard_sigmoid<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: f64,
beta: f64,
) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
tensor.primitive.tensor(),
alpha.into(),
beta.into(),
)))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right)
$$
"#
)]
#[cfg_attr(not(doc), doc = "`log_sigmoid(x) = log(1 / (1 + exp(-x)))`")]
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(
tensor.primitive.tensor(),
)))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{SiLU}\(x\) = x \cdot \sigma(x) = \frac{x}{1 + \exp(-x)}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`")]
pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().mul(sigmoid(tensor))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{hard\_swish}\(x\) = x \cdot \text{hard\_sigmoid}(x) = x \cdot \max(0, \min(1, \frac{x}{6} + 0.5))
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`"
)]
pub fn hard_swish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{Mish}\(x\)
= x \cdot \tanh(\text{Softplus}(x))
= \tanh\left(\log\(1 + \exp\(x\)\)\right)
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`"
)]
pub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().mul(softplus(tensor, 1.0).tanh())
}
pub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.tanh()
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{ELU}\(x\) =
\begin{cases}
x & \text{if } x > 0 \newline
\alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`f(x) =`\n- `x for x > 0`\n- `alpha * (exp(x) - 1) for x <= 0`"
)]
pub fn elu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
let mask = tensor.clone().lower_equal_elem(0);
let scaled = tensor.clone().exp().sub_scalar(1).mul_scalar(alpha);
tensor.mask_where(mask, scaled)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{CELU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \newline
\alpha \cdot \left(\exp\left(\frac{x}{\alpha}\right) - 1\right) & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`"
)]
pub fn celu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {
let mask = tensor.clone().lower_equal_elem(0);
let scaled = tensor
.clone()
.div_scalar(alpha)
.exp()
.sub_scalar(1)
.mul_scalar(alpha);
tensor.mask_where(mask, scaled)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{SELU}\(x\) = \gamma \cdot
\begin{cases}
x & \text{if } x > 0 \newline
\alpha \cdot (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
$$
where $\alpha \approx 1.6733$ and $\gamma \approx 1.0507$.
"#
)]
#[cfg_attr(
not(doc),
doc = "`selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0`"
)]
pub fn selu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
const ALPHA: f64 = 1.6732632423543772848170429916717_f64;
const GAMMA: f64 = 1.0507009873554804934193349852946_f64;
let mask = tensor.clone().greater_equal_elem(0.0);
let positive = tensor.clone().mul_scalar(GAMMA);
let negative = tensor.exp().sub_scalar(1.0).mul_scalar(ALPHA * GAMMA);
negative.mask_where(mask, positive)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{ThresholdedReLU}(x) =
\begin{cases}
x & \text{if } x > \alpha \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`f(x) =`\n- `x if x > alpha`\n- `0 otherwise`")]
pub fn thresholded_relu<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: f64,
) -> Tensor<B, D> {
let mask = tensor.clone().lower_equal_elem(alpha);
tensor.mask_fill(mask, 0)
}
pub fn glu<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
assert!(
tensor.dims()[dim].is_multiple_of(2),
"Input tensor along dimension {dim} must have an even size. N is divisible by 2."
);
let new_len = tensor.dims()[dim] / 2;
let a = tensor.clone().slice_dim(dim, s![0..new_len]);
let b = tensor.slice_dim(dim, s![new_len..new_len * 2]);
a.mul(sigmoid(b))
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{softsign}(x) = \frac{x}{1 + |x|}
$$
"#
)]
#[cfg_attr(not(doc), doc = "`softsign(x_i) = x_i / (1 + |x_i|)`")]
pub fn softsign<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().div(tensor.abs() + 1)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{hard\_shrink}(x) =
\begin{cases}
x & \text{if } x > \lambda \newline
x & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise`"
)]
pub fn hard_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
let mask = tensor.clone().abs().lower_equal_elem(lambda);
tensor.mask_fill(mask, 0)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{soft\_shrink}(x) =
\begin{cases}
x - \lambda & \text{if } x > \lambda \newline
x + \lambda & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`"
)]
pub fn soft_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
shrink(tensor, lambda, lambda)
}
#[cfg_attr(
doc,
doc = r#"
$$
\text{shrink}(x) =
\begin{cases}
x - \text{bias} & \text{if } x > \lambda \newline
x + \text{bias} & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`"
)]
pub fn shrink<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
lambda: f64,
bias: f64,
) -> Tensor<B, D> {
let abs_tensor = tensor.clone().abs();
let sign = tensor.clone().sign();
let shrunk = tensor.sub(sign.mul_scalar(bias));
let mask = abs_tensor.lower_equal_elem(lambda);
shrunk.mask_fill(mask, 0)
}