rai_core/transforms/
fn_impls.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use crate::{nn::Module, ty_kind, Func};

impl<I, O, F> Func<ty_kind::Basic, I, O> for F
where
    F: Fn(I) -> O,
{
    fn invoke(&self, input: I) -> O {
        self(input)
    }
}

impl<I, O, F> Func<ty_kind::Module, I, O> for F
where
    F: Fn(I) -> O,
    I: Module,
{
    fn invoke(&self, input: I) -> O {
        self(input)
    }
}

macro_rules! impl_tuple_arg_fn {
    ($($T:tt)*) => {
        paste::paste! {
            impl<$($T,)* OUT, FUNC> Func<ty_kind::Tuple<($($T,)*)>, ($($T,)*), OUT> for FUNC
            where
                FUNC: Fn($($T,)*) -> OUT,
            {
                fn invoke(&self, input: ($($T,)*)) -> OUT {
                    let ($([<$T:lower 1>],)*) = input;
                    self($([<$T:lower 1>],)*)
                }
            }
        }
    };
}

impl_tuple_arg_fn!(A);
impl_tuple_arg_fn!(A B);
impl_tuple_arg_fn!(A B C);
impl_tuple_arg_fn!(A B C D);
impl_tuple_arg_fn!(A B C D E);
impl_tuple_arg_fn!(A B C D E F);
impl_tuple_arg_fn!(A B C D E F G);
impl_tuple_arg_fn!(A B C D E F G H);
impl_tuple_arg_fn!(A B C D E F G H I);
impl_tuple_arg_fn!(A B C D E F G H I J);
impl_tuple_arg_fn!(A B C D E F G H I J K);
impl_tuple_arg_fn!(A B C D E F G H I J K L);

macro_rules! impl_array_arg_fn {
    ($S:expr; $($N:expr)*; $($T:tt)*) => {
        paste::paste! {
            impl<I, OUT, FUNC> Func<ty_kind::Array<[I; $S]>, [I; $S], OUT> for FUNC
            where
                FUNC: Fn($($T,)*) -> OUT,
            {
                fn invoke(&self, input: [I; $S]) -> OUT {
                    let [$([<i $N>],)* ..] = input;
                    self($([<i $N>],)*)
                }
            }
        }
    };
}

impl_array_arg_fn!(1; 0; I);
impl_array_arg_fn!(2; 0 1; I I);
impl_array_arg_fn!(3; 0 1 2; I I I);
impl_array_arg_fn!(4; 0 1 2 3; I I I I);
impl_array_arg_fn!(5; 0 1 2 3 4; I I I I I);
impl_array_arg_fn!(6; 0 1 2 3 4 5; I I I I I I);
impl_array_arg_fn!(7; 0 1 2 3 4 5 6; I I I I I I I);
impl_array_arg_fn!(8; 0 1 2 3 4 5 6 7; I I I I I I I I);
impl_array_arg_fn!(9; 0 1 2 3 4 5 6 7 8; I I I I I I I I I);
impl_array_arg_fn!(10; 0 1 2 3 4 5 6 7 8 9; I I I I I I I I I I);
impl_array_arg_fn!(11; 0 1 2 3 4 5 6 7 8 9 10; I I I I I I I I I I I);
impl_array_arg_fn!(12; 0 1 2 3 4 5 6 7 8 9 10 11; I I I I I I I I I I I I);