mininn_derive/
lib.rs

1use proc_macro::TokenStream;
2
3fn impl_msg_pack_format_trait(ast: syn::DeriveInput) -> TokenStream {
4    let ident = ast.ident;
5
6    quote::quote! {
7        impl MSGPackFormatting for #ident {
8            fn to_msgpack(&self) -> NNResult<Vec<u8>> {
9                Ok(rmp_serde::to_vec(&self)?)
10            }
11
12            fn from_msgpack(buff: &[u8]) -> NNResult<Box<Self>>
13            where
14                Self: Sized,
15            {
16                Ok(Box::new(rmp_serde::from_slice::<Self>(buff)?))
17            }
18        }
19    }
20    .into()
21}
22
23fn impl_layer_trait(ast: syn::DeriveInput) -> TokenStream {
24    let ident = ast.ident;
25
26    quote::quote! {
27        impl Layer for #ident {
28            fn layer_type(&self) -> &str {
29               stringify!(#ident)
30            }
31
32            fn as_any(&self) -> &dyn std::any::Any {
33                self
34            }
35        }
36
37        impl MSGPackFormatting for #ident {
38            fn to_msgpack(&self) -> NNResult<Vec<u8>> {
39                Ok(rmp_serde::to_vec(&self)?)
40            }
41
42            fn from_msgpack(buff: &[u8]) -> NNResult<Box<Self>>
43            where
44                Self: Sized,
45            {
46                Ok(Box::new(rmp_serde::from_slice::<Self>(buff)?))
47            }
48        }
49    }
50    .into()
51}
52
53fn impl_activation_trait(ast: syn::DeriveInput) -> TokenStream {
54    let ident = ast.ident;
55
56    quote::quote! {
57        impl ActivationFunction for #ident {}
58
59        impl NNUtil for #ident {
60            #[inline]
61            fn name(&self) -> &str {
62                stringify!(#ident)
63            }
64
65            #[inline]
66            fn from_name(name: &str) -> NNResult<Box<Self>>
67            where
68                Self: Sized,
69            {
70                Ok(Box::new(#ident))
71            }
72        }
73    }
74    .into()
75}
76
77fn impl_cost_trait(ast: syn::DeriveInput) -> TokenStream {
78    let ident = ast.ident;
79
80    quote::quote! {
81        impl CostFunction for #ident {}
82
83        impl NNUtil for #ident {
84            #[inline]
85            fn name(&self) -> &str {
86                stringify!(#ident)
87            }
88
89            #[inline]
90            fn from_name(name: &str) -> NNResult<Box<Self>>
91            where
92                Self: Sized,
93            {
94                Ok(Box::new(#ident))
95            }
96        }
97    }
98    .into()
99}
100
101#[proc_macro_derive(MSGPackFormatting)]
102pub fn msg_pack_format_derive_macro(item: TokenStream) -> TokenStream {
103    let ast = syn::parse(item).unwrap();
104
105    impl_msg_pack_format_trait(ast)
106}
107
108#[proc_macro_derive(Layer)]
109pub fn layer_derive_macro(item: TokenStream) -> TokenStream {
110    let ast = syn::parse(item).unwrap();
111
112    impl_layer_trait(ast)
113}
114
115#[proc_macro_derive(ActivationFunction)]
116pub fn activation_derive_macro(item: TokenStream) -> TokenStream {
117    let ast = syn::parse(item).unwrap();
118
119    impl_activation_trait(ast)
120}
121
122#[proc_macro_derive(CostFunction)]
123pub fn cost_derive_macro(item: TokenStream) -> TokenStream {
124    let ast = syn::parse(item).unwrap();
125
126    impl_cost_trait(ast)
127}