Skip to main content

ntex_grpc_derive/
lib.rs

1use proc_macro::TokenStream;
2use syn::{fold::Fold, parse::Parse, parse::ParseStream, punctuated::Punctuated};
3
4const ERR_M_MESSAGE: &str = "invalid method definition, expected: #[method(name)]";
5
6#[proc_macro_attribute]
7pub fn server(attr: TokenStream, item: TokenStream) -> TokenStream {
8    server_impl(attr, item)
9}
10
11fn server_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
12    let mut srv = syn::parse_macro_input!(attr as GrpcService);
13    let input: syn::ItemImpl = syn::parse2(item.into()).unwrap();
14
15    match input.self_ty.as_ref() {
16        syn::Type::Path(tp) => {
17            srv.self_ty = tp.path.clone();
18            if let Some(s) = tp.path.segments.last() {
19                srv.name = format!("{}", s.ident);
20            } else {
21                panic!("struct name is required");
22            }
23        }
24        _ => panic!("struct impl block is supported only"),
25    }
26
27    let input = srv.fold_item_impl(input);
28
29    let ty = srv.self_ty;
30    let srvpath = srv.service;
31    let srvname = srv.service_name;
32    let srvmod = srv.service_mod;
33    let modname = quote::format_ident!("_priv_{}", srv.name);
34    let methods_prefix = quote::format_ident!("{}Methods", srvname);
35    let mut methods_path = srvmod;
36    methods_path.segments.push(methods_prefix.into());
37
38    let mut methods = Vec::new();
39    for (m_name, fn_name, span) in srv.methods {
40        methods.push(quote::quote_spanned! {span=>
41            Some(#methods_path::#m_name(method)) => {
42                use ::ntex_grpc::MethodDef;
43                let req = ::ntex_grpc::server::Request {
44                    message: method.decode(&mut req.payload)?,
45                    name: req.name,
46                    headers: req.headers
47                };
48
49                let result = #ty::#fn_name(self, ::ntex_grpc::server::FromRequest::from(req)).await;
50
51                let res = method.server_result(result);
52                let response = ::ntex_grpc::server::Response::from(res);
53                let mut buf = ::ntex_grpc::BytesMut::new();
54                method.encode(response.message, &mut buf);
55
56                Ok(::ntex_grpc::server::ServerResponse::with_headers(buf.freeze(), response.headers))
57            }
58        });
59    }
60
61    let service = quote::quote! {
62        mod #modname {
63            use super::*;
64
65            impl ::ntex_grpc::Service<::ntex_grpc::server::ServerRequest> for #ty {
66                type Response = ::ntex_grpc::server::ServerResponse;
67                type Error = ::ntex_grpc::server::ServerError;
68
69                async fn call(&self, mut req: ::ntex_grpc::server::ServerRequest, _: ::ntex_grpc::ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
70                    use ::ntex_grpc::{ServiceDef, MethodDef};
71
72                    match #srvpath::method_by_name(&req.name) {
73                        #(#methods)*
74                        Some(_) => Err(::ntex_grpc::server::ServerError::new(
75                            ::ntex_grpc::GrpcStatus::Unimplemented,
76                            ::ntex_grpc::HeaderValue::from_shared(
77                                ::ntex_grpc::ByteString::from(format!("Service method is not implemented: {0}", req.name)).into_bytes()
78                            ).unwrap(),
79                            None
80                        )),
81                        None => Err(::ntex_grpc::server::ServerError::new(
82                            ::ntex_grpc::GrpcStatus::NotFound,
83                            ::ntex_grpc::HeaderValue::from_shared(
84                                ::ntex_grpc::ByteString::from(format!("Service method is not found: {0}", req.name)).into_bytes()
85                            ).unwrap(),
86                            None
87                        ))
88                    }
89                }
90            }
91        }
92    };
93
94    let tokens = quote::quote! {
95        #input
96        #service
97    };
98    tokens.into()
99}
100
101#[derive(Debug)]
102struct GrpcService {
103    name: String,
104    self_ty: syn::Path,
105    service: syn::Path,
106    service_mod: syn::Path,
107    service_name: syn::Ident,
108    methods: Vec<(syn::Ident, syn::Ident, proc_macro2::Span)>,
109}
110
111impl Parse for GrpcService {
112    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
113        let parsed: Punctuated<syn::Path, syn::Token![,]> = Punctuated::parse_terminated(input)?;
114        let path = parsed.first().unwrap().clone();
115        let service = parsed.first().unwrap().clone();
116        let mut service_mod = service.clone();
117        service_mod.segments.pop();
118        let service_name = path.segments.last().unwrap().ident.clone();
119        Ok(GrpcService {
120            service,
121            service_mod,
122            service_name,
123            methods: Vec::new(),
124            name: String::new(),
125            self_ty: path,
126        })
127    }
128}
129
130impl Fold for GrpcService {
131    fn fold_impl_item_fn(&mut self, mut m: syn::ImplItemFn) -> syn::ImplItemFn {
132        for idx in 0..m.attrs.len() {
133            let attr = &m.attrs[idx];
134            if attr.path().is_ident("method") {
135                let lst = if let syn::Meta::List(ref lst) = attr.meta {
136                    lst
137                } else {
138                    panic!("{}", ERR_M_MESSAGE)
139                };
140
141                let name: syn::Path = lst.parse_args().expect(ERR_M_MESSAGE);
142                let m_name = if let Some(ident) = name.get_ident() {
143                    ident.clone()
144                } else {
145                    panic!("only simple identifiers are supported: {:?}", name);
146                };
147
148                let _ = m.attrs.remove(idx);
149                self.methods
150                    .push((m_name, m.sig.ident.clone(), m.sig.fn_token.span));
151                break;
152            }
153        }
154
155        m
156    }
157}