typhoon-account-macro 0.3.0

Procedural macro for deriving account state implementations
Documentation
use {
    keys::PrimaryKeys,
    quote::{quote, ToTokens},
    syn::{parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Item, Path, Token},
    typhoon_discriminator::DiscriminatorBuilder,
};

mod keys;

fn has_derive(attrs: &[syn::Attribute], derive_name: &str) -> bool {
    attrs
        .iter()
        .filter(|attr| attr.path().is_ident("derive"))
        .filter_map(|attr| {
            attr.parse_args_with(Punctuated::<Path, Token![,]>::parse_terminated)
                .ok()
        })
        .flatten()
        .any(|path| {
            path.segments
                .last()
                .is_some_and(|segment| segment.ident == derive_name)
        })
}

#[proc_macro_derive(AccountState, attributes(key, no_space))]
pub fn derive_account(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let item = parse_macro_input!(item as Item);
    let (attrs, name, generics, fields) = match item {
        Item::Struct(ref item_struct) => (
            &item_struct.attrs,
            &item_struct.ident,
            &item_struct.generics,
            &item_struct.fields,
        ),
        _ => {
            return Error::new(item.span(), "Invalid account type")
                .into_compile_error()
                .into()
        }
    };

    let space_token = if attrs.iter().any(|a| a.path().is_ident("no_space")) {
        None
    } else {
        Some(quote! {
            impl #name {
                pub const SPACE: usize = <#name as Discriminator>::DISCRIMINATOR.len() + core::mem::size_of::<#name>();
            }
        })
    };
    let (_, ty_generics, where_clause) = generics.split_for_impl();

    let keys = match PrimaryKeys::try_from(fields) {
        Ok(fields) => fields,
        Err(err) => return err.to_compile_error().into(),
    };
    let seeded_trait = keys.split_for_impl(name);
    let discriminator = DiscriminatorBuilder::new(&name.to_string()).build();
    let account_strategy = if has_derive(attrs, "SchemaRead") {
        quote!(
            WincodeStrategy<
                {
                    matches!(
                <Self as wincode::SchemaRead<'static, wincode::config::DefaultConfig>>::TYPE_META,
                wincode::TypeMeta::Static { zero_copy: true, .. }
            )
                },
            >
        )
    } else if has_derive(attrs, "BorshDeserialize") {
        quote!(BorshStrategy)
    } else {
        quote!(BytemuckStrategy)
    };

    quote! {
        impl CheckOwner for #name #ty_generics #where_clause {
            #[inline(always)]
            fn owned_by(owner: &Address) -> bool {
                address_eq(owner, &crate::ID)
            }
        }

        impl Discriminator for #name #ty_generics #where_clause {
            const DISCRIMINATOR: &'static [u8] = &[#(#discriminator),*];
        }

        impl DataStrategy for #name #ty_generics #where_clause {
            type Strategy = #account_strategy;
        }

        #space_token

        #seeded_trait
    }
    .into_token_stream()
    .into()
}