1use std::collections::{BTreeMap, HashSet};
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 spanned::Spanned,
10 Data, DeriveInput, Fields, Ident, LitStr, Token, Type,
11};
12
13fn error_tokens(span: Span, message: &str) -> TokenStream {
15 syn::Error::new(span, message).to_compile_error().into()
16}
17
18const ATTR_NAME: &str = "rpc";
20const TX_ATTR: &str = "tx";
22const RX_ATTR: &str = "rx";
24const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
26
27fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
29 quote! {
30 impl #enum_name {
31 pub fn parent_span(&self) -> tracing::Span {
33 let span = match self {
34 #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
35 };
36 span.cloned().unwrap_or_else(|| ::tracing::Span::current())
37 }
38 }
39 }
40}
41
42fn generate_channels_impl(
43 mut args: NamedTypeArgs,
44 service_name: &Ident,
45 request_type: &Type,
46 attr_span: Span,
47) -> syn::Result<TokenStream2> {
48 let rx = args.types.remove(RX_ATTR).unwrap_or_else(|| {
51 syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
53 });
54 let tx = args.get(TX_ATTR, attr_span)?;
55
56 let res = quote! {
57 impl ::irpc::Channels<#service_name> for #request_type {
58 type Tx = #tx;
59 type Rx = #rx;
60 }
61 };
62
63 args.check_empty(attr_span)?;
64 Ok(res)
65}
66
67fn generate_case_from_impls(
69 enum_name: &Ident,
70 variants_with_attr: &[(Ident, Type)],
71) -> TokenStream2 {
72 let mut impls = quote! {};
73
74 for (variant_name, inner_type) in variants_with_attr {
76 let impl_tokens = quote! {
77 impl From<#inner_type> for #enum_name {
78 fn from(value: #inner_type) -> Self {
79 #enum_name::#variant_name(value)
80 }
81 }
82 };
83
84 impls = quote! {
85 #impls
86 #impl_tokens
87 };
88 }
89
90 impls
91}
92
93fn generate_message_enum_from_impls(
95 message_enum_name: &Ident,
96 variants_with_attr: &[(Ident, Type)],
97 service_name: &Ident,
98) -> TokenStream2 {
99 let mut impls = quote! {};
100
101 for (variant_name, inner_type) in variants_with_attr {
103 let impl_tokens = quote! {
104 impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
105 fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
106 #message_enum_name::#variant_name(value)
107 }
108 }
109 };
110
111 impls = quote! {
112 #impls
113 #impl_tokens
114 };
115 }
116
117 impls
118}
119
120fn generate_type_aliases(
122 variants: &[(Ident, Type)],
123 service_name: &Ident,
124 suffix: &str,
125) -> TokenStream2 {
126 let mut aliases = quote! {};
127
128 for (variant_name, inner_type) in variants {
129 let type_name = format!("{}{}", variant_name, suffix);
132 let type_ident = Ident::new(&type_name, variant_name.span());
133
134 let alias = quote! {
135 pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
137 };
138
139 aliases = quote! {
140 #aliases
141 #alias
142 };
143 }
144
145 aliases
146}
147
148#[proc_macro_attribute]
203pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
204 let mut input = parse_macro_input!(item as DeriveInput);
205 let args = parse_macro_input!(attr as MacroArgs);
206
207 let service_name = args.service_name;
208 let message_enum_name = args.message_enum_name;
209 let alias_suffix = args.alias_suffix;
210
211 let enum_name = &input.ident;
212 let input_span = input.span();
213
214 let data_enum = match &mut input.data {
215 Data::Enum(data_enum) => data_enum,
216 _ => return error_tokens(input.span(), "RpcRequests can only be applied to enums"),
217 };
218
219 let mut channel_impls = Vec::new();
221 let mut types = HashSet::new();
223 let mut all_variants = Vec::new();
225 let mut variants_with_attr = Vec::new();
227
228 for variant in &mut data_enum.variants {
229 let request_type = match &variant.fields {
231 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed[0].ty,
232 _ => {
233 return error_tokens(
234 variant.span(),
235 "Each variant must have exactly one unnamed field",
236 )
237 }
238 };
239 all_variants.push((variant.ident.clone(), request_type.clone()));
240
241 if !types.insert(request_type.to_token_stream().to_string()) {
242 return error_tokens(input_span, "Each variant must have a unique request type");
243 }
244
245 let mut rpc_attr = None;
247 let mut multiple_rpc_attrs = false;
248
249 variant.attrs.retain(|attr| {
250 if attr.path.is_ident(ATTR_NAME) {
251 if rpc_attr.is_some() {
252 multiple_rpc_attrs = true;
253 true } else {
255 rpc_attr = Some(attr.clone());
256 false }
258 } else {
259 true }
261 });
262
263 if multiple_rpc_attrs {
265 return error_tokens(
266 variant.span(),
267 "Each variant can only have one rpc attribute",
268 );
269 }
270
271 if let Some(attr) = rpc_attr {
273 variants_with_attr.push((variant.ident.clone(), request_type.clone()));
274
275 let args = match attr.parse_args::<NamedTypeArgs>() {
276 Ok(info) => info,
277 Err(e) => return e.to_compile_error().into(),
278 };
279
280 match generate_channels_impl(args, &service_name, request_type, attr.span()) {
281 Ok(impls) => channel_impls.push(impls),
282 Err(e) => return e.to_compile_error().into(),
283 }
284 }
285 }
286
287 let original_from_impls = generate_case_from_impls(enum_name, &variants_with_attr);
289
290 let type_aliases = if let Some(suffix) = alias_suffix {
292 generate_type_aliases(&all_variants, &service_name, &suffix)
294 } else {
295 quote! {}
296 };
297
298 let extended_enum_code = if let Some(message_enum_name) = message_enum_name {
300 let message_variants = all_variants
301 .iter()
302 .map(|(variant_name, inner_type)| {
303 quote! {
304 #variant_name(::irpc::WithChannels<#inner_type, #service_name>)
305 }
306 })
307 .collect::<Vec<_>>();
308
309 let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
311
312 let message_enum = quote! {
314 #[derive(Debug)]
315 pub enum #message_enum_name {
316 #(#message_variants),*
317 }
318 };
319
320 let parent_span_impl = generate_parent_span_impl(&message_enum_name, &variant_names);
322
323 let message_from_impls = generate_message_enum_from_impls(
325 &message_enum_name,
326 &variants_with_attr,
327 &service_name,
328 );
329
330 quote! {
331 #message_enum
332 #parent_span_impl
333 #message_from_impls
334 }
335 } else {
336 quote! {}
338 };
339
340 let output = quote! {
342 #input
343
344 #(#channel_impls)*
346
347 #original_from_impls
349
350 #type_aliases
352
353 #extended_enum_code
355 };
356
357 output.into()
358}
359
360struct MacroArgs {
362 service_name: Ident,
363 message_enum_name: Option<Ident>,
364 alias_suffix: Option<String>,
365}
366
367impl Parse for MacroArgs {
368 fn parse(input: ParseStream) -> syn::Result<Self> {
369 let service_name: Ident = input.parse()?;
371
372 let mut message_enum_name = None;
374 let mut alias_suffix = None;
375
376 while input.peek(Token![,]) {
378 input.parse::<Token![,]>()?;
379 let param_name: Ident = input.parse()?;
380 input.parse::<Token![=]>()?;
381
382 match param_name.to_string().as_str() {
383 "message" => {
384 message_enum_name = Some(input.parse()?);
385 }
386 "alias" => {
387 let lit: LitStr = input.parse()?;
388 alias_suffix = Some(lit.value());
389 }
390 _ => {
391 return Err(syn::Error::new(
392 param_name.span(),
393 format!("Unknown parameter: {}", param_name),
394 ));
395 }
396 }
397 }
398
399 Ok(MacroArgs {
400 service_name,
401 message_enum_name,
402 alias_suffix,
403 })
404 }
405}
406
407struct NamedTypeArgs {
408 types: BTreeMap<String, Type>,
409}
410
411impl NamedTypeArgs {
412 fn get(&mut self, key: &str, span: Span) -> syn::Result<Type> {
414 self.types
415 .remove(key)
416 .ok_or_else(|| syn::Error::new(span, format!("rpc requires a {key} type")))
417 }
418
419 fn check_empty(&self, span: Span) -> syn::Result<()> {
421 if self.types.is_empty() {
422 Ok(())
423 } else {
424 Err(syn::Error::new(
425 span,
426 format!(
427 "Unknown arguments provided: {:?}",
428 self.types.keys().collect::<Vec<_>>()
429 ),
430 ))
431 }
432 }
433}
434
435impl Parse for NamedTypeArgs {
437 fn parse(input: ParseStream) -> syn::Result<Self> {
438 let mut types = BTreeMap::new();
439
440 loop {
441 if input.is_empty() {
442 break;
443 }
444
445 let key: Ident = input.parse()?;
446 let _: Token![=] = input.parse()?;
447 let value: Type = input.parse()?;
448
449 types.insert(key.to_string(), value);
450
451 if !input.peek(Token![,]) {
452 break;
453 }
454 let _: Token![,] = input.parse()?;
455 }
456
457 Ok(NamedTypeArgs { types })
458 }
459}