1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use crate::{shapes::*, tensor::*};

use super::{cmp::*, BroadcastTo, ChooseFrom, Device, TryMul};

/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let a = dev.tensor([0.05,0.05,0.05,0.05]);
/// let r = t.prelu(a);
/// assert_eq!(r.array(), [-0.05, 0.0, 1.0, 2.0]);
/// ```
pub fn prelu<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D> + Merge<R>, R: Default>(
    lhs: Tensor<S, E, D, T>,
    rhs: Tensor<S, E, D, R>,
) -> Tensor<S, E, D, T> {
    lhs.prelu(rhs)
}

/// Computes `prelu`, but with a scalar value. `max(0, t) + a*min(0, t)`
pub fn leakyrelu<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    lhs: Tensor<S, E, D, T>,
    rhs: E,
) -> Tensor<S, E, D, T> {
    lhs.prelu(rhs)
}

/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let a = dev.tensor([0.05,0.05,0.05,0.05]);
/// let r = prelu(t, a);
/// assert_eq!(r.array(), [-0.05, 0.0, 1.0, 2.0]);
/// ```
pub trait TryPReLU<T = Self>: HasErr {
    fn prelu(self, rhs: T) -> Self {
        self.try_prelu(rhs).unwrap()
    }

    fn try_prelu(self, rhs: T) -> Result<Self, Self::Err>;
}

impl<S: Shape, E: Dtype, D, LhsTape: Tape<E, D>, R> TryPReLU<Tensor<S, E, D, R>>
    for Tensor<S, E, D, LhsTape>
where
    D: Device<E>,
    LhsTape: Merge<R>,
{
    /// See [prelu]
    fn try_prelu(self, rhs: Tensor<S, E, D, R>) -> Result<Self, Self::Err> {
        let scaled = self.with_empty_tape().try_mul(rhs)?;
        self.try_lt(E::default())?.try_choose(scaled, self)
    }
}

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> TryPReLU<E> for Tensor<S, E, D, T> {
    /// See [prelu]
    fn try_prelu(self, rhs: E) -> Result<Self, Self::Err> {
        let dev = self.device.clone();
        let scale = dev.tensor(rhs).retaped::<T>().broadcast_like(self.shape());
        let scaled = self.with_empty_tape().try_mul(scale)?;
        self.try_lt(E::default())?.try_choose(scaled, self)
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        tensor::*,
        tensor_ops::{prelu::TryPReLU, *},
        tests::*,
    };

    #[test]
    fn test_prelu() {
        let dev: TestDevice = Default::default();
        let x = dev
            .tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
            .to_dtype::<TestDtype>();
        let y = dev
            .tensor([0.05, 0.05, 0.05, 0.05, 0.05])
            .to_dtype::<TestDtype>();
        let r = x.leaky_trace().prelu(y.clone());
        assert_close_to_literal!(r, [-0.1, -0.05, 0.0, 1.0, 2.0]);
        // NOTE: call .exp() to make sure we cover cases where .prelu() uses the result's gradient
        let g = r.exp().mean().backward();
        assert_close_to_literal!(
            g.get(&x),
            [0.00904837, 0.00951229, 0.2, 0.54365635, 1.4778112]
        );
        assert_close_to_literal!(g.get(&y), [-0.3619348, -0.1902458, 0.0, 0.0, 0.0]);
    }
}