bolt_attribute_bolt_system/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{
5 parse_macro_input, parse_quote, visit_mut::VisitMut, Expr, FnArg, GenericArgument, ItemFn,
6 ItemMod, ItemStruct, PathArguments, ReturnType, Stmt, Type, TypePath,
7};
8
9#[derive(Default)]
10struct SystemTransform;
11
12#[derive(Default)]
13struct Extractor {
14 context_struct_name: Option<String>,
15 field_count: Option<usize>,
16}
17
18#[proc_macro_attribute]
37pub fn system(_attr: TokenStream, item: TokenStream) -> TokenStream {
38 let mut ast = parse_macro_input!(item as ItemMod);
39
40 let mut extractor = Extractor::default();
42 extractor.visit_item_mod_mut(&mut ast);
43
44 if extractor.field_count.is_some() {
45 let use_super = syn::parse_quote! { use super::*; };
46 if let Some((_, ref mut items)) = ast.content {
47 items.insert(0, syn::Item::Use(use_super));
48 SystemTransform::add_variadic_execute_function(items);
49 }
50
51 let mut transform = SystemTransform;
52 transform.visit_item_mod_mut(&mut ast);
53
54 let expanded = quote! {
56 #[program]
57 #ast
58 };
59
60 TokenStream::from(expanded)
61 } else {
62 panic!(
63 "Could not find the component bundle: {} in the module",
64 extractor.context_struct_name.unwrap()
65 );
66 }
67}
68
69impl SystemTransform {
70 fn visit_stmts_mut(&mut self, stmts: &mut Vec<Stmt>) {
71 for stmt in stmts {
72 if let Stmt::Expr(ref mut expr) | Stmt::Semi(ref mut expr, _) = stmt {
73 self.visit_expr_mut(expr);
74 }
75 }
76 }
77}
78
79impl VisitMut for SystemTransform {
81 fn visit_expr_mut(&mut self, expr: &mut Expr) {
83 match expr {
84 Expr::ForLoop(for_loop_expr) => {
85 self.visit_stmts_mut(&mut for_loop_expr.body.stmts);
86 }
87 Expr::Loop(loop_expr) => {
88 self.visit_stmts_mut(&mut loop_expr.body.stmts);
89 }
90 Expr::If(if_expr) => {
91 self.visit_stmts_mut(&mut if_expr.then_branch.stmts);
92 if let Some((_, else_expr)) = &mut if_expr.else_branch {
93 self.visit_expr_mut(else_expr);
94 }
95 }
96 Expr::Block(block_expr) => {
97 self.visit_stmts_mut(&mut block_expr.block.stmts);
98 }
99 _ => (),
100 }
101 if let Some(inner_variable) = Self::extract_inner_ok_expression(expr) {
102 let new_return_expr: Expr = match inner_variable {
103 Expr::Tuple(tuple_expr) => {
104 let tuple_elements = tuple_expr.elems.iter().map(|elem| {
105 quote! { (#elem).try_to_vec()? }
106 });
107 parse_quote! { Ok((#(#tuple_elements),*)) }
108 }
109 _ => {
110 parse_quote! {
111 #inner_variable.try_to_vec()
112 }
113 }
114 };
115 if let Expr::Return(return_expr) = expr {
116 return_expr.expr = Some(Box::new(new_return_expr));
117 } else {
118 *expr = new_return_expr;
119 }
120 }
121 }
122
123 fn visit_item_fn_mut(&mut self, item_fn: &mut ItemFn) {
125 if item_fn.sig.ident == "execute" {
126 if let ReturnType::Type(_, type_box) = &item_fn.sig.output {
128 if let Type::Path(type_path) = &**type_box {
129 if !Self::check_is_result_vec_u8(type_path) {
130 item_fn.sig.output = parse_quote! { -> Result<Vec<Vec<u8>>> };
131 let block = &mut item_fn.block;
133 self.visit_stmts_mut(&mut block.stmts);
134 }
135 }
136 }
137 Self::modify_args(item_fn);
139 }
140 }
141
142 fn visit_item_mod_mut(&mut self, item_mod: &mut ItemMod) {
145 let content = match item_mod.content.as_mut() {
146 Some(content) => &mut content.1,
147 None => return,
148 };
149
150 let mut extra_accounts_struct_name = None;
151
152 for item in content.iter_mut() {
153 match item {
154 syn::Item::Fn(item_fn) => self.visit_item_fn_mut(item_fn),
155 syn::Item::Struct(item_struct) => {
156 if let Some(attr) = item_struct
157 .attrs
158 .iter_mut()
159 .find(|attr| attr.path.is_ident("system_input"))
160 {
161 attr.tokens.append_all(quote! { (session_key) });
162 }
163 if item_struct
164 .attrs
165 .iter()
166 .any(|attr| attr.path.is_ident("extra_accounts"))
167 {
168 extra_accounts_struct_name = Some(&item_struct.ident);
169 break;
170 }
171 }
172 _ => {}
173 }
174 }
175
176 if let Some(struct_name) = extra_accounts_struct_name {
177 let initialize_extra_accounts = quote! {
178 #[automatically_derived]
179 pub fn init_extra_accounts(_ctx: Context<#struct_name>) -> Result<()> {
180 Ok(())
181 }
182 };
183 content.push(syn::parse2(initialize_extra_accounts).unwrap());
184 }
185 }
186}
187
188impl SystemTransform {
189 fn add_variadic_execute_function(content: &mut Vec<syn::Item>) {
190 content.push(syn::parse2(quote! {
191 pub fn bolt_execute<'info>(ctx: Context<'_, '_, 'info, 'info, VariadicBoltComponents<'info>>, args: Vec<u8>) -> Result<Vec<Vec<u8>>> {
192 let mut components = Components::try_from(&ctx)?;
193 let bumps = ComponentsBumps {};
194 let context = Context::new(ctx.program_id, &mut components, ctx.remaining_accounts, bumps);
195 execute(context, args)
196 }
197 }).unwrap());
198 }
199
200 fn check_is_result_vec_u8(ty: &TypePath) -> bool {
202 if let Some(segment) = ty.path.segments.last() {
203 if segment.ident == "Result" {
204 if let PathArguments::AngleBracketed(args) = &segment.arguments {
205 if let Some(GenericArgument::Type(Type::Tuple(tuple))) = args.args.first() {
206 return tuple.elems.iter().all(|elem| {
207 if let Type::Path(type_path) = elem {
208 if let Some(segment) = type_path.path.segments.first() {
209 return segment.ident == "Vec" && Self::is_u8_vec(segment);
210 }
211 }
212 false
213 });
214 } else if let Some(GenericArgument::Type(Type::Path(type_path))) =
215 args.args.first()
216 {
217 if let Some(segment) = type_path.path.segments.first() {
218 return segment.ident == "Vec" && Self::is_u8_vec(segment);
219 }
220 }
221 }
222 }
223 }
224 false
225 }
226
227 fn is_u8_vec(segment: &syn::PathSegment) -> bool {
229 if let PathArguments::AngleBracketed(args) = &segment.arguments {
230 if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first() {
231 if let Some(segment) = path.path.segments.first() {
232 return segment.ident == "u8";
233 }
234 }
235 }
236 false
237 }
238
239 fn extract_inner_ok_expression(expr: &Expr) -> Option<&Expr> {
241 match expr {
242 Expr::Call(expr_call) => {
243 if let Expr::Path(expr_path) = &*expr_call.func {
245 if let Some(last_segment) = expr_path.path.segments.last() {
246 if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
247 return expr_call.args.first();
249 }
250 }
251 }
252 }
253 Expr::Return(expr_return) => {
254 if let Some(expr_return_inner) = &expr_return.expr {
256 if let Expr::Call(expr_call) = expr_return_inner.as_ref() {
257 if let Expr::Path(expr_path) = &*expr_call.func {
258 if let Some(last_segment) = expr_path.path.segments.last() {
259 if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
260 return expr_call.args.first();
262 }
263 }
264 }
265 }
266 }
267 }
268 _ => {}
269 }
270 None
271 }
272
273 fn modify_args(item_fn: &mut ItemFn) {
274 if item_fn.sig.inputs.len() >= 2 {
275 let second_arg = &mut item_fn.sig.inputs[1];
276 let is_vec_u8 = if let FnArg::Typed(syn::PatType { ty, .. }) = second_arg {
277 match &**ty {
278 Type::Path(type_path) => {
279 if let Some(segment) = type_path.path.segments.first() {
280 segment.ident == "Vec" && Self::is_u8_vec(segment)
281 } else {
282 false
283 }
284 }
285 _ => false,
286 }
287 } else {
288 false
289 };
290 if !is_vec_u8 {
291 if let FnArg::Typed(pat_type) = second_arg {
292 let original_type = pat_type.ty.to_token_stream();
293 let arg_original_name = pat_type.pat.to_token_stream();
294 if let syn::Pat::Ident(ref mut pat_ident) = *pat_type.pat {
295 let new_ident_name = format!("_{}", pat_ident.ident);
296 pat_ident.ident =
297 Ident::new(&new_ident_name, proc_macro2::Span::call_site());
298 }
299 let arg_name = pat_type.pat.to_token_stream();
300 pat_type.ty = Box::new(syn::parse_quote! { Vec<u8> });
301 let parse_stmt: Stmt = parse_quote! {
302 let #arg_original_name = parse_args::<#original_type>(&#arg_name);
303 };
304 item_fn.block.stmts.insert(0, parse_stmt);
305 }
306 }
307 }
308 }
309}
310
311impl VisitMut for Extractor {
313 fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
314 for input in &i.sig.inputs {
315 if let FnArg::Typed(pat_type) = input {
316 if let Type::Path(type_path) = &*pat_type.ty {
317 let last_segment = type_path.path.segments.last().unwrap();
318 if last_segment.ident == "Context" {
319 if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
320 if let Some(syn::GenericArgument::Type(syn::Type::Path(type_path))) =
321 args.args.first()
322 {
323 let ident = &type_path.path.segments.first().unwrap().ident;
324 self.context_struct_name = Some(ident.to_string());
325 }
326 }
327 }
328 }
329 }
330 }
331 }
332
333 fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
334 if let Some(name) = &self.context_struct_name {
335 if i.ident == name {
336 self.field_count = Some(i.fields.len());
337 }
338 }
339 }
340}