1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 punctuated::Punctuated,
10 spanned::Spanned,
11 token::Comma,
12 Attribute, Data, DeriveInput, Error, Fields, Ident, LitStr, Token, Type, Visibility,
13};
14
15const RPC_ATTR_NAME: &str = "rpc";
17const WRAP_ATTR_NAME: &str = "wrap";
19const TX_ATTR: &str = "tx";
21const RX_ATTR: &str = "rx";
23const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
25const DEFAULT_TX_TYPE: &str = "::irpc::channel::none::NoSender";
27
28#[proc_macro_attribute]
30pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
31 let mut input = parse_macro_input!(item as DeriveInput);
32 let args = parse_macro_input!(attr as MacroArgs);
33
34 let enum_name = &input.ident;
35 let vis = &input.vis;
36
37 let data_enum = match &mut input.data {
38 Data::Enum(data_enum) => data_enum,
39 _ => {
40 return error_tokens(
41 input.span(),
42 "The rpc_requests macro can only be applied to enums",
43 )
44 }
45 };
46
47 let cfg_feature_rpc = match args.rpc_feature.as_ref() {
48 None => quote!(),
49 Some(feature) => quote!(#[cfg(feature = #feature)]),
50 };
51
52 let mut channel_impls = TokenStream2::new();
54 let mut types = HashSet::new();
56 let mut all_variants = Vec::new();
58 let mut variants_with_attr = Vec::new();
60 let mut wrapper_types = TokenStream2::new();
62
63 for variant in &mut data_enum.variants {
64 let rpc_attr = match VariantRpcArgs::from_attrs(&mut variant.attrs) {
65 Ok(args) => args,
66 Err(err) => return err.into_compile_error().into(),
67 };
68
69 let request_type = match rpc_attr.wrap {
70 None => match &mut variant.fields {
71 Fields::Unnamed(ref mut fields) if fields.unnamed.len() == 1 => {
72 fields.unnamed[0].ty.clone()
73 }
74 _ => return error_tokens(
75 variant.span(),
76 "Each variant must either have exactly one unnamed field, or use the `wrap` argument in the `rpc` attribute.",
77 ),
78 },
79 Some(WrapArgs { ident, derive, vis }) => {
80 let vis = vis.as_ref().unwrap_or(&input.vis).clone();
81 let ty = type_from_ident(&ident);
82 let struc = struct_from_variant_fields(ident, variant.fields.clone(), variant.attrs.clone(), vis);
83 wrapper_types.extend(quote! {
84 #[derive(::std::fmt::Debug, ::serde::Serialize, ::serde::Deserialize, #(#derive),* )]
85 #struc
86 });
87 variant.fields = single_unnamed_field(ty.clone());
88 ty
89 }
90 };
91
92 all_variants.push((variant.ident.clone(), request_type.clone()));
93
94 if !types.insert(request_type.to_token_stream().to_string()) {
95 return error_tokens(
96 variant.span(),
97 "Each variant must have a unique request type",
98 );
99 }
100
101 if let Some(args) = rpc_attr.rpc {
102 variants_with_attr.push((variant.ident.clone(), request_type.clone()));
103 channel_impls.extend(generate_channels_impl(args, enum_name, &request_type))
104 }
105 }
106
107 let protocol_enum_from_impls =
109 generate_protocol_enum_from_impls(enum_name, &variants_with_attr);
110
111 let type_aliases = if let Some(suffix) = args.alias_suffix {
113 generate_type_aliases(&all_variants, enum_name, &suffix)
115 } else {
116 quote! {}
117 };
118
119 let extended_enum_code = if let Some(message_enum_name) = args.message_enum_name.as_ref() {
121 let message_variants = all_variants
122 .iter()
123 .map(|(variant_name, inner_type)| {
124 quote! {
125 #variant_name(::irpc::WithChannels<#inner_type, #enum_name>)
126 }
127 })
128 .collect::<Vec<_>>();
129
130 let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
132
133 let doc = format!("Message enum for [`{enum_name}`]");
135 let message_enum = quote! {
136 #[doc = #doc]
137 #[allow(missing_docs)]
138 #[derive(::std::fmt::Debug)]
139 #vis enum #message_enum_name {
140 #(#message_variants),*
141 }
142 };
143
144 let parent_span_impl = if !args.no_spans {
146 generate_parent_span_impl(message_enum_name, &variant_names)
147 } else {
148 quote! {}
149 };
150
151 let message_from_impls =
153 generate_message_enum_from_impls(message_enum_name, &variants_with_attr, enum_name);
154
155 let service_impl = quote! {
156 impl ::irpc::Service for #enum_name {
157 type Message = #message_enum_name;
158 }
159 };
160
161 let remote_service_impl = if !args.no_rpc {
162 let block =
163 generate_remote_service_impl(message_enum_name, enum_name, &variants_with_attr);
164 quote! {
165 #cfg_feature_rpc
166 #block
167 }
168 } else {
169 quote! {}
170 };
171
172 quote! {
173 #message_enum
174 #service_impl
175 #remote_service_impl
176 #parent_span_impl
177 #message_from_impls
178 }
179 } else {
180 quote! {}
181 };
182
183 let output = quote! {
185 #input
186
187 #wrapper_types
189
190 #channel_impls
192
193 #protocol_enum_from_impls
195
196 #type_aliases
198
199 #extended_enum_code
201 };
202
203 output.into()
204}
205
206fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
208 quote! {
209 impl #enum_name {
210 pub fn parent_span(&self) -> ::tracing::Span {
212 let span = match self {
213 #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
214 };
215 span.cloned().unwrap_or_else(|| ::tracing::Span::current())
216 }
217 }
218 }
219}
220
221fn generate_channels_impl(
222 args: RpcArgs,
223 service_name: &Ident,
224 request_type: &Type,
225) -> TokenStream2 {
226 let rx = args.rx.unwrap_or_else(|| {
227 syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
229 });
230 let tx = args.tx.unwrap_or_else(|| {
231 syn::parse_str::<Type>(DEFAULT_TX_TYPE).expect("Failed to parse default tx type")
233 });
234
235 quote! {
236 impl ::irpc::Channels<#service_name> for #request_type {
237 type Tx = #tx;
238 type Rx = #rx;
239 }
240 }
241}
242
243fn generate_protocol_enum_from_impls(
245 enum_name: &Ident,
246 variants_with_attr: &[(Ident, Type)],
247) -> TokenStream2 {
248 variants_with_attr
249 .iter()
250 .map(|(variant_name, inner_type)| {
251 quote! {
252 impl From<#inner_type> for #enum_name {
253 fn from(value: #inner_type) -> Self {
254 #enum_name::#variant_name(value)
255 }
256 }
257 }
258 })
259 .collect()
260}
261
262fn generate_message_enum_from_impls(
264 message_enum_name: &Ident,
265 variants_with_attr: &[(Ident, Type)],
266 service_name: &Ident,
267) -> TokenStream2 {
268 variants_with_attr
269 .iter()
270 .map(|(variant_name, inner_type)| {
271 quote! {
272 impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
273 fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
274 #message_enum_name::#variant_name(value)
275 }
276 }
277 }
278 })
279 .collect()
280}
281
282fn generate_remote_service_impl(
284 message_enum_name: &Ident,
285 proto_enum_name: &Ident,
286 variants_with_attr: &[(Ident, Type)],
287) -> TokenStream2 {
288 let variants = variants_with_attr
289 .iter()
290 .map(|(variant_name, _inner_type)| {
291 quote! {
292 #proto_enum_name::#variant_name(msg) => {
293 #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
294 }
295 }
296 });
297
298 quote! {
299 impl ::irpc::rpc::RemoteService for #proto_enum_name {
300 fn with_remote_channels(
301 self,
302 rx: ::irpc::rpc::quinn::RecvStream,
303 tx: ::irpc::rpc::quinn::SendStream
304 ) -> Self::Message {
305 match self {
306 #(#variants),*
307 }
308 }
309 }
310 }
311}
312
313fn generate_type_aliases(
315 variants: &[(Ident, Type)],
316 service_name: &Ident,
317 suffix: &str,
318) -> TokenStream2 {
319 variants
320 .iter()
321 .map(|(variant_name, inner_type)| {
322 let type_name = format!("{variant_name}{suffix}");
325 let type_ident = Ident::new(&type_name, variant_name.span());
326 quote! {
327 pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
329 }
330 })
331 .collect()
332}
333
334#[derive(Default)]
336struct MacroArgs {
337 message_enum_name: Option<Ident>,
338 alias_suffix: Option<String>,
339 rpc_feature: Option<String>,
340 no_rpc: bool,
341 no_spans: bool,
342}
343
344impl Parse for MacroArgs {
345 fn parse(input: ParseStream) -> syn::Result<Self> {
346 let mut this = Self::default();
347 loop {
348 let arg: Ident = input.parse()?;
349 match arg.to_string().as_str() {
350 "message" => {
351 input.parse::<Token![=]>()?;
352 let value: Ident = input.parse()?;
353 this.message_enum_name = Some(value);
354 }
355 "alias" => {
356 input.parse::<Token![=]>()?;
357 let value: LitStr = input.parse()?;
358 this.alias_suffix = Some(value.value());
359 }
360 "rpc_feature" => {
361 input.parse::<Token![=]>()?;
362 if this.no_rpc {
363 return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
364 }
365 let value: LitStr = input.parse()?;
366 this.rpc_feature = Some(value.value());
367 }
368 "no_rpc" => {
369 if this.rpc_feature.is_some() {
370 return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
371 }
372 this.no_rpc = true;
373 }
374 "no_spans" => {
375 this.no_spans = true;
376 }
377 _ => {
378 return syn_err(arg.span(), format!("Unknown parameter: {arg}"));
379 }
380 }
381
382 if input.peek(Token![,]) {
383 input.parse::<Token![,]>()?;
384 } else {
385 break;
386 }
387 }
388
389 Ok(this)
390 }
391}
392
393#[derive(Default)]
394struct VariantRpcArgs {
395 wrap: Option<WrapArgs>,
396 rpc: Option<RpcArgs>,
397}
398
399impl VariantRpcArgs {
400 fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
401 let mut this = Self::default();
402 let mut remaining_attrs = Vec::new();
403 for attr in attrs.drain(..) {
404 let ident = attr.path.get_ident().map(|ident| ident.to_string());
405 match ident.as_deref() {
406 Some(RPC_ATTR_NAME) => {
407 if this.rpc.is_some() {
408 syn_err(attr.span(), "Each variant can have only one rpc attribute")?;
409 }
410 this.rpc = Some(attr.parse_args()?);
411 }
412 Some(WRAP_ATTR_NAME) => {
413 if this.wrap.is_some() {
414 syn_err(attr.span(), "Each variant can have only one wrap attribute")?;
415 }
416 this.wrap = Some(attr.parse_args()?);
417 }
418 _ => remaining_attrs.push(attr),
419 }
420 }
421 *attrs = remaining_attrs;
422 Ok(this)
423 }
424}
425
426#[derive(Default)]
427struct RpcArgs {
428 rx: Option<Type>,
429 tx: Option<Type>,
430}
431
432impl Parse for RpcArgs {
434 fn parse(input: ParseStream) -> syn::Result<Self> {
435 let mut this = Self::default();
436 while !input.is_empty() {
437 let arg: Ident = input.parse()?;
438 let _: Token![=] = input.parse()?;
439 let value: Type = input.parse()?;
440 if arg == RX_ATTR {
441 this.rx = Some(value);
442 } else if arg == TX_ATTR {
443 this.tx = Some(value);
444 } else {
445 syn_err(arg.span(), "Unexpected argument in rpc attribute")?;
446 }
447 if !input.peek(Token![,]) {
448 break;
449 } else {
450 let _: Token![,] = input.parse()?;
451 }
452 }
453
454 Ok(this)
455 }
456}
457
458struct WrapArgs {
459 vis: Option<Visibility>,
460 ident: Ident,
461 derive: Vec<Type>,
462}
463
464impl Parse for WrapArgs {
465 fn parse(input: ParseStream) -> syn::Result<Self> {
466 let vis = match input.parse::<Visibility>()? {
467 Visibility::Inherited => None,
468 vis => Some(vis),
469 };
470 let ident: Ident = input.parse()?;
471 let mut this = Self {
472 ident,
473 derive: Default::default(),
474 vis,
475 };
476 while input.peek(Token![,]) {
477 let _: Token![,] = input.parse()?;
478 let arg: Ident = input.parse()?;
479 match arg.to_string().as_str() {
480 "derive" => {
481 let content;
482 syn::parenthesized!(content in input);
483 let types: Punctuated<Type, Comma> = content.parse_terminated(Type::parse)?;
484 this.derive = types.into_iter().collect();
485 }
486 _ => syn_err(arg.span(), "Unexpected argument in wrap argument")?,
487 }
488 }
489 if !input.is_empty() {
490 syn_err(input.span(), "Unexpected tokens in wrap argument")?;
491 }
492 Ok(this)
493 }
494}
495
496fn type_from_ident(ident: &Ident) -> Type {
497 Type::Path(syn::TypePath {
498 qself: None,
499 path: syn::Path {
500 leading_colon: None,
501 segments: Punctuated::from_iter([syn::PathSegment::from(ident.clone())]),
502 },
503 })
504}
505
506fn struct_from_variant_fields(
507 ident: Ident,
508 mut fields: Fields,
509 attrs: Vec<Attribute>,
510 vis: Visibility,
511) -> syn::ItemStruct {
512 set_fields_vis(&mut fields, &vis);
513 let span = ident.span();
514 syn::ItemStruct {
515 attrs,
516 vis,
517 struct_token: Token,
518 ident,
519 generics: Default::default(),
520 semi_token: match &fields {
521 Fields::Unit => Some(Token),
522 Fields::Unnamed(_) => Some(Token),
523 Fields::Named(_) => None,
524 },
525 fields,
526 }
527}
528
529fn single_unnamed_field(ty: Type) -> Fields {
530 let field = syn::Field {
531 attrs: vec![],
532 vis: Visibility::Inherited,
533 ident: None,
534 colon_token: None,
535 ty,
536 };
537 Fields::Unnamed(syn::FieldsUnnamed {
538 paren_token: syn::token::Paren(Span::call_site()),
539 unnamed: Punctuated::from_iter([field]),
540 })
541}
542
543fn set_fields_vis(fields: &mut Fields, vis: &Visibility) {
544 let inner = match fields {
545 Fields::Named(ref mut named) => named.named.iter_mut(),
546 Fields::Unnamed(ref mut unnamed) => unnamed.unnamed.iter_mut(),
547 Fields::Unit => return,
548 };
549 for field in inner {
550 field.vis = vis.clone();
551 }
552}
553
554fn error_tokens(span: Span, message: &str) -> TokenStream {
556 Error::new(span, message).to_compile_error().into()
557}
558
559fn syn_err<T>(span: Span, message: impl std::fmt::Display) -> syn::Result<T> {
560 Err(Error::new(span, message))
561}