dcl_rpc/
codegen.rs

1//! Generate service code from a service definition in a `.proto` file.
2
3// Guidelines for generated code:
4//
5// Use fully-qualified paths, to reduce the chance of clashing with
6// user provided names.
7
8use proc_macro2::TokenStream;
9use prost_build::{Method, Service, ServiceGenerator};
10use quote::{format_ident, quote};
11
12#[derive(Default)]
13pub struct RPCServiceGenerator {}
14
15pub struct MethodSigTokensParams {
16    body: Option<TokenStream>,
17    with_context: bool,
18    is_for_client: bool,
19}
20
21impl RPCServiceGenerator {
22    pub fn new() -> RPCServiceGenerator {
23        Default::default()
24    }
25
26    fn client_stream_request(&self) -> TokenStream {
27        quote!(ClientStreamRequest)
28    }
29
30    fn server_stream_response(&self) -> TokenStream {
31        quote!(ServerStreamResponse)
32    }
33
34    fn method_sig_tokens(&self, method: &Method, params: MethodSigTokensParams) -> TokenStream {
35        let input_type = self.extract_input_token(method);
36        let output_type = self.extract_output_token(method, params.is_for_client);
37        let name = extract_name_token(method);
38        let context = extract_context_token(&params);
39        let body = extract_body_token(params);
40
41        if let Some(input_type) = input_type {
42            quote! {
43                async fn #name(&self, request: #input_type #context)
44                    #output_type #body
45            }
46        } else {
47            quote! {
48                async fn #name(&self #context)
49                    #output_type #body
50            }
51        }
52    }
53
54    fn extract_input_token(&self, method: &Method) -> Option<TokenStream> {
55        if method.input_type.to_string().eq("()") {
56            None
57        } else {
58            let input_type = format_ident!("{}", method.input_type);
59            Some(match method.client_streaming {
60                true => {
61                    let client_stream_request = self.client_stream_request();
62                    quote!(#client_stream_request<#input_type>)
63                }
64                false => quote!(#input_type),
65            })
66        }
67    }
68
69    fn extract_output_token(&self, method: &Method, is_client: bool) -> TokenStream {
70        if method.output_type.to_string().eq("()") {
71            // The unit type can not be casted to an Ident, so the empty token is needed
72            if is_client {
73                quote! { -> ClientResult<()> }
74            } else {
75                quote! { -> Result<(), Error> }
76            }
77        } else {
78            let output_type = format_ident!("{}", method.output_type);
79            match method.server_streaming {
80                true => {
81                    let server_stream_response = self.server_stream_response();
82                    if is_client {
83                        quote! {-> ClientResult<#server_stream_response<#output_type>>}
84                    } else {
85                        quote! {-> Result<#server_stream_response<#output_type>, Error>}
86                    }
87                }
88                false => {
89                    if is_client {
90                        quote! {-> ClientResult<#output_type>}
91                    } else {
92                        quote! {-> Result<#output_type, Error>}
93                    }
94                }
95            }
96        }
97    }
98
99    fn generate_stream_types(&self, buf: &mut String) {
100        buf.push('\n');
101        buf.push_str("use dcl_rpc::stream_protocol::Generator;");
102        buf.push('\n');
103        buf.push_str("pub type ServerStreamResponse<T> = Generator<T>;");
104        buf.push('\n');
105        buf.push_str("pub type ClientStreamRequest<T> = Generator<T>;");
106        buf.push('\n');
107    }
108
109    #[cfg(feature = "client")]
110    fn generate_client_trait(&self, service: &Service, buf: &mut String) {
111        // This is done with strings rather than tokens because Prost provides functions that
112        // return doc comments as strings.
113        buf.push_str("use dcl_rpc::client::ClientResult;\n");
114        buf.push('\n');
115        service.comments.append_with_indent(0, buf);
116
117        buf.push_str("#[async_trait::async_trait]\n");
118        buf.push_str(&format!(
119            "pub trait {}ClientDefinition<T: Transport + 'static>: ServiceClient<T> +  Send + Sync + 'static {{",
120            service.name
121        ));
122        for method in service.methods.iter() {
123            buf.push('\n');
124            method.comments.append_with_indent(1, buf);
125            buf.push_str(&format!(
126                "    {};\n",
127                self.method_sig_tokens(
128                    method,
129                    MethodSigTokensParams {
130                        body: None,
131                        with_context: false,
132                        is_for_client: true
133                    }
134                )
135            ));
136        }
137        buf.push_str("}\n");
138    }
139
140    fn get_server_service_name(&self, service: &Service) -> String {
141        format!("{}Server", service.name)
142    }
143
144    #[cfg(feature = "server")]
145    fn generate_server_trait(&self, service: &Service, buf: &mut String) {
146        buf.push_str("use std::sync::Arc;\n");
147        buf.push_str("use dcl_rpc::{rpc_protocol::{RemoteErrorResponse}, service_module_definition::ProcedureContext};\n");
148        // This is done with strings rather than tokens because Prost provides functions that
149        // return doc comments as strings.
150        buf.push('\n');
151        service.comments.append_with_indent(0, buf);
152
153        buf.push_str("#[async_trait::async_trait]\n");
154        buf.push_str(&format!(
155            "pub trait {}<Context, Error: RemoteErrorResponse>: Send + Sync + 'static {{",
156            self.get_server_service_name(service)
157        ));
158        for method in service.methods.iter() {
159            buf.push('\n');
160            method.comments.append_with_indent(1, buf);
161            buf.push_str(&format!(
162                "    {};\n",
163                self.method_sig_tokens(
164                    method,
165                    MethodSigTokensParams {
166                        body: None,
167                        with_context: true,
168                        is_for_client: false
169                    }
170                )
171            ));
172        }
173        buf.push_str("}\n");
174    }
175
176    #[cfg(feature = "client")]
177    fn generate_client_service(&self, service: &Service, buf: &mut String) {
178        buf.push('\n');
179        // Create struct
180
181        buf.push_str(
182            "use dcl_rpc::{client::{RpcClientModule, ServiceClient}, transports::Transport};",
183        );
184        buf.push_str(&format!(
185            "pub struct {}Client<T: Transport + 'static> {{",
186            service.name
187        ));
188        buf.push_str(&format!(
189            "    {},\n",
190            "rpc_client_module: RpcClientModule<T>"
191        ));
192        buf.push('}');
193
194        buf.push('\n');
195
196        buf.push_str(&format!(
197            "impl<T: Transport + 'static> ServiceClient<T> for {}Client<T> {{
198    fn set_client_module(rpc_client_module: RpcClientModule<T>) -> Self {{
199        Self {{ rpc_client_module }}
200    }}
201}}
202",
203            service.name
204        ));
205
206        buf.push_str("#[async_trait::async_trait]\n");
207        buf.push_str(&format!(
208            "impl<T: Transport + 'static> {}ClientDefinition<T> for {}Client<T> {{",
209            service.name, service.name
210        ));
211        for method in service.methods.iter() {
212            buf.push('\n');
213            method.comments.append_with_indent(1, buf);
214            let input_type = self.extract_input_token(method);
215            let append_request = input_type.is_some();
216            let body = match (method.client_streaming, method.server_streaming) {
217                (false, false) => self.generate_unary_call(&method.proto_name, append_request),
218                (false, true) => {
219                    self.generate_server_streams_procedure(&method.proto_name, append_request)
220                }
221                (true, false) => {
222                    self.generate_client_streams_procedure(&method.proto_name, append_request)
223                }
224                (true, true) => {
225                    self.generate_bidir_streams_procedure(&method.proto_name, append_request)
226                }
227            };
228            buf.push_str(&format!(
229                "    {}\n",
230                self.method_sig_tokens(
231                    method,
232                    MethodSigTokensParams {
233                        body: Some(body),
234                        with_context: false,
235                        is_for_client: true
236                    }
237                )
238            ));
239        }
240        buf.push_str("}\n");
241    }
242
243    #[cfg(feature = "client")]
244    fn generate_unary_call(&self, name: &str, append_request: bool) -> TokenStream {
245        let request = if append_request {
246            quote!(request)
247        } else {
248            quote! { () }
249        };
250        quote! {
251            self.rpc_client_module
252                .call_unary_procedure(#name, #request)
253                .await
254        }
255    }
256
257    #[cfg(feature = "client")]
258    fn generate_server_streams_procedure(&self, name: &str, append_request: bool) -> TokenStream {
259        let request = if append_request {
260            quote!(request)
261        } else {
262            quote! { () }
263        };
264
265        quote! {
266            self.rpc_client_module
267                .call_server_streams_procedure(#name, #request)
268                .await
269        }
270    }
271
272    #[cfg(feature = "client")]
273    fn generate_client_streams_procedure(&self, name: &str, append_request: bool) -> TokenStream {
274        let request = if append_request {
275            quote!(request)
276        } else {
277            quote! { () }
278        };
279
280        quote! {
281            self.rpc_client_module
282                .call_client_streams_procedure(#name, #request)
283                .await
284        }
285    }
286
287    #[cfg(feature = "client")]
288    fn generate_bidir_streams_procedure(&self, name: &str, append_request: bool) -> TokenStream {
289        let request = if append_request {
290            quote!(request)
291        } else {
292            quote! { () }
293        };
294
295        quote! {
296            self.rpc_client_module
297                .call_bidir_streams_procedure(#name, #request)
298                .await
299        }
300    }
301
302    #[cfg(feature = "server")]
303    fn generate_server_service(&self, service: &Service, buf: &mut String) {
304        buf.push_str("use dcl_rpc::server::RpcServerPort;\n");
305        buf.push_str("use dcl_rpc::service_module_definition::ServiceModuleDefinition;\n");
306        buf.push_str("use prost::Message;\n");
307
308        let name = format!("{}Registration", service.name);
309        buf.push('\n');
310        buf.push_str(&format!("pub struct {} {{}}\n", name));
311        buf.push('\n');
312
313        buf.push('\n');
314        buf.push_str(&format!("impl {} {{", name));
315        buf.push_str(&format!("    {}", self.generate_register_service(service)));
316        buf.push_str("}\n");
317    }
318
319    #[cfg(feature = "server")]
320    fn generate_register_service(&self, service: &Service) -> TokenStream {
321        let service_name = &service.name;
322        let name = self.get_server_service_name(service);
323        let trait_name: TokenStream = name.parse().unwrap();
324
325        let mut methods: Vec<TokenStream> = vec![];
326        for method in &service.methods {
327            methods.push(match (method.client_streaming, method.server_streaming) {
328                (false, false) => self.generate_add_unary_call(method),
329                (false, true) => self.generate_add_server_streams_procedure(method),
330                (true, false) => self.generate_add_client_streams_procedure(method),
331                (true, true) => self.generate_add_bidir_streams_procedure(method),
332            });
333        }
334        quote! {
335        pub fn register_service<
336                S: #trait_name<Context, Error> + Send + Sync + 'static,
337                Context: Send + Sync + 'static,
338                Error: RemoteErrorResponse + Send + Sync + 'static
339            >(
340                port: &mut RpcServerPort<Context>,
341                service: S
342            ) {
343                let mut service_def = ServiceModuleDefinition::new();
344                // Share service ownership
345                let shareable_service = Arc::new(service);
346
347                #(#methods)*
348
349                port.register_module(#service_name.to_string(), service_def)
350            }
351        }
352    }
353
354    #[cfg(feature = "server")]
355    fn generate_add_unary_call(&self, method: &Method) -> TokenStream {
356        let method_name: TokenStream = method.name.parse().unwrap();
357        let proto_method_name = &method.proto_name;
358        let input_type = self.extract_input_token(method);
359
360        let service_call;
361        let request;
362        if let Some(input_type) = input_type {
363            service_call = quote! {
364                service.#method_name(#input_type::decode(request.as_slice()).unwrap(), context).await
365            };
366            request = quote! {request}
367        } else {
368            service_call = quote! { service.#method_name(context).await };
369            request = quote! {_request}
370        };
371        quote! {
372            let service = Arc::clone(&shareable_service);
373            service_def.add_unary(#proto_method_name, move |#request, context| {
374                let service = service.clone();
375                Box::pin(async move {
376                    match #service_call {
377                        Ok(response) => Ok(response.encode_to_vec()),
378                        Err(err) => Err(err.into())
379                    }
380                })
381            });
382        }
383    }
384
385    #[cfg(feature = "server")]
386    fn generate_add_server_streams_procedure(&self, method: &Method) -> TokenStream {
387        let method_name: TokenStream = method.name.parse().unwrap();
388        let proto_method_name = &method.proto_name;
389        let input_type: TokenStream = method.input_type.parse().unwrap();
390        let extracted_input_type = self.extract_input_token(method);
391
392        let service_stream;
393        let request;
394        if extracted_input_type.is_some() {
395            service_stream = quote! {
396                service.#method_name(#input_type::decode(request.as_slice()).unwrap(), context).await
397            };
398            request = quote! { request };
399        } else {
400            service_stream = quote! {
401                service.#method_name(context).await
402            };
403            request = quote! { _request };
404        };
405
406        quote! {
407            let service = Arc::clone(&shareable_service);
408            service_def.add_server_streams(#proto_method_name, move |#request, context| {
409                let service = service.clone();
410                Box::pin(async move {
411                    match #service_stream {
412                        // Transforming and filling the new generator is spawned so the response is quick
413                        Ok(server_streams_generator) => Ok(Generator::from_generator(server_streams_generator, |item| Some(item.encode_to_vec()))),
414                        Err(err) => Err(err.into())
415                    }
416                })
417            });
418        }
419    }
420
421    #[cfg(feature = "server")]
422    fn generate_add_client_streams_procedure(&self, method: &Method) -> TokenStream {
423        let method_name: TokenStream = method.name.parse().unwrap();
424        let proto_method_name = &method.proto_name;
425        let input_type: TokenStream = method.input_type.parse().unwrap();
426        let extracted_input_type = self.extract_input_token(method);
427
428        let input;
429        let request;
430        if extracted_input_type.is_some() {
431            input = quote! {
432                #input_type::decode(item.as_slice()).unwrap()
433            };
434            request = quote! { request };
435        } else {
436            input = quote! { () };
437            request = quote! { _request };
438        };
439        quote! {
440            let service = Arc::clone(&shareable_service);
441            service_def.add_client_streams(#proto_method_name, move |#request, context| {
442                let service = service.clone();
443                Box::pin(async move {
444                    let generator = Generator::from_generator(request, |item| {
445                        Some(#input)
446                    });
447
448                    match service.#method_name(generator, context).await {
449                        Ok(response) => Ok(response.encode_to_vec()),
450                        Err(err) => Err(err.into())
451                    }
452                })
453            });
454        }
455    }
456
457    #[cfg(feature = "server")]
458    fn generate_add_bidir_streams_procedure(&self, method: &Method) -> TokenStream {
459        let method_name: TokenStream = method.name.parse().unwrap();
460        let proto_method_name = &method.proto_name;
461        let input_type: TokenStream = method.input_type.parse().unwrap();
462        let extracted_input_type = self.extract_input_token(method);
463
464        let input;
465        let request;
466        if extracted_input_type.is_some() {
467            input = quote! {
468                #input_type::decode(item.as_slice()).unwrap()
469            };
470            request = quote! { request };
471        } else {
472            input = quote! { () };
473            request = quote! { _request };
474        };
475
476        quote! {
477            let service = Arc::clone(&shareable_service);
478            service_def.add_bidir_streams(#proto_method_name, move |#request, context| {
479                let service = service.clone();
480                Box::pin(async move {
481                    let generator = Generator::from_generator(request, |item| {
482                        Some(#input)
483                    });
484
485                    match service.#method_name(generator, context).await {
486                        Ok(response_generator) => Ok(Generator::from_generator(response_generator, |item| Some(item.encode_to_vec()))),
487                        Err(err) => Err(err.into())
488                    }
489                })
490            });
491        }
492    }
493}
494
495fn extract_name_token(method: &Method) -> proc_macro2::Ident {
496    format_ident!("{}", method.name)
497}
498
499fn extract_context_token(params: &MethodSigTokensParams) -> TokenStream {
500    match params.with_context {
501        true => quote! {, context: ProcedureContext<Context>},
502        false => TokenStream::default(),
503    }
504}
505
506fn extract_body_token(params: MethodSigTokensParams) -> TokenStream {
507    let body = params.body;
508    match body {
509        Some(body) => quote! { { #body } },
510        None => TokenStream::default(),
511    }
512}
513
514impl ServiceGenerator for RPCServiceGenerator {
515    fn generate(&mut self, service: Service, buf: &mut String) {
516        self.generate_stream_types(buf);
517        #[cfg(feature = "client")]
518        self.generate_client_trait(&service, buf);
519        #[cfg(feature = "client")]
520        self.generate_client_service(&service, buf);
521        #[cfg(feature = "server")]
522        self.generate_server_trait(&service, buf);
523        #[cfg(feature = "server")]
524        self.generate_server_service(&service, buf);
525        println!("{}", buf);
526    }
527
528    fn finalize(&mut self, _buf: &mut String) {}
529}