bolt_attribute_bolt_system/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{
5    parse_macro_input, parse_quote, visit_mut::VisitMut, Expr, FnArg, GenericArgument, ItemFn,
6    ItemMod, ItemStruct, PathArguments, ReturnType, Stmt, Type, TypePath,
7};
8
9#[derive(Default)]
10struct SystemTransform;
11
12#[derive(Default)]
13struct Extractor {
14    context_struct_name: Option<String>,
15    field_count: Option<usize>,
16}
17
18/// This macro attribute is used to define a BOLT system.
19///
20/// Bolt components are themselves programs. The macro adds parsing and serialization
21///
22/// # Example
23/// ```ignore
24/// #[system]
25/// pub mod system_fly {
26///     pub fn execute(ctx: Context<Component>, _args: Vec<u8>) -> Result<Position> {
27///         let pos = Position {
28///             x: ctx.accounts.position.x,
29///             y: ctx.accounts.position.y,
30///             z: ctx.accounts.position.z + 1,
31///         };
32///         Ok(pos)
33///     }
34/// }
35/// ```
36#[proc_macro_attribute]
37pub fn system(_attr: TokenStream, item: TokenStream) -> TokenStream {
38    let mut ast = parse_macro_input!(item as ItemMod);
39
40    // Extract the number of components from the module
41    let mut extractor = Extractor::default();
42    extractor.visit_item_mod_mut(&mut ast);
43
44    if extractor.field_count.is_some() {
45        let use_super = syn::parse_quote! { use super::*; };
46        if let Some((_, ref mut items)) = ast.content {
47            items.insert(0, syn::Item::Use(use_super));
48            SystemTransform::add_variadic_execute_function(items);
49        }
50
51        let mut transform = SystemTransform;
52        transform.visit_item_mod_mut(&mut ast);
53
54        // Add `#[program]` macro and try_to_vec implementation
55        let expanded = quote! {
56            #[program]
57            #ast
58        };
59
60        TokenStream::from(expanded)
61    } else {
62        panic!(
63            "Could not find the component bundle: {} in the module",
64            extractor.context_struct_name.unwrap()
65        );
66    }
67}
68
69impl SystemTransform {
70    fn visit_stmts_mut(&mut self, stmts: &mut Vec<Stmt>) {
71        for stmt in stmts {
72            if let Stmt::Expr(ref mut expr) | Stmt::Semi(ref mut expr, _) = stmt {
73                self.visit_expr_mut(expr);
74            }
75        }
76    }
77}
78
79/// Visits the AST and modifies the system function
80impl VisitMut for SystemTransform {
81    // Modify the return instruction to return Result<Vec<u8>>
82    fn visit_expr_mut(&mut self, expr: &mut Expr) {
83        match expr {
84            Expr::ForLoop(for_loop_expr) => {
85                self.visit_stmts_mut(&mut for_loop_expr.body.stmts);
86            }
87            Expr::Loop(loop_expr) => {
88                self.visit_stmts_mut(&mut loop_expr.body.stmts);
89            }
90            Expr::If(if_expr) => {
91                self.visit_stmts_mut(&mut if_expr.then_branch.stmts);
92                if let Some((_, else_expr)) = &mut if_expr.else_branch {
93                    self.visit_expr_mut(else_expr);
94                }
95            }
96            Expr::Block(block_expr) => {
97                self.visit_stmts_mut(&mut block_expr.block.stmts);
98            }
99            _ => (),
100        }
101        if let Some(inner_variable) = Self::extract_inner_ok_expression(expr) {
102            let new_return_expr: Expr = match inner_variable {
103                Expr::Tuple(tuple_expr) => {
104                    let tuple_elements = tuple_expr.elems.iter().map(|elem| {
105                        quote! { (#elem).try_to_vec()? }
106                    });
107                    parse_quote! { Ok((#(#tuple_elements),*)) }
108                }
109                _ => {
110                    parse_quote! {
111                        #inner_variable.try_to_vec()
112                    }
113                }
114            };
115            if let Expr::Return(return_expr) = expr {
116                return_expr.expr = Some(Box::new(new_return_expr));
117            } else {
118                *expr = new_return_expr;
119            }
120        }
121    }
122
123    // Modify the return type of the system function to Result<Vec<u8>,*>
124    fn visit_item_fn_mut(&mut self, item_fn: &mut ItemFn) {
125        if item_fn.sig.ident == "execute" {
126            // Modify the return type to Result<Vec<u8>> if necessary
127            if let ReturnType::Type(_, type_box) = &item_fn.sig.output {
128                if let Type::Path(type_path) = &**type_box {
129                    if !Self::check_is_result_vec_u8(type_path) {
130                        item_fn.sig.output = parse_quote! { -> Result<Vec<Vec<u8>>> };
131                        // Modify the return statement inside the function body
132                        let block = &mut item_fn.block;
133                        self.visit_stmts_mut(&mut block.stmts);
134                    }
135                }
136            }
137            // If second argument is not Vec<u8>, modify it to be so and use parse_args
138            Self::modify_args(item_fn);
139        }
140    }
141
142    // Visit all the functions inside the system module and inject the init_extra_accounts function
143    // if the module contains a struct with the `extra_accounts` attribute
144    fn visit_item_mod_mut(&mut self, item_mod: &mut ItemMod) {
145        let content = match item_mod.content.as_mut() {
146            Some(content) => &mut content.1,
147            None => return,
148        };
149
150        let mut extra_accounts_struct_name = None;
151
152        for item in content.iter_mut() {
153            match item {
154                syn::Item::Fn(item_fn) => self.visit_item_fn_mut(item_fn),
155                syn::Item::Struct(item_struct) => {
156                    if let Some(attr) = item_struct
157                        .attrs
158                        .iter_mut()
159                        .find(|attr| attr.path.is_ident("system_input"))
160                    {
161                        attr.tokens.append_all(quote! { (session_key) });
162                    }
163                    if item_struct
164                        .attrs
165                        .iter()
166                        .any(|attr| attr.path.is_ident("extra_accounts"))
167                    {
168                        extra_accounts_struct_name = Some(&item_struct.ident);
169                        break;
170                    }
171                }
172                _ => {}
173            }
174        }
175
176        if let Some(struct_name) = extra_accounts_struct_name {
177            let initialize_extra_accounts = quote! {
178            #[automatically_derived]
179                pub fn init_extra_accounts(_ctx: Context<#struct_name>) -> Result<()> {
180                    Ok(())
181                }
182            };
183            content.push(syn::parse2(initialize_extra_accounts).unwrap());
184        }
185    }
186}
187
188impl SystemTransform {
189    fn add_variadic_execute_function(content: &mut Vec<syn::Item>) {
190        content.push(syn::parse2(quote! {
191            pub fn bolt_execute<'info>(ctx: Context<'_, '_, 'info, 'info, VariadicBoltComponents<'info>>, args: Vec<u8>) -> Result<Vec<Vec<u8>>> {
192                let mut components = Components::try_from(&ctx)?;
193                let bumps = ComponentsBumps {};
194                let context = Context::new(ctx.program_id, &mut components, ctx.remaining_accounts, bumps);
195                execute(context, args)
196            }
197        }).unwrap());
198    }
199
200    // Helper function to check if a type is `Vec<u8>` or `(Vec<u8>, Vec<u8>, ...)`
201    fn check_is_result_vec_u8(ty: &TypePath) -> bool {
202        if let Some(segment) = ty.path.segments.last() {
203            if segment.ident == "Result" {
204                if let PathArguments::AngleBracketed(args) = &segment.arguments {
205                    if let Some(GenericArgument::Type(Type::Tuple(tuple))) = args.args.first() {
206                        return tuple.elems.iter().all(|elem| {
207                            if let Type::Path(type_path) = elem {
208                                if let Some(segment) = type_path.path.segments.first() {
209                                    return segment.ident == "Vec" && Self::is_u8_vec(segment);
210                                }
211                            }
212                            false
213                        });
214                    } else if let Some(GenericArgument::Type(Type::Path(type_path))) =
215                        args.args.first()
216                    {
217                        if let Some(segment) = type_path.path.segments.first() {
218                            return segment.ident == "Vec" && Self::is_u8_vec(segment);
219                        }
220                    }
221                }
222            }
223        }
224        false
225    }
226
227    // Helper function to check if a type is Vec<u8>
228    fn is_u8_vec(segment: &syn::PathSegment) -> bool {
229        if let PathArguments::AngleBracketed(args) = &segment.arguments {
230            if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first() {
231                if let Some(segment) = path.path.segments.first() {
232                    return segment.ident == "u8";
233                }
234            }
235        }
236        false
237    }
238
239    // Helper function to check if an expression is an `Ok(...)` or `return Ok(...);` variant
240    fn extract_inner_ok_expression(expr: &Expr) -> Option<&Expr> {
241        match expr {
242            Expr::Call(expr_call) => {
243                // Direct `Ok(...)` call
244                if let Expr::Path(expr_path) = &*expr_call.func {
245                    if let Some(last_segment) = expr_path.path.segments.last() {
246                        if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
247                            // Return the first argument of the Ok(...) call
248                            return expr_call.args.first();
249                        }
250                    }
251                }
252            }
253            Expr::Return(expr_return) => {
254                // `return Ok(...);`
255                if let Some(expr_return_inner) = &expr_return.expr {
256                    if let Expr::Call(expr_call) = expr_return_inner.as_ref() {
257                        if let Expr::Path(expr_path) = &*expr_call.func {
258                            if let Some(last_segment) = expr_path.path.segments.last() {
259                                if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
260                                    // Return the first argument of the return Ok(...) call
261                                    return expr_call.args.first();
262                                }
263                            }
264                        }
265                    }
266                }
267            }
268            _ => {}
269        }
270        None
271    }
272
273    fn modify_args(item_fn: &mut ItemFn) {
274        if item_fn.sig.inputs.len() >= 2 {
275            let second_arg = &mut item_fn.sig.inputs[1];
276            let is_vec_u8 = if let FnArg::Typed(syn::PatType { ty, .. }) = second_arg {
277                match &**ty {
278                    Type::Path(type_path) => {
279                        if let Some(segment) = type_path.path.segments.first() {
280                            segment.ident == "Vec" && Self::is_u8_vec(segment)
281                        } else {
282                            false
283                        }
284                    }
285                    _ => false,
286                }
287            } else {
288                false
289            };
290            if !is_vec_u8 {
291                if let FnArg::Typed(pat_type) = second_arg {
292                    let original_type = pat_type.ty.to_token_stream();
293                    let arg_original_name = pat_type.pat.to_token_stream();
294                    if let syn::Pat::Ident(ref mut pat_ident) = *pat_type.pat {
295                        let new_ident_name = format!("_{}", pat_ident.ident);
296                        pat_ident.ident =
297                            Ident::new(&new_ident_name, proc_macro2::Span::call_site());
298                    }
299                    let arg_name = pat_type.pat.to_token_stream();
300                    pat_type.ty = Box::new(syn::parse_quote! { Vec<u8> });
301                    let parse_stmt: Stmt = parse_quote! {
302                        let #arg_original_name = parse_args::<#original_type>(&#arg_name);
303                    };
304                    item_fn.block.stmts.insert(0, parse_stmt);
305                }
306            }
307        }
308    }
309}
310
311/// Visits the AST to extract the number of input components
312impl VisitMut for Extractor {
313    fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
314        for input in &i.sig.inputs {
315            if let FnArg::Typed(pat_type) = input {
316                if let Type::Path(type_path) = &*pat_type.ty {
317                    let last_segment = type_path.path.segments.last().unwrap();
318                    if last_segment.ident == "Context" {
319                        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
320                            if let Some(syn::GenericArgument::Type(syn::Type::Path(type_path))) =
321                                args.args.first()
322                            {
323                                let ident = &type_path.path.segments.first().unwrap().ident;
324                                self.context_struct_name = Some(ident.to_string());
325                            }
326                        }
327                    }
328                }
329            }
330        }
331    }
332
333    fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
334        if let Some(name) = &self.context_struct_name {
335            if i.ident == name {
336                self.field_count = Some(i.fields.len());
337            }
338        }
339    }
340}