dcl_rpc_codegen/
lib.rs

1//! Generate service code from a service definition.
2
3// Guidelines for generated code:
4//
5// Use fully-qualified paths, to reduce the chance of clashing with
6// user provided names.
7
8use proc_macro2::TokenStream;
9use prost_build::{Method, Service, ServiceGenerator};
10use quote::{format_ident, quote};
11
12#[derive(Default)]
13pub struct RPCServiceGenerator {}
14
15impl RPCServiceGenerator {
16    pub fn new() -> RPCServiceGenerator {
17        Default::default()
18    }
19
20    fn client_stream_request(&self) -> TokenStream {
21        quote!(ClientStreamRequest)
22    }
23
24    fn server_stream_response(&self) -> TokenStream {
25        quote!(ServerStreamResponse)
26    }
27
28    fn method_sig_tokens(&self, method: &Method, body: Option<TokenStream>) -> TokenStream {
29        let name = format_ident!("{}", method.name);
30        let input_type = format_ident!("{}", method.input_type);
31        let output_type = format_ident!("{}", method.output_type);
32
33        let input_type = if method.client_streaming {
34            let client_stream_request = self.client_stream_request();
35            quote!(#client_stream_request<#input_type>)
36        } else {
37            quote!(#input_type)
38        };
39
40        let output_type = if method.server_streaming {
41            let server_stream_response = self.server_stream_response();
42            quote!(#server_stream_response<#output_type>)
43        } else {
44            quote!(#output_type)
45        };
46
47        if let Some(body) = body {
48            quote! {
49                async fn #name(&self, request: #input_type)
50                    -> #output_type {
51                        #body
52                    }
53            }
54        } else {
55            quote! {
56                async fn #name(&self, request: #input_type)
57                    -> #output_type
58            }
59        }
60    }
61
62    fn method_sig_tokens_with_context(&self, method: &Method) -> TokenStream {
63        let name = format_ident!("{}", method.name);
64        let input_type = format_ident!("{}", method.input_type);
65        let output_type = format_ident!("{}", method.output_type);
66
67        let input_type = if method.client_streaming {
68            let client_stream_request = self.client_stream_request();
69            quote!(#client_stream_request<#input_type>)
70        } else {
71            quote!(#input_type)
72        };
73
74        let output_type = if method.server_streaming {
75            let server_stream_response = self.server_stream_response();
76            quote!(#server_stream_response<#output_type>)
77        } else {
78            quote!(#output_type)
79        };
80
81        quote! {
82            async fn #name(&self, request: #input_type, context: Arc<Context>)
83                -> #output_type
84        }
85    }
86
87    fn generate_stream_types(&self, buf: &mut String) {
88        buf.push('\n');
89        buf.push_str("use dcl_rpc::stream_protocol::Generator;");
90        buf.push('\n');
91        buf.push_str("pub type ServerStreamResponse<T> = Generator<T>;");
92        buf.push('\n');
93        buf.push_str("pub type ClientStreamRequest<T> = Generator<T>;");
94        buf.push('\n');
95    }
96
97    fn generate_client_trait(&self, service: &Service, buf: &mut String) {
98        // This is done with strings rather than tokens because Prost provides functions that
99        // return doc comments as strings.
100        buf.push('\n');
101        service.comments.append_with_indent(0, buf);
102
103        buf.push_str("#[async_trait::async_trait]\n");
104        buf.push_str(&format!(
105            "pub trait {}: Send + Sync + 'static {{",
106            service.name
107        ));
108        for method in service.methods.iter() {
109            buf.push('\n');
110            method.comments.append_with_indent(1, buf);
111            buf.push_str(&format!("    {};\n", self.method_sig_tokens(method, None)));
112        }
113        buf.push_str("}\n");
114    }
115
116    fn get_server_service_name(&self, service: &Service) -> String {
117        format!("Shared{}", service.name)
118    }
119
120    fn generate_server_trait(&self, service: &Service, buf: &mut String) {
121        buf.push_str("use std::sync::Arc;\n");
122        // This is done with strings rather than tokens because Prost provides functions that
123        // return doc comments as strings.
124        buf.push('\n');
125        service.comments.append_with_indent(0, buf);
126
127        buf.push_str("#[async_trait::async_trait]\n");
128        buf.push_str(&format!(
129            "pub trait {}<Context>: Send + Sync + 'static {{",
130            self.get_server_service_name(service)
131        ));
132        for method in service.methods.iter() {
133            buf.push('\n');
134            method.comments.append_with_indent(1, buf);
135            buf.push_str(&format!(
136                "    {};\n",
137                self.method_sig_tokens_with_context(method)
138            ));
139        }
140        buf.push_str("}\n");
141    }
142
143    fn generate_client_service(&self, service: &Service, buf: &mut String) {
144        buf.push('\n');
145        // Create struct
146
147        buf.push_str("use dcl_rpc::client::{RpcClientModule, ServiceClient};");
148        buf.push_str(&format!("pub struct {}Client {{", service.name));
149        buf.push_str(&format!("    {},\n", "rpc_client_module: RpcClientModule"));
150        buf.push_str("}");
151
152        buf.push('\n');
153
154        buf.push_str(&format!(
155            "impl ServiceClient for {}Client {{
156    fn set_client_module(rpc_client_module: RpcClientModule) -> Self {{
157        Self {{ rpc_client_module }}
158    }}
159}}
160",
161            service.name
162        ));
163
164        buf.push_str("#[async_trait::async_trait]\n");
165        buf.push_str(&format!(
166            "impl {} for {}Client {{",
167            service.name, service.name
168        ));
169        for method in service.methods.iter() {
170            buf.push('\n');
171            method.comments.append_with_indent(1, buf);
172            let body = match (method.client_streaming, method.server_streaming) {
173                (false, false) => self.generate_unary_call(&method.proto_name),
174                (false, true) => self.generate_server_streams_procedure(&method.proto_name),
175                (true, false) => self.generate_client_streams_procedure(&method.proto_name),
176                (true, true) => self.generate_bidir_streams_procedure(&method.proto_name),
177            };
178            buf.push_str(&format!(
179                "    {}\n",
180                self.method_sig_tokens(method, Some(body))
181            ));
182        }
183        buf.push_str("}\n");
184    }
185
186    fn generate_unary_call(&self, name: &str) -> TokenStream {
187        quote! {
188            self.rpc_client_module
189                .call_unary_procedure(#name, request)
190                .await
191                .unwrap()
192        }
193    }
194
195    fn generate_server_streams_procedure(&self, name: &str) -> TokenStream {
196        quote! {
197            self.rpc_client_module
198                .call_server_streams_procedure(#name, request)
199                .await
200                .unwrap()
201        }
202    }
203
204    fn generate_client_streams_procedure(&self, name: &str) -> TokenStream {
205        quote! {
206            self.rpc_client_module
207                .call_client_streams_procedure(#name, request)
208                .await
209                .unwrap()
210        }
211    }
212
213    fn generate_bidir_streams_procedure(&self, name: &str) -> TokenStream {
214        quote! {
215            self.rpc_client_module
216                .call_bidir_streams_procedure(#name, request)
217                .await
218                .unwrap()
219        }
220    }
221
222    fn generate_server_service(&self, service: &Service, buf: &mut String) {
223        buf.push_str("use dcl_rpc::server::RpcServerPort;\n");
224        buf.push_str("use dcl_rpc::service_module_definition::ServiceModuleDefinition;\n");
225        buf.push_str("use prost::Message;\n");
226
227        let name = format!("{}Registration", service.name);
228        buf.push('\n');
229        buf.push_str(&format!("pub struct {} {{}}\n", name));
230        buf.push('\n');
231
232        buf.push('\n');
233        buf.push_str(&format!("impl {} {{", name));
234        buf.push_str(&format!("    {}", self.generate_register_service(service)));
235        buf.push_str("}\n");
236    }
237
238    fn generate_register_service(&self, service: &Service) -> TokenStream {
239        let service_name = &service.name;
240        let name = self.get_server_service_name(service);
241        let trait_name: TokenStream = name.parse().unwrap();
242
243        let mut methods: Vec<TokenStream> = vec![];
244        for method in &service.methods {
245            methods.push(match (method.client_streaming, method.server_streaming) {
246                (false, false) => self.generate_add_unary_call(&method),
247                (false, true) => self.generate_add_server_streams_procedure(&method),
248                (true, false) => self.generate_add_client_streams_procedure(&method),
249                (true, true) => self.generate_add_bidir_streams_procedure(&method),
250            });
251        }
252        quote! {
253        pub fn register_service<
254                S: #trait_name<Context> + Send + Sync + 'static,
255                Context: Send + Sync + 'static
256            >(
257                port: &mut RpcServerPort<Context>,
258                service: S
259            ) {
260                let mut service_def = ServiceModuleDefinition::new();
261                // Share service ownership
262                let shareable_service = Arc::new(service);
263
264                #(#methods)*
265
266                port.register_module(#service_name.to_string(), service_def)
267            }
268        }
269    }
270
271    fn generate_add_unary_call(&self, method: &Method) -> TokenStream {
272        let method_name: TokenStream = method.name.parse().unwrap();
273        let proto_method_name = &method.proto_name;
274        let input_type: TokenStream = method.input_type.parse().unwrap();
275        quote! {
276            let service = Arc::clone(&shareable_service);
277            service_def.add_unary(#proto_method_name, move |request, context| {
278                let service = service.clone();
279                Box::pin(async move {
280                    let response = service
281                        .#method_name(#input_type::decode(request.as_slice()).unwrap(), context)
282                        .await;
283                    response.encode_to_vec()
284                })
285            });
286        }
287    }
288
289    fn generate_add_server_streams_procedure(&self, method: &Method) -> TokenStream {
290        let method_name: TokenStream = method.name.parse().unwrap();
291        let proto_method_name = &method.proto_name;
292        let input_type: TokenStream = method.input_type.parse().unwrap();
293        quote! {
294            let service = Arc::clone(&shareable_service);
295            service_def.add_server_streams(#proto_method_name, move |request, context| {
296                let service = service.clone();
297                Box::pin(async move {
298                    let server_streams = service
299                        .#method_name(#input_type::decode(request.as_slice()).unwrap(), context)
300                        .await;
301                    // Transforming and filling the new generator is spawned so the response is quick
302                    Generator::from_generator(server_streams, |item| item.encode_to_vec())
303                })
304            });
305        }
306    }
307
308    fn generate_add_client_streams_procedure(&self, method: &Method) -> TokenStream {
309        let method_name: TokenStream = method.name.parse().unwrap();
310        let proto_method_name = &method.proto_name;
311        let input_type: TokenStream = method.input_type.parse().unwrap();
312        quote! {
313            let service = Arc::clone(&shareable_service);
314            service_def.add_client_streams(#proto_method_name, move |request, context| {
315                let service = service.clone();
316                Box::pin(async move {
317                    let generator = Generator::from_generator(request, |item| {
318                        #input_type::decode(item.as_slice()).unwrap()
319                    });
320
321                    let response = service.#method_name(generator, context).await;
322                    response.encode_to_vec()
323                })
324            });
325        }
326    }
327
328    fn generate_add_bidir_streams_procedure(&self, method: &Method) -> TokenStream {
329        let method_name: TokenStream = method.name.parse().unwrap();
330        let proto_method_name = &method.proto_name;
331        let input_type: TokenStream = method.input_type.parse().unwrap();
332        quote! {
333            let service = Arc::clone(&shareable_service);
334            service_def.add_bidir_streams(#proto_method_name, move |request, context| {
335                let service = service.clone();
336                Box::pin(async move {
337                    let generator = Generator::from_generator(request, |item| {
338                        #input_type::decode(item.as_slice()).unwrap()
339                    });
340
341                    let response = service.#method_name(generator, context).await;
342                    Generator::from_generator(response, |item| item.encode_to_vec())
343                })
344            });
345        }
346    }
347}
348
349impl ServiceGenerator for RPCServiceGenerator {
350    fn generate(&mut self, service: Service, buf: &mut String) {
351        self.generate_stream_types(buf);
352        self.generate_client_trait(&service, buf);
353        self.generate_client_service(&service, buf);
354        self.generate_server_trait(&service, buf);
355        self.generate_server_service(&service, buf);
356        println!("{}", buf.to_string());
357    }
358
359    fn finalize(&mut self, _buf: &mut String) {}
360}