burn_tch/ops/
activation.rs1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_tensor::ops::ActivationOps;
3
4impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
5 fn relu(tensor: TchTensor) -> TchTensor {
6 tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
7 }
8
9 fn gelu(tensor: TchTensor) -> TchTensor {
10 tensor.unary_ops(
11 |mut tensor| tensor.gelu_("none"),
12 |tensor| tensor.gelu("none"),
13 )
14 }
15
16 fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
17 let storage = tensor.storage.clone();
18 let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
19
20 TchTensor::from_existing(tensor, storage)
21 }
22
23 fn sigmoid(tensor: TchTensor) -> TchTensor {
24 tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
25 }
26
27 fn log_sigmoid(tensor: TchTensor) -> TchTensor {
28 let storage = tensor.storage.clone();
33 let tensor = tensor.tensor.log_sigmoid();
34
35 TchTensor::from_existing(tensor, storage)
36 }
37}