Skip to main content

burn_dispatch/ops/
activation.rs

1use burn_backend::{Scalar, ops::ActivationOps, tensor::FloatTensor};
2
3use crate::Dispatch;
4use crate::backends::*;
5
6impl ActivationOps<Self> for Dispatch {
7    fn leaky_relu(tensor: FloatTensor<Self>, negative_slope: Scalar) -> FloatTensor<Self> {
8        unary_float!(tensor, float, |tensor| B::leaky_relu(tensor, negative_slope) => Float)
9    }
10
11    fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
12        unary_float!(tensor, float, |tensor| B::relu(tensor) => Float)
13    }
14
15    fn relu_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
16        binary_float!((output, float), (grad, float), |output, grad| B::relu_backward(output, grad) => Float)
17    }
18
19    fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
20        unary_float!(tensor, float, |tensor| B::gelu(tensor) => Float)
21    }
22
23    fn prelu(tensor: FloatTensor<Self>, alpha: FloatTensor<Self>) -> FloatTensor<Self> {
24        binary_float!((tensor, float), (alpha, float), |tensor, alpha| B::prelu(tensor, alpha) => Float)
25    }
26
27    fn gelu_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
28        binary_float!((x, float), (grad, float), |x, grad| B::gelu_backward(x, grad) => Float)
29    }
30
31    fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
32        unary_float!(tensor, float, |tensor| B::sigmoid(tensor) => Float)
33    }
34
35    fn sigmoid_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
36        binary_float!((output, float), (grad, float), |output, grad| B::sigmoid_backward(output, grad) => Float)
37    }
38
39    fn hard_sigmoid(tensor: FloatTensor<Self>, alpha: Scalar, beta: Scalar) -> FloatTensor<Self> {
40        unary_float!(tensor, float, |tensor| B::hard_sigmoid(tensor, alpha, beta) => Float)
41    }
42
43    fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
44        unary_float!(tensor, float, |tensor| B::log_sigmoid(tensor) => Float)
45    }
46
47    fn log_sigmoid_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {
48        binary_float!((x, float), (grad, float), |x, grad| B::log_sigmoid_backward(x, grad) => Float)
49    }
50}