Skip to main content

irpc_derive/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{ToTokens, quote};
6use syn::{
7    Attribute, Data, DeriveInput, Error, Fields, Ident, LitStr, Token, Type, Visibility,
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    spanned::Spanned,
12    token::Comma,
13};
14
15/// Attribute on protocol enums and variants
16const RPC_ATTR_NAME: &str = "rpc";
17/// Attribute on variants to wrap in generated struct
18const WRAP_ATTR_NAME: &str = "wrap";
19/// The tx type name
20const TX_ATTR: &str = "tx";
21/// The rx type name
22const RX_ATTR: &str = "rx";
23/// Fully qualified path to the default rx type
24const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
25/// Fully qualified path to the default tx type
26const DEFAULT_TX_TYPE: &str = "::irpc::channel::none::NoSender";
27
28// See `irpc::rpc_requests` for docs.
29#[proc_macro_attribute]
30pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
31    let mut input = parse_macro_input!(item as DeriveInput);
32    let args = parse_macro_input!(attr as MacroArgs);
33
34    let enum_name = &input.ident;
35    let vis = &input.vis;
36
37    let data_enum = match &mut input.data {
38        Data::Enum(data_enum) => data_enum,
39        _ => {
40            return error_tokens(
41                input.span(),
42                "The rpc_requests macro can only be applied to enums",
43            );
44        }
45    };
46
47    let cfg_feature_rpc = match args.rpc_feature.as_ref() {
48        None => quote!(),
49        Some(feature) => quote!(#[cfg(feature = #feature)]),
50    };
51
52    // Collect trait implementations
53    let mut channel_impls = TokenStream2::new();
54    // Types to check for uniqueness
55    let mut types = HashSet::new();
56    // All variant names and types
57    let mut all_variants = Vec::new();
58    // Variants with rpc attributes (for From implementations)
59    let mut variants_with_attr = Vec::new();
60    // Wrapper types (via wrap attribute)
61    let mut wrapper_types = TokenStream2::new();
62
63    for variant in &mut data_enum.variants {
64        let rpc_attr = match VariantRpcArgs::from_attrs(&mut variant.attrs) {
65            Ok(args) => args,
66            Err(err) => return err.into_compile_error().into(),
67        };
68
69        let request_type = match rpc_attr.wrap {
70            None => match &mut variant.fields {
71                Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
72                    fields.unnamed[0].ty.clone()
73                }
74                _ => {
75                    return error_tokens(
76                        variant.span(),
77                        "Each variant must either have exactly one unnamed field, or use the `wrap` argument in the `rpc` attribute.",
78                    );
79                }
80            },
81            Some(WrapArgs { ident, derive, vis }) => {
82                let vis = vis.as_ref().unwrap_or(&input.vis).clone();
83                let ty = type_from_ident(&ident);
84                let struc = struct_from_variant_fields(
85                    ident,
86                    variant.fields.clone(),
87                    variant.attrs.clone(),
88                    vis,
89                );
90                wrapper_types.extend(quote! {
91                    #[derive(::std::fmt::Debug, ::serde::Serialize, ::serde::Deserialize, #(#derive),* )]
92                    #struc
93                });
94                variant.fields = single_unnamed_field(ty.clone());
95                ty
96            }
97        };
98
99        all_variants.push((variant.ident.clone(), request_type.clone()));
100
101        if !types.insert(request_type.to_token_stream().to_string()) {
102            return error_tokens(
103                variant.span(),
104                "Each variant must have a unique request type",
105            );
106        }
107
108        if let Some(args) = rpc_attr.rpc {
109            variants_with_attr.push((variant.ident.clone(), request_type.clone()));
110            channel_impls.extend(generate_channels_impl(args, enum_name, &request_type))
111        }
112    }
113
114    // Generate From implementations for the original enum (only for variants with rpc attributes)
115    let protocol_enum_from_impls =
116        generate_protocol_enum_from_impls(enum_name, &variants_with_attr);
117
118    // Generate type aliases if requested
119    let type_aliases = if let Some(suffix) = args.alias_suffix {
120        // Use all variants for type aliases, not just those with rpc attributes
121        generate_type_aliases(&all_variants, enum_name, &suffix)
122    } else {
123        quote! {}
124    };
125
126    // Generate the extended message enum if requested
127    let extended_enum_code = if let Some(message_enum_name) = args.message_enum_name.as_ref() {
128        let message_variants = all_variants
129            .iter()
130            .map(|(variant_name, inner_type)| {
131                quote! {
132                    #variant_name(::irpc::WithChannels<#inner_type, #enum_name>)
133                }
134            })
135            .collect::<Vec<_>>();
136
137        // Extract variant names for the parent_span implementation
138        let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
139
140        // Create the message enum definition
141        let doc = format!("Message enum for [`{enum_name}`]");
142        let message_enum = quote! {
143            #[doc = #doc]
144            #[allow(missing_docs)]
145            #[derive(::std::fmt::Debug)]
146            #vis enum #message_enum_name {
147                #(#message_variants),*
148            }
149        };
150
151        // Generate parent_span method
152        let parent_span_impl = if !args.no_spans {
153            generate_parent_span_impl(message_enum_name, &variant_names)
154        } else {
155            quote! {}
156        };
157
158        // Generate From implementations for the message enum (only for variants with rpc attributes)
159        let message_from_impls =
160            generate_message_enum_from_impls(message_enum_name, &variants_with_attr, enum_name);
161
162        let span_propagation = args.span_propagation;
163        let service_impl = quote! {
164            impl ::irpc::Service for #enum_name {
165                type Message = #message_enum_name;
166                const SPAN_PROPAGATION: bool = #span_propagation;
167            }
168        };
169
170        let remote_service_impl = if !args.no_rpc {
171            let block = generate_remote_service_impl(
172                message_enum_name,
173                enum_name,
174                &variants_with_attr,
175                args.span_propagation,
176            );
177            quote! {
178                #cfg_feature_rpc
179                #block
180            }
181        } else {
182            quote! {}
183        };
184
185        quote! {
186            #message_enum
187            #service_impl
188            #remote_service_impl
189            #parent_span_impl
190            #message_from_impls
191        }
192    } else {
193        quote! {}
194    };
195
196    // Combine everything
197    let output = quote! {
198        #input
199
200        // Wrapper types
201        #wrapper_types
202
203        // Channel implementations
204        #channel_impls
205
206        // From implementations for the original enum
207        #protocol_enum_from_impls
208
209        // Type aliases for WithChannels
210        #type_aliases
211
212        // Extended enum and its implementations
213        #extended_enum_code
214    };
215
216    output.into()
217}
218
219/// Generate parent span method for an enum
220fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
221    quote! {
222        impl #enum_name {
223            /// Get the parent span of the message
224            pub fn parent_span(&self) -> ::tracing::Span {
225                let span = match self {
226                    #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
227                };
228                span.cloned().unwrap_or_else(|| ::tracing::Span::current())
229            }
230        }
231    }
232}
233
234fn generate_channels_impl(
235    args: RpcArgs,
236    service_name: &Ident,
237    request_type: &Type,
238) -> TokenStream2 {
239    let rx = args.rx.unwrap_or_else(|| {
240        // We can safely unwrap here because this is a known valid type
241        syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
242    });
243    let tx = args.tx.unwrap_or_else(|| {
244        // We can safely unwrap here because this is a known valid type
245        syn::parse_str::<Type>(DEFAULT_TX_TYPE).expect("Failed to parse default tx type")
246    });
247
248    quote! {
249        impl ::irpc::Channels<#service_name> for #request_type {
250            type Tx = #tx;
251            type Rx = #rx;
252        }
253    }
254}
255
256/// Generates `From` impls for protocol enum variants with an rpc attribute.
257fn generate_protocol_enum_from_impls(
258    enum_name: &Ident,
259    variants_with_attr: &[(Ident, Type)],
260) -> TokenStream2 {
261    variants_with_attr
262        .iter()
263        .map(|(variant_name, inner_type)| {
264            quote! {
265                impl From<#inner_type> for #enum_name {
266                    fn from(value: #inner_type) -> Self {
267                        #enum_name::#variant_name(value)
268                    }
269                }
270            }
271        })
272        .collect()
273}
274
275/// Generate `From<WithChannels<T, Service>>` impls for message enum variants.
276fn generate_message_enum_from_impls(
277    message_enum_name: &Ident,
278    variants_with_attr: &[(Ident, Type)],
279    service_name: &Ident,
280) -> TokenStream2 {
281    variants_with_attr
282        .iter()
283        .map(|(variant_name, inner_type)| {
284            quote! {
285                impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
286                    fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
287                        #message_enum_name::#variant_name(value)
288                    }
289                }
290            }
291        })
292        .collect()
293}
294
295/// Generate `RemoteService` impl for message enums.
296///
297/// When `span_propagation` is true, the generated code will create spans for each
298/// request and set their parent from the propagated remote context.
299fn generate_remote_service_impl(
300    message_enum_name: &Ident,
301    proto_enum_name: &Ident,
302    variants_with_attr: &[(Ident, Type)],
303    span_propagation: bool,
304) -> TokenStream2 {
305    let variants = variants_with_attr
306        .iter()
307        .map(|(variant_name, _inner_type)| {
308            let span_name = variant_name.to_string();
309
310            if span_propagation {
311                // When span_propagation is enabled, create spans and set parent from remote context
312                quote! {
313                    #proto_enum_name::#variant_name(msg) => {
314                        // Create a span for this specific RPC operation
315                        let span = ::tracing::info_span!(#span_name);
316                        // Set its parent to the propagated remote context if available
317                        ::irpc::span_propagation::set_span_parent_from_remote(&span);
318                        let _guard = span.enter();
319                        #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
320                    }
321                }
322            } else {
323                quote! {
324                    #proto_enum_name::#variant_name(msg) => {
325                        #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx)))
326                    }
327                }
328            }
329        });
330
331    quote! {
332        impl ::irpc::rpc::RemoteService for #proto_enum_name {
333            fn with_remote_channels(
334                self,
335                rx: ::irpc::rpc::noq::RecvStream,
336                tx: ::irpc::rpc::noq::SendStream
337            ) -> Self::Message {
338                match self {
339                    #(#variants),*
340                }
341            }
342        }
343    }
344}
345
346/// Generate type aliases for `WithChannels<T, Service>`
347fn generate_type_aliases(
348    variants: &[(Ident, Type)],
349    service_name: &Ident,
350    suffix: &str,
351) -> TokenStream2 {
352    variants
353        .iter()
354        .map(|(variant_name, inner_type)| {
355            // Create a type name using the variant name + suffix
356            // For example: Sum + "Msg" = SumMsg
357            let type_name = format!("{variant_name}{suffix}");
358            let type_ident = Ident::new(&type_name, variant_name.span());
359            quote! {
360                /// Type alias for WithChannels<#inner_type, #service_name>
361                pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
362            }
363        })
364        .collect()
365}
366
367// Parse arguments for the macro
368#[derive(Default)]
369struct MacroArgs {
370    message_enum_name: Option<Ident>,
371    alias_suffix: Option<String>,
372    rpc_feature: Option<String>,
373    no_rpc: bool,
374    no_spans: bool,
375    /// When true, includes span context in the wire format and enables span propagation.
376    span_propagation: bool,
377}
378
379impl Parse for MacroArgs {
380    fn parse(input: ParseStream) -> syn::Result<Self> {
381        let mut this = Self::default();
382        loop {
383            let arg: Ident = input.parse()?;
384            match arg.to_string().as_str() {
385                "message" => {
386                    input.parse::<Token![=]>()?;
387                    let value: Ident = input.parse()?;
388                    this.message_enum_name = Some(value);
389                }
390                "alias" => {
391                    input.parse::<Token![=]>()?;
392                    let value: LitStr = input.parse()?;
393                    this.alias_suffix = Some(value.value());
394                }
395                "rpc_feature" => {
396                    input.parse::<Token![=]>()?;
397                    if this.no_rpc {
398                        return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
399                    }
400                    let value: LitStr = input.parse()?;
401                    this.rpc_feature = Some(value.value());
402                }
403                "no_rpc" => {
404                    if this.rpc_feature.is_some() {
405                        return syn_err(arg.span(), "rpc_feature is incompatible with no_rpc");
406                    }
407                    this.no_rpc = true;
408                }
409                "no_spans" => {
410                    this.no_spans = true;
411                }
412                "span_propagation" => {
413                    this.span_propagation = true;
414                }
415                _ => {
416                    return syn_err(arg.span(), format!("Unknown parameter: {arg}"));
417                }
418            }
419
420            if input.peek(Token![,]) {
421                input.parse::<Token![,]>()?;
422            } else {
423                break;
424            }
425        }
426
427        Ok(this)
428    }
429}
430
431#[derive(Default)]
432struct VariantRpcArgs {
433    wrap: Option<WrapArgs>,
434    rpc: Option<RpcArgs>,
435}
436
437impl VariantRpcArgs {
438    fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
439        let mut this = Self::default();
440        let mut remaining_attrs = Vec::new();
441        for attr in attrs.drain(..) {
442            let ident = attr.path().get_ident().map(|ident| ident.to_string());
443            match ident.as_deref() {
444                Some(RPC_ATTR_NAME) => {
445                    if this.rpc.is_some() {
446                        syn_err(attr.span(), "Each variant can have only one rpc attribute")?;
447                    }
448                    this.rpc = Some(attr.parse_args()?);
449                }
450                Some(WRAP_ATTR_NAME) => {
451                    if this.wrap.is_some() {
452                        syn_err(attr.span(), "Each variant can have only one wrap attribute")?;
453                    }
454                    this.wrap = Some(attr.parse_args()?);
455                }
456                _ => remaining_attrs.push(attr),
457            }
458        }
459        *attrs = remaining_attrs;
460        Ok(this)
461    }
462}
463
464#[derive(Default)]
465struct RpcArgs {
466    rx: Option<Type>,
467    tx: Option<Type>,
468}
469
470/// Parse the rpc args as a comma separated list of name=type pairs
471impl Parse for RpcArgs {
472    fn parse(input: ParseStream) -> syn::Result<Self> {
473        let mut this = Self::default();
474        while !input.is_empty() {
475            let arg: Ident = input.parse()?;
476            let _: Token![=] = input.parse()?;
477            let value: Type = input.parse()?;
478            if arg == RX_ATTR {
479                this.rx = Some(value);
480            } else if arg == TX_ATTR {
481                this.tx = Some(value);
482            } else {
483                syn_err(arg.span(), "Unexpected argument in rpc attribute")?;
484            }
485            if !input.peek(Token![,]) {
486                break;
487            } else {
488                let _: Token![,] = input.parse()?;
489            }
490        }
491
492        Ok(this)
493    }
494}
495
496struct WrapArgs {
497    vis: Option<Visibility>,
498    ident: Ident,
499    derive: Vec<Type>,
500}
501
502impl Parse for WrapArgs {
503    fn parse(input: ParseStream) -> syn::Result<Self> {
504        let vis = match input.parse::<Visibility>()? {
505            Visibility::Inherited => None,
506            vis => Some(vis),
507        };
508        let ident: Ident = input.parse()?;
509        let mut this = Self {
510            ident,
511            derive: Default::default(),
512            vis,
513        };
514        while input.peek(Token![,]) {
515            let _: Token![,] = input.parse()?;
516            let arg: Ident = input.parse()?;
517            match arg.to_string().as_str() {
518                "derive" => {
519                    let content;
520                    syn::parenthesized!(content in input);
521                    let types: Punctuated<Type, Comma> = Punctuated::parse_terminated(&content)?;
522                    this.derive = types.into_iter().collect();
523                }
524                _ => syn_err(arg.span(), "Unexpected argument in wrap argument")?,
525            }
526        }
527        if !input.is_empty() {
528            syn_err(input.span(), "Unexpected tokens in wrap argument")?;
529        }
530        Ok(this)
531    }
532}
533
534fn type_from_ident(ident: &Ident) -> Type {
535    Type::Path(syn::TypePath {
536        qself: None,
537        path: syn::Path {
538            leading_colon: None,
539            segments: Punctuated::from_iter([syn::PathSegment::from(ident.clone())]),
540        },
541    })
542}
543
544fn struct_from_variant_fields(
545    ident: Ident,
546    mut fields: Fields,
547    attrs: Vec<Attribute>,
548    vis: Visibility,
549) -> syn::ItemStruct {
550    set_fields_vis(&mut fields, &vis);
551    let span = ident.span();
552    syn::ItemStruct {
553        attrs,
554        vis,
555        struct_token: Token![struct](span),
556        ident,
557        generics: Default::default(),
558        semi_token: match &fields {
559            Fields::Unit => Some(Token![;](span)),
560            Fields::Unnamed(_) => Some(Token![;](span)),
561            Fields::Named(_) => None,
562        },
563        fields,
564    }
565}
566
567fn single_unnamed_field(ty: Type) -> Fields {
568    let field = syn::Field {
569        attrs: vec![],
570        vis: Visibility::Inherited,
571        ident: None,
572        colon_token: None,
573        mutability: syn::FieldMutability::None,
574        ty,
575    };
576    Fields::Unnamed(syn::FieldsUnnamed {
577        paren_token: syn::token::Paren(Span::call_site()),
578        unnamed: Punctuated::from_iter([field]),
579    })
580}
581
582fn set_fields_vis(fields: &mut Fields, vis: &Visibility) {
583    let inner = match fields {
584        Fields::Named(named) => named.named.iter_mut(),
585        Fields::Unnamed(unnamed) => unnamed.unnamed.iter_mut(),
586        Fields::Unit => return,
587    };
588    for field in inner {
589        field.vis = vis.clone();
590    }
591}
592
593// Helper function for error reporting
594fn error_tokens(span: Span, message: &str) -> TokenStream {
595    Error::new(span, message).to_compile_error().into()
596}
597
598fn syn_err<T>(span: Span, message: impl std::fmt::Display) -> syn::Result<T> {
599    Err(Error::new(span, message))
600}