use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
use burn_tensor::ops::ActivationOps;
impl<E: TchElement, Q: QuantElement> ActivationOps<Self> for LibTorch<E, Q> {
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}
fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(
|mut tensor| tensor.gelu_("none"),
|tensor| tensor.gelu("none"),
)
}
fn gelu_backward<const D: usize>(
tensor: TchTensor<E, D>,
grad: TchTensor<E, D>,
) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
TchTensor::from_existing(tensor, storage)
}
fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
}
fn log_sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.log_sigmoid();
TchTensor::from_existing(tensor, storage)
}
}