gpl_session_macros_attribute/
lib.rs1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, Data, DeriveInput, Fields, Token,
7};
8
9struct SessionArgs {
10 signer: syn::ExprAssign,
11 authority: syn::ExprAssign,
12}
13
14impl Parse for SessionArgs {
15 fn parse(input: ParseStream) -> syn::Result<Self> {
16 let signer = input.parse()?;
17
18 input.parse::<Token![,]>()?;
19
20 let authority = input.parse()?;
21 Ok(SessionArgs { signer, authority })
22 }
23}
24
25fn is_session(attr: &syn::Attribute) -> bool {
26 attr.path.is_ident("session")
27}
28
29#[proc_macro_derive(Session, attributes(session))]
31pub fn derive(input: TokenStream) -> TokenStream {
32 let input_parsed = parse_macro_input!(input as DeriveInput);
33
34 let fields = match input_parsed.data {
35 Data::Struct(data) => match data.fields {
36 Fields::Named(fields) => fields,
37 _ => panic!("Session trait can only be derived for structs with named fields"),
38 },
39 _ => panic!("Session trait can only be derived for structs"),
40 };
41
42 let session_token_field = fields
44 .named
45 .iter()
46 .find(|field| field.ident.as_ref().unwrap().to_string() == "session_token")
47 .expect("Session trait can only be derived for structs with a session_token field");
48 {
49 let session_token_type = &session_token_field.ty;
50 let session_token_type_string = quote! { #session_token_type }.to_string();
51 assert!(
52 session_token_type_string == "Option < Account < 'info, SessionToken > >",
53 "Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>>"
54 );
55 };
56
57 let session_attr = session_token_field
59 .attrs
60 .iter()
61 .find(|attr| is_session(attr))
62 .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
63
64 let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
65
66 let session_signer = session_args.signer.right.into_token_stream();
67
68 let session_authority = session_args.authority.right.into_token_stream();
70
71 let struct_name = &input_parsed.ident;
72 let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
73
74 let output = quote! {
75
76 #[automatically_derived]
77 impl #impl_generics Session #ty_generics for #struct_name #ty_generics #where_clause {
78
79 fn target_program(&self) -> Pubkey {
81 crate::id()
82 }
83
84 fn session_token(&self) -> Option<Account<'info, SessionToken>> {
86 self.session_token.clone()
87 }
88
89 fn session_authority(&self) -> Pubkey {
91 self.#session_authority
92 }
93
94 fn session_signer(&self) -> Signer<'info> {
96 self.#session_signer.clone()
97 }
98
99 }
100 };
101
102 output.into()
103}
104
105struct SessionAuthArgs(syn::Expr, syn::Expr);
106
107impl Parse for SessionAuthArgs {
108 fn parse(input: ParseStream) -> syn::Result<Self> {
109 let equality_expr = input.parse()?;
110 input.parse::<Token![,]>()?;
111 let error_expr = input.parse()?;
112 Ok(SessionAuthArgs(equality_expr, error_expr))
113 }
114}
115
116#[proc_macro_attribute]
117pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
119 let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
120
121 let input_fn = parse_macro_input!(item as syn::ItemFn);
122 let input_fn_name = input_fn.sig.ident;
123 let input_fn_vis = input_fn.vis;
124 let input_fn_block = input_fn.block;
125 let input_fn_inputs = input_fn.sig.inputs;
126 let input_fn_output = input_fn.sig.output;
127
128 let output = quote! {
129 #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
130 let session_token = ctx.accounts.session_token();
134 if let Some(token) = session_token {
135 require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
136 require_eq!(
138 ctx.accounts.session_authority(),
139 token.authority.key(),
140 #error_ty
141 );
142 } else {
143 require!(
144 #auth_expr,
145 #error_ty
146 );
147 }
148 #input_fn_block
150 }
151 };
152 output.into()
153}