bolt_attribute_bolt_system/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{
5    parse_macro_input, parse_quote, visit_mut::VisitMut, Expr, FnArg, GenericArgument, ItemFn,
6    ItemMod, ItemStruct, PathArguments, ReturnType, Stmt, Type, TypePath,
7};
8
9#[derive(Default)]
10struct SystemTransform;
11
12#[derive(Default)]
13struct Extractor {
14    context_struct_name: Option<String>,
15    field_count: Option<usize>,
16}
17
18/// This macro attribute is used to define a BOLT system.
19///
20/// Bolt components are themselves programs. The macro adds parsing and serialization
21///
22/// # Example
23/// ```ignore
24/// #[system]
25/// pub mod system_fly {
26///     pub fn execute(ctx: Context<Component>, _args: Vec<u8>) -> Result<Position> {
27///         let pos = Position {
28///             x: ctx.accounts.position.x,
29///             y: ctx.accounts.position.y,
30///             z: ctx.accounts.position.z + 1,
31///         };
32///         Ok(pos)
33///     }
34/// }
35/// ```
36#[proc_macro_attribute]
37pub fn system(_attr: TokenStream, item: TokenStream) -> TokenStream {
38    let mut ast = parse_macro_input!(item as ItemMod);
39
40    // Extract the number of components from the module
41    let mut extractor = Extractor::default();
42    extractor.visit_item_mod_mut(&mut ast);
43
44    if extractor.field_count.is_some() {
45        let use_super = syn::parse_quote! { use super::*; };
46        if let Some((_, ref mut items)) = ast.content {
47            items.insert(0, syn::Item::Use(use_super));
48            SystemTransform::add_variadic_execute_function(items);
49        }
50
51        let mut transform = SystemTransform;
52        transform.visit_item_mod_mut(&mut ast);
53
54        // Add `#[program]` macro and try_to_vec implementation
55        let expanded = quote! {
56            #[program]
57            #ast
58        };
59
60        TokenStream::from(expanded)
61    } else {
62        panic!(
63            "Could not find the component bundle: {} in the module",
64            extractor.context_struct_name.unwrap()
65        );
66    }
67}
68
69impl SystemTransform {
70    fn visit_stmts_mut(&mut self, stmts: &mut Vec<Stmt>) {
71        for stmt in stmts {
72            if let Stmt::Expr(ref mut expr) | Stmt::Semi(ref mut expr, _) = stmt {
73                self.visit_expr_mut(expr);
74            }
75        }
76    }
77}
78
79/// Visits the AST and modifies the system function
80impl VisitMut for SystemTransform {
81    // Modify the return instruction to return Result<Vec<u8>>
82    fn visit_expr_mut(&mut self, expr: &mut Expr) {
83        match expr {
84            Expr::ForLoop(for_loop_expr) => {
85                self.visit_stmts_mut(&mut for_loop_expr.body.stmts);
86            }
87            Expr::Loop(loop_expr) => {
88                self.visit_stmts_mut(&mut loop_expr.body.stmts);
89            }
90            Expr::If(if_expr) => {
91                self.visit_stmts_mut(&mut if_expr.then_branch.stmts);
92                if let Some((_, else_expr)) = &mut if_expr.else_branch {
93                    self.visit_expr_mut(else_expr);
94                }
95            }
96            Expr::Block(block_expr) => {
97                self.visit_stmts_mut(&mut block_expr.block.stmts);
98            }
99            _ => (),
100        }
101        if let Some(inner_variable) = Self::extract_inner_ok_expression(expr) {
102            let new_return_expr: Expr = match inner_variable {
103                Expr::Tuple(tuple_expr) => {
104                    let tuple_elements = tuple_expr.elems.iter().map(|elem| {
105                        quote! { (#elem).try_to_vec()? }
106                    });
107                    parse_quote! { Ok((#(#tuple_elements),*)) }
108                }
109                _ => {
110                    parse_quote! {
111                        #inner_variable.try_to_vec()
112                    }
113                }
114            };
115            if let Expr::Return(return_expr) = expr {
116                return_expr.expr = Some(Box::new(new_return_expr));
117            } else {
118                *expr = new_return_expr;
119            }
120        }
121    }
122
123    // Modify the return type of the system function to Result<Vec<u8>,*>
124    fn visit_item_fn_mut(&mut self, item_fn: &mut ItemFn) {
125        if item_fn.sig.ident == "execute" {
126            // Ensure execute has lifetimes and a fully-qualified Context
127            Self::inject_lifetimes_and_context(item_fn);
128            // Modify the return type to Result<Vec<u8>> if necessary
129            if let ReturnType::Type(_, type_box) = &item_fn.sig.output {
130                if let Type::Path(type_path) = &**type_box {
131                    if !Self::check_is_result_vec_u8(type_path) {
132                        item_fn.sig.output = parse_quote! { -> Result<Vec<Vec<u8>>> };
133                        // Modify the return statement inside the function body
134                        let block = &mut item_fn.block;
135                        self.visit_stmts_mut(&mut block.stmts);
136                    }
137                }
138            }
139            // If second argument is not Vec<u8>, modify it to be so and use parse_args
140            Self::modify_args(item_fn);
141        }
142    }
143
144    // Visit all the functions inside the system module and inject the init_extra_accounts function
145    // if the module contains a struct with the `extra_accounts` attribute
146    fn visit_item_mod_mut(&mut self, item_mod: &mut ItemMod) {
147        let content = match item_mod.content.as_mut() {
148            Some(content) => &mut content.1,
149            None => return,
150        };
151
152        let mut extra_accounts_struct_name = None;
153
154        for item in content.iter_mut() {
155            match item {
156                syn::Item::Fn(item_fn) => self.visit_item_fn_mut(item_fn),
157                syn::Item::Struct(item_struct) => {
158                    if let Some(attr) = item_struct
159                        .attrs
160                        .iter_mut()
161                        .find(|attr| attr.path.is_ident("system_input"))
162                    {
163                        attr.tokens.append_all(quote! { (session_key) });
164                    }
165                    if item_struct
166                        .attrs
167                        .iter()
168                        .any(|attr| attr.path.is_ident("extra_accounts"))
169                    {
170                        extra_accounts_struct_name = Some(&item_struct.ident);
171                        break;
172                    }
173                }
174                _ => {}
175            }
176        }
177
178        if let Some(struct_name) = extra_accounts_struct_name {
179            let initialize_extra_accounts = quote! {
180            #[automatically_derived]
181                pub fn init_extra_accounts(_ctx: Context<#struct_name>) -> Result<()> {
182                    Ok(())
183                }
184            };
185            content.push(syn::parse2(initialize_extra_accounts).unwrap());
186        }
187    }
188}
189
190impl SystemTransform {
191    fn inject_lifetimes_and_context(item_fn: &mut ItemFn) {
192        // Add lifetimes <'a, 'b, 'c, 'info> if missing
193        let lifetime_idents = ["a", "b", "c", "info"];
194        for name in lifetime_idents.iter() {
195            let exists = item_fn.sig.generics.params.iter().any(|p| match p {
196                syn::GenericParam::Lifetime(l) => l.lifetime.ident == *name,
197                _ => false,
198            });
199            if !exists {
200                let lifetime: syn::Lifetime =
201                    syn::parse_str(&format!("'{}", name)).expect("valid lifetime");
202                let gp: syn::GenericParam = syn::parse_quote!(#lifetime);
203                item_fn.sig.generics.params.push(gp);
204            }
205        }
206
207        // Update the first argument type from Context<Components> to Context<'a, 'b, 'c, 'info, Components<'info>>
208        if let Some(FnArg::Typed(pat_type)) = item_fn.sig.inputs.first_mut() {
209            if let Type::Path(type_path) = pat_type.ty.as_mut() {
210                if let Some(last_segment) = type_path.path.segments.last_mut() {
211                    if last_segment.ident == "Context" {
212                        // Extract Components path from existing generic args (if any)
213                        let mut components_ty_opt: Option<Type> = None;
214                        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
215                            for ga in args.args.iter() {
216                                if let GenericArgument::Type(t) = ga {
217                                    components_ty_opt = Some(t.clone());
218                                    break;
219                                }
220                            }
221                        }
222
223                        // If not found, leave early
224                        if let Some(components_ty) = components_ty_opt {
225                            // Ensure Components<'info>
226                            let components_with_info: Type = match components_ty {
227                                Type::Path(mut tp) => {
228                                    let seg = tp.path.segments.last_mut().unwrap();
229                                    match &mut seg.arguments {
230                                        PathArguments::AngleBracketed(ab) => {
231                                            if ab.args.is_empty() {
232                                                ab.args.push(GenericArgument::Lifetime(
233                                                    syn::parse_quote!('info),
234                                                ));
235                                            }
236                                        }
237                                        _ => {
238                                            seg.arguments = PathArguments::AngleBracketed(
239                                                syn::AngleBracketedGenericArguments {
240                                                    colon2_token: None,
241                                                    lt_token: Default::default(),
242                                                    args: std::iter::once(
243                                                        GenericArgument::Lifetime(
244                                                            syn::parse_quote!('info),
245                                                        ),
246                                                    )
247                                                    .collect(),
248                                                    gt_token: Default::default(),
249                                                },
250                                            );
251                                        }
252                                    }
253                                    Type::Path(tp)
254                                }
255                                other => other,
256                            };
257
258                            // Build new Context<'a, 'b, 'c, 'info, Components<'info>> type
259                            let new_ty: Type = syn::parse_quote! {
260                                Context<'a, 'b, 'c, 'info, #components_with_info>
261                            };
262                            pat_type.ty = Box::new(new_ty);
263                        }
264                    }
265                }
266            }
267        }
268    }
269    fn add_variadic_execute_function(content: &mut Vec<syn::Item>) {
270        content.push(syn::parse2(quote! {
271            pub fn bolt_execute<'a, 'b, 'info>(ctx: Context<'a, 'b, 'info, 'info, VariadicBoltComponents<'info>>, args: Vec<u8>) -> Result<Vec<Vec<u8>>> {
272                let mut components = Components::try_from(&ctx)?;
273                let bumps = ComponentsBumps {};
274                let context = Context::new(ctx.program_id, &mut components, ctx.remaining_accounts, bumps);
275                execute(context, args)
276            }
277        }).unwrap());
278    }
279
280    // Helper function to check if a type is `Vec<u8>` or `(Vec<u8>, Vec<u8>, ...)`
281    fn check_is_result_vec_u8(ty: &TypePath) -> bool {
282        if let Some(segment) = ty.path.segments.last() {
283            if segment.ident == "Result" {
284                if let PathArguments::AngleBracketed(args) = &segment.arguments {
285                    if let Some(GenericArgument::Type(Type::Tuple(tuple))) = args.args.first() {
286                        return tuple.elems.iter().all(|elem| {
287                            if let Type::Path(type_path) = elem {
288                                if let Some(segment) = type_path.path.segments.first() {
289                                    return segment.ident == "Vec" && Self::is_u8_vec(segment);
290                                }
291                            }
292                            false
293                        });
294                    } else if let Some(GenericArgument::Type(Type::Path(type_path))) =
295                        args.args.first()
296                    {
297                        if let Some(segment) = type_path.path.segments.first() {
298                            return segment.ident == "Vec" && Self::is_u8_vec(segment);
299                        }
300                    }
301                }
302            }
303        }
304        false
305    }
306
307    // Helper function to check if a type is Vec<u8>
308    fn is_u8_vec(segment: &syn::PathSegment) -> bool {
309        if let PathArguments::AngleBracketed(args) = &segment.arguments {
310            if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first() {
311                if let Some(segment) = path.path.segments.first() {
312                    return segment.ident == "u8";
313                }
314            }
315        }
316        false
317    }
318
319    // Helper function to check if an expression is an `Ok(...)` or `return Ok(...);` variant
320    fn extract_inner_ok_expression(expr: &Expr) -> Option<&Expr> {
321        match expr {
322            Expr::Call(expr_call) => {
323                // Direct `Ok(...)` call
324                if let Expr::Path(expr_path) = &*expr_call.func {
325                    if let Some(last_segment) = expr_path.path.segments.last() {
326                        if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
327                            // Return the first argument of the Ok(...) call
328                            return expr_call.args.first();
329                        }
330                    }
331                }
332            }
333            Expr::Return(expr_return) => {
334                // `return Ok(...);`
335                if let Some(expr_return_inner) = &expr_return.expr {
336                    if let Expr::Call(expr_call) = expr_return_inner.as_ref() {
337                        if let Expr::Path(expr_path) = &*expr_call.func {
338                            if let Some(last_segment) = expr_path.path.segments.last() {
339                                if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
340                                    // Return the first argument of the return Ok(...) call
341                                    return expr_call.args.first();
342                                }
343                            }
344                        }
345                    }
346                }
347            }
348            _ => {}
349        }
350        None
351    }
352
353    fn modify_args(item_fn: &mut ItemFn) {
354        if item_fn.sig.inputs.len() >= 2 {
355            let second_arg = &mut item_fn.sig.inputs[1];
356            let is_vec_u8 = if let FnArg::Typed(syn::PatType { ty, .. }) = second_arg {
357                match &**ty {
358                    Type::Path(type_path) => {
359                        if let Some(segment) = type_path.path.segments.first() {
360                            segment.ident == "Vec" && Self::is_u8_vec(segment)
361                        } else {
362                            false
363                        }
364                    }
365                    _ => false,
366                }
367            } else {
368                false
369            };
370            if !is_vec_u8 {
371                if let FnArg::Typed(pat_type) = second_arg {
372                    let original_type = pat_type.ty.to_token_stream();
373                    let arg_original_name = pat_type.pat.to_token_stream();
374                    if let syn::Pat::Ident(ref mut pat_ident) = *pat_type.pat {
375                        let new_ident_name = format!("_{}", pat_ident.ident);
376                        pat_ident.ident =
377                            Ident::new(&new_ident_name, proc_macro2::Span::call_site());
378                    }
379                    let arg_name = pat_type.pat.to_token_stream();
380                    pat_type.ty = Box::new(syn::parse_quote! { Vec<u8> });
381                    let parse_stmt: Stmt = parse_quote! {
382                        let #arg_original_name = parse_args::<#original_type>(&#arg_name);
383                    };
384                    item_fn.block.stmts.insert(0, parse_stmt);
385                }
386            }
387        }
388    }
389}
390
391/// Visits the AST to extract the number of input components
392impl VisitMut for Extractor {
393    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
394        for input in &i.sig.inputs {
395            if let FnArg::Typed(pat_type) = input {
396                if let Type::Path(type_path) = &*pat_type.ty {
397                    let last_segment = type_path.path.segments.last().unwrap();
398                    if last_segment.ident == "Context" {
399                        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
400                            // Find the first generic argument that is a Type::Path (e.g., Components)
401                            for ga in args.args.iter() {
402                                if let syn::GenericArgument::Type(syn::Type::Path(type_path)) = ga {
403                                    if let Some(first_seg) = type_path.path.segments.first() {
404                                        self.context_struct_name =
405                                            Some(first_seg.ident.to_string());
406                                        break;
407                                    }
408                                }
409                            }
410                        }
411                    }
412                }
413            }
414        }
415    }
416
417    fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
418        if let Some(name) = &self.context_struct_name {
419            if i.ident == name {
420                self.field_count = Some(i.fields.len());
421            }
422        }
423    }
424}