burn_central_macros/
lib.rs1use proc_macro::TokenStream;
23use proc_macro2::Ident;
24use quote::quote;
25
26use strum::Display;
27use syn::{Error, ItemFn, Meta, Path, parse_macro_input, punctuated::Punctuated, spanned::Spanned};
28
29#[derive(Eq, Hash, PartialEq, Display)]
30#[strum(serialize_all = "PascalCase")]
31enum ProcedureType {
32 Training,
33 Inference,
34}
35
36impl TryFrom<Path> for ProcedureType {
37 type Error = Error;
38
39 fn try_from(path: Path) -> Result<Self, Self::Error> {
40 match path.get_ident() {
41 Some(ident) => match ident.to_string().as_str() {
42 "training" => Ok(Self::Training),
43 "inference" => Ok(Self::Inference),
44 _ => Err(Error::new_spanned(
45 path,
46 "Expected `training` or `inference`",
47 )),
48 },
49 None => Err(Error::new_spanned(
50 path,
51 "Expected `training` or `inference`",
52 )),
53 }
54 }
55}
56
57fn compile_errors(errors: Vec<Error>) -> proc_macro2::TokenStream {
58 errors
59 .into_iter()
60 .map(|err| err.to_compile_error())
61 .collect()
62}
63
64fn generate_flag_register_stream(
65 item: &ItemFn,
66 builder_fn_ident: &Ident,
67 procedure_type: &ProcedureType,
68 routine_name: &syn::LitStr,
69) -> proc_macro2::TokenStream {
70 let fn_name = &item.sig.ident;
71 let builder_fn_name = &builder_fn_ident;
72 let proc_type_str = procedure_type.to_string().to_lowercase();
73
74 let serialized_fn_item = syn_serde::json::to_string(item);
75 let serialized_bytes = serialized_fn_item.as_bytes();
76 let byte_array_literal = syn::LitByteStr::new(serialized_bytes, item.span());
77
78 let const_name = Ident::new(
79 &format!(
80 "BURN_CENTRAL_FUNCTION_{}",
81 fn_name.to_string().to_uppercase()
82 ),
83 proc_macro2::Span::call_site(),
84 );
85
86 let ast_const_name = Ident::new(
87 &format!("_BURN_FUNCTION_AST_{}", fn_name.to_string().to_uppercase()),
88 proc_macro2::Span::call_site(),
89 );
90
91 quote! {
92 const _: () = {
93 #[allow(dead_code)]
94 const #const_name: &str = concat!(
95 "BCFN1|",
96 module_path!(),
97 "|",
98 stringify!(#fn_name),
99 "|",
100 stringify!(#builder_fn_name),
101 "|",
102 #routine_name,
103 "|",
104 #proc_type_str,
105 "|END"
106 );
107
108 #[allow(dead_code)]
109 const #ast_const_name: &[u8] = #byte_array_literal;
110 };
111 }
112}
113
114fn get_string_arg(
115 args: &Punctuated<Meta, syn::Token![,]>,
116 arg_name: &str,
117 errors: &mut Vec<Error>,
118) -> Option<syn::LitStr> {
119 args.iter()
120 .find(|meta| meta.path().is_ident(arg_name))
121 .and_then(|meta| match meta.require_name_value() {
122 Ok(value) => match &value.value {
123 syn::Expr::Lit(syn::ExprLit {
124 lit: syn::Lit::Str(name),
125 ..
126 }) => Some(name.clone()),
127 _ => {
128 errors.push(Error::new(
129 value.value.span(),
130 format!("Expected a string literal for the `{arg_name}` argument."),
131 ));
132 None
133 }
134 },
135 Err(err) => {
136 errors.push(err);
137 None
138 }
139 })
140}
141
142fn validate_registered_name(name: &str) -> Result<(), String> {
143 if name.is_empty() {
144 return Err("Registered name cannot be empty.".to_string());
145 }
146 if name.contains(' ') {
147 return Err("Registered name cannot contain spaces.".to_string());
148 }
149 if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
150 return Err(
151 "Registered name can only contain alphanumeric characters and underscores.".to_string(),
152 );
153 }
154 Ok(())
155}
156
157#[proc_macro_attribute]
158pub fn register(args: TokenStream, item: TokenStream) -> TokenStream {
160 let mut errors = Vec::<Error>::new();
161 let args = parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
162 let item = parse_macro_input!(item as ItemFn);
163 let fn_name = &item.sig.ident;
164
165 if args.is_empty() {
166 errors.push(Error::new(
167 args.span(),
168 "Expected one argument for the #[register] attribute. Please provide the procedure type (training or inference) as the first argument.",
169 ));
170 }
171
172 let procedure_type = match ProcedureType::try_from(
174 args.first()
175 .expect("Should be able to get first arg.")
176 .path()
177 .clone(),
178 ) {
179 Ok(procedure_type) => procedure_type,
180 Err(err) => {
181 return err.into_compile_error().into();
182 }
183 };
184
185 if procedure_type == ProcedureType::Inference {
186 errors.push(Error::new_spanned(
187 args.first().unwrap().path(),
188 "Inference procedures are not supported yet. Please use training procedures.",
189 ));
190 }
191
192 let maybe_registered_name = get_string_arg(&args, "name", &mut errors);
193
194 if let Some(name) = &maybe_registered_name {
195 if let Err(err) = validate_registered_name(&name.value()) {
196 errors.push(Error::new_spanned(
197 name,
198 format!("Invalid registered name: {err}"),
199 ));
200 }
201 }
202
203 let builder_fn_name = syn::Ident::new(
204 &format!("__{fn_name}_builder"),
205 proc_macro2::Span::call_site(),
206 );
207
208 let registered_name_str = {
209 let name = maybe_registered_name
210 .map(|name| name.value())
211 .unwrap_or_else(|| fn_name.to_string());
212
213 syn::LitStr::new(&name, fn_name.span())
214 };
215
216 let builder_item = match procedure_type {
217 ProcedureType::Training => {
218 quote! {
219 #[doc(hidden)]
220 pub fn #builder_fn_name<B: burn::tensor::backend::AutodiffBackend>(
221 exec: &mut burn_central::runtime::ExecutorBuilder<B>,
222 ) {
223 exec.train(#registered_name_str, #fn_name);
224 }
225 }
226 }
227 ProcedureType::Inference => {
228 quote! {}
229 }
230 };
231
232 let flag_register = generate_flag_register_stream(
233 &item,
234 &builder_fn_name,
235 &procedure_type,
236 ®istered_name_str,
237 );
238
239 let code = quote! {
240 #[allow(dead_code)]
241 #item
242
243 #flag_register
244 #builder_item
245 };
246
247 if !errors.is_empty() {
249 compile_errors(errors).into()
250 } else {
251 code.into()
252 }
253}