use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::activation::normalize_softmax_dim;
use crate::ops::traits::{
ActivationOps, BinaryOps, CompareOps, ConditionalOps, CumulativeOps, RandomOps, ScalarOps,
UnaryOps,
};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
pub fn softplus_impl<R, C>(client: &C, a: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime,
C: ActivationOps<R> + UnaryOps<R> + ScalarOps<R> + BinaryOps<R>,
{
let relu_x = client.relu(a)?;
let abs_x = client.abs(a)?;
let neg_abs = client.neg(&abs_x)?;
let exp_neg_abs = client.exp(&neg_abs)?;
let one_plus = client.add_scalar(&exp_neg_abs, 1.0)?;
let log_term = client.log(&one_plus)?;
client.add(&relu_x, &log_term)
}
pub fn log_softmax_impl<R, C>(client: &C, a: &Tensor<R>, dim: isize) -> Result<Tensor<R>>
where
R: Runtime,
C: BinaryOps<R> + CumulativeOps<R>,
{
let ndim = a.ndim();
let dim_idx = normalize_softmax_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
let lse = client.logsumexp(a, &[dim_idx], true)?;
client.sub(a, &lse)
}
pub fn dropout_impl<R, C>(client: &C, a: &Tensor<R>, p: f64, training: bool) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RandomOps<R> + CompareOps<R> + ConditionalOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
if !training || p == 0.0 {
return Ok(a.clone());
}
if p >= 1.0 {
return Ok(Tensor::<R>::zeros(a.shape(), a.dtype(), client.device()));
}
let rand_tensor = client.rand(a.shape(), a.dtype())?;
let threshold = Tensor::<R>::full_scalar(a.shape(), a.dtype(), p, client.device());
let mask = client.gt(&rand_tensor, &threshold)?;
let scale = 1.0 / (1.0 - p);
let scaled = client.mul_scalar(a, scale)?;
let zeros = Tensor::<R>::zeros(a.shape(), a.dtype(), client.device());
client.where_cond(&mask, &scaled, &zeros)
}