burn_core/nn/
relu.rs

1use crate as burn;
2
3use crate::module::Module;
4use crate::tensor::Tensor;
5use crate::tensor::backend::Backend;
6
7/// Applies the rectified linear unit function element-wise
8/// See also [relu](burn::tensor::activation::relu)
9///
10#[derive(Module, Clone, Debug, Default)]
11pub struct Relu;
12
13impl Relu {
14    /// Create the module.
15    pub fn new() -> Self {
16        Self {}
17    }
18    /// Applies the forward pass on the input tensor.
19    ///
20    /// # Shapes
21    ///
22    /// - input: `[..., any]`
23    /// - output: `[..., any]`
24    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
25        crate::tensor::activation::relu(input)
26    }
27}
28
29#[cfg(test)]
30mod tests {
31    use super::*;
32
33    #[test]
34    fn display() {
35        let layer = Relu::new();
36
37        assert_eq!(alloc::format!("{layer}"), "Relu");
38    }
39}