burn-tensor 0.18.0

Tensor library with user-friendly APIs and automatic differentiation support
use crate::backend::Backend;
use crate::check::TensorCheck;
use crate::{Tensor, TensorPrimitive, check};

/// Applies the rectified linear unit function element-wise
/// as described in the paper [Deep Learning using Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375).
///
/// $$\text{ReLU}\(x\) = \(x\)^+ = \max\(0, x\)$$
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
    tensor.relu()
}

/// Applies the leaky rectified linear unit function element-wise.
///
/// $$
/// \text{LeakyReLU}\(x\) = \max\(0,x\) + \text{negative\\_slope} \times \min\(0, x\)
/// $$
///
/// or
///
/// $$
///\text{LeakyReLU}(x) =
/// \begin{cases}
///     x & \text{if } x \geq 0 \newline
///     \text{negative\\_slope} \times x & \text{otherwise}
/// \end{cases}
/// $$
///
pub fn leaky_relu<const D: usize, B: Backend>(
    tensor: Tensor<B, D>,
    negative_slope: f64,
) -> Tensor<B, D> {
    Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(
        tensor.primitive.tensor(),
        crate::ElementConversion::elem(negative_slope),
    )))
}

/// Applies the Gaussian Error Linear Units function as described in the paper
/// [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
///
/// $$
/// \text{GELU}\(x\)
/// = x \cdot \Phi\(x\)
/// = x \cdot \frac{1}{2}\left(1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right)
/// $$
///
/// where $\Phi\(x\)$ is the cumulative distribution function for the Gaussian distribution.
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
    Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))
}

/// Applies Parametric ReLu activation function as described in the paper
/// [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).
///
/// The tensor is assumed to be of shape \[batch_size, channels, ...\].
/// `alpha` is assumed to be of shape \[channels\] or \[1\].
///
/// $$
/// \text{PReLU}\(x\) = \max\(0,x\) + \alpha * \min\(0, x\)
/// $$
///
/// or
///
/// $$
///\text{PReLU}(x) =
/// \begin{cases}
///     x & \text{if } x \geq 0 \newline
///     \alpha x & \text{otherwise}
/// \end{cases}
/// $$
///
pub fn prelu<const D: usize, B: Backend>(
    tensor: Tensor<B, D>,
    alpha: Tensor<B, 1>,
) -> Tensor<B, D> {
    check!(TensorCheck::check_prelu_shape::<D>(
        &tensor.shape(),
        &alpha.shape()
    ));

    let weight = if alpha.dims()[0] == 1 {
        // if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D
        alpha.reshape([1; D])
    } else {
        // D>=2 because the case where D==1 and num_weights >1 is handled by check function
        // there is more than 1 weight and rank is more than 2
        let num_weights = alpha.dims()[0];
        let mut s = [1; D];
        s[1] = num_weights;
        // reshape the weights to (1, channels,1 ...)
        alpha.reshape(s)
    };

    Tensor::from_primitive(TensorPrimitive::Float(B::prelu(
        tensor.primitive.tensor(),
        weight.primitive.tensor(),
    )))
}

/// Applies the softmax function on the input tensor along the given dimension.
///
/// $$
/// \text{softmax}\(x_i\) = \frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}
/// $$
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
    check!(TensorCheck::dim_ops::<D>("softmax", dim));

    let tensor = tensor.clone() - tensor.detach().max_dim(dim);
    let tensor = tensor.exp();
    let tensor_tmp = tensor.clone().sum_dim(dim);

    tensor.div(tensor_tmp)
}

/// Applies the softmin function on the input tensor along the given dimension.
///
/// $$
/// \text{softmin}\(x_i\) = \frac{\exp\(-x_i\)}{\sum_j \exp\(-x_j\)}
/// $$
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
    check!(TensorCheck::dim_ops::<D>("softmin", dim));
    softmax(tensor.neg(), dim)
}

/// Applies the SoftPlus function element-wise.
///
/// $$
/// \text{softplus}\(x\) = \frac{1}{\beta}\log\(1 + \exp\(\beta x\)\)
/// $$
///
/// The SoftPlus function is a smooth approximation of the ReLU function.
pub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {
    let tensor = (tensor.mul_scalar(beta).exp() + 1).log();
    tensor.div_scalar(beta)
}

/// Applies the "quiet softmax" function on the input tensor along the given dimension.
///
/// This function is similar to the softmax function, but it allows for "no selection" when
/// all the outputs are close to zero.
///
/// $$
/// \text{quiet\\_softmax}\(x_i\) = \frac{\exp\(x_i\)}{1 + \sum_j \exp\(x_j\)}
/// $$
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
    check!(TensorCheck::dim_ops::<D>("softmax", dim));

    let tensor = tensor.clone() - tensor.detach().max_dim(dim);
    let tensor = tensor.exp();
    let tensor_tmp = tensor.clone().sum_dim(dim);

    tensor.div(tensor_tmp + 1)
}

/// Applies the log softmax function on the input tensor along the given dimension.
///
/// $$
/// \text{log\\_softmax}\(x_i\)
/// = \log\left(\text{softmax}\(x_i\)\right)
/// = \log\left(\frac{\exp\(x_i\)}{\sum_j \exp\(x_j\)}\right)
/// $$
///
/// # Arguments
/// - `dim`: the dimension along which Softmax will be computed.
///
/// # Panics
/// - If `dim` is outside [0, D)
pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
    check!(TensorCheck::dim_ops::<D>("log softmax", dim));

    let tensor = tensor.clone() - tensor.detach().max_dim(dim);
    let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();

    tensor.sub(tensor_tmp)
}

/// Applies the sigmoid function element-wise.
///
/// $$
/// \text{sigmoid}\(x\)
/// = \sigma(x)
/// = \frac{1}{1 + \exp(-x)}
/// $$
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
    Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(
        tensor.primitive.tensor(),
    )))
}

/// Applies the hard sigmoid function element-wise.
///
/// $$
/// \text{hard\\_sigmoid}\(x\) = \max(0, \min(1, \alpha x + \beta))
/// $$
pub fn hard_sigmoid<const D: usize, B: Backend>(
    tensor: Tensor<B, D>,
    alpha: f64,
    beta: f64,
) -> Tensor<B, D> {
    Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(
        tensor.primitive.tensor(),
        crate::ElementConversion::elem(alpha),
        crate::ElementConversion::elem(beta),
    )))
}

/// Applies the log sigmoid function element-wise.
///
/// $$
/// \text{log\\_sigmoid}\(x\) = \log\left(\frac{1}{1 + \exp(-x)}\right)
/// $$
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
    Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(
        tensor.primitive.tensor(),
    )))
}

/// Applies the SiLU function (also known as the swish function) element-wise.
///
/// $$
/// \text{SiLU}\(x\) = x * \sigma(x) = \frac{x}{1 + \exp(-x)}
/// $$
pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
    tensor.clone().mul(sigmoid(tensor))
}

/// Applies the Mish function as described in the paper in
/// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
///
/// $$
/// \text{Mish}\(x\)
/// = x * \tanh(\text{Softplus}(x))
/// = \tanh\left(\log\(1 + \exp\(x\)\)\right)
/// $$
pub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
    tensor.clone().mul(softplus(tensor, 1.0).tanh())
}

/// Applies the tanh function element-wise.
pub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
    tensor.tanh()
}