1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
use crate::{shapes::*, tensor::*};
use super::{cmp::*, BroadcastTo, ChooseFrom, Device, TryMul};
/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let a = dev.tensor([0.05,0.05,0.05,0.05]);
/// let r = t.prelu(a);
/// assert_eq!(r.array(), [-0.05, 0.0, 1.0, 2.0]);
/// ```
pub fn prelu<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D> + Merge<R>, R: Default>(
lhs: Tensor<S, E, D, T>,
rhs: Tensor<S, E, D, R>,
) -> Tensor<S, E, D, T> {
lhs.prelu(rhs)
}
/// Computes `prelu`, but with a scalar value. `max(0, t) + a*min(0, t)`
pub fn leakyrelu<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
lhs: Tensor<S, E, D, T>,
rhs: E,
) -> Tensor<S, E, D, T> {
lhs.prelu(rhs)
}
/// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)`
///
/// In other words, for each element i:
/// - if `lhs[i] < 0`, use `lhs[i] * rhs[i]`
/// - if `lhs[i] >= 0`, use `lhs[i]`
///
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0]);
/// let a = dev.tensor([0.05,0.05,0.05,0.05]);
/// let r = prelu(t, a);
/// assert_eq!(r.array(), [-0.05, 0.0, 1.0, 2.0]);
/// ```
pub trait TryPReLU<T = Self>: HasErr {
fn prelu(self, rhs: T) -> Self {
self.try_prelu(rhs).unwrap()
}
fn try_prelu(self, rhs: T) -> Result<Self, Self::Err>;
}
impl<S: Shape, E: Dtype, D, LhsTape: Tape<E, D>, R> TryPReLU<Tensor<S, E, D, R>>
for Tensor<S, E, D, LhsTape>
where
D: Device<E>,
LhsTape: Merge<R>,
{
/// See [prelu]
fn try_prelu(self, rhs: Tensor<S, E, D, R>) -> Result<Self, Self::Err> {
let scaled = self.with_empty_tape().try_mul(rhs)?;
self.try_lt(E::default())?.try_choose(scaled, self)
}
}
impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> TryPReLU<E> for Tensor<S, E, D, T> {
/// See [prelu]
fn try_prelu(self, rhs: E) -> Result<Self, Self::Err> {
let dev = self.device.clone();
let scale = dev.tensor(rhs).retaped::<T>().broadcast_like(self.shape());
let scaled = self.with_empty_tape().try_mul(scale)?;
self.try_lt(E::default())?.try_choose(scaled, self)
}
}
#[cfg(test)]
mod tests {
use crate::{
tensor::*,
tensor_ops::{prelu::TryPReLU, *},
tests::*,
};
#[test]
fn test_prelu() {
let dev: TestDevice = Default::default();
let x = dev
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
.to_dtype::<TestDtype>();
let y = dev
.tensor([0.05, 0.05, 0.05, 0.05, 0.05])
.to_dtype::<TestDtype>();
let r = x.leaky_trace().prelu(y.clone());
assert_close_to_literal!(r, [-0.1, -0.05, 0.0, 1.0, 2.0]);
// NOTE: call .exp() to make sure we cover cases where .prelu() uses the result's gradient
let g = r.exp().mean().backward();
assert_close_to_literal!(
g.get(&x),
[0.00904837, 0.00951229, 0.2, 0.54365635, 1.4778112]
);
assert_close_to_literal!(g.get(&y), [-0.3619348, -0.1902458, 0.0, 0.0, 0.0]);
}
}