1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
use super::*;
use crate::{shapes::*, tensor::*};
/// Reduction along multiple axes using [LogSumExp](https://en.wikipedia.org/wiki/LogSumExp).
pub trait LogSumExpTo: HasErr + HasShape {
/// [LogSumExp](https://en.wikipedia.org/wiki/LogSumExp) reduction.
///
/// **Pytorch equivalent**: `t.exp().sum(Axes).log()`
///
/// **Related functions**: [ln()], [exp()], [log_softmax()], [softmax()]
///
/// Example:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t: Tensor<Rank3<2, 4, 6>, f32, _> = dev.zeros();
/// let _ = t.logsumexp::<Rank2<2, 4>, _>(); // or `logsumexp::<_, Axis<2>>()`
/// ```
///
/// Multi axis logsumexp:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// # let t: Tensor<Rank3<2, 4, 6>, f32, _> = dev.zeros();
/// let _ = t.logsumexp::<Rank1<4>, _>(); // or `logsumexp::<_, Axes2<0, 2>>()`
/// ```
fn logsumexp<Dst: Shape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: ReduceShapeTo<Dst, Ax>,
{
self.try_logsumexp().unwrap()
}
/// Fallible version of [LogSumExpTo::logsumexp]
fn try_logsumexp<Dst: Shape, Ax: Axes>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: ReduceShapeTo<Dst, Ax>;
}
impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> LogSumExpTo for Tensor<S, E, D, T> {
fn try_logsumexp<Dst: Shape, Ax: Axes>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: ReduceShapeTo<Dst, Ax>,
{
let shape = *self.shape();
let (t, tape) = self.split_tape();
let max: Tensor<Dst, E, D> = t.clone().try_max()?;
let t = {
// does normalization outside of backprop graph.
// since try_sub will create a new tensor id, we need to reset the id
// back to t's id before the subtraction.
let keep_id = t.id;
let mut t = t.try_sub(max.clone().try_broadcast_like::<_, Ax>(&shape)?)?;
t.id = keep_id;
t
};
let dst = t.put_tape(tape).try_exp()?.try_sum::<Dst, Ax>()?.try_ln()?;
{
// does normalization outside of backprop graph
let (dst, tape) = dst.split_tape();
let keep_id = dst.id;
let mut dst = dst.try_add(max)?;
dst.id = keep_id;
Ok(dst.put_tape(tape))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
#[test]
fn test_logsumexp_1d() {
let dev: TestDevice = Default::default();
let a = dev
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
.to_dtype::<TestDtype>();
let r = a.leaky_trace().logsumexp();
assert_close_to_literal!(r, 2.4519143);
let g = r.backward();
assert_close_to_literal!(
g.get(&a),
[0.011656231, 0.03168492, 0.08612854, 0.23412165, 0.6364086]
);
}
#[test]
fn test_logsumexp_2d() {
let dev: TestDevice = Default::default();
let a = dev
.tensor([[-2.0, -1.0, 0.0], [1.0, 4.0, 7.0]])
.to_dtype::<TestDtype>();
let r = a.leaky_trace().logsumexp::<Rank1<2>, _>();
assert_close_to_literal!(r, [0.40760595, 7.0509458]);
let g = r.mean().backward();
assert_close_to_literal!(
g.get(&a),
[
[0.045015287, 0.12236424, 0.33262047],
[0.0011778167, 0.023657078, 0.47516513],
]
);
}
}