1use proc_macro::TokenStream;
2use syn::{fold::Fold, parse::Parse, parse::ParseStream, punctuated::Punctuated};
3
4const ERR_M_MESSAGE: &str = "invalid method definition, expected: #[method(name)]";
5
6#[proc_macro_attribute]
7pub fn server(attr: TokenStream, item: TokenStream) -> TokenStream {
8 server_impl(attr, item)
9}
10
11fn server_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
12 let mut srv = syn::parse_macro_input!(attr as GrpcService);
13 let input: syn::ItemImpl = syn::parse2(item.into()).unwrap();
14
15 match input.self_ty.as_ref() {
16 syn::Type::Path(tp) => {
17 srv.self_ty = tp.path.clone();
18 if let Some(s) = tp.path.segments.last() {
19 srv.name = format!("{}", s.ident);
20 } else {
21 panic!("struct name is required");
22 }
23 }
24 _ => panic!("struct impl block is supported only"),
25 }
26
27 let input = srv.fold_item_impl(input);
28
29 let ty = srv.self_ty;
30 let srvpath = srv.service;
31 let srvname = srv.service_name;
32 let srvmod = srv.service_mod;
33 let modname = quote::format_ident!("_priv_{}", srv.name);
34 let methods_prefix = quote::format_ident!("{}Methods", srvname);
35 let mut methods_path = srvmod;
36 methods_path.segments.push(methods_prefix.into());
37
38 let mut methods = Vec::new();
39 for (m_name, fn_name, span) in srv.methods {
40 methods.push(quote::quote_spanned! {span=>
41 Some(#methods_path::#m_name(method)) => {
42 use ::ntex_grpc::MethodDef;
43 let req = ::ntex_grpc::server::Request {
44 message: method.decode(&mut req.payload)?,
45 name: req.name,
46 headers: req.headers
47 };
48
49 let result = #ty::#fn_name(self, ::ntex_grpc::server::FromRequest::from(req)).await;
50
51 let res = method.server_result(result);
52 let response = ::ntex_grpc::server::Response::from(res);
53 let mut buf = ::ntex_grpc::BytesMut::new();
54 method.encode(response.message, &mut buf);
55
56 Ok(::ntex_grpc::server::ServerResponse::with_headers(buf.freeze(), response.headers))
57 }
58 });
59 }
60
61 let service = quote::quote! {
62 mod #modname {
63 use super::*;
64
65 impl ::ntex_grpc::Service<::ntex_grpc::server::ServerRequest> for #ty {
66 type Response = ::ntex_grpc::server::ServerResponse;
67 type Error = ::ntex_grpc::server::ServerError;
68
69 async fn call(&self, mut req: ::ntex_grpc::server::ServerRequest, _: ::ntex_grpc::ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
70 use ::ntex_grpc::{ServiceDef, MethodDef};
71
72 match #srvpath::method_by_name(&req.name) {
73 #(#methods)*
74 Some(_) => Err(::ntex_grpc::server::ServerError::new(
75 ::ntex_grpc::GrpcStatus::Unimplemented,
76 ::ntex_grpc::HeaderValue::from_shared(
77 ::ntex_grpc::ByteString::from(format!("Service method is not implemented: {0}", req.name)).into_bytes()
78 ).unwrap(),
79 None
80 )),
81 None => Err(::ntex_grpc::server::ServerError::new(
82 ::ntex_grpc::GrpcStatus::NotFound,
83 ::ntex_grpc::HeaderValue::from_shared(
84 ::ntex_grpc::ByteString::from(format!("Service method is not found: {0}", req.name)).into_bytes()
85 ).unwrap(),
86 None
87 ))
88 }
89 }
90 }
91 }
92 };
93
94 let tokens = quote::quote! {
95 #input
96 #service
97 };
98 tokens.into()
99}
100
101#[derive(Debug)]
102struct GrpcService {
103 name: String,
104 self_ty: syn::Path,
105 service: syn::Path,
106 service_mod: syn::Path,
107 service_name: syn::Ident,
108 methods: Vec<(syn::Ident, syn::Ident, proc_macro2::Span)>,
109}
110
111impl Parse for GrpcService {
112 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
113 let parsed: Punctuated<syn::Path, syn::Token![,]> = Punctuated::parse_terminated(input)?;
114 let path = parsed.first().unwrap().clone();
115 let service = parsed.first().unwrap().clone();
116 let mut service_mod = service.clone();
117 service_mod.segments.pop();
118 let service_name = path.segments.last().unwrap().ident.clone();
119 Ok(GrpcService {
120 service,
121 service_mod,
122 service_name,
123 methods: Vec::new(),
124 name: String::new(),
125 self_ty: path,
126 })
127 }
128}
129
130impl Fold for GrpcService {
131 fn fold_impl_item_fn(&mut self, mut m: syn::ImplItemFn) -> syn::ImplItemFn {
132 for idx in 0..m.attrs.len() {
133 let attr = &m.attrs[idx];
134 if attr.path().is_ident("method") {
135 let lst = if let syn::Meta::List(ref lst) = attr.meta {
136 lst
137 } else {
138 panic!("{}", ERR_M_MESSAGE)
139 };
140
141 let name: syn::Path = lst.parse_args().expect(ERR_M_MESSAGE);
142 let m_name = if let Some(ident) = name.get_ident() {
143 ident.clone()
144 } else {
145 panic!("only simple identifiers are supported: {:?}", name);
146 };
147
148 let _ = m.attrs.remove(idx);
149 self.methods
150 .push((m_name, m.sig.ident.clone(), m.sig.fn_token.span));
151 break;
152 }
153 }
154
155 m
156 }
157}