1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use super::*;
use crate::{shapes::*, tensor::*};
/// Reduction along multiple axes using [LogSumExp](https://en.wikipedia.org/wiki/LogSumExp).
pub trait LogSumExpTo: HasErr + HasShape {
/// [LogSumExp](https://en.wikipedia.org/wiki/LogSumExp) reduction.
///
/// **Pytorch equivalent**: `t.exp().sum(Axes).log()`
///
/// **Related functions**: [ln()], [exp()], [log_softmax()], [softmax()]
///
/// Example:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t: Tensor<Rank3<2, 4, 6>, f32, _> = dev.zeros();
/// let _ = t.logsumexp::<Rank2<2, 4>, _>(); // or `logsumexp::<_, Axis<2>>()`
/// ```
///
/// Multi axis logsumexp:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// # let t: Tensor<Rank3<2, 4, 6>, f32, _> = dev.zeros();
/// let _ = t.logsumexp::<Rank1<4>, _>(); // or `logsumexp::<_, Axes2<0, 2>>()`
/// ```
fn logsumexp<Dst: Shape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: ReduceShapeTo<Dst, Ax>,
{
self.try_logsumexp().unwrap()
}
/// Fallible version of [LogSumExpTo::logsumexp]
fn try_logsumexp<Dst: Shape, Ax: Axes>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: ReduceShapeTo<Dst, Ax>;
}
impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> LogSumExpTo for Tensor<S, E, D, T> {
fn try_logsumexp<Dst: Shape, Ax: Axes>(self) -> Result<Self::WithShape<Dst>, Self::Err>
where
Self::Shape: ReduceShapeTo<Dst, Ax>,
{
let shape = *self.shape();
let (t, tape) = self.split_tape();
let max: Tensor<Dst, E, D> = t.clone().try_max()?;
let t = {
// does normalization outside of backprop graph.
// since try_sub will create a new tensor id, we need to reset the id
// back to t's id before the subtraction.
let keep_id = t.id;
let mut t = t.try_sub(max.clone().try_broadcast_like::<_, Ax>(&shape)?)?;
t.id = keep_id;
t
};
let dst = t.put_tape(tape).try_exp()?.try_sum::<Dst, Ax>()?.try_ln()?;
{
// does normalization outside of backprop graph
let (dst, tape) = dst.split_tape();
let keep_id = dst.id;
let mut dst = dst.try_add(max)?;
dst.id = keep_id;
Ok(dst.put_tape(tape))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
#[test]
fn test_logsumexp_1d() {
let dev: TestDevice = Default::default();
let a = dev
.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
.to_dtype::<TestDtype>();
let r = a.leaky_trace().logsumexp();
assert_close_to_literal!(r, 2.4519143);
let g = r.backward();
assert_close_to_literal!(
g.get(&a),
[0.011656231, 0.03168492, 0.08612854, 0.23412165, 0.6364086]
);
}
#[test]
fn test_logsumexp_2d() {
let dev: TestDevice = Default::default();
let a = dev
.tensor([[-2.0, -1.0, 0.0], [1.0, 4.0, 7.0]])
.to_dtype::<TestDtype>();
let r = a.leaky_trace().logsumexp::<Rank1<2>, _>();
assert_close_to_literal!(r, [0.40760595, 7.0509458]);
let g = r.mean().backward();
assert_close_to_literal!(
g.get(&a),
[
[0.045015287, 0.12236424, 0.33262047],
[0.0011778167, 0.023657078, 0.47516513],
]
);
}
}