Skip to main content

burn_central_macros/
lib.rs

1//! # Burn Central Macros
2//! As define in the burn central crate documentation, this crate provide the macros to
3//! register a functions. You probably don't need more information to it then that, but if you
4//! do we got you covered.
5//!
6//! ## Role and responsablity
7//! The role of the macros is really light weight. We don't want to make the register crate the
8//! hearth of our runtime. So it simply wrap your function into another functions define in the
9//! runtime and the runtime does the rest of the magic.
10//!
11//! ## Usage
12//! To use the macros you simply need to import it from this crate and use the `register` macro
13//! to mark your training and inference functions. Here is an example:
14//! ```ignore
15//! # use burn_central_macros::register;
16//! #[register(training, name = "my_training_procedure")]
17//! async fn my_training_function() {
18//!  // Your training code here
19//! }
20//! ```
21
22use proc_macro::TokenStream;
23use proc_macro2::Ident;
24use quote::quote;
25
26use strum::Display;
27use syn::{Error, ItemFn, Meta, Path, parse_macro_input, punctuated::Punctuated, spanned::Spanned};
28
29#[derive(Eq, Hash, PartialEq, Display)]
30#[strum(serialize_all = "PascalCase")]
31enum ProcedureType {
32    Training,
33    Inference,
34}
35
36impl TryFrom<Path> for ProcedureType {
37    type Error = Error;
38
39    fn try_from(path: Path) -> Result<Self, Self::Error> {
40        match path.get_ident() {
41            Some(ident) => match ident.to_string().as_str() {
42                "training" => Ok(Self::Training),
43                "inference" => Ok(Self::Inference),
44                _ => Err(Error::new_spanned(
45                    path,
46                    "Expected `training` or `inference`",
47                )),
48            },
49            None => Err(Error::new_spanned(
50                path,
51                "Expected `training` or `inference`",
52            )),
53        }
54    }
55}
56
57fn compile_errors(errors: Vec<Error>) -> proc_macro2::TokenStream {
58    errors
59        .into_iter()
60        .map(|err| err.to_compile_error())
61        .collect()
62}
63
64fn generate_flag_register_stream(
65    item: &ItemFn,
66    builder_fn_ident: &Ident,
67    procedure_type: &ProcedureType,
68    routine_name: &syn::LitStr,
69) -> proc_macro2::TokenStream {
70    let fn_name = &item.sig.ident;
71    let builder_fn_name = &builder_fn_ident;
72    let proc_type_str = procedure_type.to_string().to_lowercase();
73
74    let serialized_fn_item = syn_serde::json::to_string(item);
75    let serialized_bytes = serialized_fn_item.as_bytes();
76    let byte_array_literal = syn::LitByteStr::new(serialized_bytes, item.span());
77
78    let const_name = Ident::new(
79        &format!(
80            "BURN_CENTRAL_FUNCTION_{}",
81            fn_name.to_string().to_uppercase()
82        ),
83        proc_macro2::Span::call_site(),
84    );
85
86    let ast_const_name = Ident::new(
87        &format!("_BURN_FUNCTION_AST_{}", fn_name.to_string().to_uppercase()),
88        proc_macro2::Span::call_site(),
89    );
90
91    quote! {
92        const _: () = {
93            #[allow(dead_code)]
94            const #const_name: &str = concat!(
95                "BCFN1|",
96                module_path!(),
97                "|",
98                stringify!(#fn_name),
99                "|",
100                stringify!(#builder_fn_name),
101                "|",
102                #routine_name,
103                "|",
104                #proc_type_str,
105                "|END"
106            );
107
108            #[allow(dead_code)]
109            const #ast_const_name: &[u8] = #byte_array_literal;
110        };
111    }
112}
113
114fn get_string_arg(
115    args: &Punctuated<Meta, syn::Token![,]>,
116    arg_name: &str,
117    errors: &mut Vec<Error>,
118) -> Option<syn::LitStr> {
119    args.iter()
120        .find(|meta| meta.path().is_ident(arg_name))
121        .and_then(|meta| match meta.require_name_value() {
122            Ok(value) => match &value.value {
123                syn::Expr::Lit(syn::ExprLit {
124                    lit: syn::Lit::Str(name),
125                    ..
126                }) => Some(name.clone()),
127                _ => {
128                    errors.push(Error::new(
129                        value.value.span(),
130                        format!("Expected a string literal for the `{arg_name}` argument."),
131                    ));
132                    None
133                }
134            },
135            Err(err) => {
136                errors.push(err);
137                None
138            }
139        })
140}
141
142fn validate_registered_name(name: &str) -> Result<(), String> {
143    if name.is_empty() {
144        return Err("Registered name cannot be empty.".to_string());
145    }
146    if name.contains(' ') {
147        return Err("Registered name cannot contain spaces.".to_string());
148    }
149    if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
150        return Err(
151            "Registered name can only contain alphanumeric characters and underscores.".to_string(),
152        );
153    }
154    Ok(())
155}
156
157#[proc_macro_attribute]
158/// Macro to register your training and inference functions.
159pub fn register(args: TokenStream, item: TokenStream) -> TokenStream {
160    let mut errors = Vec::<Error>::new();
161    let args = parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
162    let item = parse_macro_input!(item as ItemFn);
163    let fn_name = &item.sig.ident;
164
165    if args.is_empty() {
166        errors.push(Error::new(
167            args.span(),
168            "Expected one argument for the #[register] attribute. Please provide the procedure type (training or inference) as the first argument.",
169        ));
170    }
171
172    // Determine the proc type (training or inference)
173    let procedure_type = match ProcedureType::try_from(
174        args.first()
175            .expect("Should be able to get first arg.")
176            .path()
177            .clone(),
178    ) {
179        Ok(procedure_type) => procedure_type,
180        Err(err) => {
181            return err.into_compile_error().into();
182        }
183    };
184
185    if procedure_type == ProcedureType::Inference {
186        errors.push(Error::new_spanned(
187            args.first().unwrap().path(),
188            "Inference procedures are not supported yet. Please use training procedures.",
189        ));
190    }
191
192    let maybe_registered_name = get_string_arg(&args, "name", &mut errors);
193
194    if let Some(name) = &maybe_registered_name {
195        if let Err(err) = validate_registered_name(&name.value()) {
196            errors.push(Error::new_spanned(
197                name,
198                format!("Invalid registered name: {err}"),
199            ));
200        }
201    }
202
203    let builder_fn_name = syn::Ident::new(
204        &format!("__{fn_name}_builder"),
205        proc_macro2::Span::call_site(),
206    );
207
208    let registered_name_str = {
209        let name = maybe_registered_name
210            .map(|name| name.value())
211            .unwrap_or_else(|| fn_name.to_string());
212
213        syn::LitStr::new(&name, fn_name.span())
214    };
215
216    let builder_item = match procedure_type {
217        ProcedureType::Training => {
218            quote! {
219                #[doc(hidden)]
220                pub fn #builder_fn_name<B: burn::tensor::backend::AutodiffBackend>(
221                    exec: &mut burn_central::runtime::ExecutorBuilder<B>,
222                ) {
223                    exec.train(#registered_name_str, #fn_name);
224                }
225            }
226        }
227        ProcedureType::Inference => {
228            quote! {}
229        }
230    };
231
232    let flag_register = generate_flag_register_stream(
233        &item,
234        &builder_fn_name,
235        &procedure_type,
236        &registered_name_str,
237    );
238
239    let code = quote! {
240        #[allow(dead_code)]
241        #item
242
243        #flag_register
244        #builder_item
245    };
246
247    // If there are any errors, combine them and return
248    if !errors.is_empty() {
249        compile_errors(errors).into()
250    } else {
251        code.into()
252    }
253}