burn_tch/ops/
activation.rs

1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_tensor::ops::ActivationOps;
3
4impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
5    fn relu(tensor: TchTensor) -> TchTensor {
6        tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
7    }
8
9    fn gelu(tensor: TchTensor) -> TchTensor {
10        tensor.unary_ops(
11            |mut tensor| tensor.gelu_("none"),
12            |tensor| tensor.gelu("none"),
13        )
14    }
15
16    fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
17        let storage = tensor.storage.clone();
18        let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
19
20        TchTensor::from_existing(tensor, storage)
21    }
22
23    fn sigmoid(tensor: TchTensor) -> TchTensor {
24        tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
25    }
26
27    fn log_sigmoid(tensor: TchTensor) -> TchTensor {
28        // NOTE: we don't override log_sigmoid_backward because Torch has a special backward
29        // formula that uses a buffer with computed values from the forward pass
30
31        // no in-place log_sigmoid_
32        let storage = tensor.storage.clone();
33        let tensor = tensor.tensor.log_sigmoid();
34
35        TchTensor::from_existing(tensor, storage)
36    }
37}