gradients_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro2::{TokenStream, TokenTree};
3use proc_macro_error::{proc_macro_error, emit_error};
4use quote::{quote, ToTokens};
5use syn::{
6    parse_macro_input, punctuated::Punctuated, token::Comma, Data, DeriveInput, Field, Fields,
7    Ident,
8};
9
10#[proc_macro_attribute]
11#[proc_macro_error]
12pub fn network(
13    _attr: proc_macro::TokenStream,
14    item: proc_macro::TokenStream,
15) -> proc_macro::TokenStream {
16    let input = parse_macro_input!(item as DeriveInput);
17    let name = input.ident;
18
19    let fields = match input.data {
20        Data::Struct(data) => match data.fields {
21            Fields::Named(fields) => fields.named,
22            _ => panic!("The network attribute can be applied on structs only."),
23        },
24        _ => panic!("The network attribute can be applied on structs only."),
25    };
26
27    proc_macro::TokenStream::from(add_lifetimes_derive_net(name, fields))
28}
29
30fn add_lifetimes_derive_net(name: Ident, fields: Punctuated<Field, Comma>) -> TokenStream {
31    let mut prev_out_size_info = (None, None);
32
33    let fields_with_lifetimes = fields
34        .iter()
35        .map(|f| {
36            let mut in_or_out_size = 0;
37
38            let name = &f.ident;
39            let t = &f.ty;
40            let type_token = t.into_token_stream();
41
42            if type_token.to_string().starts_with("Linear") {
43                let mut in_out_size = TokenStream::new();
44                for token in type_token {
45                    if let TokenTree::Literal(lit) = &token {
46                        let lit_tokens = lit.to_token_stream();
47
48                        in_out_size.extend(lit_tokens.clone());
49                        
50                        if in_or_out_size == 1 {
51                            prev_out_size_info = (Some(lit_tokens.clone()), name.clone());
52                        }
53                        
54                        // comparing the output size with the next input size of the linear layer
55                        if let Some(prev_out_size) = &prev_out_size_info.0 {
56                            if in_or_out_size == 0 && prev_out_size.to_string() != lit_tokens.to_string() {
57                                emit_error! { lit_tokens,
58                                    format!("The output and input size of {prev_ident:?} (output size: {prev_out}) and {ident:?} (input size: {input}) do not match.",
59                                 
60                                        prev_ident=prev_out_size_info.1.as_ref().unwrap().to_string(), 
61                                        prev_out=prev_out_size.to_string(), 
62                                        ident=name.as_ref().unwrap().to_string(), 
63                                        input=lit_tokens.to_string()
64                                    );                              
65                                    note=format!("The input size of {ident:?} must be equal to the output size of {prev_ident:?}.",
66                                            ident=name.as_ref().unwrap().to_string(), 
67                                            prev_ident=prev_out_size_info.1.as_ref().unwrap().to_string(), 
68                                    );
69                                    help=format!("Set the input size of {ident:?} to {prev_out}.",
70                                        ident=name.as_ref().unwrap().to_string(),
71                                        prev_out=prev_out_size.to_string(), 
72                                    );
73                                }
74                            }
75                        }
76                        in_or_out_size += 1;
77                    }
78
79                    if let TokenTree::Punct(pun) = token {
80                        if pun.as_char() != ',' {
81                            continue;
82                        }
83                        in_out_size.extend(pun.to_token_stream());
84                    }
85                }
86
87                quote! {#name: Linear<'a, T, #in_out_size>,}
88            } else {
89                quote!(#name: #t<'a, T>,)
90            }
91        })
92        .collect::<TokenStream>();
93
94    let with_device_chain = fields
95        .iter()
96        .map(|f| {
97            let name = &f.ident;
98
99            quote!(#name: WithDevice::with_device(device),)
100        })
101        .collect::<TokenStream>();
102
103    quote! {
104        use gradients::{NeuralNetwork, Alloc, WithDevice, number::Float, GraphReturn};
105        #[derive(NeuralNetwork)]
106        struct #name<'a, T> {
107            #fields_with_lifetimes
108        }
109
110        impl<'a, T: Float> #name<'a, T> {
111            pub fn with_device<D: Alloc<T>+GraphReturn>(device: &'a D) -> Self {
112                Self { #with_device_chain }
113            }
114        }
115    }
116}
117
118#[proc_macro_derive(NoParams)]
119pub fn derive_params(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
120    let input = parse_macro_input!(input as DeriveInput);
121
122    let name = input.ident;
123    proc_macro::TokenStream::from(impl_params(name))
124}
125
126fn impl_params(name: Ident) -> TokenStream {
127    quote! {
128        impl<'a, T> GetParam<'a,T> for #name<'a, T> {}
129        impl<'a, T> WithDevice<'a, T> for #name<'a, T> {}
130        impl<'a, T> #name<'a, T> {
131            pub fn with_device<D>(_dev: &D) -> #name<'a, T> {
132                Self::default()
133            }
134        }
135    }
136}
137
138#[proc_macro_derive(NeuralNetwork)]
139pub fn derive_neural_network(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
140    let input = parse_macro_input!(input as DeriveInput);
141
142    let name = input.ident;
143    let fields = match input.data {
144        Data::Struct(data) => match data.fields {
145            Fields::Named(fields) => fields.named,
146            _ => panic!("Structs only"),
147        },
148        _ => panic!("Structs only"),
149    };
150
151    proc_macro::TokenStream::from(impl_neural_network(name, fields))
152}
153
154fn impl_neural_network(name: Ident, fields: Punctuated<Field, Comma>) -> TokenStream {
155    let forward_chain = fields.iter().fold(quote!(&inputs), |acc, f| {
156        let name = &f.ident;
157        quote!(self.#name.forward(&#acc))
158    });
159
160    let default_chain = fields
161        .iter()
162        .map(|f| {
163            let name = &f.ident;
164            quote!(#name: Default::default(),)
165        })
166        .collect::<TokenStream>();
167
168    let backward_chain = fields.iter().rev().fold(quote!(&grad), |acc, f| {
169        let name = &f.ident;
170        quote!(self.#name.backward(&#acc))
171    });
172
173    let vec = quote! {let mut vec = Vec::new();};
174
175    let params = fields
176        .iter()
177        .map(|f| {
178            let name = &f.ident;
179            quote!(
180               if let Some(params) = self.#name.params() {
181                   vec.push(params);
182               }
183            )
184        })
185        .collect::<TokenStream>();
186    let return_vec = quote! {vec};
187
188    quote! {
189        use gradients::{GetParam, Param, Matrix};
190
191
192        impl<'a, T> Default for #name<'a, T> {
193            fn default() -> Self {
194                Self { #default_chain }
195            }
196        }
197        impl<'a, T: gradients::number::Float+gradients::CDatatype+gradients::GenericBlas + gradients::CudaTranspose> NeuralNetwork<'a, T> for #name<'a, T> {
198            fn forward(&mut self, inputs: &Matrix<'a, T>) -> Matrix<'a, T> {
199                #forward_chain
200            }
201
202            fn backward(&mut self, grad: &Matrix<'a, T>) -> Matrix<'a, T> {
203                #backward_chain
204            }
205
206            fn params(&mut self) -> Vec<Param<'a, T>> {
207                #vec
208                #params
209                #return_vec
210            }
211        }
212    }
213}