burn_dispatch/ops/
activation.rs1use burn_backend::{Scalar, ops::ActivationOps, tensor::FloatTensor};
2
3use crate::Dispatch;
4use crate::backends::*;
5
6impl ActivationOps<Self> for Dispatch {
7 fn leaky_relu(tensor: FloatTensor<Self>, negative_slope: Scalar) -> FloatTensor<Self> {
8 unary_float!(tensor, float, |tensor| B::leaky_relu(tensor, negative_slope) => Float)
9 }
10
11 fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
12 unary_float!(tensor, float, |tensor| B::relu(tensor) => Float)
13 }
14
15 fn relu_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
16 binary_float!((output, float), (grad, float), |output, grad| B::relu_backward(output, grad) => Float)
17 }
18
19 fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
20 unary_float!(tensor, float, |tensor| B::gelu(tensor) => Float)
21 }
22
23 fn prelu(tensor: FloatTensor<Self>, alpha: FloatTensor<Self>) -> FloatTensor<Self> {
24 binary_float!((tensor, float), (alpha, float), |tensor, alpha| B::prelu(tensor, alpha) => Float)
25 }
26
27 fn gelu_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
28 binary_float!((x, float), (grad, float), |x, grad| B::gelu_backward(x, grad) => Float)
29 }
30
31 fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
32 unary_float!(tensor, float, |tensor| B::sigmoid(tensor) => Float)
33 }
34
35 fn sigmoid_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
36 binary_float!((output, float), (grad, float), |output, grad| B::sigmoid_backward(output, grad) => Float)
37 }
38
39 fn hard_sigmoid(tensor: FloatTensor<Self>, alpha: Scalar, beta: Scalar) -> FloatTensor<Self> {
40 unary_float!(tensor, float, |tensor| B::hard_sigmoid(tensor, alpha, beta) => Float)
41 }
42
43 fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
44 unary_float!(tensor, float, |tensor| B::log_sigmoid(tensor) => Float)
45 }
46
47 fn log_sigmoid_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
48 binary_float!((x, float), (grad, float), |x, grad| B::log_sigmoid_backward(x, grad) => Float)
49 }
50}