sov-modules-macros 0.3.0

Macros for use with the Sovereign SDK module system
use std::str::FromStr;

use proc_macro2::Ident;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
    parenthesized, Attribute, FnArg, ImplItem, Meta, MetaList, PatType, Path, PathSegment,
    Signature,
};

/// Returns an attribute with the name `rpc_method` replaced with `method`, and the index
/// into the argument array where the attribute was found.
fn get_method_attribute(attributes: &[Attribute]) -> Option<(Attribute, usize)> {
    for (idx, attribute) in attributes.iter().enumerate() {
        if let Ok(Meta::List(MetaList { path, .. })) = attribute.parse_meta() {
            if path.is_ident("rpc_method") {
                let mut new_attr = attribute.clone();
                let path = &mut new_attr.path;
                path.segments.last_mut().unwrap().ident = format_ident!("method");
                return Some((new_attr, idx));
            }
        }
    }
    None
}

fn jsonrpsee_rpc_macro_path() -> Path {
    let segments = vec![
        Ident::new("jsonrpsee", proc_macro2::Span::call_site()),
        Ident::new("proc_macros", proc_macro2::Span::call_site()),
        Ident::new("rpc", proc_macro2::Span::call_site()),
    ];

    let path_segments = segments.into_iter().map(|ident| PathSegment {
        ident,
        arguments: syn::PathArguments::None,
    });

    Path {
        leading_colon: Some(syn::Token![::](proc_macro2::Span::call_site())),
        segments: syn::punctuated::Punctuated::from_iter(path_segments),
    }
}

fn find_working_set_argument(sig: &Signature) -> Option<(usize, syn::Type)> {
    for (idx, input) in sig.inputs.iter().enumerate() {
        if let FnArg::Typed(PatType { ty, .. }) = input {
            if let syn::Type::Reference(syn::TypeReference { elem, .. }) = *ty.clone() {
                if let syn::Type::Path(syn::TypePath { path, .. }) = elem.as_ref() {
                    if let Some(segment) = path.segments.last() {
                        // TODO: enforce that the working set has exactly one angle bracketed argument
                        if segment.ident == "WorkingSet" && !segment.arguments.is_empty() {
                            return Some((idx, *elem.clone()));
                        }
                    }
                }
            }
        }
    }
    None
}

struct RpcImplBlock {
    pub(crate) type_name: Ident,
    pub(crate) methods: Vec<RpcEnabledMethod>,
    pub(crate) working_set_type: Option<syn::Type>,
    pub(crate) generics: syn::Generics,
}

struct RpcEnabledMethod {
    pub(crate) method_name: Ident,
    pub(crate) method_signature: Signature,
    pub(crate) docs: Vec<Attribute>,
    pub(crate) idx_of_working_set_arg: Option<usize>,
}

impl RpcImplBlock {
    /// Builds the trait `_RpcImpl` That will be implemented by the runtime
    fn build_rpc_impl_trait(&self) -> proc_macro2::TokenStream {
        let type_name = &self.type_name;
        let generics = &self.generics;
        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

        let mut blanket_impl_methods = vec![];
        let mut impl_trait_methods = vec![];

        let impl_trait_name = format_ident!("{}RpcImpl", self.type_name);

        for method in self.methods.iter() {
            // Extract the names of the formal arguments
            let arg_values = method
                .method_signature
                .inputs
                .clone()
                .into_iter()
                .map(|item| {
                    if let FnArg::Typed(PatType { pat, .. }) = item {
                        if let syn::Pat::Ident(syn::PatIdent { ident, .. }) = *pat {
                            return quote! { #ident };
                        }
                        unreachable!("Expected a pattern identifier")
                    } else {
                        quote! { self }
                    }
                });

            let mut signature = method.method_signature.clone();
            let method_name = &method.method_name;
            let docs = &method.docs;

            let impl_trait_method = if let Some(idx) = method.idx_of_working_set_arg {
                // If necessary, adjust the signature to remove the working set argument and replace it with one generated by the implementer.
                // Remove the "self" argument as well
                let pre_working_set_args = arg_values
                    .clone()
                    .take(idx)
                    .filter(|arg| arg.to_string() != quote! { self }.to_string());
                let post_working_set_args = arg_values
                    .clone()
                    .skip(idx + 1)
                    .filter(|arg| arg.to_string() != quote! { self }.to_string());
                let mut inputs: Vec<syn::FnArg> = signature.inputs.clone().into_iter().collect();
                inputs.remove(idx);

                signature.inputs = inputs.into_iter().collect();

                quote! {
                    #( #docs )*
                    #signature {
                        <#type_name #ty_generics as ::std::default::Default>::default().#method_name(#(#pre_working_set_args,)* &mut Self::get_working_set(self), #(#post_working_set_args),* )
                    }
                }
            } else {
                // Remove the "self" argument, since the method is invoked on `self` using dot notation
                let arg_values = arg_values
                    .clone()
                    .filter(|arg| arg.to_string() != quote! { self }.to_string());
                quote! {
                    #signature {
                        <#type_name  #ty_generics as ::std::default::Default>::default().#method_name(#(#arg_values),*)
                    }
                }
            };

            impl_trait_methods.push(impl_trait_method);

            let blanket_impl_method = if let Some(idx) = method.idx_of_working_set_arg {
                // If necessary, adjust the signature to remove the working set argument.
                let pre_working_set_args = arg_values.clone().take(idx);
                let post_working_set_args = arg_values.clone().skip(idx + 1);
                quote! {
                    #( #docs )*
                    #signature {
                        <Self as #impl_trait_name #ty_generics >::#method_name(#(#pre_working_set_args,)* #(#post_working_set_args),* )
                    }
                }
            } else {
                quote! {
                    #( #docs )*
                    #signature {
                        <Self as #impl_trait_name #ty_generics >::#method_name(#(#arg_values),*)
                    }
                }
            };

            blanket_impl_methods.push(blanket_impl_method);
        }

        let rpc_impl_trait = if let Some(ref working_set_type) = self.working_set_type {
            quote! {
                /// Allows a Runtime to be converted into a functional RPC server by simply implementing the two required methods -
                /// `get_backing_impl(&self) -> MyModule` and `get_working_set(&self) -> ::sov_modules_api::WorkingSet<C>`
                pub trait #impl_trait_name #generics #where_clause {
                    /// Get a clean working set on top of the latest state
                    fn get_working_set(&self) -> #working_set_type;
                    #(#impl_trait_methods)*
                }
            }
        } else {
            quote! {
                /// Allows a Runtime to be converted into a functional RPC server by simply implementing the two required methods -
                /// `get_backing_impl(&self) -> MyModule` and `get_working_set(&self) -> ::sov_modules_api::WorkingSet<C>`
                pub trait #impl_trait_name #generics #where_clause {
                    #(#impl_trait_methods)*
                }
            }
        };

        let blanket_impl_generics = quote! {
            #impl_generics
        }
        .to_string();
        let blanket_impl_generics_without_braces = proc_macro2::TokenStream::from_str(
            &blanket_impl_generics[1..blanket_impl_generics.len() - 1],
        )
        .expect("Failed to parse generics without braces as token stream");
        let rpc_server_trait_name = format_ident!("{}RpcServer", self.type_name);
        let blanket_impl = quote! {
            impl <MacroGeneratedTypeWithLongNameToAvoidCollisions: #impl_trait_name #ty_generics
            + Send
            + Sync
            + 'static,  #blanket_impl_generics_without_braces > #rpc_server_trait_name #ty_generics for MacroGeneratedTypeWithLongNameToAvoidCollisions #where_clause {
                #(#blanket_impl_methods)*
            }
        };

        quote! {
            #rpc_impl_trait
            #blanket_impl
        }
    }
}

fn add_server_bounds_attr_if_missing(attrs: &mut Vec<syn::NestedMeta>) {
    for attr in attrs.iter() {
        if let syn::NestedMeta::Meta(syn::Meta::List(syn::MetaList { path, .. })) = attr {
            if path.is_ident("server_bounds") {
                return;
            }
        }
    }
    attrs.push(syn::NestedMeta::Meta(syn::Meta::List(
        syn::parse_quote! { server_bounds() },
    )));
}

fn build_rpc_trait(
    mut attrs: Vec<syn::NestedMeta>,
    type_name: Ident,
    mut input: syn::ItemImpl,
) -> Result<proc_macro2::TokenStream, syn::Error> {
    let intermediate_trait_name = format_ident!("{}Rpc", type_name);
    // If the user hasn't directly provided trait bounds, override jsonrpsee's defaults
    // with an empty bound. This prevents spurious compilation errors like `Context does not implement DeserializeOwned`
    add_server_bounds_attr_if_missing(&mut attrs);

    let wrapped_attr_args = quote! {
        ( #(#attrs),* )
    };
    let rpc_attribute = syn::Attribute {
        pound_token: syn::token::Pound {
            spans: [proc_macro2::Span::call_site()],
        },
        style: syn::AttrStyle::Outer,
        bracket_token: syn::token::Bracket {
            span: proc_macro2::Span::call_site(),
        },
        path: jsonrpsee_rpc_macro_path(),
        tokens: wrapped_attr_args,
    };
    // Iterate over the methods from the `impl` block, building up three lists of items as we go

    let generics = &input.generics;
    let mut rpc_info = RpcImplBlock {
        type_name: type_name.clone(),
        methods: vec![],
        working_set_type: None,
        generics: generics.clone(),
    };

    let mut intermediate_trait_items = vec![];
    let mut simplified_impl_items = vec![];
    for item in input.items.into_iter() {
        if let ImplItem::Method(ref method) = item {
            if let Some((attr, idx_of_rpc_attr)) = get_method_attribute(&method.attrs) {
                let mut intermediate_trait_inputs = method.sig.inputs.clone();
                let working_set_arg = find_working_set_argument(&method.sig);
                let idx_of_working_set_arg = if let Some((idx, ty)) = working_set_arg {
                    // Remove the working set argument from the intermediate trait signature
                    let mut inputs: Vec<syn::FnArg> =
                        intermediate_trait_inputs.into_iter().collect();
                    inputs.remove(idx);
                    intermediate_trait_inputs = inputs.into_iter().collect();

                    // Store the type of the working set argument for later reference
                    rpc_info.working_set_type = Some(ty);
                    Some(idx)
                } else {
                    None
                };
                let docs = method
                    .attrs
                    .iter()
                    .filter(|attr| attr.path.is_ident("doc"))
                    .cloned()
                    .collect::<Vec<_>>();
                rpc_info.methods.push(RpcEnabledMethod {
                    method_name: method.sig.ident.clone(),
                    method_signature: method.sig.clone(),
                    docs: docs.clone(),
                    idx_of_working_set_arg,
                });

                // Remove the working set argument from the signature
                let mut intermediate_signature = method.sig.clone();
                intermediate_signature.inputs = intermediate_trait_inputs;

                // Build the annotated signature for the intermediate trait
                let annotated_signature = quote! {
                    #( #docs )*
                    #attr
                    #intermediate_signature;
                };
                intermediate_trait_items.push(annotated_signature);

                let mut original_method = method.clone();
                original_method.attrs.remove(idx_of_rpc_attr);
                simplified_impl_items.push(ImplItem::Method(original_method));
                continue;
            }
        }
        simplified_impl_items.push(item)
    }

    let impl_rpc_trait_impl = rpc_info.build_rpc_impl_trait();

    // Replace the original impl block with a new version with the rpc_gen and related annotations removed
    input.items = simplified_impl_items;
    let simplified_impl = quote! {
        #input
    };

    let doc_string = format!("Generated RPC trait for {}", type_name);
    let (_, ty_generics, where_clause) = generics.split_for_impl();

    let rpc_output = quote! {
        #simplified_impl

        #impl_rpc_trait_impl


        #rpc_attribute
        #[doc = #doc_string]
        pub trait #intermediate_trait_name  #generics #where_clause {

            #(#intermediate_trait_items)*

            /// Check the health of the RPC server
            #[method(name = "health")]
            fn health(&self) -> ::jsonrpsee::core::RpcResult<()> {
                Ok(())
            }

            /// Get the address of this module
            #[method(name = "moduleAddress")]
            fn module_address(&self) -> ::jsonrpsee::core::RpcResult<String> {
                Ok(<#type_name #ty_generics as ::sov_modules_api::ModuleInfo>::address(&<#type_name #ty_generics as ::core::default::Default>::default()).to_string())
            }

        }
    };
    Ok(rpc_output)
}

pub(crate) fn rpc_gen(
    attrs: Vec<syn::NestedMeta>,
    input: syn::ItemImpl,
) -> Result<proc_macro2::TokenStream, syn::Error> {
    let type_name = match *input.self_ty {
        syn::Type::Path(ref type_path) => &type_path.path.segments.last().unwrap().ident,
        _ => return Err(syn::Error::new_spanned(input.self_ty, "Invalid type")),
    };

    build_rpc_trait(attrs, type_name.clone(), input)
}

struct TypeList(pub Punctuated<syn::Type, syn::token::Comma>);

impl Parse for TypeList {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let content;
        parenthesized!(content in input);
        Ok(TypeList(content.parse_terminated(syn::Type::parse)?))
    }
}