salvo-oapi-macros 0.91.1

Macros for salvo-oapi
Documentation
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, quote};
use syn::{Expr, Ident, ImplItem, Item, Pat, ReturnType, Signature, Type};

use crate::doc_comment::CommentAttributes;
use crate::{Array, DiagResult, InputType, Operation, omit_type_path_lifetimes, parse_input_type};

mod attr;
pub(crate) use attr::EndpointAttr;

fn metadata(
    salvo: &Ident,
    oapi: &Ident,
    attr: &EndpointAttr,
    name: &Ident,
    mut modifiers: Vec<TokenStream>,
) -> DiagResult<TokenStream> {
    let tfn = Ident::new(
        &format!("__macro_gen_oapi_endpoint_type_id_{name}"),
        Span::call_site(),
    );
    let cfn = Ident::new(
        &format!("__macro_gen_oapi_endpoint_creator_{name}"),
        Span::call_site(),
    );
    let opt = Operation::new(attr);
    modifiers.append(opt.modifiers()?.as_mut());
    let status_codes = Array::from_iter(attr.status_codes.iter().map(|expr| match expr {
        Expr::Lit(lit) => {
            quote! {
                #salvo::http::StatusCode::from_u16(#lit).unwrap()
            }
        }
        _ => {
            quote! {
                #expr
            }
        }
    }));
    let modifiers = if modifiers.is_empty() {
        None
    } else {
        Some(quote! {{
            let mut components = &mut components;
            let mut operation = &mut operation;
            #(#modifiers)*
        }})
    };
    let stream = quote! {
        fn #tfn() -> ::std::any::TypeId {
            ::std::any::TypeId::of::<#name>()
        }
        fn #cfn() -> #oapi::oapi::Endpoint {
            let mut components = #oapi::oapi::Components::new();
            let status_codes: &[#salvo::http::StatusCode] = &#status_codes;
            let mut operation = #oapi::oapi::Operation::new();
            #modifiers
            if operation.operation_id.is_none() {
                operation.operation_id = Some(#oapi::oapi::naming::assign_name::<#name>(#oapi::oapi::naming::NameRule::Auto));
            }
            if !status_codes.is_empty() {
                let responses = std::ops::DerefMut::deref_mut(&mut operation.responses);
                responses.retain(|k,_| {
                    if let Ok(code) = <#salvo::http::StatusCode as std::str::FromStr>::from_str(k) {
                        status_codes.contains(&code)
                    } else {
                        true
                    }
                });
            }
            #oapi::oapi::Endpoint{
                operation,
                components,
            }
        }
        #oapi::oapi::__private::inventory::submit! {
            #oapi::oapi::EndpointRegistry::save(#tfn, #cfn)
        }
    };
    Ok(stream)
}
pub(crate) fn generate(mut attr: EndpointAttr, input: Item) -> syn::Result<TokenStream> {
    let salvo = crate::salvo_crate();
    let oapi = crate::oapi_crate();
    match input {
        Item::Fn(mut item_fn) => {
            let attrs = item_fn
                .attrs
                .iter()
                .filter(|attr| !attr.path().is_ident("endpoint"))
                .collect::<Vec<_>>();
            let vis = &item_fn.vis;
            let sig = &mut item_fn.sig;
            let body = &item_fn.block;
            let name = &sig.ident;
            let docs = item_fn
                .attrs
                .iter()
                .filter(|attr| attr.path().is_ident("doc"))
                .cloned()
                .collect::<Vec<_>>();

            let sdef = quote! {
                #(#docs)*
                #[allow(non_camel_case_types)]
                #[derive(Debug)]
                #vis struct #name;
                impl #name {
                    #(#attrs)*
                    #sig {
                        #body
                    }
                }
            };

            attr.doc_comments = Some(CommentAttributes::from_attributes(&docs).0);
            attr.deprecated = if attrs.iter().any(|attr| attr.path().is_ident("deprecated")) {
                Some(true)
            } else {
                None
            };

            let (hfn, modifiers) = handle_fn(&salvo, &oapi, sig)?;
            let meta = metadata(&salvo, &oapi, &attr, name, modifiers)?;
            Ok(quote! {
                #sdef
                #[#salvo::async_trait]
                impl #salvo::Handler for #name {
                    #hfn
                }
                #meta
            })
        }
        Item::Impl(item_impl) => {
            let attrs = &item_impl.attrs;

            attr.doc_comments = Some(CommentAttributes::from_attributes(attrs).0);
            attr.deprecated = if attrs.iter().any(|attr| attr.path().is_ident("deprecated")) {
                Some(true)
            } else {
                None
            };

            let mut hmtd = None;
            for item in &item_impl.items {
                if let ImplItem::Fn(method) = item
                    && method.sig.ident == Ident::new("handle", Span::call_site())
                {
                    hmtd = Some(method);
                }
            }
            let Some(hmtd) = hmtd else {
                return Err(syn::Error::new_spanned(
                    item_impl.impl_token,
                    "missing handle function",
                ));
            };
            let (hfn, modifiers) = handle_fn(&salvo, &oapi, &hmtd.sig)?;
            let ty = &item_impl.self_ty;
            let (impl_generics, _, where_clause) = &item_impl.generics.split_for_impl();
            let name = Ident::new(&ty.to_token_stream().to_string(), Span::call_site());
            let meta = metadata(&salvo, &oapi, &attr, &name, modifiers)?;

            Ok(quote! {
                #item_impl
                #[#salvo::async_trait]
                impl #impl_generics #salvo::Handler for #ty #where_clause {
                    #hfn
                }
                #meta
            })
        }
        _ => Err(syn::Error::new_spanned(
            input,
            "#[handler] must added to `impl` or `fn`",
        )),
    }
}

fn handle_fn(
    salvo: &Ident,
    oapi: &Ident,
    sig: &Signature,
) -> syn::Result<(TokenStream, Vec<TokenStream>)> {
    let name = &sig.ident;
    let mut extract_ts = Vec::with_capacity(sig.inputs.len());
    let mut call_args: Vec<Ident> = Vec::with_capacity(sig.inputs.len());
    let mut modifiers = Vec::new();
    for input in &sig.inputs {
        match parse_input_type(input) {
            InputType::Request(_pat) => {
                call_args.push(Ident::new("__macro_gen_req", Span::call_site()));
            }
            InputType::Depot(_pat) => {
                call_args.push(Ident::new("__macro_gen_depot", Span::call_site()));
            }
            InputType::Response(_pat) => {
                call_args.push(Ident::new("__macro_gen_res", Span::call_site()));
            }
            InputType::FlowCtrl(_pat) => {
                call_args.push(Ident::new("__macro_gen_ctrl", Span::call_site()));
            }
            InputType::Unknown => {
                return Err(syn::Error::new_spanned(
                    &sig.inputs,
                    "the inputs parameters must be Request, Depot, Response or FlowCtrl",
                ));
            }
            InputType::NoReference(pat) => {
                if let (Pat::Ident(ident), Type::Path(ty)) = (&*pat.pat, &*pat.ty) {
                    call_args.push(ident.ident.clone());
                    let ty = omit_type_path_lifetimes(ty);
                    let idv = pat.pat.to_token_stream().to_string();
                    // If id like `mut pdata`, then idv is `pdata`;
                    let idv = idv
                        .rsplit_once(' ')
                        .map(|(_, v)| v.to_owned())
                        .unwrap_or(idv);
                    let id = Ident::new(&idv, Span::call_site());
                    let idv = idv.trim_start_matches('_');
                    extract_ts.push(quote!{
                        let #id: #ty = match <#ty as #salvo::Extractible>::extract_with_arg(__macro_gen_req, __macro_gen_depot, #idv).await {
                            Ok(data) => {
                                data
                            },
                            Err(e) => {
                                e.write(__macro_gen_req, __macro_gen_depot, __macro_gen_res).await;
                                // If status code is not set or is not error, set it to 400.
                                let status_code = __macro_gen_res.status_code.unwrap_or_default();
                                if !status_code.is_client_error() && !status_code.is_server_error() {
                                    __macro_gen_res.status_code(#salvo::http::StatusCode::BAD_REQUEST);
                                }
                                return;
                            }
                        };
                    });
                    modifiers.push(quote! {
                         <#ty as #oapi::oapi::EndpointArgRegister>::register(components, operation, #idv);
                    });
                } else {
                    return Err(syn::Error::new_spanned(pat, "invalid param definition"));
                }
            }
            InputType::Receiver(_) => {
                call_args.push(Ident::new("self", Span::call_site()));
            }
        }
    }

    let hfn = match &sig.output {
        ReturnType::Default => {
            if sig.asyncness.is_none() {
                quote! {
                    async fn handle(&self, __macro_gen_req: &mut #salvo::Request, __macro_gen_depot: &mut #salvo::Depot, __macro_gen_res: &mut #salvo::Response, __macro_gen_ctrl: &mut #salvo::FlowCtrl) {
                        #(#extract_ts)*
                        Self::#name(#(#call_args),*)
                    }
                }
            } else {
                quote! {
                    async fn handle(&self, __macro_gen_req: &mut #salvo::Request, __macro_gen_depot: &mut #salvo::Depot, __macro_gen_res: &mut #salvo::Response, __macro_gen_ctrl: &mut #salvo::FlowCtrl) {
                        #(#extract_ts)*
                        Self::#name(#(#call_args),*).await
                    }
                }
            }
        }
        ReturnType::Type(_, ty) => {
            modifiers.push(quote! {
                <#ty as #oapi::oapi::EndpointOutRegister>::register(components, operation);
            });
            if sig.asyncness.is_none() {
                quote! {
                    async fn handle(&self, __macro_gen_req: &mut #salvo::Request, __macro_gen_depot: &mut #salvo::Depot, __macro_gen_res: &mut #salvo::Response, __macro_gen_ctrl: &mut #salvo::FlowCtrl) {
                        #(#extract_ts)*
                        #salvo::Writer::write(Self::#name(#(#call_args),*), __macro_gen_req, __macro_gen_depot, __macro_gen_res).await;
                    }
                }
            } else {
                quote! {
                    async fn handle(&self, __macro_gen_req: &mut #salvo::Request, __macro_gen_depot: &mut #salvo::Depot, __macro_gen_res: &mut #salvo::Response, __macro_gen_ctrl: &mut #salvo::FlowCtrl) {
                        #(#extract_ts)*
                        #salvo::Writer::write(Self::#name(#(#call_args),*).await, __macro_gen_req, __macro_gen_depot, __macro_gen_res).await;
                    }
                }
            }
        }
    };
    Ok((hfn, modifiers))
}

#[cfg(test)]
mod tests {
    use quote::quote;
    use syn::{Ident, Signature, parse_str};

    use super::handle_fn;

    #[test]
    fn test_handle_fn() {
        let salvo = Ident::new("salvo", proc_macro2::Span::call_site());
        let oapi = Ident::new("salvo_oapi", proc_macro2::Span::call_site());
        let sig: Signature = parse_str("fn hello(name: String)").unwrap();
        let (hfn, modifiers) = handle_fn(&salvo, &oapi, &sig).unwrap();
        let expected_hfn = quote! {
            async fn handle(&self, __macro_gen_req: &mut salvo::Request, __macro_gen_depot: &mut salvo::Depot, __macro_gen_res: &mut salvo::Response, __macro_gen_ctrl: &mut salvo::FlowCtrl) {
                let name: String = match <String as salvo::Extractible>::extract_with_arg(__macro_gen_req, __macro_gen_depot, "name").await {
                    Ok(data) => {
                        data
                    },
                    Err(e) => {
                        e.write(__macro_gen_req, __macro_gen_depot, __macro_gen_res).await;
                        // If status code is not set or is not error, set it to 400.
                        let status_code = __macro_gen_res.status_code.unwrap_or_default();
                        if !status_code.is_client_error() && !status_code.is_server_error() {
                            __macro_gen_res.status_code(salvo::http::StatusCode::BAD_REQUEST);
                        }
                        return;
                    }
                };
                Self::hello(name)
            }
        };
        assert_eq!(hfn.to_string(), expected_hfn.to_string());
        assert_eq!(modifiers.len(), 1);
        let expected_modifier = quote! {
            <String as salvo_oapi::oapi::EndpointArgRegister>::register(components, operation, "name");
        };
        assert_eq!(modifiers[0].to_string(), expected_modifier.to_string());
    }
}