use crate::{LibTorch, TchTensor, element::TchElement};
use burn_backend::ops::ActivationOps;
impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
fn relu(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}
fn gelu(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.gelu_("none"),
|tensor| tensor.gelu("none"),
)
}
fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
TchTensor::from_existing(tensor, storage)
}
fn sigmoid(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
}
fn log_sigmoid(tensor: TchTensor) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.log_sigmoid();
TchTensor::from_existing(tensor, storage)
}
fn softmax(tensor: TchTensor, dim: usize) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.softmax(dim as i64, None);
TchTensor::from_existing(tensor, storage)
}
fn log_softmax(tensor: TchTensor, dim: usize) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.log_softmax(dim as i64, None);
TchTensor::from_existing(tensor, storage)
}
fn softmin(tensor: TchTensor, dim: usize) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.neg().softmax(dim as i64, None);
TchTensor::from_existing(tensor, storage)
}
fn prelu(tensor: TchTensor, alpha: TchTensor) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.prelu(&alpha.tensor);
TchTensor::from_existing(tensor, storage)
}
}