typhoon-context-macro 0.1.0-alpha.2

TODO
Documentation
use {
    super::{ConstraintGenerator, GeneratorResult},
    crate::{
        arguments::{Argument, Arguments},
        context::Context,
        visitor::ContextVisitor,
    },
    proc_macro2::{Span, TokenStream},
    quote::{format_ident, quote},
    syn::{parse_quote, Ident},
};

#[derive(Default)]
pub struct ArgumentsGenerator {
    context_name: Option<Ident>,
    arg_name: Option<Ident>,
    arg_struct: Option<TokenStream>,
}

impl ArgumentsGenerator {
    pub fn new() -> Self {
        ArgumentsGenerator::default()
    }
}

impl ConstraintGenerator for ArgumentsGenerator {
    fn generate(&self) -> Result<GeneratorResult, syn::Error> {
        let mut result = GeneratorResult::default();
        if let Some(ref name) = self.arg_name {
            result
                .new_fields
                .push(parse_quote!(pub args: Args<'info, #name>));
            result.at_init =
                quote!(let args = Args::<#name>::from_entrypoint(accounts, instruction_data)?;);

            if let Some(ref arg_struct) = self.arg_struct {
                result.global_outside = arg_struct.clone();
            }
        }
        Ok(result)
    }
}

impl ContextVisitor for ArgumentsGenerator {
    fn visit_context(&mut self, context: &Context) -> Result<(), syn::Error> {
        self.context_name = Some(context.item_struct.ident.clone());

        self.visit_accounts(&context.accounts)?;

        if let Some(args) = &context.args {
            self.visit_arguments(args)?;
        }

        Ok(())
    }

    fn visit_arguments(&mut self, arguments: &Arguments) -> Result<(), syn::Error> {
        match arguments {
            Arguments::Struct(name) => self.arg_name = Some(name.clone()),
            Arguments::Values(args) => {
                let Some(ref context_name) = self.context_name else {
                    return Err(syn::Error::new(
                        Span::call_site(),
                        "Not in a valid context.",
                    ));
                };

                let struct_name = format_ident!("{context_name}Args");
                let fields = args
                    .iter()
                    .map(|Argument { name, ty }: &Argument| quote!(pub #name: #ty));

                let generated_struct = quote! {
                    #[derive(Debug, PartialEq, bytemuck::AnyBitPattern, bytemuck::NoUninit, Copy, Clone)]
                    #[repr(C)]
                    pub struct #struct_name {
                        #(#fields),*
                    }
                };

                self.arg_name = Some(struct_name);
                self.arg_struct = Some(generated_struct);
            }
        }

        Ok(())
    }
}