1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
use burn_tensor::ops::ActivationOps;

impl<E: TchElement, Q: QuantElement> ActivationOps<Self> for LibTorch<E, Q> {
    fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
        tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
    }

    fn gelu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
        tensor.unary_ops(
            |mut tensor| tensor.gelu_("none"),
            |tensor| tensor.gelu("none"),
        )
    }

    fn gelu_backward<const D: usize>(
        tensor: TchTensor<E, D>,
        grad: TchTensor<E, D>,
    ) -> TchTensor<E, D> {
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");

        TchTensor::from_existing(tensor, storage)
    }

    fn sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
        tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
    }

    fn log_sigmoid<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
        // NOTE: we don't override log_sigmoid_backward because Torch has a special backward
        // formula that uses a buffer with computed values from the forward pass

        // no in-place log_sigmoid_
        let storage = tensor.storage.clone();
        let tensor = tensor.tensor.log_sigmoid();

        TchTensor::from_existing(tensor, storage)
    }
}