1extern crate proc_macro;
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::ToTokens;
5use syn::{spanned::Spanned, DeriveInput, Ident, Path, Type};
6
7#[derive(Debug, deluxe::ExtractAttributes)]
8#[deluxe(attributes(module))]
9#[deluxe(default)]
10struct ContainerOpts {
11 #[deluxe(rename = input)]
12 input_ty: Option<Type>,
13
14 #[deluxe(rename = output)]
15 output_ty: Option<Type>,
16
17 #[deluxe(rename = crate)]
18 crate_root: Path,
19
20 trainable: bool,
21}
22impl Default for ContainerOpts {
23 fn default() -> Self {
24 Self {
25 input_ty: None,
26 output_ty: None,
27 crate_root: syn::parse_quote!(rai),
28 trainable: true,
29 }
30 }
31}
32
33#[derive(Debug, deluxe::ParseAttributes)]
34#[deluxe(attributes(param))]
35struct FieldOpts<'t> {
36 #[deluxe(container)]
37 field: &'t syn::Field,
38 #[deluxe(default)]
39 rename: Option<String>,
40 #[deluxe(default)]
41 skip: bool,
42}
43
44#[proc_macro_derive(Module, attributes(module, param))]
45pub fn module(item: TokenStream) -> TokenStream {
46 let mut input: DeriveInput = syn::parse(item).expect("syn::parse ok");
47
48 let errors = deluxe::Errors::new();
49 let ContainerOpts {
50 input_ty,
51 output_ty,
52 crate_root,
53 trainable,
54 } = deluxe::extract_attributes_optional(&mut input, &errors);
55
56 let mut field_opts: Vec<FieldOpts> = Vec::new();
57 let mut is_unit_struct = false;
58 if let syn::Data::Struct(s) = &mut input.data {
59 match &mut s.fields {
60 syn::Fields::Named(fields) => {
61 for field in fields.named.iter_mut() {
62 match deluxe::parse_attributes(field) {
63 Ok(f_opts) => field_opts.push(f_opts),
64 Err(e) => errors.push_syn(e),
65 }
66 }
67 }
68 syn::Fields::Unit => is_unit_struct = true,
69 syn::Fields::Unnamed(_) => errors.push(Span::call_site(), "tuple is not supported"),
70 }
71 }
72 if !errors.is_empty() {
73 return errors.into_token_stream().into();
74 }
75
76 let receiver_name = &input.ident;
77 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
78
79 let input_ty = input_ty.unwrap_or_else(|| {
80 syn::parse_quote! {
81 ::#crate_root::Tensor
82 }
83 });
84 let output_ty = output_ty.unwrap_or_else(|| {
85 syn::parse_quote! {
86 ::#crate_root::Tensor
87 }
88 });
89
90 let call_fwd = match &input_ty {
91 Type::Path(_) | Type::Array(_) => {
92 quote::quote! {
93 self.fwd(input)
94 }
95 }
96 Type::Tuple(tuple) => {
97 let args: Vec<_> = tuple
98 .elems
99 .iter()
100 .enumerate()
101 .map(|(i, t)| {
102 let arg = Ident::new(&format!("a{i}"), t.span());
103 quote::quote! {
104 #arg
105 }
106 })
107 .collect();
108
109 quote::quote! {
110 let (#(#args,)*) = input;
111 self.fwd(#(::#crate_root::nn::ToApplyArg::to_arg(#args),)*)
112 }
113 }
114 _ => panic!("unsupported module input type"),
115 };
116
117 let module_impls = if is_unit_struct || !trainable {
118 quote::quote! {
119 impl #impl_generics ::#crate_root::nn::Module for #receiver_name #type_generics #where_clause {
120 type Input = #input_ty;
121 type Output = #output_ty;
122
123 #[inline]
124 fn forward(&self, input: &Self::Input) -> Self::Output {
125 #call_fwd
126 }
127 fn gather_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {}
128 fn update_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {}
129 fn gather_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {}
130 fn update_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {}
131 }
132
133 impl #impl_generics ::#crate_root::ValueSpec for #receiver_name #type_generics #where_clause {
134 type Kind = ::#crate_root::ty_kind::Module;
135 type Tensors = ();
136 type Gradient = ();
137 }
138
139 impl #impl_generics ::#crate_root::nn::NonTrainableModule for #receiver_name #type_generics #where_clause {}
140 }
141 } else {
142 let update_params: Vec<_> = field_opts
143 .iter()
144 .filter(|f| !f.skip)
145 .map(|f| {
146 let field_name = f.field.ident.as_ref().unwrap();
147 quote::quote! {
148 ::#crate_root::nn::WithParams::update_by_id(&self.#field_name, params);
149 }
150 })
151 .collect();
152
153 let gather_params: Vec<_> = field_opts
154 .iter()
155 .filter(|f| !f.skip)
156 .map(|f| {
157 let field_name = f.field.ident.as_ref().unwrap();
158 quote::quote! {
159 ::#crate_root::nn::WithParams::gather_by_id(&self.#field_name, params);
160 }
161 })
162 .collect();
163
164 let update_named_params: Vec<_> = field_opts
165 .iter()
166 .filter(|f| !f.skip)
167 .map(|f| {
168 let field_name = f.field.ident.as_ref().unwrap();
169 let f_name = field_name.to_string();
170 let param_name = f.rename.as_ref().unwrap_or(&f_name);
171 quote::quote! {
172 ::#crate_root::nn::WithParams::update_by_name(&self.#field_name, params, prefix, #param_name);
173 }
174 })
175 .collect();
176
177 let gather_named_params: Vec<_> = field_opts
178 .iter()
179 .filter(|f| !f.skip)
180 .map(|f| {
181 let field_name = f.field.ident.as_ref().unwrap();
182 let f_name = field_name.to_string();
183 let param_name = f.rename.as_ref().unwrap_or(&f_name);
184 quote::quote! {
185 ::#crate_root::nn::WithParams::gather_by_name(&self.#field_name, params, prefix, #param_name);
186 }
187 })
188 .collect();
189
190 quote::quote! {
191 impl #impl_generics ::#crate_root::nn::Module for #receiver_name #type_generics #where_clause {
192 type Input = #input_ty;
193 type Output = #output_ty;
194
195 #[inline]
196 fn forward(&self, input: &Self::Input) -> Self::Output {
197 #call_fwd
198 }
199
200 fn gather_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {
201 #(#gather_params)*
202 }
203
204 fn update_params(&self, params: &mut std::collections::HashMap<usize, ::#crate_root::Tensor>) {
205 #(#update_params)*
206 }
207
208 fn gather_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {
209 #(#gather_named_params)*
210 }
211
212 fn update_named_params(&self, prefix: &str, params: &mut std::collections::HashMap<String, ::#crate_root::Tensor>) {
213 #(#update_named_params)*
214 }
215 }
216
217 impl #impl_generics ::#crate_root::ValueSpec for #receiver_name #type_generics #where_clause {
218 type Kind = ::#crate_root::ty_kind::Module;
219 type Tensors = std::collections::HashMap<usize, Tensor>;
220 type Gradient = std::collections::HashMap<usize, Tensor>;
221 }
222
223 impl #impl_generics ::#crate_root::nn::TrainableModule for #receiver_name #type_generics #where_clause {}
224 }
225 };
226
227 module_impls.into()
228}