1use proc_macro2::TokenStream;
9use prost_build::{Method, Service, ServiceGenerator};
10use quote::{format_ident, quote};
11
12#[derive(Default)]
13pub struct RPCServiceGenerator {}
14
15pub struct MethodSigTokensParams {
16 body: Option<TokenStream>,
17 with_context: bool,
18 is_for_client: bool,
19}
20
21impl RPCServiceGenerator {
22 pub fn new() -> RPCServiceGenerator {
23 Default::default()
24 }
25
26 fn client_stream_request(&self) -> TokenStream {
27 quote!(ClientStreamRequest)
28 }
29
30 fn server_stream_response(&self) -> TokenStream {
31 quote!(ServerStreamResponse)
32 }
33
34 fn method_sig_tokens(&self, method: &Method, params: MethodSigTokensParams) -> TokenStream {
35 let input_type = self.extract_input_token(method);
36 let output_type = self.extract_output_token(method, params.is_for_client);
37 let name = extract_name_token(method);
38 let context = extract_context_token(¶ms);
39 let body = extract_body_token(params);
40
41 if let Some(input_type) = input_type {
42 quote! {
43 async fn #name(&self, request: #input_type #context)
44 #output_type #body
45 }
46 } else {
47 quote! {
48 async fn #name(&self #context)
49 #output_type #body
50 }
51 }
52 }
53
54 fn extract_input_token(&self, method: &Method) -> Option<TokenStream> {
55 if method.input_type.to_string().eq("()") {
56 None
57 } else {
58 let input_type = format_ident!("{}", method.input_type);
59 Some(match method.client_streaming {
60 true => {
61 let client_stream_request = self.client_stream_request();
62 quote!(#client_stream_request<#input_type>)
63 }
64 false => quote!(#input_type),
65 })
66 }
67 }
68
69 fn extract_output_token(&self, method: &Method, is_client: bool) -> TokenStream {
70 if method.output_type.to_string().eq("()") {
71 if is_client {
73 quote! { -> ClientResult<()> }
74 } else {
75 quote! { -> Result<(), Error> }
76 }
77 } else {
78 let output_type = format_ident!("{}", method.output_type);
79 match method.server_streaming {
80 true => {
81 let server_stream_response = self.server_stream_response();
82 if is_client {
83 quote! {-> ClientResult<#server_stream_response<#output_type>>}
84 } else {
85 quote! {-> Result<#server_stream_response<#output_type>, Error>}
86 }
87 }
88 false => {
89 if is_client {
90 quote! {-> ClientResult<#output_type>}
91 } else {
92 quote! {-> Result<#output_type, Error>}
93 }
94 }
95 }
96 }
97 }
98
99 fn generate_stream_types(&self, buf: &mut String) {
100 buf.push('\n');
101 buf.push_str("use dcl_rpc::stream_protocol::Generator;");
102 buf.push('\n');
103 buf.push_str("pub type ServerStreamResponse<T> = Generator<T>;");
104 buf.push('\n');
105 buf.push_str("pub type ClientStreamRequest<T> = Generator<T>;");
106 buf.push('\n');
107 }
108
109 #[cfg(feature = "client")]
110 fn generate_client_trait(&self, service: &Service, buf: &mut String) {
111 buf.push_str("use dcl_rpc::client::ClientResult;\n");
114 buf.push('\n');
115 service.comments.append_with_indent(0, buf);
116
117 buf.push_str("#[async_trait::async_trait]\n");
118 buf.push_str(&format!(
119 "pub trait {}ClientDefinition<T: Transport + 'static>: ServiceClient<T> + Send + Sync + 'static {{",
120 service.name
121 ));
122 for method in service.methods.iter() {
123 buf.push('\n');
124 method.comments.append_with_indent(1, buf);
125 buf.push_str(&format!(
126 " {};\n",
127 self.method_sig_tokens(
128 method,
129 MethodSigTokensParams {
130 body: None,
131 with_context: false,
132 is_for_client: true
133 }
134 )
135 ));
136 }
137 buf.push_str("}\n");
138 }
139
140 fn get_server_service_name(&self, service: &Service) -> String {
141 format!("{}Server", service.name)
142 }
143
144 #[cfg(feature = "server")]
145 fn generate_server_trait(&self, service: &Service, buf: &mut String) {
146 buf.push_str("use std::sync::Arc;\n");
147 buf.push_str("use dcl_rpc::{rpc_protocol::{RemoteErrorResponse}, service_module_definition::ProcedureContext};\n");
148 buf.push('\n');
151 service.comments.append_with_indent(0, buf);
152
153 buf.push_str("#[async_trait::async_trait]\n");
154 buf.push_str(&format!(
155 "pub trait {}<Context, Error: RemoteErrorResponse>: Send + Sync + 'static {{",
156 self.get_server_service_name(service)
157 ));
158 for method in service.methods.iter() {
159 buf.push('\n');
160 method.comments.append_with_indent(1, buf);
161 buf.push_str(&format!(
162 " {};\n",
163 self.method_sig_tokens(
164 method,
165 MethodSigTokensParams {
166 body: None,
167 with_context: true,
168 is_for_client: false
169 }
170 )
171 ));
172 }
173 buf.push_str("}\n");
174 }
175
176 #[cfg(feature = "client")]
177 fn generate_client_service(&self, service: &Service, buf: &mut String) {
178 buf.push('\n');
179 buf.push_str(
182 "use dcl_rpc::{client::{RpcClientModule, ServiceClient}, transports::Transport};",
183 );
184 buf.push_str(&format!(
185 "pub struct {}Client<T: Transport + 'static> {{",
186 service.name
187 ));
188 buf.push_str(&format!(
189 " {},\n",
190 "rpc_client_module: RpcClientModule<T>"
191 ));
192 buf.push('}');
193
194 buf.push('\n');
195
196 buf.push_str(&format!(
197 "impl<T: Transport + 'static> ServiceClient<T> for {}Client<T> {{
198 fn set_client_module(rpc_client_module: RpcClientModule<T>) -> Self {{
199 Self {{ rpc_client_module }}
200 }}
201}}
202",
203 service.name
204 ));
205
206 buf.push_str("#[async_trait::async_trait]\n");
207 buf.push_str(&format!(
208 "impl<T: Transport + 'static> {}ClientDefinition<T> for {}Client<T> {{",
209 service.name, service.name
210 ));
211 for method in service.methods.iter() {
212 buf.push('\n');
213 method.comments.append_with_indent(1, buf);
214 let input_type = self.extract_input_token(method);
215 let append_request = input_type.is_some();
216 let body = match (method.client_streaming, method.server_streaming) {
217 (false, false) => self.generate_unary_call(&method.proto_name, append_request),
218 (false, true) => {
219 self.generate_server_streams_procedure(&method.proto_name, append_request)
220 }
221 (true, false) => {
222 self.generate_client_streams_procedure(&method.proto_name, append_request)
223 }
224 (true, true) => {
225 self.generate_bidir_streams_procedure(&method.proto_name, append_request)
226 }
227 };
228 buf.push_str(&format!(
229 " {}\n",
230 self.method_sig_tokens(
231 method,
232 MethodSigTokensParams {
233 body: Some(body),
234 with_context: false,
235 is_for_client: true
236 }
237 )
238 ));
239 }
240 buf.push_str("}\n");
241 }
242
243 #[cfg(feature = "client")]
244 fn generate_unary_call(&self, name: &str, append_request: bool) -> TokenStream {
245 let request = if append_request {
246 quote!(request)
247 } else {
248 quote! { () }
249 };
250 quote! {
251 self.rpc_client_module
252 .call_unary_procedure(#name, #request)
253 .await
254 }
255 }
256
257 #[cfg(feature = "client")]
258 fn generate_server_streams_procedure(&self, name: &str, append_request: bool) -> TokenStream {
259 let request = if append_request {
260 quote!(request)
261 } else {
262 quote! { () }
263 };
264
265 quote! {
266 self.rpc_client_module
267 .call_server_streams_procedure(#name, #request)
268 .await
269 }
270 }
271
272 #[cfg(feature = "client")]
273 fn generate_client_streams_procedure(&self, name: &str, append_request: bool) -> TokenStream {
274 let request = if append_request {
275 quote!(request)
276 } else {
277 quote! { () }
278 };
279
280 quote! {
281 self.rpc_client_module
282 .call_client_streams_procedure(#name, #request)
283 .await
284 }
285 }
286
287 #[cfg(feature = "client")]
288 fn generate_bidir_streams_procedure(&self, name: &str, append_request: bool) -> TokenStream {
289 let request = if append_request {
290 quote!(request)
291 } else {
292 quote! { () }
293 };
294
295 quote! {
296 self.rpc_client_module
297 .call_bidir_streams_procedure(#name, #request)
298 .await
299 }
300 }
301
302 #[cfg(feature = "server")]
303 fn generate_server_service(&self, service: &Service, buf: &mut String) {
304 buf.push_str("use dcl_rpc::server::RpcServerPort;\n");
305 buf.push_str("use dcl_rpc::service_module_definition::ServiceModuleDefinition;\n");
306 buf.push_str("use prost::Message;\n");
307
308 let name = format!("{}Registration", service.name);
309 buf.push('\n');
310 buf.push_str(&format!("pub struct {} {{}}\n", name));
311 buf.push('\n');
312
313 buf.push('\n');
314 buf.push_str(&format!("impl {} {{", name));
315 buf.push_str(&format!(" {}", self.generate_register_service(service)));
316 buf.push_str("}\n");
317 }
318
319 #[cfg(feature = "server")]
320 fn generate_register_service(&self, service: &Service) -> TokenStream {
321 let service_name = &service.name;
322 let name = self.get_server_service_name(service);
323 let trait_name: TokenStream = name.parse().unwrap();
324
325 let mut methods: Vec<TokenStream> = vec![];
326 for method in &service.methods {
327 methods.push(match (method.client_streaming, method.server_streaming) {
328 (false, false) => self.generate_add_unary_call(method),
329 (false, true) => self.generate_add_server_streams_procedure(method),
330 (true, false) => self.generate_add_client_streams_procedure(method),
331 (true, true) => self.generate_add_bidir_streams_procedure(method),
332 });
333 }
334 quote! {
335 pub fn register_service<
336 S: #trait_name<Context, Error> + Send + Sync + 'static,
337 Context: Send + Sync + 'static,
338 Error: RemoteErrorResponse + Send + Sync + 'static
339 >(
340 port: &mut RpcServerPort<Context>,
341 service: S
342 ) {
343 let mut service_def = ServiceModuleDefinition::new();
344 let shareable_service = Arc::new(service);
346
347 #(#methods)*
348
349 port.register_module(#service_name.to_string(), service_def)
350 }
351 }
352 }
353
354 #[cfg(feature = "server")]
355 fn generate_add_unary_call(&self, method: &Method) -> TokenStream {
356 let method_name: TokenStream = method.name.parse().unwrap();
357 let proto_method_name = &method.proto_name;
358 let input_type = self.extract_input_token(method);
359
360 let service_call;
361 let request;
362 if let Some(input_type) = input_type {
363 service_call = quote! {
364 service.#method_name(#input_type::decode(request.as_slice()).unwrap(), context).await
365 };
366 request = quote! {request}
367 } else {
368 service_call = quote! { service.#method_name(context).await };
369 request = quote! {_request}
370 };
371 quote! {
372 let service = Arc::clone(&shareable_service);
373 service_def.add_unary(#proto_method_name, move |#request, context| {
374 let service = service.clone();
375 Box::pin(async move {
376 match #service_call {
377 Ok(response) => Ok(response.encode_to_vec()),
378 Err(err) => Err(err.into())
379 }
380 })
381 });
382 }
383 }
384
385 #[cfg(feature = "server")]
386 fn generate_add_server_streams_procedure(&self, method: &Method) -> TokenStream {
387 let method_name: TokenStream = method.name.parse().unwrap();
388 let proto_method_name = &method.proto_name;
389 let input_type: TokenStream = method.input_type.parse().unwrap();
390 let extracted_input_type = self.extract_input_token(method);
391
392 let service_stream;
393 let request;
394 if extracted_input_type.is_some() {
395 service_stream = quote! {
396 service.#method_name(#input_type::decode(request.as_slice()).unwrap(), context).await
397 };
398 request = quote! { request };
399 } else {
400 service_stream = quote! {
401 service.#method_name(context).await
402 };
403 request = quote! { _request };
404 };
405
406 quote! {
407 let service = Arc::clone(&shareable_service);
408 service_def.add_server_streams(#proto_method_name, move |#request, context| {
409 let service = service.clone();
410 Box::pin(async move {
411 match #service_stream {
412 Ok(server_streams_generator) => Ok(Generator::from_generator(server_streams_generator, |item| Some(item.encode_to_vec()))),
414 Err(err) => Err(err.into())
415 }
416 })
417 });
418 }
419 }
420
421 #[cfg(feature = "server")]
422 fn generate_add_client_streams_procedure(&self, method: &Method) -> TokenStream {
423 let method_name: TokenStream = method.name.parse().unwrap();
424 let proto_method_name = &method.proto_name;
425 let input_type: TokenStream = method.input_type.parse().unwrap();
426 let extracted_input_type = self.extract_input_token(method);
427
428 let input;
429 let request;
430 if extracted_input_type.is_some() {
431 input = quote! {
432 #input_type::decode(item.as_slice()).unwrap()
433 };
434 request = quote! { request };
435 } else {
436 input = quote! { () };
437 request = quote! { _request };
438 };
439 quote! {
440 let service = Arc::clone(&shareable_service);
441 service_def.add_client_streams(#proto_method_name, move |#request, context| {
442 let service = service.clone();
443 Box::pin(async move {
444 let generator = Generator::from_generator(request, |item| {
445 Some(#input)
446 });
447
448 match service.#method_name(generator, context).await {
449 Ok(response) => Ok(response.encode_to_vec()),
450 Err(err) => Err(err.into())
451 }
452 })
453 });
454 }
455 }
456
457 #[cfg(feature = "server")]
458 fn generate_add_bidir_streams_procedure(&self, method: &Method) -> TokenStream {
459 let method_name: TokenStream = method.name.parse().unwrap();
460 let proto_method_name = &method.proto_name;
461 let input_type: TokenStream = method.input_type.parse().unwrap();
462 let extracted_input_type = self.extract_input_token(method);
463
464 let input;
465 let request;
466 if extracted_input_type.is_some() {
467 input = quote! {
468 #input_type::decode(item.as_slice()).unwrap()
469 };
470 request = quote! { request };
471 } else {
472 input = quote! { () };
473 request = quote! { _request };
474 };
475
476 quote! {
477 let service = Arc::clone(&shareable_service);
478 service_def.add_bidir_streams(#proto_method_name, move |#request, context| {
479 let service = service.clone();
480 Box::pin(async move {
481 let generator = Generator::from_generator(request, |item| {
482 Some(#input)
483 });
484
485 match service.#method_name(generator, context).await {
486 Ok(response_generator) => Ok(Generator::from_generator(response_generator, |item| Some(item.encode_to_vec()))),
487 Err(err) => Err(err.into())
488 }
489 })
490 });
491 }
492 }
493}
494
495fn extract_name_token(method: &Method) -> proc_macro2::Ident {
496 format_ident!("{}", method.name)
497}
498
499fn extract_context_token(params: &MethodSigTokensParams) -> TokenStream {
500 match params.with_context {
501 true => quote! {, context: ProcedureContext<Context>},
502 false => TokenStream::default(),
503 }
504}
505
506fn extract_body_token(params: MethodSigTokensParams) -> TokenStream {
507 let body = params.body;
508 match body {
509 Some(body) => quote! { { #body } },
510 None => TokenStream::default(),
511 }
512}
513
514impl ServiceGenerator for RPCServiceGenerator {
515 fn generate(&mut self, service: Service, buf: &mut String) {
516 self.generate_stream_types(buf);
517 #[cfg(feature = "client")]
518 self.generate_client_trait(&service, buf);
519 #[cfg(feature = "client")]
520 self.generate_client_service(&service, buf);
521 #[cfg(feature = "server")]
522 self.generate_server_trait(&service, buf);
523 #[cfg(feature = "server")]
524 self.generate_server_service(&service, buf);
525 println!("{}", buf);
526 }
527
528 fn finalize(&mut self, _buf: &mut String) {}
529}