use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Token, Type,
TypePath,
};
struct SessionArgs {
signer: syn::ExprAssign,
authority: syn::ExprAssign,
}
impl Parse for SessionArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let signer = input.parse()?;
input.parse::<Token![,]>()?;
let authority = input.parse()?;
Ok(SessionArgs { signer, authority })
}
}
fn is_session(attr: &syn::Attribute) -> bool {
attr.path.is_ident("session")
}
#[derive(Copy, Clone, Eq, PartialEq)]
enum SessionTokenType {
V1,
V2,
}
struct SessionTokenInfo {
token_type: SessionTokenType,
lifetime: syn::Lifetime,
}
fn get_session_token_info(ty: &Type) -> Option<SessionTokenInfo> {
let Type::Path(TypePath { path, .. }) = ty else {
return None;
};
let Some(seg) = path.segments.last() else {
return None;
};
if seg.ident != "Option" {
return None;
}
let PathArguments::AngleBracketed(ref opt_args) = seg.arguments else {
return None;
};
let Some(GenericArgument::Type(Type::Path(TypePath {
path: acct_path, ..
}))) = opt_args.args.first()
else {
return None;
};
let Some(acct_seg) = acct_path.segments.last() else {
return None;
};
if acct_seg.ident != "Account" {
return None;
}
let PathArguments::AngleBracketed(ref acct_args) = acct_seg.arguments else {
return None;
};
let mut args = acct_args.args.iter();
let Some(GenericArgument::Lifetime(lifetime)) = args.next() else {
return None;
};
let lifetime = lifetime.clone();
let Some(GenericArgument::Type(Type::Path(TypePath { path: st_path, .. }))) = args.next()
else {
return None;
};
if let Some(st_seg) = st_path.segments.last() {
if st_seg.ident == "SessionToken" {
return Some(SessionTokenInfo { token_type: SessionTokenType::V1, lifetime });
} else if st_seg.ident == "SessionTokenV2" {
return Some(SessionTokenInfo { token_type: SessionTokenType::V2, lifetime });
}
}
None
}
fn derive_impl(input: TokenStream, expected: Option<SessionTokenType>) -> TokenStream {
let input_parsed = parse_macro_input!(input as DeriveInput);
let fields = match input_parsed.data {
Data::Struct(data) => match data.fields {
Fields::Named(fields) => fields,
_ => panic!("Session trait can only be derived for structs with named fields"),
},
_ => panic!("Session trait can only be derived for structs"),
};
let session_token_field = fields
.named
.iter()
.find(|field| field.ident.as_ref().unwrap().to_string() == "session_token")
.expect("Session trait can only be derived for structs with a session_token field");
let session_token_type = &session_token_field.ty;
let info = get_session_token_info(session_token_type)
.expect("Session trait can only be derived for structs with a session_token field of type Option<Account<'info, SessionToken>> or Option<Account<'info, SessionTokenV2>>");
let token_type = info.token_type;
if let Some(expected) = expected {
if token_type != expected {
return syn::Error::new_spanned(
&session_token_field.ty,
"#[derive(SessionV2)] requires Option<Account<'info, SessionTokenV2>>",
)
.to_compile_error()
.into();
}
}
let session_attr = session_token_field
.attrs
.iter()
.find(|attr| is_session(attr))
.expect("Session trait can only be derived for structs with a session_token field with the #[session] attribute");
let session_args = session_attr.parse_args::<SessionArgs>().unwrap();
let session_signer = session_args.signer.right.into_token_stream();
let session_authority = session_args.authority.right.into_token_stream();
let struct_name = &input_parsed.ident;
let (impl_generics, ty_generics, where_clause) = input_parsed.generics.split_for_impl();
let info_lifetime = info.lifetime;
let output = match token_type {
SessionTokenType::V1 => quote! {
#[automatically_derived]
impl #impl_generics ::session_keys::Session<#info_lifetime> for #struct_name #ty_generics #where_clause {
fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
crate::id()
}
fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionToken>> {
self.session_token.clone()
}
fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
self.#session_authority
}
fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
self.#session_signer.clone()
}
}
},
SessionTokenType::V2 => quote! {
#[automatically_derived]
impl #impl_generics ::session_keys::SessionV2<#info_lifetime> for #struct_name #ty_generics #where_clause {
fn target_program(&self) -> ::anchor_lang::prelude::Pubkey {
crate::id()
}
fn session_token(&self) -> Option<::anchor_lang::prelude::Account<#info_lifetime, ::session_keys::SessionTokenV2>> {
self.session_token.clone()
}
fn session_authority(&self) -> ::anchor_lang::prelude::Pubkey {
self.#session_authority
}
fn session_signer(&self) -> ::anchor_lang::prelude::Signer<#info_lifetime> {
self.#session_signer.clone()
}
}
},
};
output.into()
}
#[proc_macro_derive(Session, attributes(session))]
pub fn derive_session(input: TokenStream) -> TokenStream {
derive_impl(input, None)
}
#[proc_macro_derive(SessionV2, attributes(session))]
pub fn derive_session_v2(input: TokenStream) -> TokenStream {
derive_impl(input, Some(SessionTokenType::V2))
}
struct SessionAuthArgs(syn::Expr, syn::Expr);
impl Parse for SessionAuthArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let equality_expr = input.parse()?;
input.parse::<Token![,]>()?;
let error_expr = input.parse()?;
Ok(SessionAuthArgs(equality_expr, error_expr))
}
}
#[proc_macro_attribute]
pub fn session_auth_or(attr: TokenStream, item: TokenStream) -> TokenStream {
let SessionAuthArgs(auth_expr, error_ty) = parse_macro_input!(attr);
let input_fn = parse_macro_input!(item as syn::ItemFn);
let input_fn_name = input_fn.sig.ident;
let input_fn_vis = input_fn.vis;
let input_fn_block = input_fn.block;
let input_fn_inputs = input_fn.sig.inputs;
let input_fn_output = input_fn.sig.output;
let output = quote! {
#input_fn_vis fn #input_fn_name(#input_fn_inputs) #input_fn_output {
use ::session_keys::{Session as _, SessionV2 as _};
let session_token = ctx.accounts.session_token();
if let Some(token) = session_token {
require!(ctx.accounts.is_valid()?, SessionError::InvalidToken);
require_eq!(
ctx.accounts.session_authority(),
token.authority.key(),
#error_ty
);
} else {
require!(
#auth_expr,
#error_ty
);
}
#input_fn_block
}
};
output.into()
}