rai_derive/
lib.rs

1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::ToTokens;
5use syn::{spanned::Spanned, DeriveInput, Ident, Path, Type};
6
7#[derive(Debug, deluxe::ExtractAttributes)]
8#[deluxe(attributes(module))]
9#[deluxe(default)]
10struct ContainerOpts {
11    #[deluxe(rename = input)]
12    input_ty: Option<Type>,
13
14    #[deluxe(rename = output)]
15    output_ty: Option<Type>,
16
17    #[deluxe(rename = crate)]
18    crate_root: Path,
19
20    trainable: bool,
21}
22impl Default for ContainerOpts {
23    fn default() -> Self {
24        Self {
25            input_ty: None,
26            output_ty: None,
27            crate_root: syn::parse_quote!(rai),
28            trainable: true,
29        }
30    }
31}
32
33#[derive(Debug, deluxe::ParseAttributes)]
34#[deluxe(attributes(param))]
35struct FieldOpts<'t> {
36    #[deluxe(container)]
37    field: &'t syn::Field,
38    #[deluxe(default)]
39    rename: Option<String>,
40    #[deluxe(default)]
41    skip: bool,
42}
43
44#[proc_macro_derive(Module, attributes(module, param))]
45pub fn module(item: TokenStream) -> TokenStream {
46    let mut input: DeriveInput = syn::parse(item).expect("syn::parse ok");
47
48    let errors = deluxe::Errors::new();
49    let ContainerOpts {
50        input_ty,
51        output_ty,
52        crate_root,
53        trainable,
54    } = deluxe::extract_attributes_optional(&mut input, &errors);
55
56    let mut field_opts: Vec<FieldOpts> = Vec::new();
57    let mut is_unit_struct = false;
58    if let syn::Data::Struct(s) = &mut input.data {
59        match &mut s.fields {
60            syn::Fields::Named(fields) => {
61                for field in fields.named.iter_mut() {
62                    match deluxe::parse_attributes(field) {
63                        Ok(f_opts) => field_opts.push(f_opts),
64                        Err(e) => errors.push_syn(e),
65                    }
66                }
67            }
68            syn::Fields::Unit => is_unit_struct = true,
69            syn::Fields::Unnamed(_) => errors.push(Span::call_site(), "tuple is not supported"),
70        }
71    }
72    if !errors.is_empty() {
73        return errors.into_token_stream().into();
74    }
75
76    let receiver_name = &input.ident;
77    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
78
79    let input_ty = input_ty.unwrap_or_else(|| {
80        syn::parse_quote! {
81            ::#crate_root::Tensor
82        }
83    });
84    let output_ty = output_ty.unwrap_or_else(|| {
85        syn::parse_quote! {
86            ::#crate_root::Tensor
87        }
88    });
89
90    let call_fwd = match &input_ty {
91        Type::Path(_) | Type::Array(_) => {
92            quote::quote! {
93                self.fwd(input)
94            }
95        }
96        Type::Tuple(tuple) => {
97            let args: Vec<_> = tuple
98                .elems
99                .iter()
100                .enumerate()
101                .map(|(i, t)| {
102                    let arg = Ident::new(&format!("a{i}"), t.span());
103                    quote::quote! {
104                        #arg
105                    }
106                })
107                .collect();
108
109            quote::quote! {
110                let (#(#args,)*) = input;
111                self.fwd(#(::#crate_root::nn::ToApplyArg::to_arg(#args),)*)
112            }
113        }
114        _ => panic!("unsupported module input type"),
115    };
116
117    let module_impls = if is_unit_struct || !trainable {
118        quote::quote! {
119            impl #impl_generics ::#crate_root::nn::Module for #receiver_name #type_generics #where_clause {
120                type Input = #input_ty;
121                type Output = #output_ty;
122
123                #[inline]
124                fn forward(&self, input: &Self::Input) -> Self::Output {
125                    #call_fwd
126                }
127                fn gather_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {}
128                fn update_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {}
129                fn gather_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {}
130                fn update_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {}
131            }
132
133            impl #impl_generics ::#crate_root::ValueSpec for #receiver_name #type_generics #where_clause {
134                type Kind = ::#crate_root::ty_kind::Module;
135                type Tensors = ();
136                type Gradient = ();
137            }
138
139            impl #impl_generics ::#crate_root::nn::NonTrainableModule for #receiver_name #type_generics #where_clause {}
140        }
141    } else {
142        let update_params: Vec<_> = field_opts
143            .iter()
144            .filter(|f| !f.skip)
145            .map(|f| {
146                let field_name = f.field.ident.as_ref().unwrap();
147                quote::quote! {
148                    ::#crate_root::nn::WithParams::update_by_id(&self.#field_name, params);
149                }
150            })
151            .collect();
152
153        let gather_params: Vec<_> = field_opts
154            .iter()
155            .filter(|f| !f.skip)
156            .map(|f| {
157                let field_name = f.field.ident.as_ref().unwrap();
158                quote::quote! {
159                    ::#crate_root::nn::WithParams::gather_by_id(&self.#field_name, params);
160                }
161            })
162            .collect();
163
164        let update_named_params: Vec<_> = field_opts
165            .iter()
166            .filter(|f| !f.skip)
167            .map(|f| {
168                let field_name = f.field.ident.as_ref().unwrap();
169                let f_name = field_name.to_string();
170                let param_name = f.rename.as_ref().unwrap_or(&f_name);
171                quote::quote! {
172                    ::#crate_root::nn::WithParams::update_by_name(&self.#field_name, params, prefix, #param_name);
173                }
174            })
175            .collect();
176
177        let gather_named_params: Vec<_> = field_opts
178            .iter()
179            .filter(|f| !f.skip)
180            .map(|f| {
181                let field_name = f.field.ident.as_ref().unwrap();
182                let f_name = field_name.to_string();
183                let param_name = f.rename.as_ref().unwrap_or(&f_name);
184                quote::quote! {
185                    ::#crate_root::nn::WithParams::gather_by_name(&self.#field_name, params, prefix, #param_name);
186                }
187            })
188            .collect();
189
190        quote::quote! {
191            impl #impl_generics ::#crate_root::nn::Module for #receiver_name #type_generics #where_clause {
192                type Input = #input_ty;
193                type Output = #output_ty;
194
195                #[inline]
196                fn forward(&self, input: &Self::Input) -> Self::Output {
197                    #call_fwd
198                }
199
200                fn gather_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {
201                    #(#gather_params)*
202                }
203
204                fn update_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {
205                    #(#update_params)*
206                }
207
208                fn gather_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {
209                    #(#gather_named_params)*
210                }
211
212                fn update_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {
213                    #(#update_named_params)*
214                }
215            }
216
217            impl #impl_generics ::#crate_root::ValueSpec for #receiver_name #type_generics #where_clause {
218                type Kind = ::#crate_root::ty_kind::Module;
219                type Tensors = std::collections::HashMap<usize, Tensor>;
220                type Gradient = std::collections::HashMap<usize, Tensor>;
221            }
222
223            impl #impl_generics ::#crate_root::nn::TrainableModule for #receiver_name #type_generics #where_clause {}
224        }
225    };
226
227    module_impls.into()
228}