burn_tch/ops/
activation.rs1use crate::{LibTorch, TchTensor, element::TchElement};
2use burn_backend::ops::ActivationOps;
3
4impl<E: TchElement> ActivationOps<Self> for LibTorch<E> {
5 fn relu(tensor: TchTensor) -> TchTensor {
6 tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
7 }
8
9 fn gelu(tensor: TchTensor) -> TchTensor {
10 tensor.unary_ops(
11 |mut tensor| tensor.gelu_("none"),
12 |tensor| tensor.gelu("none"),
13 )
14 }
15
16 fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
17 let storage = tensor.storage.clone();
18 let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");
19
20 TchTensor::from_existing(tensor, storage)
21 }
22
23 fn sigmoid(tensor: TchTensor) -> TchTensor {
24 tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
25 }
26
27 fn log_sigmoid(tensor: TchTensor) -> TchTensor {
28 let storage = tensor.storage.clone();
33 let tensor = tensor.tensor.log_sigmoid();
34
35 TchTensor::from_existing(tensor, storage)
36 }
37
38 fn softmax(tensor: TchTensor, dim: usize) -> TchTensor {
39 let storage = tensor.storage.clone();
40 let tensor = tensor.tensor.softmax(dim as i64, None);
41 TchTensor::from_existing(tensor, storage)
42 }
43
44 fn log_softmax(tensor: TchTensor, dim: usize) -> TchTensor {
45 let storage = tensor.storage.clone();
46 let tensor = tensor.tensor.log_softmax(dim as i64, None);
47 TchTensor::from_existing(tensor, storage)
48 }
49
50 fn softmin(tensor: TchTensor, dim: usize) -> TchTensor {
51 let storage = tensor.storage.clone();
52 let tensor = tensor.tensor.neg().softmax(dim as i64, None);
53 TchTensor::from_existing(tensor, storage)
54 }
55
56 fn prelu(tensor: TchTensor, alpha: TchTensor) -> TchTensor {
57 let storage = tensor.storage.clone();
58 let tensor = tensor.tensor.prelu(&alpha.tensor);
59 TchTensor::from_existing(tensor, storage)
60 }
61}