boluo-macros 0.1.1

boluo的宏
Documentation
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::parse::{Parse, ParseStream};
use syn::{Attribute, Error, Ident, ItemFn, LitStr, Token, Visibility};

pub fn route(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr = syn::parse_macro_input!(attr as RouteAttr);

    let item_fn = match syn::parse::<ItemFn>(item.clone()) {
        Ok(item_fn) => item_fn,
        Err(e) => return input_and_compile_error(item, e),
    };

    match Route::new(attr, item_fn) {
        Ok(route) => route.into_token_stream().into(),
        Err(e) => input_and_compile_error(item, e),
    }
}

struct PathAttr(String);

impl Parse for PathAttr {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        input.parse::<LitStr>().map(|s| s.value()).map(Self)
    }
}

impl ToTokens for PathAttr {
    fn to_tokens(&self, stream: &mut TokenStream2) {
        stream.append(LitStr::new(self.0.as_str(), Span::call_site()).token());
    }
}

#[derive(PartialEq, Eq, Hash)]
struct MethodAttr(String);

impl MethodAttr {
    fn parse_array(input: ParseStream) -> syn::Result<Vec<Self>> {
        let content;
        let _bracket_token = syn::bracketed!(content in input);

        let methods = content.parse_terminated(MethodAttr::parse_str, Token![,])?;
        Ok(methods.into_iter().collect())
    }

    fn parse_str(input: ParseStream) -> syn::Result<Self> {
        input.parse::<LitStr>().and_then(MethodAttr::try_from)
    }

    fn parse(input: ParseStream) -> syn::Result<Vec<Self>> {
        let name = input.parse::<Ident>()?;

        if name != "method" {
            return Err(Error::new_spanned(
                &name,
                &format!("illegal attribute `{}`", name),
            ));
        }

        input.parse::<Token![=]>()?;

        if MethodAttr::parse_array(&input.fork()).is_ok() {
            return MethodAttr::parse_array(input);
        }

        MethodAttr::parse_str(input).map(|v| vec![v])
    }
}

impl ToTokens for MethodAttr {
    fn to_tokens(&self, stream: &mut TokenStream2) {
        stream.append(LitStr::new(self.0.as_str(), Span::call_site()).token());
    }
}

impl TryFrom<LitStr> for MethodAttr {
    type Error = Error;

    fn try_from(value: LitStr) -> Result<Self, Self::Error> {
        let method = value.value();
        if method.len() == 0 {
            Err(Error::new_spanned(value, "invalid HTTP method"))
        } else {
            Ok(Self(method))
        }
    }
}

struct RouteAttr {
    path: PathAttr,
    methods: Vec<MethodAttr>,
}

impl Parse for RouteAttr {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let path = PathAttr::parse(input).map_err(|_| {
            Error::new(
                Span::call_site(),
                format!(r#"invalid route definition, expected #[route("<path>")]"#),
            )
        })?;

        if input.peek(Token![,]) {
            input.parse::<Token![,]>()?;
        }

        let methods = if input.is_empty() {
            vec![]
        } else {
            MethodAttr::parse(input)?
        };

        Ok(Self { path, methods })
    }
}

struct Route {
    item_fn: ItemFn,
    vis: Visibility,
    name: Ident,
    attr: RouteAttr,
    docs: Vec<Attribute>,
}

impl Route {
    fn new(attr: RouteAttr, item_fn: ItemFn) -> syn::Result<Self> {
        let vis = item_fn.vis.clone();
        let name = item_fn.sig.ident.clone();

        let docs = item_fn
            .attrs
            .iter()
            .filter(|attr| attr.path().is_ident("doc"))
            .cloned()
            .collect();

        Ok(Self {
            item_fn,
            vis,
            name,
            attr,
            docs,
        })
    }
}

impl ToTokens for Route {
    fn to_tokens(&self, tokens: &mut TokenStream2) {
        let Self {
            item_fn,
            vis,
            name,
            attr,
            docs,
        } = self;

        let RouteAttr { path, methods } = attr;

        let methods = methods.iter();

        let stream = quote! {
            #(#docs)*
            #[allow(non_camel_case_types)]
            #[derive(Clone, Copy)]
            #vis struct #name;

            impl ::boluo::service::Service<::boluo::request::Request> for #name {
                type Response = ::boluo::response::Response;
                type Error = ::boluo::BoxError;

                async fn call(
                    &self,
                    req: ::boluo::request::Request,
                ) -> ::std::result::Result<Self::Response, Self::Error> {
                    #item_fn

                    fn assert_service<S>(
                        service: S,
                    ) -> impl ::boluo::service::Service<
                        ::boluo::request::Request,
                        Response = ::boluo::response::Response,
                        Error = ::boluo::BoxError,
                    >
                    where
                        S: ::boluo::service::Service<
                            ::boluo::request::Request,
                            Response = ::boluo::response::Response,
                            Error = ::boluo::BoxError,
                        >,
                    {
                        service
                    }

                    let service = ::boluo::handler::handler_fn(#name);
                    let service = assert_service(service);

                    ::boluo::service::Service::call(&service, req).await
                }
            }

            impl ::std::convert::Into<::boluo::route::Route<#name>> for #name {
                fn into(self) -> ::boluo::route::Route<#name> {
                    let method_route = ::boluo::route::any(#name)
                        #(.add(::boluo::http::Method::try_from(#methods).unwrap()))*;
                    ::boluo::route::Route::new(#path, method_route)
                }
            }
        };

        tokens.extend(stream);
    }
}

fn input_and_compile_error(mut item: TokenStream, err: Error) -> TokenStream {
    let compile_err = TokenStream::from(err.to_compile_error());
    item.extend(compile_err);
    item
}