use crate::backend::Backend;
use crate::check::TensorCheck;
use crate::{check, Tensor};
use crate::{ElementPrecision, Precision};
pub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.relu()
}
pub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
Tensor::from_primitive(B::gelu(tensor.primitive))
}
pub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("softmax", dim));
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor = tensor.exp();
let tensor_tmp = tensor.clone().sum_dim(dim);
tensor.div(tensor_tmp)
}
pub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {
check!(TensorCheck::dim_ops::<D>("log softmax", dim));
let tensor = tensor.clone() - tensor.detach().max_dim(dim);
let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();
tensor.sub(tensor_tmp)
}
pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
log_sigmoid(tensor).exp()
}
pub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
match B::FloatElem::precision() {
Precision::Half => {
let tensor_full = tensor.to_full_precision();
let tensor_tmp = tensor_full.neg().exp().add_scalar(1.0_f32).log().neg();
Tensor::from_full_precision(tensor_tmp)
}
_ => tensor.neg().exp().add_scalar(1.0_f32).log().neg(),
}
}
pub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {
tensor.clone().mul(sigmoid(tensor))
}