1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use crate::{shapes::*, tensor::*};
use super::{cmp::*, BroadcastTo, ChooseFrom, Device, TryMul};
/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let a = dev.tensor([0.05,0.05,0.05,0.05]);
/// let r = t.prelu(a);
/// assert_eq!(r.array(), [-0.05, 0.0, 1.0, 2.0]);
/// ```
pub fn prelu<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D> + Merge<R>, R: Default>(
lhs: Tensor<S, E, D, T>,
rhs: Tensor<S, E, D, R>,
) -> Tensor<S, E, D, T> {
lhs.prelu(rhs)
}
/// Computes `prelu`, but with a scalar value. `max(0, t) + a*min(0, t)`
pub fn leakyrelu<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
lhs: Tensor<S, E, D, T>,
rhs: E,
) -> Tensor<S, E, D, T> {
lhs.prelu(rhs)
}
/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let a = dev.tensor([0.05,0.05,0.05,0.05]);
/// let r = prelu(t, a);
/// assert_eq!(r.array(), [-0.05, 0.0, 1.0, 2.0]);
/// ```
pub trait TryPReLU<T = Self>: HasErr {
fn prelu(self, rhs: T) -> Self {
self.try_prelu(rhs).unwrap()
}
fn try_prelu(self, rhs: T) -> Result<Self, Self::Err>;
}
impl<S: Shape, E: Dtype, D, LhsTape: Tape<E, D>, R> TryPReLU<Tensor<S, E, D, R>>
for Tensor<S, E, D, LhsTape>
where
D: Device<E>,
LhsTape: Merge<R>,
{
/// See [prelu]
fn try_prelu(self, rhs: Tensor<S, E, D, R>) -> Result<Self, Self::Err> {
let scaled = self.with_empty_tape().try_mul(rhs)?;
self.try_lt(E::default())?.try_choose(scaled, self)
}
}
impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> TryPReLU<E> for Tensor<S, E, D, T> {
/// See [prelu]
fn try_prelu(self, rhs: E) -> Result<Self, Self::Err> {
let dev = self.device.clone();
let scale = dev.tensor(rhs).retaped::<T>().broadcast_like(self.shape());
let scaled = self.with_empty_tape().try_mul(scale)?;
self.try_lt(E::default())?.try_choose(scaled, self)
}
}
#[cfg(test)]
mod tests {
use crate::{
tensor::*,
tensor_ops::{prelu::TryPReLU, *},
tests::*,
};
#[test]
fn test_prelu() {
let dev: TestDevice = Default::default();
let x = dev
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
.to_dtype::<TestDtype>();
let y = dev
.tensor([0.05, 0.05, 0.05, 0.05, 0.05])
.to_dtype::<TestDtype>();
let r = x.leaky_trace().prelu(y.clone());
assert_close_to_literal!(r, [-0.1, -0.05, 0.0, 1.0, 2.0]);
// NOTE: call .exp() to make sure we cover cases where .prelu() uses the result's gradient
let g = r.exp().mean().backward();
assert_close_to_literal!(
g.get(&x),
[0.00904837, 0.00951229, 0.2, 0.54365635, 1.4778112]
);
assert_close_to_literal!(g.get(&y), [-0.3619348, -0.1902458, 0.0, 0.0, 0.0]);
}
}