prpc_build/
server.rs

1use super::{Method, Service};
2use crate::{generate_doc_comment, generate_doc_comments, naive_snake_case, Builder};
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{Ident, Lit, LitStr};
6
7/// Generate service for Server.
8///
9/// This takes some `Service` and will generate a `TokenStream` that contains
10/// a public module containing the server service and handler trait.
11pub fn generate<T: Service>(service: &T, config: &Builder) -> TokenStream {
12    let attributes = &config.server_attributes;
13    let methods = generate_methods(service, config, false);
14    let json_methods = generate_methods(service, config, true);
15
16    let server_service = quote::format_ident!("{}Server", service.name());
17    let server_trait = quote::format_ident!("{}Rpc", service.name());
18    let server_mod = quote::format_ident!("{}_server", naive_snake_case(service.name()));
19    let service_name = Lit::Str(LitStr::new(service.name(), Span::call_site()));
20    let supported_methods = generate_supported_methods(service, config);
21    let method_enum = generate_methods_enum(service, config);
22    let generated_trait = generate_trait(service, config, server_trait.clone());
23    let service_doc = generate_doc_comments(service.comment());
24    let mod_attributes = attributes.for_mod(service.package());
25    let struct_attributes = attributes.for_struct(service.identifier());
26
27    quote! {
28        /// Generated server implementations.
29        #(#mod_attributes)*
30        pub mod #server_mod {
31            use alloc::vec::Vec;
32
33            #method_enum
34
35            #generated_trait
36
37            #service_doc
38            #(#struct_attributes)*
39            #[derive(Debug)]
40            pub struct #server_service<T: #server_trait> {
41                inner: T,
42            }
43
44            impl<T: #server_trait> #server_service<T> {
45                pub fn new(inner: T) -> Self {
46                    Self {
47                        inner,
48                    }
49                }
50
51                pub async fn dispatch_request(self, path: &str, _data: impl AsRef<[u8]>) -> Result<Vec<u8>, ::prpc::server::Error> {
52                    #![allow(clippy::let_unit_value)]
53                    match path {
54                        #methods
55                        _ => anyhow::bail!("Service not found: {path}"),
56                    }
57                }
58
59                pub async fn dispatch_json_request(self, path: &str, _data: impl AsRef<[u8]>, _query: bool) -> Result<Vec<u8>, ::prpc::server::Error> {
60                    #![allow(clippy::let_unit_value)]
61                    match path {
62                        #json_methods
63                        _ => anyhow::bail!("Service not found: {path}"),
64                    }
65                }
66                #supported_methods
67            }
68
69            impl<T: #server_trait> ::prpc::server::NamedService for #server_service<T> {
70                const NAME: &'static str = #service_name;
71            }
72            impl<T: #server_trait> ::prpc::server::Service for #server_service<T> {
73                type Methods = &'static [&'static str];
74                fn methods() -> Self::Methods {
75                    Self::supported_methods()
76                }
77                async fn dispatch_request(self, path: &str, data: impl AsRef<[u8]>, json: bool, query: bool) -> Result<Vec<u8>, ::prpc::server::Error> {
78                    if json {
79                        self.dispatch_json_request(path, data, query).await
80                    } else {
81                        self.dispatch_request(path, data).await
82                    }
83                }
84            }
85            impl<T: #server_trait> From<T> for #server_service<T> {
86                fn from(inner: T) -> Self {
87                    Self::new(inner)
88                }
89            }
90        }
91    }
92}
93
94fn generate_trait<T: Service>(service: &T, config: &Builder, server_trait: Ident) -> TokenStream {
95    let methods =
96        generate_trait_methods(service, &config.proto_path, config.compile_well_known_types);
97    let trait_doc = generate_doc_comment(format!(
98        "Generated trait containing RPC methods that should be implemented for use with {}Server.",
99        service.name()
100    ));
101
102    quote! {
103        #trait_doc
104        pub trait #server_trait {
105            #methods
106        }
107    }
108}
109
110fn generate_trait_methods<T: Service>(
111    service: &T,
112    proto_path: &str,
113    compile_well_known_types: bool,
114) -> TokenStream {
115    let mut stream = TokenStream::new();
116
117    for method in service.methods() {
118        let name = quote::format_ident!("{}", method.name());
119
120        let (req_message, res_message) =
121            method.request_response_name(proto_path, compile_well_known_types);
122
123        let method_doc = generate_doc_comments(method.comment());
124
125        let method = match (method.client_streaming(), method.server_streaming()) {
126            (false, false) => {
127                template_quote::quote! {
128                    #method_doc
129                    async fn #name(self
130                        #(if req_message.is_some()) {
131                            , request: #req_message
132                        }
133                    ) -> ::anyhow::Result<#res_message>;
134                }
135            }
136            _ => {
137                panic!("Streaming RPC not supported");
138            }
139        };
140
141        stream.extend(method);
142    }
143
144    stream
145}
146
147fn generate_supported_methods<T: Service>(service: &T, config: &Builder) -> TokenStream {
148    let mut all_methods = TokenStream::new();
149    for method in service.methods() {
150        let path = crate::join_path(
151            config,
152            service.package(),
153            service.identifier(),
154            method.identifier(),
155        );
156
157        let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
158        all_methods.extend(quote! {
159            #method_path,
160        });
161    }
162
163    quote! {
164        pub fn supported_methods()
165            -> &'static [&'static str] {
166                &[
167                    #all_methods
168                ]
169            }
170    }
171}
172
173fn generate_methods_enum<T: Service>(service: &T, config: &Builder) -> TokenStream {
174    let mut paths = vec![];
175    let mut variants = vec![];
176    for method in service.methods() {
177        let path = crate::join_path(
178            config,
179            service.package(),
180            service.identifier(),
181            method.identifier(),
182        );
183
184        let variant = Ident::new(method.identifier(), Span::call_site());
185        variants.push(variant);
186
187        let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
188        paths.push(method_path);
189    }
190
191    let enum_name = Ident::new(
192        &format!("{}Method", service.identifier()),
193        Span::call_site(),
194    );
195    quote! {
196        pub enum #enum_name {
197            #(#variants,)*
198        }
199
200        impl #enum_name {
201            #[allow(clippy::should_implement_trait)]
202            pub fn from_str(path: &str) -> Option<Self> {
203                match path {
204                    #(#paths => Some(Self::#variants),)*
205                    _ => None,
206                }
207            }
208        }
209    }
210}
211
212fn generate_methods<T: Service>(service: &T, config: &Builder, json: bool) -> TokenStream {
213    let mut stream = TokenStream::new();
214
215    for method in service.methods() {
216        let path = crate::join_path(
217            config,
218            service.package(),
219            service.identifier(),
220            method.identifier(),
221        );
222        let method_path = Lit::Str(LitStr::new(&path, Span::call_site()));
223        let method_ident = quote::format_ident!("{}", method.name());
224
225        let method_stream = match (method.client_streaming(), method.server_streaming()) {
226            (false, false) => generate_unary(method, config, method_ident, json),
227            _ => {
228                panic!("Streaming RPC not supported");
229            }
230        };
231
232        let method = quote! {
233            #method_path => {
234                #method_stream
235            }
236        };
237        stream.extend(method);
238    }
239
240    stream
241}
242
243fn generate_unary<T: Method>(
244    method: &T,
245    config: &Builder,
246    method_ident: Ident,
247    json: bool,
248) -> TokenStream {
249    let (request, _response) =
250        method.request_response_name(&config.proto_path, config.compile_well_known_types);
251
252    if json {
253        template_quote::quote! {
254            #(if request.is_none()) {
255                let response = self.inner.#method_ident().await?;
256            }
257            #(else) {
258                let data = _data.as_ref();
259                let input: #request = if data.is_empty() {
260                    Default::default()
261                } else if _query {
262                    ::prpc::serde_qs::from_bytes(data)?
263                } else {
264                    ::prpc::serde_json::from_slice(data)?
265                };
266                let response = self.inner.#method_ident(input).await?;
267            }
268            Ok(serde_json::to_vec(&response)?)
269        }
270    } else {
271        template_quote::quote! {
272            #(if request.is_none()) {
273                let response = self.inner.#method_ident().await?;
274            }
275            #(else) {
276                let input: #request = ::prpc::Message::decode(_data.as_ref())?;
277                let response = self.inner.#method_ident(input).await?;
278            }
279            Ok(::prpc::codec::encode_message_to_vec(&response))
280        }
281    }
282}