use super::ops::*;
use crate::autograd::Var;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{ActivationOps, CompareOps, ReduceOps, ScalarOps, TensorOps, UnaryOps};
use crate::runtime::{Runtime, RuntimeClient};
use std::sync::Arc;
pub fn var_relu<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + CompareOps<R>,
R::Client: TensorOps<R> + CompareOps<R>,
{
let output = client.relu(a.tensor())?;
if a.requires_grad() {
let grad_fn = ReluBackward::<R>::new(a.id(), a.tensor().clone(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_sigmoid<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R>,
R::Client: TensorOps<R>,
{
let output = client.sigmoid(a.tensor())?;
if a.requires_grad() {
let grad_fn = SigmoidBackward::<R>::new(a.id(), output.clone(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_silu<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ActivationOps<R> + ScalarOps<R>,
R::Client: TensorOps<R> + ActivationOps<R> + ScalarOps<R>,
{
let output = client.silu(a.tensor())?;
if a.requires_grad() {
let grad_fn = SiluBackward::<R>::new(
a.id(),
a.tensor().clone(),
output.clone(),
a.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_softplus<R, C>(a: &Var<R>, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + ActivationOps<R>,
R::Client: TensorOps<R> + ActivationOps<R>,
{
let output = client.softplus(a.tensor())?;
if a.requires_grad() {
let grad_fn = SoftplusBackward::<R>::new(a.id(), a.tensor().clone(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_softmax<R, C>(a: &Var<R>, dim: isize, client: &C) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + TensorOps<R>,
R::Client: TensorOps<R> + ReduceOps<R> + ScalarOps<R>,
{
let output = client.softmax(a.tensor(), dim)?;
if a.requires_grad() {
let grad_fn = SoftmaxBackward::<R>::new(a.id(), output.clone(), dim, a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_log_softmax<R, C>(a: &Var<R>, dim: isize, client: &C) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + TensorOps<R> + ActivationOps<R>,
R::Client: TensorOps<R> + UnaryOps<R> + ReduceOps<R> + ScalarOps<R>,
{
let output = client.log_softmax(a.tensor(), dim)?;
if a.requires_grad() {
let grad_fn =
LogSoftmaxBackward::<R>::new(a.id(), output.clone(), dim, a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}