1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
use crate::{Differentiable, Func};

macro_rules! impl_tuple_arg_fn {
    ($($T:tt)*) => {
        paste::paste! {
            impl<$($T,)* OUT, FUNC> Func<($($T,)*), OUT> for FUNC
            where
                $($T: Differentiable,)*
                OUT: Differentiable,
                FUNC: Fn($($T,)*) -> OUT,
            {
                fn apply(&self, input: ($($T,)*)) -> OUT {
                    let ($([<$T:lower 1>],)*) = input;
                    self($([<$T:lower 1>],)*)
                }
            }
        }
    };
}

impl_tuple_arg_fn!(A);
impl_tuple_arg_fn!(A B);
impl_tuple_arg_fn!(A B C);
impl_tuple_arg_fn!(A B C D);
impl_tuple_arg_fn!(A B C D E);
impl_tuple_arg_fn!(A B C D E F);
impl_tuple_arg_fn!(A B C D E F G);
impl_tuple_arg_fn!(A B C D E F G H);
impl_tuple_arg_fn!(A B C D E F G H I);
impl_tuple_arg_fn!(A B C D E F G H I J);
impl_tuple_arg_fn!(A B C D E F G H I J K);
impl_tuple_arg_fn!(A B C D E F G H I J K L);

macro_rules! impl_array_arg_fn {
    ($S:expr; $($N:expr)*; $($T:tt)*) => {
        paste::paste! {
            impl<I, OUT, FUNC> Func<[I; $S], OUT> for FUNC
            where
                I: Differentiable,
                OUT: Differentiable,
                FUNC: Fn($($T,)*) -> OUT,
            {
                fn apply(&self, input: [I; $S]) -> OUT {
                    let [$([<i $N>],)* ..] = input;
                    self($([<i $N>],)*)
                }
            }
        }
    };
}

impl_array_arg_fn!(1; 0; I);
impl_array_arg_fn!(2; 0 1; I I);
impl_array_arg_fn!(3; 0 1 2; I I I);
impl_array_arg_fn!(4; 0 1 2 3; I I I I);
impl_array_arg_fn!(5; 0 1 2 3 4; I I I I I);
impl_array_arg_fn!(6; 0 1 2 3 4 5; I I I I I I);
impl_array_arg_fn!(7; 0 1 2 3 4 5 6; I I I I I I I);
impl_array_arg_fn!(8; 0 1 2 3 4 5 6 7; I I I I I I I I);
impl_array_arg_fn!(9; 0 1 2 3 4 5 6 7 8; I I I I I I I I I);
impl_array_arg_fn!(10; 0 1 2 3 4 5 6 7 8 9; I I I I I I I I I I);
impl_array_arg_fn!(11; 0 1 2 3 4 5 6 7 8 9 10; I I I I I I I I I I I);
impl_array_arg_fn!(12; 0 1 2 3 4 5 6 7 8 9 10 11; I I I I I I I I I I I I);