resuma-macros 0.3.3

Procedural macros for Resuma: view!, #[component], #[server], #[island], js!
Documentation
//! `#[server]` — exposes an async fn as a server action.
//!
//! Generates:
//!  * a wrapper that registers the action in the global registry
//!  * optional `&FlowRequest` injection as the last parameter

use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::{parse2, FnArg, ItemFn, Pat, ReturnType, Type};

pub fn expand(_args: TokenStream, input: TokenStream) -> TokenStream {
    let func: ItemFn = match parse2(input) {
        Ok(f) => f,
        Err(e) => return e.to_compile_error(),
    };

    let name = func.sig.ident.clone();
    let name_str = name.to_string();
    let vis = &func.vis;
    let block = &func.block;
    let asyncness = &func.sig.asyncness;
    let output = &func.sig.output;

    if asyncness.is_none() {
        return syn::Error::new(Span::call_site(), "#[server] functions must be async")
            .to_compile_error();
    }

    let inputs = func.sig.inputs.clone();

    let mut arg_idents = Vec::new();
    let mut arg_types = Vec::new();
    for a in &inputs {
        if let FnArg::Typed(pt) = a {
            if let Pat::Ident(pi) = &*pt.pat {
                arg_idents.push(pi.ident.clone());
                arg_types.push((*pt.ty).clone());
            }
        }
    }

    let (call_idents, has_req) = split_request_arg(&arg_idents, &arg_types);

    let dispatcher_name = format_ident!("__resuma_action_dispatch_{}", name);
    let registry_ctor = format_ident!("__resuma_action_register_{}", name);
    let trampoline_name = format_ident!("__resuma_action_trampoline_{}", name);

    let json_extract = call_idents.iter().enumerate().map(|(i, id)| {
        quote! {
            let #id: _ = match args.get(#i).cloned() {
                Some(v) => match ::resuma::__private::serde_json::from_value(v) {
                    Ok(v) => v,
                    Err(e) => return Err(::resuma::__private::ResumaError::Other(format!("bad arg `{}`: {}", stringify!(#id), e))),
                },
                None => return Err(::resuma::__private::ResumaError::Other(format!("missing arg `{}`", stringify!(#id)))),
            };
        }
    });

    let call = match (call_idents.is_empty(), has_req) {
        (true, true) => quote!( #name( &req ) ),
        (true, false) => quote!( #name() ),
        (false, true) => quote!( #name( #(#call_idents),*, &req ) ),
        (false, false) => quote!( #name( #(#call_idents),* ) ),
    };

    let returns_result = return_type_is_result(output);

    let serialize_result = if returns_result {
        quote! {
            match #call.await {
                Ok(v) => ::resuma::__private::serde_json::to_value(&v)
                    .map_err(::resuma::__private::ResumaError::from),
                Err(e) => Err(e),
            }
        }
    } else {
        quote! {
            ::resuma::__private::serde_json::to_value(&#call.await)
                .map_err(::resuma::__private::ResumaError::from)
        }
    };

    quote! {
        #vis async fn #name ( #inputs ) #output #block

        #[doc(hidden)]
        pub async fn #dispatcher_name(
            args: ::std::vec::Vec<::resuma::__private::serde_json::Value>,
            req: ::resuma::FlowRequest,
        ) -> ::resuma::__private::Result<::resuma::__private::serde_json::Value> {
            #(#json_extract)*
            #serialize_result
        }

        #[doc(hidden)]
        fn #trampoline_name (
            args: ::std::vec::Vec<::resuma::__private::serde_json::Value>,
            req: ::resuma::FlowRequest,
        ) -> ::std::pin::Pin<::std::boxed::Box<
            dyn ::std::future::Future<
                Output = ::resuma::__private::Result<::resuma::__private::serde_json::Value>,
            > + ::std::marker::Send,
        >> {
            ::std::boxed::Box::pin(#dispatcher_name(args, req))
        }

        #[doc(hidden)]
        #[::resuma::__private::ctor::ctor]
        fn #registry_ctor() {
            ::resuma::__private::register_server_action(#name_str, #trampoline_name);
        }
    }
}

fn split_request_arg(idents: &[syn::Ident], types: &[Type]) -> (Vec<syn::Ident>, bool) {
    if idents.is_empty() {
        return (Vec::new(), false);
    }
    let last_ty = &types[types.len() - 1];
    if is_flow_request_ref(last_ty) {
        (idents[..idents.len() - 1].to_vec(), true)
    } else {
        (idents.to_vec(), false)
    }
}

fn is_flow_request_ref(ty: &Type) -> bool {
    let Type::Reference(r) = ty else {
        return false;
    };
    is_flow_request_type(&r.elem)
}

fn is_flow_request_type(ty: &Type) -> bool {
    match ty {
        Type::Path(p) => p
            .path
            .segments
            .last()
            .is_some_and(|s| s.ident == "FlowRequest"),
        _ => false,
    }
}

fn return_type_is_result(output: &ReturnType) -> bool {
    let ReturnType::Type(_, ty) = output else {
        return false;
    };
    let Type::Path(p) = &**ty else {
        return false;
    };
    p.path.segments.last().is_some_and(|s| s.ident == "Result")
}