zenu_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, parse_quote,
7    token::Comma,
8    Attribute, Data, DeriveInput, Field, Ident, Token, Type,
9};
10
11#[proc_macro_derive(Parameters, attributes(zenu, parameters))]
12pub fn zenu_derive_parameters(input: TokenStream) -> TokenStream {
13    let input = parse_macro_input!(input as DeriveInput);
14
15    let parameters_impl = impl_parameters(&input);
16
17    TokenStream::from(parameters_impl)
18}
19
20fn impl_parameters(input: &DeriveInput) -> TokenStream2 {
21    let name = &input.ident;
22    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
23
24    let fields = match &input.data {
25        Data::Struct(data) => &data.fields,
26        _ => panic!("ZenuModel only supports structs"),
27    };
28
29    let fields = fields.iter().filter(|field| !has_zenu_skip_attr(field));
30
31    let weights_code = fields.clone().map(|field| {
32        let field_name = &field.ident;
33        quote! {
34            for (name, variable) in &self.#field_name.weights() {
35                let name = format!("{}.{}", stringify!(#field_name), name);
36                params.insert(name.clone(), variable.clone());
37            }
38        }
39    });
40
41    let biases_code = fields.clone().map(|field| {
42        let field_name = &field.ident;
43        quote! {
44            for (name, variable) in &self.#field_name.biases() {
45                let name = format!("{}.{}", stringify!(#field_name), name);
46                params.insert(name.clone(), variable.clone());
47            }
48        }
49    });
50
51    let (num_type, device_type) = parse_parameters_attr(&input.attrs);
52
53    quote!(
54        impl #impl_generics ::zenu::layer::Parameters #ty_generics for #name #ty_generics #where_clause {
55            fn weights(&self) -> std::collections::HashMap<String, ::zenu::autograd::Variable<#num_type, #device_type>> {
56                let mut params = std::collections::HashMap::new();
57                #(
58                    #weights_code
59                )*
60                params
61            }
62
63            fn biases(&self) -> std::collections::HashMap<String, ::zenu::autograd::Variable<#num_type, #device_type>> {
64                let mut params = std::collections::HashMap::new();
65                #(
66                    #biases_code
67                )*
68                params
69            }
70        }
71    )
72}
73
74fn has_zenu_skip_attr(field: &Field) -> bool {
75    field
76        .attrs
77        .iter()
78        .any(|attr| attr.path.is_ident("zenu") && attr.tokens.to_string().contains("skip"))
79}
80
81fn parse_parameters_attr(attrs: &[Attribute]) -> (Type, Type) {
82    let mut num_type: Type = parse_quote!(f32);
83    let mut device_type: Type = parse_quote!(Cpu);
84
85    for attr in attrs {
86        if attr.path.is_ident("parameters") {
87            let args = syn::parse2::<ParametersArgs>(attr.tokens.clone())
88                .expect("Failed to parse parameters attribute");
89            if let Some(ty) = args.num {
90                num_type = ty;
91            }
92            if let Some(ty) = args.device {
93                device_type = ty;
94            }
95        }
96    }
97
98    (num_type, device_type)
99}
100
101struct ParametersArgs {
102    num: Option<Type>,
103    device: Option<Type>,
104}
105
106impl Parse for ParametersArgs {
107    fn parse(input: ParseStream) -> syn::Result<Self> {
108        let content;
109        syn::parenthesized!(content in input);
110
111        let mut num = None;
112        let mut device = None;
113
114        while !content.is_empty() {
115            let ident: Ident = content.parse()?;
116            let _: Token![=] = content.parse()?;
117            let ty: Type = content.parse()?;
118
119            if ident == "num" {
120                num = Some(ty);
121            } else if ident == "device" {
122                device = Some(ty);
123                // } else {
124                //     return Err(syn::Error::new(
125                //         ident.span(),
126                //         "Expected 'num' or 'device' in parameters attribute",
127                //     ));
128            }
129
130            if content.peek(Comma) {
131                let _: Comma = content.parse()?;
132            } else {
133                break;
134            }
135        }
136
137        Ok(ParametersArgs { num, device })
138    }
139}