1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use core::marker::PhantomData;

use crate::{
    ops::Permutable,
    shape::{Axes, PermutableBy},
    tensor::{Backward, GradAcc, GradientRef, Tensor, Variable},
};

#[derive(Debug, Clone)]
pub struct PermuteBackwardV<'g, G, Dims> {
    grad: GradientRef<'g, G>,
    dims: PhantomData<Dims>,
}

impl<S, G, Dims> Backward<S> for PermuteBackwardV<'_, G, Dims>
where
    Dims: Axes + PermutableBy<Dims>,
    <Dims as PermutableBy<Dims>>::Output: Axes,
    S: Permutable<<Dims as PermutableBy<Dims>>::Output>,
    G: GradAcc<<S as Permutable<<Dims as PermutableBy<Dims>>::Output>>::Output>,
{
    fn backward(self, res_grad: S) {
        self.grad.accumulate(res_grad._permute());
    }
}

impl<'g, S, Dims> Permutable<Dims> for &'g Variable<S>
where
    Dims: Axes,
    S: Clone + Permutable<Dims>,
{
    type Output = Tensor<<S as Permutable<Dims>>::Output, PermuteBackwardV<'g, S, Dims>>;
    fn _permute(self) -> Self::Output {
        Tensor {
            data: self.data.clone()._permute(),
            grad_fn: PermuteBackwardV {
                grad: GradientRef::new(&self.grad),
                dims: PhantomData,
            },
        }
    }
}

#[derive(Debug, Clone)]
pub struct PermuteBackwardT<F, Dims> {
    grad_fn: F,
    dims: PhantomData<Dims>,
}

impl<S, F, Dims> Backward<S> for PermuteBackwardT<F, Dims>
where
    Dims: Axes + PermutableBy<Dims>,
    <Dims as PermutableBy<Dims>>::Output: Axes,
    S: Permutable<<Dims as PermutableBy<Dims>>::Output>,
    F: Backward<<S as Permutable<<Dims as PermutableBy<Dims>>::Output>>::Output>,
{
    fn backward(self, res_grad: S) {
        self.grad_fn.backward(res_grad._permute());
    }
}

impl<S, F, Dims> Permutable<Dims> for Tensor<S, F>
where
    Dims: Axes,
    S: Permutable<Dims>,
{
    type Output = Tensor<<S as Permutable<Dims>>::Output, PermuteBackwardT<F, Dims>>;
    fn _permute(self) -> Self::Output {
        Tensor {
            data: self.data._permute(),
            grad_fn: PermuteBackwardT {
                grad_fn: self.grad_fn,
                dims: PhantomData,
            },
        }
    }
}