1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use core::marker::PhantomData;
use crate::{
ops::Permutable,
shape::{Axes, PermutableBy},
tensor::{Backward, GradAcc, GradientRef, Tensor, Variable},
};
#[derive(Debug, Clone)]
pub struct PermuteBackwardV<'g, G, Dims> {
grad: GradientRef<'g, G>,
dims: PhantomData<Dims>,
}
impl<S, G, Dims> Backward<S> for PermuteBackwardV<'_, G, Dims>
where
Dims: Axes + PermutableBy<Dims>,
<Dims as PermutableBy<Dims>>::Output: Axes,
S: Permutable<<Dims as PermutableBy<Dims>>::Output>,
G: GradAcc<<S as Permutable<<Dims as PermutableBy<Dims>>::Output>>::Output>,
{
fn backward(self, res_grad: S) {
self.grad.accumulate(res_grad._permute());
}
}
impl<'g, S, Dims> Permutable<Dims> for &'g Variable<S>
where
Dims: Axes,
S: Clone + Permutable<Dims>,
{
type Output = Tensor<<S as Permutable<Dims>>::Output, PermuteBackwardV<'g, S, Dims>>;
fn _permute(self) -> Self::Output {
Tensor {
data: self.data.clone()._permute(),
grad_fn: PermuteBackwardV {
grad: GradientRef::new(&self.grad),
dims: PhantomData,
},
}
}
}
#[derive(Debug, Clone)]
pub struct PermuteBackwardT<F, Dims> {
grad_fn: F,
dims: PhantomData<Dims>,
}
impl<S, F, Dims> Backward<S> for PermuteBackwardT<F, Dims>
where
Dims: Axes + PermutableBy<Dims>,
<Dims as PermutableBy<Dims>>::Output: Axes,
S: Permutable<<Dims as PermutableBy<Dims>>::Output>,
F: Backward<<S as Permutable<<Dims as PermutableBy<Dims>>::Output>>::Output>,
{
fn backward(self, res_grad: S) {
self.grad_fn.backward(res_grad._permute());
}
}
impl<S, F, Dims> Permutable<Dims> for Tensor<S, F>
where
Dims: Axes,
S: Permutable<Dims>,
{
type Output = Tensor<<S as Permutable<Dims>>::Output, PermuteBackwardT<F, Dims>>;
fn _permute(self) -> Self::Output {
Tensor {
data: self.data._permute(),
grad_fn: PermuteBackwardT {
grad_fn: self.grad_fn,
dims: PhantomData,
},
}
}
}