1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#[macro_export]
macro_rules! non_differentiable {
    ($($path:tt)+) => {
        $crate::__non_differentiable!(begin $($path)+);
    };
}

#[doc(hidden)]
#[macro_export]
macro_rules! __non_differentiable {
    // Invocation started with `<`, parse generics.
    (begin < $($rest:tt)*) => {
        $crate::__non_differentiable!(generics () () $($rest)*);
    };

    // Invocation did not start with `<`.
    (begin $first:tt $($rest:tt)*) => {
        $crate::__non_differentiable!(path () ($first) $($rest)*);
    };

    // End of generics.
    (generics ($($generics:tt)*) () > $($rest:tt)*) => {
        $crate::__non_differentiable!(path ($($generics)*) () $($rest)*);
    };

    // Generics open bracket.
    (generics ($($generics:tt)*) ($($brackets:tt)*) < $($rest:tt)*) => {
        $crate::__non_differentiable!(generics ($($generics)* <) ($($brackets)* <) $($rest)*);
    };

    // Generics close bracket.
    (generics ($($generics:tt)*) (< $($brackets:tt)*) > $($rest:tt)*) => {
        $crate::__non_differentiable!(generics ($($generics)* >) ($($brackets)*) $($rest)*);
    };

    // Token inside of generics.
    (generics ($($generics:tt)*) ($($brackets:tt)*) $first:tt $($rest:tt)*) => {
        $crate::__non_differentiable!(generics ($($generics)* $first) ($($brackets)*) $($rest)*);
    };

    // End with `where` clause.
    (path ($($generics:tt)*) ($($path:tt)*) where $($rest:tt)*) => {
        $crate::__non_differentiable!(impl ($($generics)*) ($($path)*) ($($rest)*));
    };

    // End without `where` clause.
    (path ($($generics:tt)*) ($($path:tt)*)) => {
        $crate::__non_differentiable!(impl ($($generics)*) ($($path)*) ());
    };

    // Token inside of path.
    (path ($($generics:tt)*) ($($path:tt)*) $first:tt $($rest:tt)*) => {
        $crate::__non_differentiable!(path ($($generics)*) ($($path)* $first) $($rest)*);
    };

    // The impl.
    (impl ($($generics:tt)*) ($($path:tt)*) ($($bound:tt)*)) => {
        impl<$($generics)*> $crate::Differentiable for $($path)* where $($bound)* {
            type Tensors = ();
            type Gradient = ();
            fn tensors(&self) -> Self::Tensors {}
            fn grad(_: &Self::Tensors, _: &HashMap<usize, Tensor>) -> Self::Gradient {}
            fn grad_map(_: &Self::Tensors, _: Self::Gradient, _: &mut HashMap<usize, Tensor>) {}
        }
    };
}

#[macro_export]
macro_rules! differentiable_module {
    ($m:ident) => {
        impl $crate::Differentiable for $m {
            type Tensors = HashMap<usize, Tensor>;
            type Gradient = HashMap<usize, Tensor>;

            fn tensors(&self) -> Self::Tensors {
                $crate::Module::parameters(self)
            }

            fn grad(tensors: &Self::Tensors, grad_map: &HashMap<usize, Tensor>) -> Self::Gradient {
                tensors
                    .keys()
                    .map(|id| (*id, grad_map.get(id).unwrap().clone()))
                    .collect()
            }

            fn grad_map(
                tensors: &Self::Tensors,
                grad: Self::Gradient,
                out: &mut HashMap<usize, Tensor>,
            ) {
                for id in tensors.keys() {
                    out.insert(*id, grad.get(id).unwrap().clone());
                }
            }
        }

        impl $crate::DifferentiableModule for $m {}
    };
}