1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use syn::parse::{Parse, ParseStream};
4
5extern crate proc_macro;
6extern crate proc_macro2;
7extern crate quote;
8extern crate syn;
9
10use quote::quote;
11use syn::spanned::Spanned;
12use syn::token::{Comma, Mut};
13use syn::{
14 braced, parenthesized, parse_macro_input, parse_quote, Attribute, FnArg, Ident, LitStr, Pat,
15 PatType, Result, ReturnType, Token, Type, Visibility,
16};
17
18macro_rules! extend_errors {
22 ($errors: ident, $e: expr) => {
23 match $errors {
24 Ok(_) => $errors = Err($e),
25 Err(ref mut errors) => errors.extend($e),
26 }
27 };
28}
29
30#[allow(dead_code)]
31#[derive(Debug)]
32struct ServiceMacroInput {
33 attrs: Vec<Attribute>,
34 vis: Visibility,
35 ident: Ident,
36 methods: Vec<Method>,
37}
38
39#[allow(dead_code)]
40#[derive(Debug)]
41struct Method {
42 attrs: Vec<Attribute>,
43 ident: Ident,
44 args: Vec<PatType>,
45 output: ReturnType,
46 receiver: bool,
47 receiver_mutability: Option<Mut>,
48}
49
50impl Parse for Method {
51 fn parse(input: ParseStream) -> syn::Result<Self> {
52 let attrs = input.call(Attribute::parse_outer)?;
53
54 input.parse::<Token![async]>()?;
55
56 input.parse::<Token![fn]>()?;
57 let ident = input.parse()?;
58 let content;
59 parenthesized!(content in input);
60 let mut args = Vec::new();
61 let mut errors = Ok(());
62 let mut receiver = false;
63 let mut receiver_mutability = None;
64 for arg in content.parse_terminated::<FnArg, Comma>(FnArg::parse)? {
65 match arg {
66 FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => {
67 args.push(captured);
68 }
69 FnArg::Typed(captured) => {
70 extend_errors!(
71 errors,
72 syn::Error::new(captured.pat.span(), "patterns aren't allowed in RPC args")
73 );
74 }
75 FnArg::Receiver(r) => {
76 receiver = true;
77 receiver_mutability = r.mutability.clone();
78 }
79 }
80 }
81 errors?;
82 let output = input.parse()?;
83 input.parse::<Token![;]>()?;
84
85 Ok(Self {
86 attrs,
87 ident,
88 args,
89 output,
90 receiver,
91 receiver_mutability,
92 })
93 }
94}
95
96impl Parse for ServiceMacroInput {
97 fn parse(input: ParseStream) -> Result<Self> {
98 let attrs = input.call(Attribute::parse_outer)?;
99 let vis: Visibility = input.parse()?;
100 input.parse::<Token![trait]>()?;
101 let ident: Ident = input.parse()?;
102
103 let mut methods = Vec::<Method>::new();
104
105 let content;
106 braced!(content in input);
107 while !content.is_empty() {
108 methods.push(content.parse()?);
109 }
110
111 Ok(Self {
112 attrs,
113 vis,
114 ident,
115 methods,
116 })
117 }
118}
119
120struct AttrsInput {
121 other_side: Ident,
122 variant: String,
123}
124
125impl Parse for AttrsInput {
126 fn parse(input: ParseStream) -> Result<Self> {
127 let mut other_side: Option<Ident> = None;
128 let mut variant: Option<String> = None;
129
130 while !input.is_empty() {
131 let ident: Ident = input.parse()?;
132 if ident == "other_side" {
133 input.parse::<Token![=]>()?;
134 other_side = Some(input.parse()?);
135 } else if ident == "variant" {
136 input.parse::<Token![=]>()?;
137 let lit: LitStr = input.parse()?;
138 variant = Some(lit.value());
139 } else {
140 return Err(syn::Error::new_spanned(ident, "Unexpected identifier"));
141 }
142
143 if !input.is_empty() {
145 input.parse::<Token![,]>()?;
146 }
147 }
148
149 Ok(AttrsInput {
150 other_side: other_side.unwrap(),
151 variant: variant.unwrap(),
152 })
153 }
154}
155
156#[proc_macro_attribute]
157pub fn service(attr: TokenStream, original_input: TokenStream) -> TokenStream {
158 let attrs = parse_macro_input!(attr as AttrsInput);
159
160 let derive = quote! {
161 #[derive(Debug, utils::serde::Serialize, utils::serde::Deserialize, Clone)]
162 };
163
164 let cloned = original_input.clone();
165 let input = parse_macro_input!(cloned as ServiceMacroInput);
166 let unit_type: &Type = &parse_quote!(());
167
168 let ident = input.ident;
169 let request_ident = Ident::new(&format!("{}Request", ident), ident.span());
170 let response_ident = Ident::new(&format!("{}Response", ident), ident.span());
171 let message_ident = Ident::new(&format!("{}Message", ident), ident.span());
172 let dummy_ident = Ident::new(&format!("Dummy{}Service", ident), ident.span());
173 let client_ident = Ident::new(&format!("{}Client", ident), ident.span());
174 let mut requests_variants = Vec::new();
175 let mut requests_structs = Vec::new();
176 let mut response_variants = Vec::new();
177 let mut client_methods = Vec::new();
178 let mut service_match_arms = Vec::new();
179
180 let snake_ident = ident.to_string().to_case(Case::Snake);
181 let variant = attrs.variant;
182 #[allow(unused)]
183 let create_named_variant_ident =
184 Ident::new(&format!("create_{snake_ident}_{variant}"), ident.span());
185 let other_side = attrs.other_side;
187 let other_side_client_ident = Ident::new(
188 &format!("{}Client", other_side.to_string()),
189 other_side.span(),
190 );
191
192 let server_or_client_fn = if &variant == "server" {
193 let server_ident = Ident::new(&format!("{}Server", ident), ident.span());
194 quote! {
195 pub struct #server_ident {
196 server: utils::Server
197 }
198
199 impl #server_ident {
200 pub async fn accept<T>(&self, service: T) -> Option<<T as utils::Service<#dummy_ident>>::Client>
201 where T: utils::Service<#dummy_ident> + Clone + 'static {
202 let (sender, receiver, close_receiver) = self.server.accept().await?;
203 let client = utils::Client::new(sender, receiver, service, Some(close_receiver));
204 Some(client)
205 }
206 }
207
208 pub async fn #create_named_variant_ident<A>(addr: A)
209 -> Result<#server_ident, std::io::Error>
210 where
211 A: utils::tokio::net::ToSocketAddrs
212 {
213 let server = utils::create_server(addr).await?;
214 Ok(#server_ident { server })
215 }
216 }
217 } else {
218 let other_side_snake = other_side.to_string().to_case(Case::Snake);
219 let connect_to_ident = Ident::new(&format!("connect_to_{other_side_snake}"), ident.span());
220 quote! {
221 pub async fn #connect_to_ident<A, T>(addr: A, service: T)
222 -> Result<<T as utils::Service<#dummy_ident>>::Client, std::io::Error>
223 where
224 A: utils::tokio::net::ToSocketAddrs,
225 T: utils::Service<#dummy_ident> + Clone + 'static,
226 {
227 let (sender, mut receiver) = utils::create_client(addr).await?;
228 let client = utils::Client::new(sender, receiver, service, None);
229 Ok(client)
230 }
231 }
232 };
233
234 let mut trait_methods = Vec::new();
235
236 for method in input.methods {
237 let receiver = quote! { &self };
238
239 let pascal = method.ident.to_string().to_case(Case::Pascal);
240 let method_ident = method.ident.clone();
241 let method_request_ident =
242 Ident::new(&format!("{}{}Request", ident, pascal), method.ident.span());
243 let request_variant = quote! {
244 #method_request_ident(#method_request_ident)
245 };
246 requests_variants.push(request_variant);
247
248 let method_response_ident =
249 Ident::new(&format!("{}{}Response", ident, pascal), method.ident.span());
250 let return_ty = match method.output {
251 ReturnType::Default => unit_type,
252 ReturnType::Type(_, ref ty) => ty,
253 };
254 response_variants.push(quote! {
255 #method_response_ident(#return_ty)
256 });
257
258 let mut args = Vec::new();
259 let mut arg_names: Vec<Ident> = Vec::new();
260 for arg in method.args.clone() {
261 let ident = match *arg.pat {
262 Pat::Ident(ident) => ident.ident,
263 _ => unreachable!(),
264 };
265 arg_names.push(ident.clone());
266 let ty = arg.ty;
267 args.push(quote! {
268 #ident: #ty
269 });
270 }
271 requests_structs.push(quote! {
272 #derive
273 pub struct #method_request_ident {
274 #(#args),*
275 }
276 });
277
278 let args = method.args;
279 let output_ty = match method.output {
280 ReturnType::Type(_, ref t) => t,
281 ReturnType::Default => unit_type,
282 };
283
284 client_methods.push(quote! {
285 pub async fn #method_ident(#receiver, #(#args),*) -> anyhow::Result<#output_ty> {
287 let response = self.client
288 .request::<#request_ident, #response_ident>(#request_ident::#method_request_ident(
289 #method_request_ident { #(#arg_names),* },
290 )).await?;
291
292 Ok(match response {
293 #response_ident::#method_response_ident(r) => r,
294 _ => unreachable!()
295 })
296 }
297 });
298
299 trait_methods.push(quote! {
305 fn #method_ident(#receiver, client: #other_side_client_ident, #(#args),*) -> impl std::future::Future<Output = #output_ty> + Send;
306 });
307
308 service_match_arms.push(quote! {
309 #request_ident::#method_request_ident(request) => #response_ident::#method_response_ident(self.#method_ident(client, #(request.#arg_names),*).await),
310 });
311 }
312
313 let mut result = quote! {
314 pub trait #ident {
315 #(#trait_methods)*
316 }
317 };
318
319 let impl_service = quote! {
320 impl<T> utils::Service<#dummy_ident> for T
321 where T: #ident + Send + Sync {
322 type Request = #request_ident;
323 type Response = #response_ident;
324 type Client = #other_side_client_ident;
325
326 fn handle_request(
327 &self,
328 client: Self::Client,
329 message: Self::Request,
330 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self::Response> + Send + '_>> {
331 Box::pin(async {
332 match message {
333 #(#service_match_arms)*
334 }
335 })
336 }
337 }
338 };
339
340 let get_close_receiver = quote! {
341 pub async fn get_close_receiver(&self) -> Option<tokio::sync::oneshot::Receiver<()>> {
342 self.client.get_close_receiver().await
343 }
344 };
345
346 let generated = quote! {
347 pub struct #dummy_ident;
348
349 #derive
350 pub enum #request_ident {
351 #(#requests_variants),*
352 }
353
354 #derive
355 pub enum #response_ident {
356 #(#response_variants),*
357 }
358
359 #(#requests_structs)*
360
361 #derive
362 pub enum #message_ident {
363 Request(#request_ident),
364 Response(#response_ident),
365 }
366
367 #[derive(Clone)]
368 pub struct #client_ident {
369 client: utils::Client
371 }
372
373 impl utils::ClientTrait for #client_ident {
374 fn new(client: utils::Client) -> Self {
375 Self {
376 client
377 }
378 }
379 }
380
381 impl #client_ident {
382 pub async fn wait(&self) {
383 self.client.wait().await;
384 }
385
386 #get_close_receiver
387
388 #(#client_methods)*
389 }
390
391 #impl_service
392
393 #server_or_client_fn
394 };
395 result.extend(generated);
396 TokenStream::from(result)
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn it_works() {
405 }
408}