burn-tch 0.21.0

LibTorch backend for the Burn framework using the tch bindings.
Documentation
use crate::{LibTorch, TchTensor, element::TchElement};
use burn_backend::ops::ActivationOps;

impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
    fn relu(tensor: TchTensor) -> TchTensor {
        tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
    }

    fn gelu(tensor: TchTensor) -> TchTensor {
        tensor.unary_ops(
            |mut tensor| tensor.gelu_("none"),
            |tensor| tensor.gelu("none"),
        )
    }

    fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");

        TchTensor::from_existing(tensor, storage)
    }

    fn sigmoid(tensor: TchTensor) -> TchTensor {
        tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
    }

    fn log_sigmoid(tensor: TchTensor) -> TchTensor {
        // NOTE: we don't override log_sigmoid_backward because Torch has a special backward
        // formula that uses a buffer with computed values from the forward pass

        // no in-place log_sigmoid_
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.log_sigmoid();

        TchTensor::from_existing(tensor, storage)
    }

    fn softmax(tensor: TchTensor, dim: usize) -> TchTensor {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.softmax(dim as i64, None);
        TchTensor::from_existing(tensor, storage)
    }

    fn log_softmax(tensor: TchTensor, dim: usize) -> TchTensor {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.log_softmax(dim as i64, None);
        TchTensor::from_existing(tensor, storage)
    }

    fn softmin(tensor: TchTensor, dim: usize) -> TchTensor {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.neg().softmax(dim as i64, None);
        TchTensor::from_existing(tensor, storage)
    }

    fn prelu(tensor: TchTensor, alpha: TchTensor) -> TchTensor {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.prelu(&alpha.tensor);
        TchTensor::from_existing(tensor, storage)
    }
}