1use proc_macro2::TokenStream;
9use prost_build::{Method, Service, ServiceGenerator};
10use quote::{format_ident, quote};
11
12#[derive(Default)]
13pub struct RPCServiceGenerator {}
14
15impl RPCServiceGenerator {
16 pub fn new() -> RPCServiceGenerator {
17 Default::default()
18 }
19
20 fn client_stream_request(&self) -> TokenStream {
21 quote!(ClientStreamRequest)
22 }
23
24 fn server_stream_response(&self) -> TokenStream {
25 quote!(ServerStreamResponse)
26 }
27
28 fn method_sig_tokens(&self, method: &Method, body: Option<TokenStream>) -> TokenStream {
29 let name = format_ident!("{}", method.name);
30 let input_type = format_ident!("{}", method.input_type);
31 let output_type = format_ident!("{}", method.output_type);
32
33 let input_type = if method.client_streaming {
34 let client_stream_request = self.client_stream_request();
35 quote!(#client_stream_request<#input_type>)
36 } else {
37 quote!(#input_type)
38 };
39
40 let output_type = if method.server_streaming {
41 let server_stream_response = self.server_stream_response();
42 quote!(#server_stream_response<#output_type>)
43 } else {
44 quote!(#output_type)
45 };
46
47 if let Some(body) = body {
48 quote! {
49 async fn #name(&self, request: #input_type)
50 -> #output_type {
51 #body
52 }
53 }
54 } else {
55 quote! {
56 async fn #name(&self, request: #input_type)
57 -> #output_type
58 }
59 }
60 }
61
62 fn method_sig_tokens_with_context(&self, method: &Method) -> TokenStream {
63 let name = format_ident!("{}", method.name);
64 let input_type = format_ident!("{}", method.input_type);
65 let output_type = format_ident!("{}", method.output_type);
66
67 let input_type = if method.client_streaming {
68 let client_stream_request = self.client_stream_request();
69 quote!(#client_stream_request<#input_type>)
70 } else {
71 quote!(#input_type)
72 };
73
74 let output_type = if method.server_streaming {
75 let server_stream_response = self.server_stream_response();
76 quote!(#server_stream_response<#output_type>)
77 } else {
78 quote!(#output_type)
79 };
80
81 quote! {
82 async fn #name(&self, request: #input_type, context: Arc<Context>)
83 -> #output_type
84 }
85 }
86
87 fn generate_stream_types(&self, buf: &mut String) {
88 buf.push('\n');
89 buf.push_str("use dcl_rpc::stream_protocol::Generator;");
90 buf.push('\n');
91 buf.push_str("pub type ServerStreamResponse<T> = Generator<T>;");
92 buf.push('\n');
93 buf.push_str("pub type ClientStreamRequest<T> = Generator<T>;");
94 buf.push('\n');
95 }
96
97 fn generate_client_trait(&self, service: &Service, buf: &mut String) {
98 buf.push('\n');
101 service.comments.append_with_indent(0, buf);
102
103 buf.push_str("#[async_trait::async_trait]\n");
104 buf.push_str(&format!(
105 "pub trait {}: Send + Sync + 'static {{",
106 service.name
107 ));
108 for method in service.methods.iter() {
109 buf.push('\n');
110 method.comments.append_with_indent(1, buf);
111 buf.push_str(&format!(" {};\n", self.method_sig_tokens(method, None)));
112 }
113 buf.push_str("}\n");
114 }
115
116 fn get_server_service_name(&self, service: &Service) -> String {
117 format!("Shared{}", service.name)
118 }
119
120 fn generate_server_trait(&self, service: &Service, buf: &mut String) {
121 buf.push_str("use std::sync::Arc;\n");
122 buf.push('\n');
125 service.comments.append_with_indent(0, buf);
126
127 buf.push_str("#[async_trait::async_trait]\n");
128 buf.push_str(&format!(
129 "pub trait {}<Context>: Send + Sync + 'static {{",
130 self.get_server_service_name(service)
131 ));
132 for method in service.methods.iter() {
133 buf.push('\n');
134 method.comments.append_with_indent(1, buf);
135 buf.push_str(&format!(
136 " {};\n",
137 self.method_sig_tokens_with_context(method)
138 ));
139 }
140 buf.push_str("}\n");
141 }
142
143 fn generate_client_service(&self, service: &Service, buf: &mut String) {
144 buf.push('\n');
145 buf.push_str("use dcl_rpc::client::{RpcClientModule, ServiceClient};");
148 buf.push_str(&format!("pub struct {}Client {{", service.name));
149 buf.push_str(&format!(" {},\n", "rpc_client_module: RpcClientModule"));
150 buf.push_str("}");
151
152 buf.push('\n');
153
154 buf.push_str(&format!(
155 "impl ServiceClient for {}Client {{
156 fn set_client_module(rpc_client_module: RpcClientModule) -> Self {{
157 Self {{ rpc_client_module }}
158 }}
159}}
160",
161 service.name
162 ));
163
164 buf.push_str("#[async_trait::async_trait]\n");
165 buf.push_str(&format!(
166 "impl {} for {}Client {{",
167 service.name, service.name
168 ));
169 for method in service.methods.iter() {
170 buf.push('\n');
171 method.comments.append_with_indent(1, buf);
172 let body = match (method.client_streaming, method.server_streaming) {
173 (false, false) => self.generate_unary_call(&method.proto_name),
174 (false, true) => self.generate_server_streams_procedure(&method.proto_name),
175 (true, false) => self.generate_client_streams_procedure(&method.proto_name),
176 (true, true) => self.generate_bidir_streams_procedure(&method.proto_name),
177 };
178 buf.push_str(&format!(
179 " {}\n",
180 self.method_sig_tokens(method, Some(body))
181 ));
182 }
183 buf.push_str("}\n");
184 }
185
186 fn generate_unary_call(&self, name: &str) -> TokenStream {
187 quote! {
188 self.rpc_client_module
189 .call_unary_procedure(#name, request)
190 .await
191 .unwrap()
192 }
193 }
194
195 fn generate_server_streams_procedure(&self, name: &str) -> TokenStream {
196 quote! {
197 self.rpc_client_module
198 .call_server_streams_procedure(#name, request)
199 .await
200 .unwrap()
201 }
202 }
203
204 fn generate_client_streams_procedure(&self, name: &str) -> TokenStream {
205 quote! {
206 self.rpc_client_module
207 .call_client_streams_procedure(#name, request)
208 .await
209 .unwrap()
210 }
211 }
212
213 fn generate_bidir_streams_procedure(&self, name: &str) -> TokenStream {
214 quote! {
215 self.rpc_client_module
216 .call_bidir_streams_procedure(#name, request)
217 .await
218 .unwrap()
219 }
220 }
221
222 fn generate_server_service(&self, service: &Service, buf: &mut String) {
223 buf.push_str("use dcl_rpc::server::RpcServerPort;\n");
224 buf.push_str("use dcl_rpc::service_module_definition::ServiceModuleDefinition;\n");
225 buf.push_str("use prost::Message;\n");
226
227 let name = format!("{}Registration", service.name);
228 buf.push('\n');
229 buf.push_str(&format!("pub struct {} {{}}\n", name));
230 buf.push('\n');
231
232 buf.push('\n');
233 buf.push_str(&format!("impl {} {{", name));
234 buf.push_str(&format!(" {}", self.generate_register_service(service)));
235 buf.push_str("}\n");
236 }
237
238 fn generate_register_service(&self, service: &Service) -> TokenStream {
239 let service_name = &service.name;
240 let name = self.get_server_service_name(service);
241 let trait_name: TokenStream = name.parse().unwrap();
242
243 let mut methods: Vec<TokenStream> = vec![];
244 for method in &service.methods {
245 methods.push(match (method.client_streaming, method.server_streaming) {
246 (false, false) => self.generate_add_unary_call(&method),
247 (false, true) => self.generate_add_server_streams_procedure(&method),
248 (true, false) => self.generate_add_client_streams_procedure(&method),
249 (true, true) => self.generate_add_bidir_streams_procedure(&method),
250 });
251 }
252 quote! {
253 pub fn register_service<
254 S: #trait_name<Context> + Send + Sync + 'static,
255 Context: Send + Sync + 'static
256 >(
257 port: &mut RpcServerPort<Context>,
258 service: S
259 ) {
260 let mut service_def = ServiceModuleDefinition::new();
261 let shareable_service = Arc::new(service);
263
264 #(#methods)*
265
266 port.register_module(#service_name.to_string(), service_def)
267 }
268 }
269 }
270
271 fn generate_add_unary_call(&self, method: &Method) -> TokenStream {
272 let method_name: TokenStream = method.name.parse().unwrap();
273 let proto_method_name = &method.proto_name;
274 let input_type: TokenStream = method.input_type.parse().unwrap();
275 quote! {
276 let service = Arc::clone(&shareable_service);
277 service_def.add_unary(#proto_method_name, move |request, context| {
278 let service = service.clone();
279 Box::pin(async move {
280 let response = service
281 .#method_name(#input_type::decode(request.as_slice()).unwrap(), context)
282 .await;
283 response.encode_to_vec()
284 })
285 });
286 }
287 }
288
289 fn generate_add_server_streams_procedure(&self, method: &Method) -> TokenStream {
290 let method_name: TokenStream = method.name.parse().unwrap();
291 let proto_method_name = &method.proto_name;
292 let input_type: TokenStream = method.input_type.parse().unwrap();
293 quote! {
294 let service = Arc::clone(&shareable_service);
295 service_def.add_server_streams(#proto_method_name, move |request, context| {
296 let service = service.clone();
297 Box::pin(async move {
298 let server_streams = service
299 .#method_name(#input_type::decode(request.as_slice()).unwrap(), context)
300 .await;
301 Generator::from_generator(server_streams, |item| item.encode_to_vec())
303 })
304 });
305 }
306 }
307
308 fn generate_add_client_streams_procedure(&self, method: &Method) -> TokenStream {
309 let method_name: TokenStream = method.name.parse().unwrap();
310 let proto_method_name = &method.proto_name;
311 let input_type: TokenStream = method.input_type.parse().unwrap();
312 quote! {
313 let service = Arc::clone(&shareable_service);
314 service_def.add_client_streams(#proto_method_name, move |request, context| {
315 let service = service.clone();
316 Box::pin(async move {
317 let generator = Generator::from_generator(request, |item| {
318 #input_type::decode(item.as_slice()).unwrap()
319 });
320
321 let response = service.#method_name(generator, context).await;
322 response.encode_to_vec()
323 })
324 });
325 }
326 }
327
328 fn generate_add_bidir_streams_procedure(&self, method: &Method) -> TokenStream {
329 let method_name: TokenStream = method.name.parse().unwrap();
330 let proto_method_name = &method.proto_name;
331 let input_type: TokenStream = method.input_type.parse().unwrap();
332 quote! {
333 let service = Arc::clone(&shareable_service);
334 service_def.add_bidir_streams(#proto_method_name, move |request, context| {
335 let service = service.clone();
336 Box::pin(async move {
337 let generator = Generator::from_generator(request, |item| {
338 #input_type::decode(item.as_slice()).unwrap()
339 });
340
341 let response = service.#method_name(generator, context).await;
342 Generator::from_generator(response, |item| item.encode_to_vec())
343 })
344 });
345 }
346 }
347}
348
349impl ServiceGenerator for RPCServiceGenerator {
350 fn generate(&mut self, service: Service, buf: &mut String) {
351 self.generate_stream_types(buf);
352 self.generate_client_trait(&service, buf);
353 self.generate_client_service(&service, buf);
354 self.generate_server_trait(&service, buf);
355 self.generate_server_service(&service, buf);
356 println!("{}", buf.to_string());
357 }
358
359 fn finalize(&mut self, _buf: &mut String) {}
360}