bolt_attribute_bolt_system/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{quote, ToTokens, TokenStreamExt};
4use syn::{
5 parse_macro_input, parse_quote, visit_mut::VisitMut, Expr, FnArg, GenericArgument, ItemFn,
6 ItemMod, ItemStruct, PathArguments, ReturnType, Stmt, Type, TypePath,
7};
8
9#[derive(Default)]
10struct SystemTransform;
11
12#[derive(Default)]
13struct Extractor {
14 context_struct_name: Option<String>,
15 field_count: Option<usize>,
16}
17
18#[proc_macro_attribute]
37pub fn system(_attr: TokenStream, item: TokenStream) -> TokenStream {
38 let mut ast = parse_macro_input!(item as ItemMod);
39
40 let mut extractor = Extractor::default();
42 extractor.visit_item_mod_mut(&mut ast);
43
44 if extractor.field_count.is_some() {
45 let use_super = syn::parse_quote! { use super::*; };
46 if let Some((_, ref mut items)) = ast.content {
47 items.insert(0, syn::Item::Use(use_super));
48 SystemTransform::add_variadic_execute_function(items);
49 }
50
51 let mut transform = SystemTransform;
52 transform.visit_item_mod_mut(&mut ast);
53
54 let expanded = quote! {
56 #[program]
57 #ast
58 };
59
60 TokenStream::from(expanded)
61 } else {
62 panic!(
63 "Could not find the component bundle: {} in the module",
64 extractor.context_struct_name.unwrap()
65 );
66 }
67}
68
69impl SystemTransform {
70 fn visit_stmts_mut(&mut self, stmts: &mut Vec<Stmt>) {
71 for stmt in stmts {
72 if let Stmt::Expr(ref mut expr) | Stmt::Semi(ref mut expr, _) = stmt {
73 self.visit_expr_mut(expr);
74 }
75 }
76 }
77}
78
79impl VisitMut for SystemTransform {
81 fn visit_expr_mut(&mut self, expr: &mut Expr) {
83 match expr {
84 Expr::ForLoop(for_loop_expr) => {
85 self.visit_stmts_mut(&mut for_loop_expr.body.stmts);
86 }
87 Expr::Loop(loop_expr) => {
88 self.visit_stmts_mut(&mut loop_expr.body.stmts);
89 }
90 Expr::If(if_expr) => {
91 self.visit_stmts_mut(&mut if_expr.then_branch.stmts);
92 if let Some((_, else_expr)) = &mut if_expr.else_branch {
93 self.visit_expr_mut(else_expr);
94 }
95 }
96 Expr::Block(block_expr) => {
97 self.visit_stmts_mut(&mut block_expr.block.stmts);
98 }
99 _ => (),
100 }
101 if let Some(inner_variable) = Self::extract_inner_ok_expression(expr) {
102 let new_return_expr: Expr = match inner_variable {
103 Expr::Tuple(tuple_expr) => {
104 let tuple_elements = tuple_expr.elems.iter().map(|elem| {
105 quote! { (#elem).try_to_vec()? }
106 });
107 parse_quote! { Ok((#(#tuple_elements),*)) }
108 }
109 _ => {
110 parse_quote! {
111 #inner_variable.try_to_vec()
112 }
113 }
114 };
115 if let Expr::Return(return_expr) = expr {
116 return_expr.expr = Some(Box::new(new_return_expr));
117 } else {
118 *expr = new_return_expr;
119 }
120 }
121 }
122
123 fn visit_item_fn_mut(&mut self, item_fn: &mut ItemFn) {
125 if item_fn.sig.ident == "execute" {
126 Self::inject_lifetimes_and_context(item_fn);
128 if let ReturnType::Type(_, type_box) = &item_fn.sig.output {
130 if let Type::Path(type_path) = &**type_box {
131 if !Self::check_is_result_vec_u8(type_path) {
132 item_fn.sig.output = parse_quote! { -> Result<Vec<Vec<u8>>> };
133 let block = &mut item_fn.block;
135 self.visit_stmts_mut(&mut block.stmts);
136 }
137 }
138 }
139 Self::modify_args(item_fn);
141 }
142 }
143
144 fn visit_item_mod_mut(&mut self, item_mod: &mut ItemMod) {
147 let content = match item_mod.content.as_mut() {
148 Some(content) => &mut content.1,
149 None => return,
150 };
151
152 let mut extra_accounts_struct_name = None;
153
154 for item in content.iter_mut() {
155 match item {
156 syn::Item::Fn(item_fn) => self.visit_item_fn_mut(item_fn),
157 syn::Item::Struct(item_struct) => {
158 if let Some(attr) = item_struct
159 .attrs
160 .iter_mut()
161 .find(|attr| attr.path.is_ident("system_input"))
162 {
163 attr.tokens.append_all(quote! { (session_key) });
164 }
165 if item_struct
166 .attrs
167 .iter()
168 .any(|attr| attr.path.is_ident("extra_accounts"))
169 {
170 extra_accounts_struct_name = Some(&item_struct.ident);
171 break;
172 }
173 }
174 _ => {}
175 }
176 }
177
178 if let Some(struct_name) = extra_accounts_struct_name {
179 let initialize_extra_accounts = quote! {
180 #[automatically_derived]
181 pub fn init_extra_accounts(_ctx: Context<#struct_name>) -> Result<()> {
182 Ok(())
183 }
184 };
185 content.push(syn::parse2(initialize_extra_accounts).unwrap());
186 }
187 }
188}
189
190impl SystemTransform {
191 fn inject_lifetimes_and_context(item_fn: &mut ItemFn) {
192 let lifetime_idents = ["a", "b", "c", "info"];
194 for name in lifetime_idents.iter() {
195 let exists = item_fn.sig.generics.params.iter().any(|p| match p {
196 syn::GenericParam::Lifetime(l) => l.lifetime.ident == *name,
197 _ => false,
198 });
199 if !exists {
200 let lifetime: syn::Lifetime =
201 syn::parse_str(&format!("'{}", name)).expect("valid lifetime");
202 let gp: syn::GenericParam = syn::parse_quote!(#lifetime);
203 item_fn.sig.generics.params.push(gp);
204 }
205 }
206
207 if let Some(FnArg::Typed(pat_type)) = item_fn.sig.inputs.first_mut() {
209 if let Type::Path(type_path) = pat_type.ty.as_mut() {
210 if let Some(last_segment) = type_path.path.segments.last_mut() {
211 if last_segment.ident == "Context" {
212 let mut components_ty_opt: Option<Type> = None;
214 if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
215 for ga in args.args.iter() {
216 if let GenericArgument::Type(t) = ga {
217 components_ty_opt = Some(t.clone());
218 break;
219 }
220 }
221 }
222
223 if let Some(components_ty) = components_ty_opt {
225 let components_with_info: Type = match components_ty {
227 Type::Path(mut tp) => {
228 let seg = tp.path.segments.last_mut().unwrap();
229 match &mut seg.arguments {
230 PathArguments::AngleBracketed(ab) => {
231 if ab.args.is_empty() {
232 ab.args.push(GenericArgument::Lifetime(
233 syn::parse_quote!('info),
234 ));
235 }
236 }
237 _ => {
238 seg.arguments = PathArguments::AngleBracketed(
239 syn::AngleBracketedGenericArguments {
240 colon2_token: None,
241 lt_token: Default::default(),
242 args: std::iter::once(
243 GenericArgument::Lifetime(
244 syn::parse_quote!('info),
245 ),
246 )
247 .collect(),
248 gt_token: Default::default(),
249 },
250 );
251 }
252 }
253 Type::Path(tp)
254 }
255 other => other,
256 };
257
258 let new_ty: Type = syn::parse_quote! {
260 Context<'a, 'b, 'c, 'info, #components_with_info>
261 };
262 pat_type.ty = Box::new(new_ty);
263 }
264 }
265 }
266 }
267 }
268 }
269 fn add_variadic_execute_function(content: &mut Vec<syn::Item>) {
270 content.push(syn::parse2(quote! {
271 pub fn bolt_execute<'a, 'b, 'info>(ctx: Context<'a, 'b, 'info, 'info, VariadicBoltComponents<'info>>, args: Vec<u8>) -> Result<Vec<Vec<u8>>> {
272 let mut components = Components::try_from(&ctx)?;
273 let bumps = ComponentsBumps {};
274 let context = Context::new(ctx.program_id, &mut components, ctx.remaining_accounts, bumps);
275 execute(context, args)
276 }
277 }).unwrap());
278 }
279
280 fn check_is_result_vec_u8(ty: &TypePath) -> bool {
282 if let Some(segment) = ty.path.segments.last() {
283 if segment.ident == "Result" {
284 if let PathArguments::AngleBracketed(args) = &segment.arguments {
285 if let Some(GenericArgument::Type(Type::Tuple(tuple))) = args.args.first() {
286 return tuple.elems.iter().all(|elem| {
287 if let Type::Path(type_path) = elem {
288 if let Some(segment) = type_path.path.segments.first() {
289 return segment.ident == "Vec" && Self::is_u8_vec(segment);
290 }
291 }
292 false
293 });
294 } else if let Some(GenericArgument::Type(Type::Path(type_path))) =
295 args.args.first()
296 {
297 if let Some(segment) = type_path.path.segments.first() {
298 return segment.ident == "Vec" && Self::is_u8_vec(segment);
299 }
300 }
301 }
302 }
303 }
304 false
305 }
306
307 fn is_u8_vec(segment: &syn::PathSegment) -> bool {
309 if let PathArguments::AngleBracketed(args) = &segment.arguments {
310 if let Some(GenericArgument::Type(Type::Path(path))) = args.args.first() {
311 if let Some(segment) = path.path.segments.first() {
312 return segment.ident == "u8";
313 }
314 }
315 }
316 false
317 }
318
319 fn extract_inner_ok_expression(expr: &Expr) -> Option<&Expr> {
321 match expr {
322 Expr::Call(expr_call) => {
323 if let Expr::Path(expr_path) = &*expr_call.func {
325 if let Some(last_segment) = expr_path.path.segments.last() {
326 if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
327 return expr_call.args.first();
329 }
330 }
331 }
332 }
333 Expr::Return(expr_return) => {
334 if let Some(expr_return_inner) = &expr_return.expr {
336 if let Expr::Call(expr_call) = expr_return_inner.as_ref() {
337 if let Expr::Path(expr_path) = &*expr_call.func {
338 if let Some(last_segment) = expr_path.path.segments.last() {
339 if last_segment.ident == "Ok" && !expr_call.args.is_empty() {
340 return expr_call.args.first();
342 }
343 }
344 }
345 }
346 }
347 }
348 _ => {}
349 }
350 None
351 }
352
353 fn modify_args(item_fn: &mut ItemFn) {
354 if item_fn.sig.inputs.len() >= 2 {
355 let second_arg = &mut item_fn.sig.inputs[1];
356 let is_vec_u8 = if let FnArg::Typed(syn::PatType { ty, .. }) = second_arg {
357 match &**ty {
358 Type::Path(type_path) => {
359 if let Some(segment) = type_path.path.segments.first() {
360 segment.ident == "Vec" && Self::is_u8_vec(segment)
361 } else {
362 false
363 }
364 }
365 _ => false,
366 }
367 } else {
368 false
369 };
370 if !is_vec_u8 {
371 if let FnArg::Typed(pat_type) = second_arg {
372 let original_type = pat_type.ty.to_token_stream();
373 let arg_original_name = pat_type.pat.to_token_stream();
374 if let syn::Pat::Ident(ref mut pat_ident) = *pat_type.pat {
375 let new_ident_name = format!("_{}", pat_ident.ident);
376 pat_ident.ident =
377 Ident::new(&new_ident_name, proc_macro2::Span::call_site());
378 }
379 let arg_name = pat_type.pat.to_token_stream();
380 pat_type.ty = Box::new(syn::parse_quote! { Vec<u8> });
381 let parse_stmt: Stmt = parse_quote! {
382 let #arg_original_name = parse_args::<#original_type>(&#arg_name);
383 };
384 item_fn.block.stmts.insert(0, parse_stmt);
385 }
386 }
387 }
388 }
389}
390
391impl VisitMut for Extractor {
393 fn visit_item_fn_mut(&mut self, i: &mut ItemFn) {
394 for input in &i.sig.inputs {
395 if let FnArg::Typed(pat_type) = input {
396 if let Type::Path(type_path) = &*pat_type.ty {
397 let last_segment = type_path.path.segments.last().unwrap();
398 if last_segment.ident == "Context" {
399 if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
400 for ga in args.args.iter() {
402 if let syn::GenericArgument::Type(syn::Type::Path(type_path)) = ga {
403 if let Some(first_seg) = type_path.path.segments.first() {
404 self.context_struct_name =
405 Some(first_seg.ident.to_string());
406 break;
407 }
408 }
409 }
410 }
411 }
412 }
413 }
414 }
415 }
416
417 fn visit_item_struct_mut(&mut self, i: &mut ItemStruct) {
418 if let Some(name) = &self.context_struct_name {
419 if i.ident == name {
420 self.field_count = Some(i.fields.len());
421 }
422 }
423 }
424}