1use std::collections::{BTreeMap, HashSet};
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 spanned::Spanned,
10 Data, DeriveInput, Fields, Ident, LitStr, Token, Type,
11};
12
13fn error_tokens(span: Span, message: &str) -> TokenStream {
15 syn::Error::new(span, message).to_compile_error().into()
16}
17
18const ATTR_NAME: &str = "rpc";
20const TX_ATTR: &str = "tx";
22const RX_ATTR: &str = "rx";
24const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
26
27fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
29 quote! {
30 impl #enum_name {
31 pub fn parent_span(&self) -> tracing::Span {
33 let span = match self {
34 #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
35 };
36 span.cloned().unwrap_or_else(|| ::tracing::Span::current())
37 }
38 }
39 }
40}
41
42fn generate_channels_impl(
43 mut args: NamedTypeArgs,
44 service_name: &Ident,
45 request_type: &Type,
46 attr_span: Span,
47) -> syn::Result<TokenStream2> {
48 let rx = args.types.remove(RX_ATTR).unwrap_or_else(|| {
51 syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
53 });
54 let tx = args.get(TX_ATTR, attr_span)?;
55
56 let res = quote! {
57 impl ::irpc::Channels<#service_name> for #request_type {
58 type Tx = #tx;
59 type Rx = #rx;
60 }
61 };
62
63 args.check_empty(attr_span)?;
64 Ok(res)
65}
66
67fn generate_case_from_impls(
69 enum_name: &Ident,
70 variants_with_attr: &[(Ident, Type)],
71) -> TokenStream2 {
72 let mut impls = quote! {};
73
74 for (variant_name, inner_type) in variants_with_attr {
76 let impl_tokens = quote! {
77 impl From<#inner_type> for #enum_name {
78 fn from(value: #inner_type) -> Self {
79 #enum_name::#variant_name(value)
80 }
81 }
82 };
83
84 impls = quote! {
85 #impls
86 #impl_tokens
87 };
88 }
89
90 impls
91}
92
93fn generate_message_enum_from_impls(
95 message_enum_name: &Ident,
96 variants_with_attr: &[(Ident, Type)],
97 service_name: &Ident,
98) -> TokenStream2 {
99 let mut impls = quote! {};
100
101 for (variant_name, inner_type) in variants_with_attr {
103 let impl_tokens = quote! {
104 impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
105 fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
106 #message_enum_name::#variant_name(value)
107 }
108 }
109 };
110
111 impls = quote! {
112 #impls
113 #impl_tokens
114 };
115 }
116
117 impls
118}
119
120fn generate_type_aliases(
122 variants: &[(Ident, Type)],
123 service_name: &Ident,
124 suffix: &str,
125) -> TokenStream2 {
126 let mut aliases = quote! {};
127
128 for (variant_name, inner_type) in variants {
129 let type_name = format!("{}{}", variant_name, suffix);
132 let type_ident = Ident::new(&type_name, variant_name.span());
133
134 let alias = quote! {
135 pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
137 };
138
139 aliases = quote! {
140 #aliases
141 #alias
142 };
143 }
144
145 aliases
146}
147
148#[proc_macro_attribute]
203pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
204 let mut input = parse_macro_input!(item as DeriveInput);
205 let args = parse_macro_input!(attr as MacroArgs);
206
207 let service_name = args.service_name;
208 let message_enum_name = args.message_enum_name;
209 let alias_suffix = args.alias_suffix;
210
211 let enum_name = &input.ident;
212 let input_span = input.span();
213
214 let data_enum = match &mut input.data {
215 Data::Enum(data_enum) => data_enum,
216 _ => return error_tokens(input.span(), "RpcRequests can only be applied to enums"),
217 };
218
219 let mut channel_impls = Vec::new();
221 let mut types = HashSet::new();
223 let mut all_variants = Vec::new();
225 let mut variants_with_attr = Vec::new();
227
228 for variant in &mut data_enum.variants {
229 let request_type = match &variant.fields {
231 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed[0].ty,
232 _ => {
233 return error_tokens(
234 variant.span(),
235 "Each variant must have exactly one unnamed field",
236 )
237 }
238 };
239 all_variants.push((variant.ident.clone(), request_type.clone()));
240
241 if !types.insert(request_type.to_token_stream().to_string()) {
242 return error_tokens(input_span, "Each variant must have a unique request type");
243 }
244
245 let mut rpc_attr = None;
247 let mut multiple_rpc_attrs = false;
248
249 variant.attrs.retain(|attr| {
250 if attr.path.is_ident(ATTR_NAME) {
251 if rpc_attr.is_some() {
252 multiple_rpc_attrs = true;
253 true } else {
255 rpc_attr = Some(attr.clone());
256 false }
258 } else {
259 true }
261 });
262
263 if multiple_rpc_attrs {
265 return error_tokens(
266 variant.span(),
267 "Each variant can only have one rpc attribute",
268 );
269 }
270
271 if let Some(attr) = rpc_attr {
273 variants_with_attr.push((variant.ident.clone(), request_type.clone()));
274
275 let args = match attr.parse_args::<NamedTypeArgs>() {
276 Ok(info) => info,
277 Err(e) => return e.to_compile_error().into(),
278 };
279
280 match generate_channels_impl(args, &service_name, request_type, attr.span()) {
281 Ok(impls) => channel_impls.push(impls),
282 Err(e) => return e.to_compile_error().into(),
283 }
284 }
285 }
286
287 let original_from_impls = generate_case_from_impls(enum_name, &variants_with_attr);
289
290 let type_aliases = if let Some(suffix) = alias_suffix {
292 generate_type_aliases(&all_variants, &service_name, &suffix)
294 } else {
295 quote! {}
296 };
297
298 let extended_enum_code = if let Some(message_enum_name) = message_enum_name {
300 let message_variants = all_variants
301 .iter()
302 .map(|(variant_name, inner_type)| {
303 quote! {
304 #[allow(missing_docs)]
305 #variant_name(::irpc::WithChannels<#inner_type, #service_name>)
306 }
307 })
308 .collect::<Vec<_>>();
309
310 let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
312
313 let message_enum = quote! {
315 #[allow(missing_docs)]
316 #[derive(Debug)]
317 pub enum #message_enum_name {
318 #(#message_variants),*
319 }
320 };
321
322 let parent_span_impl = generate_parent_span_impl(&message_enum_name, &variant_names);
324
325 let message_from_impls = generate_message_enum_from_impls(
327 &message_enum_name,
328 &variants_with_attr,
329 &service_name,
330 );
331
332 quote! {
333 #message_enum
334 #parent_span_impl
335 #message_from_impls
336 }
337 } else {
338 quote! {}
340 };
341
342 let output = quote! {
344 #input
345
346 #(#channel_impls)*
348
349 #original_from_impls
351
352 #type_aliases
354
355 #extended_enum_code
357 };
358
359 output.into()
360}
361
362struct MacroArgs {
364 service_name: Ident,
365 message_enum_name: Option<Ident>,
366 alias_suffix: Option<String>,
367}
368
369impl Parse for MacroArgs {
370 fn parse(input: ParseStream) -> syn::Result<Self> {
371 let service_name: Ident = input.parse()?;
373
374 let mut message_enum_name = None;
376 let mut alias_suffix = None;
377
378 while input.peek(Token![,]) {
380 input.parse::<Token![,]>()?;
381 let param_name: Ident = input.parse()?;
382 input.parse::<Token![=]>()?;
383
384 match param_name.to_string().as_str() {
385 "message" => {
386 message_enum_name = Some(input.parse()?);
387 }
388 "alias" => {
389 let lit: LitStr = input.parse()?;
390 alias_suffix = Some(lit.value());
391 }
392 _ => {
393 return Err(syn::Error::new(
394 param_name.span(),
395 format!("Unknown parameter: {}", param_name),
396 ));
397 }
398 }
399 }
400
401 Ok(MacroArgs {
402 service_name,
403 message_enum_name,
404 alias_suffix,
405 })
406 }
407}
408
409struct NamedTypeArgs {
410 types: BTreeMap<String, Type>,
411}
412
413impl NamedTypeArgs {
414 fn get(&mut self, key: &str, span: Span) -> syn::Result<Type> {
416 self.types
417 .remove(key)
418 .ok_or_else(|| syn::Error::new(span, format!("rpc requires a {key} type")))
419 }
420
421 fn check_empty(&self, span: Span) -> syn::Result<()> {
423 if self.types.is_empty() {
424 Ok(())
425 } else {
426 Err(syn::Error::new(
427 span,
428 format!(
429 "Unknown arguments provided: {:?}",
430 self.types.keys().collect::<Vec<_>>()
431 ),
432 ))
433 }
434 }
435}
436
437impl Parse for NamedTypeArgs {
439 fn parse(input: ParseStream) -> syn::Result<Self> {
440 let mut types = BTreeMap::new();
441
442 loop {
443 if input.is_empty() {
444 break;
445 }
446
447 let key: Ident = input.parse()?;
448 let _: Token![=] = input.parse()?;
449 let value: Type = input.parse()?;
450
451 types.insert(key.to_string(), value);
452
453 if !input.peek(Token![,]) {
454 break;
455 }
456 let _: Token![,] = input.parse()?;
457 }
458
459 Ok(NamedTypeArgs { types })
460 }
461}