gpl_session_macros_attribute/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, Data, DeriveInput, Fields, Token,
7};
8
9struct SessionArgs {
10    signer: syn::ExprAssign,
11    authority: syn::ExprAssign,
12}
13
14impl Parse for SessionArgs {
15    fn parse(input: ParseStream) -> syn::Result<Self> {
16        let signer = input.parse()?;
17
18        input.parse::<Token![,]>()?;
19
20        let authority = input.parse()?;
21        Ok(SessionArgs { signer, authority })
22    }
23}
24
25fn is_session(attr: &syn::Attribute) -> bool {
26    attr.path.is_ident("session")
27}
28
29// Macro to derive Session Trait
30#[proc_macro_derive(Session, attributes(session))]
31pub fn derive(input: TokenStream) -> TokenStream {
32    let input_parsed = parse_macro_input!(input as DeriveInput);
33
34    let fields = match input_parsed.data {
35        Data::Struct(data) => match data.fields {
36            Fields::Named(fields) => fields,
37            _ => panic!("Session trait can only be derived for structs with named fields"),
38        },
39        _ => panic!("Session trait can only be derived for structs"),
40    };
41
42    // Ensure that the struct has a session_token field
43    let session_token_field = fields
44        .named
45        .iter()
46        .find(|field| field.ident.as_ref().unwrap().to_string() == "session_token")
47        .expect("Session trait can only be derived for structs with a session_token field");
48    {
49        let session_token_type = &session_token_field.ty;
50        let session_token_type_string = quote! { #session_token_type }.to_string();
51        assert!(
52        session_token_type_string == "Option < Account < 'info, SessionToken > >",
53        "Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>>"
54        );
55    };
56
57    // Session Token field must have the #[session] attribute
58    let session_attr = session_token_field
59        .attrs
60        .iter()
61        .find(|attr| is_session(attr))
62        .expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
63
64    let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
65
66    let session_signer = session_args.signer.right.into_token_stream();
67
68    // Session Authority
69    let session_authority = session_args.authority.right.into_token_stream();
70
71    let struct_name = &input_parsed.ident;
72    let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
73
74    let output = quote! {
75
76        #[automatically_derived]
77        impl #impl_generics Session #ty_generics for #struct_name #ty_generics #where_clause {
78
79            // Target Program
80            fn target_program(&self) -> Pubkey {
81                crate::id()
82            }
83
84            // Session Token
85            fn session_token(&self) -> Option<Account<'info, SessionToken>> {
86                self.session_token.clone()
87            }
88
89            // Session Authority
90            fn session_authority(&self) -> Pubkey {
91                self.#session_authority
92            }
93
94            // Session Signer
95            fn session_signer(&self) -> Signer<'info> {
96                self.#session_signer.clone()
97            }
98
99        }
100    };
101
102    output.into()
103}
104
105struct SessionAuthArgs(syn::Expr, syn::Expr);
106
107impl Parse for SessionAuthArgs {
108    fn parse(input: ParseStream) -> syn::Result<Self> {
109        let equality_expr = input.parse()?;
110        input.parse::<Token![,]>()?;
111        let error_expr = input.parse()?;
112        Ok(SessionAuthArgs(equality_expr, error_expr))
113    }
114}
115
116#[proc_macro_attribute]
117/// Macro to check if the session or the original authority is the signer
118pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
119    let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
120
121    let input_fn = parse_macro_input!(item as syn::ItemFn);
122    let input_fn_name = input_fn.sig.ident;
123    let input_fn_vis = input_fn.vis;
124    let input_fn_block = input_fn.block;
125    let input_fn_inputs = input_fn.sig.inputs;
126    let input_fn_output = input_fn.sig.output;
127
128    let output = quote! {
129        #input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
130            // Automatically generated by session_auth_or macro
131            // BEGIN SESSION AUTH
132            // Current signer is the session signer or the original authority
133            let session_token = ctx.accounts.session_token();
134            if let Some(token) = session_token {
135                require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
136                // Checks that authority of the session is the same as authority of the original account
137                require_eq!(
138                    ctx.accounts.session_authority(),
139                    token.authority.key(),
140                    #error_ty
141                );
142            } else {
143                require!(
144                    #auth_expr,
145                    #error_ty
146                );
147            }
148            // END SESSION AUTH
149            #input_fn_block
150        }
151    };
152    output.into()
153}