1extern crate proc_macro;
2use proc_macro2::{TokenStream, TokenTree};
3use proc_macro_error::{proc_macro_error, emit_error};
4use quote::{quote, ToTokens};
5use syn::{
6 parse_macro_input, punctuated::Punctuated, token::Comma, Data, DeriveInput, Field, Fields,
7 Ident,
8};
9
10#[proc_macro_attribute]
11#[proc_macro_error]
12pub fn network(
13 _attr: proc_macro::TokenStream,
14 item: proc_macro::TokenStream,
15) -> proc_macro::TokenStream {
16 let input = parse_macro_input!(item as DeriveInput);
17 let name = input.ident;
18
19 let fields = match input.data {
20 Data::Struct(data) => match data.fields {
21 Fields::Named(fields) => fields.named,
22 _ => panic!("The network attribute can be applied on structs only."),
23 },
24 _ => panic!("The network attribute can be applied on structs only."),
25 };
26
27 proc_macro::TokenStream::from(add_lifetimes_derive_net(name, fields))
28}
29
30fn add_lifetimes_derive_net(name: Ident, fields: Punctuated<Field, Comma>) -> TokenStream {
31 let mut prev_out_size_info = (None, None);
32
33 let fields_with_lifetimes = fields
34 .iter()
35 .map(|f| {
36 let mut in_or_out_size = 0;
37
38 let name = &f.ident;
39 let t = &f.ty;
40 let type_token = t.into_token_stream();
41
42 if type_token.to_string().starts_with("Linear") {
43 let mut in_out_size = TokenStream::new();
44 for token in type_token {
45 if let TokenTree::Literal(lit) = &token {
46 let lit_tokens = lit.to_token_stream();
47
48 in_out_size.extend(lit_tokens.clone());
49
50 if in_or_out_size == 1 {
51 prev_out_size_info = (Some(lit_tokens.clone()), name.clone());
52 }
53
54 if let Some(prev_out_size) = &prev_out_size_info.0 {
56 if in_or_out_size == 0 && prev_out_size.to_string() != lit_tokens.to_string() {
57 emit_error! { lit_tokens,
58 format!("The output and input size of {prev_ident:?} (output size: {prev_out}) and {ident:?} (input size: {input}) do not match.",
59
60 prev_ident=prev_out_size_info.1.as_ref().unwrap().to_string(),
61 prev_out=prev_out_size.to_string(),
62 ident=name.as_ref().unwrap().to_string(),
63 input=lit_tokens.to_string()
64 );
65 note=format!("The input size of {ident:?} must be equal to the output size of {prev_ident:?}.",
66 ident=name.as_ref().unwrap().to_string(),
67 prev_ident=prev_out_size_info.1.as_ref().unwrap().to_string(),
68 );
69 help=format!("Set the input size of {ident:?} to {prev_out}.",
70 ident=name.as_ref().unwrap().to_string(),
71 prev_out=prev_out_size.to_string(),
72 );
73 }
74 }
75 }
76 in_or_out_size += 1;
77 }
78
79 if let TokenTree::Punct(pun) = token {
80 if pun.as_char() != ',' {
81 continue;
82 }
83 in_out_size.extend(pun.to_token_stream());
84 }
85 }
86
87 quote! {#name: Linear<'a, T, #in_out_size>,}
88 } else {
89 quote!(#name: #t<'a, T>,)
90 }
91 })
92 .collect::<TokenStream>();
93
94 let with_device_chain = fields
95 .iter()
96 .map(|f| {
97 let name = &f.ident;
98
99 quote!(#name: WithDevice::with_device(device),)
100 })
101 .collect::<TokenStream>();
102
103 quote! {
104 use gradients::{NeuralNetwork, Alloc, WithDevice, number::Float, GraphReturn};
105 #[derive(NeuralNetwork)]
106 struct #name<'a, T> {
107 #fields_with_lifetimes
108 }
109
110 impl<'a, T: Float> #name<'a, T> {
111 pub fn with_device<D: Alloc<T>+GraphReturn>(device: &'a D) -> Self {
112 Self { #with_device_chain }
113 }
114 }
115 }
116}
117
118#[proc_macro_derive(NoParams)]
119pub fn derive_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
120 let input = parse_macro_input!(input as DeriveInput);
121
122 let name = input.ident;
123 proc_macro::TokenStream::from(impl_params(name))
124}
125
126fn impl_params(name: Ident) -> TokenStream {
127 quote! {
128 impl<'a, T> GetParam<'a,T> for #name<'a, T> {}
129 impl<'a, T> WithDevice<'a, T> for #name<'a, T> {}
130 impl<'a, T> #name<'a, T> {
131 pub fn with_device<D>(_dev: &D) -> #name<'a, T> {
132 Self::default()
133 }
134 }
135 }
136}
137
138#[proc_macro_derive(NeuralNetwork)]
139pub fn derive_neural_network(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
140 let input = parse_macro_input!(input as DeriveInput);
141
142 let name = input.ident;
143 let fields = match input.data {
144 Data::Struct(data) => match data.fields {
145 Fields::Named(fields) => fields.named,
146 _ => panic!("Structs only"),
147 },
148 _ => panic!("Structs only"),
149 };
150
151 proc_macro::TokenStream::from(impl_neural_network(name, fields))
152}
153
154fn impl_neural_network(name: Ident, fields: Punctuated<Field, Comma>) -> TokenStream {
155 let forward_chain = fields.iter().fold(quote!(&inputs), |acc, f| {
156 let name = &f.ident;
157 quote!(self.#name.forward(&#acc))
158 });
159
160 let default_chain = fields
161 .iter()
162 .map(|f| {
163 let name = &f.ident;
164 quote!(#name: Default::default(),)
165 })
166 .collect::<TokenStream>();
167
168 let backward_chain = fields.iter().rev().fold(quote!(&grad), |acc, f| {
169 let name = &f.ident;
170 quote!(self.#name.backward(&#acc))
171 });
172
173 let vec = quote! {let mut vec = Vec::new();};
174
175 let params = fields
176 .iter()
177 .map(|f| {
178 let name = &f.ident;
179 quote!(
180 if let Some(params) = self.#name.params() {
181 vec.push(params);
182 }
183 )
184 })
185 .collect::<TokenStream>();
186 let return_vec = quote! {vec};
187
188 quote! {
189 use gradients::{GetParam, Param, Matrix};
190
191
192 impl<'a, T> Default for #name<'a, T> {
193 fn default() -> Self {
194 Self { #default_chain }
195 }
196 }
197 impl<'a, T: gradients::number::Float+gradients::CDatatype+gradients::GenericBlas + gradients::CudaTranspose> NeuralNetwork<'a, T> for #name<'a, T> {
198 fn forward(&mut self, inputs: &Matrix<'a, T>) -> Matrix<'a, T> {
199 #forward_chain
200 }
201
202 fn backward(&mut self, grad: &Matrix<'a, T>) -> Matrix<'a, T> {
203 #backward_chain
204 }
205
206 fn params(&mut self) -> Vec<Param<'a, T>> {
207 #vec
208 #params
209 #return_vec
210 }
211 }
212 }
213}