irpc_derive/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    punctuated::Punctuated,
10    spanned::Spanned,
11    token::Comma,
12    Attribute, Data, DeriveInput, Error, Fields, Ident, LitStr, Token, Type, Visibility,
13};
14
15/// Attribute on protocol enums and variants
16const RPC_ATTR_NAME: &str = "rpc";
17/// Attribute on variants to wrap in generated struct
18const WRAP_ATTR_NAME: &str = "wrap";
19/// The tx type name
20const TX_ATTR: &str = "tx";
21/// The rx type name
22const RX_ATTR: &str = "rx";
23/// Fully qualified path to the default rx type
24const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
25/// Fully qualified path to the default tx type
26const DEFAULT_TX_TYPE: &str = "::irpc::channel::none::NoSender";
27
28// See `irpc::rpc_requests` for docs.
29#[proc_macro_attribute]
30pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
31    let mut input = parse_macro_input!(item as DeriveInput);
32    let args = parse_macro_input!(attr as MacroArgs);
33
34    let enum_name = &input.ident;
35    let vis = &input.vis;
36
37    let data_enum = match &mut input.data {
38        Data::Enum(data_enum) => data_enum,
39        _ => {
40            return error_tokens(
41                input.span(),
42                "The rpc_requests macro can only be applied to enums",
43            )
44        }
45    };
46
47    let cfg_feature_rpc = match args.rpc_feature.as_ref() {
48        None => quote!(),
49        Some(feature) => quote!(#[cfg(feature = #feature)]),
50    };
51
52    // Collect trait implementations
53    let mut channel_impls = TokenStream2::new();
54    // Types to check for uniqueness
55    let mut types = HashSet::new();
56    // All variant names and types
57    let mut all_variants = Vec::new();
58    // Variants with rpc attributes (for From implementations)
59    let mut variants_with_attr = Vec::new();
60    // Wrapper types (via wrap attribute)
61    let mut wrapper_types = TokenStream2::new();
62
63    for variant in &mut data_enum.variants {
64        let rpc_attr = match VariantRpcArgs::from_attrs(&mut variant.attrs) {
65            Ok(args) => args,
66            Err(err) => return err.into_compile_error().into(),
67        };
68
69        let request_type = match rpc_attr.wrap {
70            None => match &mut variant.fields {
71                Fields::Unnamed(ref mut fields) if fields.unnamed.len() == 1 => {
72                    fields.unnamed[0].ty.clone()
73                }
74                _ => return error_tokens(
75                    variant.span(),
76                    "Each variant must either have exactly one unnamed field, or use the `wrap` argument in the `rpc` attribute.",
77                ),
78            },
79            Some(WrapArgs { ident, derive, vis }) => {
80                let vis = vis.as_ref().unwrap_or(&input.vis).clone();
81                let ty = type_from_ident(&ident);
82                let struc = struct_from_variant_fields(ident, variant.fields.clone(), variant.attrs.clone(), vis);
83                wrapper_types.extend(quote! {
84                    #[derive(::std::fmt::Debug, ::serde::Serialize, ::serde::Deserialize, #(#derive),* )]
85                    #struc
86                });
87                variant.fields = single_unnamed_field(ty.clone());
88                ty
89            }
90        };
91
92        all_variants.push((variant.ident.clone(), request_type.clone()));
93
94        if !types.insert(request_type.to_token_stream().to_string()) {
95            return error_tokens(
96                variant.span(),
97                "Each variant must have a unique request type",
98            );
99        }
100
101        if let Some(args) = rpc_attr.rpc {
102            variants_with_attr.push((variant.ident.clone(), request_type.clone()));
103            channel_impls.extend(generate_channels_impl(args, enum_name, &request_type))
104        }
105    }
106
107    // Generate From implementations for the original enum (only for variants with rpc attributes)
108    let protocol_enum_from_impls =
109        generate_protocol_enum_from_impls(enum_name, &variants_with_attr);
110
111    // Generate type aliases if requested
112    let type_aliases = if let Some(suffix) = args.alias_suffix {
113        // Use all variants for type aliases, not just those with rpc attributes
114        generate_type_aliases(&all_variants, enum_name, &suffix)
115    } else {
116        quote! {}
117    };
118
119    // Generate the extended message enum if requested
120    let extended_enum_code = if let Some(message_enum_name) = args.message_enum_name.as_ref() {
121        let message_variants = all_variants
122            .iter()
123            .map(|(variant_name, inner_type)| {
124                quote! {
125                    #variant_name(::irpc::WithChannels<#inner_type, #enum_name>)
126                }
127            })
128            .collect::<Vec<_>>();
129
130        // Extract variant names for the parent_span implementation
131        let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
132
133        // Create the message enum definition
134        let doc = format!("Message enum for [`{enum_name}`]");
135        let message_enum = quote! {
136            #[doc = #doc]
137            #[allow(missing_docs)]
138            #[derive(::std::fmt::Debug)]
139            #vis enum #message_enum_name {
140                #(#message_variants),*
141            }
142        };
143
144        // Generate parent_span method
145        let parent_span_impl = if !args.no_spans {
146            generate_parent_span_impl(message_enum_name, &variant_names)
147        } else {
148            quote! {}
149        };
150
151        // Generate From implementations for the message enum (only for variants with rpc attributes)
152        let message_from_impls =
153            generate_message_enum_from_impls(message_enum_name, &variants_with_attr, enum_name);
154
155        let service_impl = quote! {
156            impl ::irpc::Service for #enum_name {
157                type Message = #message_enum_name;
158            }
159        };
160
161        let remote_service_impl = if !args.no_rpc {
162            let block =
163                generate_remote_service_impl(message_enum_name, enum_name, &variants_with_attr);
164            quote! {
165                #cfg_feature_rpc
166                #block
167            }
168        } else {
169            quote! {}
170        };
171
172        quote! {
173            #message_enum
174            #service_impl
175            #remote_service_impl
176            #parent_span_impl
177            #message_from_impls
178        }
179    } else {
180        quote! {}
181    };
182
183    // Combine everything
184    let output = quote! {
185        #input
186
187        // Wrapper types
188        #wrapper_types
189
190        // Channel implementations
191        #channel_impls
192
193        // From implementations for the original enum
194        #protocol_enum_from_impls
195
196        // Type aliases for WithChannels
197        #type_aliases
198
199        // Extended enum and its implementations
200        #extended_enum_code
201    };
202
203    output.into()
204}
205
206/// Generate parent span method for an enum
207fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
208    quote! {
209        impl #enum_name {
210            /// Get the parent span of the message
211            pub fn parent_span(&self) -> ::tracing::Span {
212                let span = match self {
213                    #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
214                };
215                span.cloned().unwrap_or_else(|| ::tracing::Span::current())
216            }
217        }
218    }
219}
220
221fn generate_channels_impl(
222    args: RpcArgs,
223    service_name: &Ident,
224    request_type: &Type,
225) -> TokenStream2 {
226    let rx = args.rx.unwrap_or_else(|| {
227        // We can safely unwrap here because this is a known valid type
228        syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
229    });
230    let tx = args.tx.unwrap_or_else(|| {
231        // We can safely unwrap here because this is a known valid type
232        syn::parse_str::<Type>(DEFAULT_TX_TYPE).expect("Failed to parse default tx type")
233    });
234
235    quote! {
236        impl ::irpc::Channels<#service_name> for #request_type {
237            type Tx = #tx;
238            type Rx = #rx;
239        }
240    }
241}
242
243/// Generates `From` impls for protocol enum variants with an rpc attribute.
244fn generate_protocol_enum_from_impls(
245    enum_name: &Ident,
246    variants_with_attr: &[(Ident, Type)],
247) -> TokenStream2 {
248    variants_with_attr
249        .iter()
250        .map(|(variant_name, inner_type)| {
251            quote! {
252                impl From<#inner_type> for #enum_name {
253                    fn from(value: #inner_type) -> Self {
254                        #enum_name::#variant_name(value)
255                    }
256                }
257            }
258        })
259        .collect()
260}
261
262/// Generate `From<WithChannels<T, Service>>` impls for message enum variants.
263fn generate_message_enum_from_impls(
264    message_enum_name: &Ident,
265    variants_with_attr: &[(Ident, Type)],
266    service_name: &Ident,
267) -> TokenStream2 {
268    variants_with_attr
269        .iter()
270        .map(|(variant_name, inner_type)| {
271            quote! {
272                impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
273                    fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
274                        #message_enum_name::#variant_name(value)
275                    }
276                }
277            }
278        })
279        .collect()
280}
281
282/// Generate `RemoteService` impl for message enums.
283fn generate_remote_service_impl(
284    message_enum_name: &Ident,
285    proto_enum_name: &Ident,
286    variants_with_attr: &[(Ident, Type)],
287) -> TokenStream2 {
288    let variants = variants_with_attr
289        .iter()
290        .map(|(variant_name, _inner_type)| {
291            quote! {
292                #proto_enum_name::#variant_name(msg) => {
293                    #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
294                }
295            }
296        });
297
298    quote! {
299        impl ::irpc::rpc::RemoteService for #proto_enum_name {
300            fn with_remote_channels(
301                self,
302                rx: ::irpc::rpc::quinn::RecvStream,
303                tx: ::irpc::rpc::quinn::SendStream
304            ) -> Self::Message {
305                match self {
306                    #(#variants),*
307                }
308            }
309        }
310    }
311}
312
313/// Generate type aliases for `WithChannels<T, Service>`
314fn generate_type_aliases(
315    variants: &[(Ident, Type)],
316    service_name: &Ident,
317    suffix: &str,
318) -> TokenStream2 {
319    variants
320        .iter()
321        .map(|(variant_name, inner_type)| {
322            // Create a type name using the variant name + suffix
323            // For example: Sum + "Msg" = SumMsg
324            let type_name = format!("{variant_name}{suffix}");
325            let type_ident = Ident::new(&type_name, variant_name.span());
326            quote! {
327                /// Type alias for WithChannels<#inner_type, #service_name>
328                pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
329            }
330        })
331        .collect()
332}
333
334// Parse arguments for the macro
335#[derive(Default)]
336struct MacroArgs {
337    message_enum_name: Option<Ident>,
338    alias_suffix: Option<String>,
339    rpc_feature: Option<String>,
340    no_rpc: bool,
341    no_spans: bool,
342}
343
344impl Parse for MacroArgs {
345    fn parse(input: ParseStream) -> syn::Result<Self> {
346        let mut this = Self::default();
347        loop {
348            let arg: Ident = input.parse()?;
349            match arg.to_string().as_str() {
350                "message" => {
351                    input.parse::<Token![=]>()?;
352                    let value: Ident = input.parse()?;
353                    this.message_enum_name = Some(value);
354                }
355                "alias" => {
356                    input.parse::<Token![=]>()?;
357                    let value: LitStr = input.parse()?;
358                    this.alias_suffix = Some(value.value());
359                }
360                "rpc_feature" => {
361                    input.parse::<Token![=]>()?;
362                    if this.no_rpc {
363                        return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
364                    }
365                    let value: LitStr = input.parse()?;
366                    this.rpc_feature = Some(value.value());
367                }
368                "no_rpc" => {
369                    if this.rpc_feature.is_some() {
370                        return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
371                    }
372                    this.no_rpc = true;
373                }
374                "no_spans" => {
375                    this.no_spans = true;
376                }
377                _ => {
378                    return syn_err(arg.span(), format!("Unknown parameter: {arg}"));
379                }
380            }
381
382            if input.peek(Token![,]) {
383                input.parse::<Token![,]>()?;
384            } else {
385                break;
386            }
387        }
388
389        Ok(this)
390    }
391}
392
393#[derive(Default)]
394struct VariantRpcArgs {
395    wrap: Option<WrapArgs>,
396    rpc: Option<RpcArgs>,
397}
398
399impl VariantRpcArgs {
400    fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
401        let mut this = Self::default();
402        let mut remaining_attrs = Vec::new();
403        for attr in attrs.drain(..) {
404            let ident = attr.path.get_ident().map(|ident| ident.to_string());
405            match ident.as_deref() {
406                Some(RPC_ATTR_NAME) => {
407                    if this.rpc.is_some() {
408                        syn_err(attr.span(), "Each variant can have only one rpc attribute")?;
409                    }
410                    this.rpc = Some(attr.parse_args()?);
411                }
412                Some(WRAP_ATTR_NAME) => {
413                    if this.wrap.is_some() {
414                        syn_err(attr.span(), "Each variant can have only one wrap attribute")?;
415                    }
416                    this.wrap = Some(attr.parse_args()?);
417                }
418                _ => remaining_attrs.push(attr),
419            }
420        }
421        *attrs = remaining_attrs;
422        Ok(this)
423    }
424}
425
426#[derive(Default)]
427struct RpcArgs {
428    rx: Option<Type>,
429    tx: Option<Type>,
430}
431
432/// Parse the rpc args as a comma separated list of name=type pairs
433impl Parse for RpcArgs {
434    fn parse(input: ParseStream) -> syn::Result<Self> {
435        let mut this = Self::default();
436        while !input.is_empty() {
437            let arg: Ident = input.parse()?;
438            let _: Token![=] = input.parse()?;
439            let value: Type = input.parse()?;
440            if arg == RX_ATTR {
441                this.rx = Some(value);
442            } else if arg == TX_ATTR {
443                this.tx = Some(value);
444            } else {
445                syn_err(arg.span(), "Unexpected argument in rpc attribute")?;
446            }
447            if !input.peek(Token![,]) {
448                break;
449            } else {
450                let _: Token![,] = input.parse()?;
451            }
452        }
453
454        Ok(this)
455    }
456}
457
458struct WrapArgs {
459    vis: Option<Visibility>,
460    ident: Ident,
461    derive: Vec<Type>,
462}
463
464impl Parse for WrapArgs {
465    fn parse(input: ParseStream) -> syn::Result<Self> {
466        let vis = match input.parse::<Visibility>()? {
467            Visibility::Inherited => None,
468            vis => Some(vis),
469        };
470        let ident: Ident = input.parse()?;
471        let mut this = Self {
472            ident,
473            derive: Default::default(),
474            vis,
475        };
476        while input.peek(Token![,]) {
477            let _: Token![,] = input.parse()?;
478            let arg: Ident = input.parse()?;
479            match arg.to_string().as_str() {
480                "derive" => {
481                    let content;
482                    syn::parenthesized!(content in input);
483                    let types: Punctuated<Type, Comma> = content.parse_terminated(Type::parse)?;
484                    this.derive = types.into_iter().collect();
485                }
486                _ => syn_err(arg.span(), "Unexpected argument in wrap argument")?,
487            }
488        }
489        if !input.is_empty() {
490            syn_err(input.span(), "Unexpected tokens in wrap argument")?;
491        }
492        Ok(this)
493    }
494}
495
496fn type_from_ident(ident: &Ident) -> Type {
497    Type::Path(syn::TypePath {
498        qself: None,
499        path: syn::Path {
500            leading_colon: None,
501            segments: Punctuated::from_iter([syn::PathSegment::from(ident.clone())]),
502        },
503    })
504}
505
506fn struct_from_variant_fields(
507    ident: Ident,
508    mut fields: Fields,
509    attrs: Vec<Attribute>,
510    vis: Visibility,
511) -> syn::ItemStruct {
512    set_fields_vis(&mut fields, &vis);
513    let span = ident.span();
514    syn::ItemStruct {
515        attrs,
516        vis,
517        struct_token: Token![struct](span),
518        ident,
519        generics: Default::default(),
520        semi_token: match &fields {
521            Fields::Unit => Some(Token![;](span)),
522            Fields::Unnamed(_) => Some(Token![;](span)),
523            Fields::Named(_) => None,
524        },
525        fields,
526    }
527}
528
529fn single_unnamed_field(ty: Type) -> Fields {
530    let field = syn::Field {
531        attrs: vec![],
532        vis: Visibility::Inherited,
533        ident: None,
534        colon_token: None,
535        ty,
536    };
537    Fields::Unnamed(syn::FieldsUnnamed {
538        paren_token: syn::token::Paren(Span::call_site()),
539        unnamed: Punctuated::from_iter([field]),
540    })
541}
542
543fn set_fields_vis(fields: &mut Fields, vis: &Visibility) {
544    let inner = match fields {
545        Fields::Named(ref mut named) => named.named.iter_mut(),
546        Fields::Unnamed(ref mut unnamed) => unnamed.unnamed.iter_mut(),
547        Fields::Unit => return,
548    };
549    for field in inner {
550        field.vis = vis.clone();
551    }
552}
553
554// Helper function for error reporting
555fn error_tokens(span: Span, message: &str) -> TokenStream {
556    Error::new(span, message).to_compile_error().into()
557}
558
559fn syn_err<T>(span: Span, message: impl std::fmt::Display) -> syn::Result<T> {
560    Err(Error::new(span, message))
561}