irpc_derive/
lib.rs

1use std::collections::{BTreeMap, HashSet};
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    spanned::Spanned,
10    Data, DeriveInput, Fields, Ident, LitStr, Token, Type,
11};
12
13// Helper function for error reporting
14fn error_tokens(span: Span, message: &str) -> TokenStream {
15    syn::Error::new(span, message).to_compile_error().into()
16}
17
18/// The only attribute we care about
19const ATTR_NAME: &str = "rpc";
20/// the tx type name
21const TX_ATTR: &str = "tx";
22/// the rx type name
23const RX_ATTR: &str = "rx";
24/// Fully qualified path to the default rx type
25const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
26
27/// Generate parent span method for an enum
28fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
29    quote! {
30        impl #enum_name {
31            /// Get the parent span of the message
32            pub fn parent_span(&self) -> tracing::Span {
33                let span = match self {
34                    #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
35                };
36                span.cloned().unwrap_or_else(|| ::tracing::Span::current())
37            }
38        }
39    }
40}
41
42fn generate_channels_impl(
43    mut args: NamedTypeArgs,
44    service_name: &Ident,
45    request_type: &Type,
46    attr_span: Span,
47) -> syn::Result<TokenStream2> {
48    // Try to get rx, default to NoReceiver if not present
49    // Use unwrap_or_else for a cleaner default
50    let rx = args.types.remove(RX_ATTR).unwrap_or_else(|| {
51        // We can safely unwrap here because this is a known valid type
52        syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
53    });
54    let tx = args.get(TX_ATTR, attr_span)?;
55
56    let res = quote! {
57        impl ::irpc::Channels<#service_name> for #request_type {
58            type Tx = #tx;
59            type Rx = #rx;
60        }
61    };
62
63    args.check_empty(attr_span)?;
64    Ok(res)
65}
66
67/// Generates From implementations for cases with rpc attributes
68fn generate_case_from_impls(
69    enum_name: &Ident,
70    variants_with_attr: &[(Ident, Type)],
71) -> TokenStream2 {
72    let mut impls = quote! {};
73
74    // Generate From implementations for each case that has an rpc attribute
75    for (variant_name, inner_type) in variants_with_attr {
76        let impl_tokens = quote! {
77            impl From<#inner_type> for #enum_name {
78                fn from(value: #inner_type) -> Self {
79                    #enum_name::#variant_name(value)
80                }
81            }
82        };
83
84        impls = quote! {
85            #impls
86            #impl_tokens
87        };
88    }
89
90    impls
91}
92
93/// Generate From implementations for message enum variants
94fn generate_message_enum_from_impls(
95    message_enum_name: &Ident,
96    variants_with_attr: &[(Ident, Type)],
97    service_name: &Ident,
98) -> TokenStream2 {
99    let mut impls = quote! {};
100
101    // Generate From<WithChannels<T, Service>> implementations for each case with an rpc attribute
102    for (variant_name, inner_type) in variants_with_attr {
103        let impl_tokens = quote! {
104            impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
105                fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
106                    #message_enum_name::#variant_name(value)
107                }
108            }
109        };
110
111        impls = quote! {
112            #impls
113            #impl_tokens
114        };
115    }
116
117    impls
118}
119
120/// Generate type aliases for WithChannels<T, Service>
121fn generate_type_aliases(
122    variants: &[(Ident, Type)],
123    service_name: &Ident,
124    suffix: &str,
125) -> TokenStream2 {
126    let mut aliases = quote! {};
127
128    for (variant_name, inner_type) in variants {
129        // Create a type name using the variant name + suffix
130        // For example: Sum + "Msg" = SumMsg
131        let type_name = format!("{}{}", variant_name, suffix);
132        let type_ident = Ident::new(&type_name, variant_name.span());
133
134        let alias = quote! {
135            /// Type alias for WithChannels<#inner_type, #service_name>
136            pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
137        };
138
139        aliases = quote! {
140            #aliases
141            #alias
142        };
143    }
144
145    aliases
146}
147
148/// Processes an RPC request enum and generates channel implementations.
149///
150/// This macro takes a protocol enum where each variant represents a different RPC request type
151/// and generates the necessary channel implementations for each request.
152///
153/// # Macro Arguments
154///
155/// * First positional argument (required): The service type that will handle these requests
156/// * `message` (optional): Generate an extended enum wrapping each type in `WithChannels<T, Service>`
157/// * `alias` (optional): Generate type aliases with the given suffix for each `WithChannels<T, Service>`
158///
159/// # Variant Attributes
160///
161/// Individual enum variants can be annotated with the `#[rpc(...)]` attribute to specify channel types:
162///
163/// * `#[rpc(tx=SomeType)]`: Specify the transmitter/sender channel type (required)
164/// * `#[rpc(tx=SomeType, rx=OtherType)]`: Also specify a receiver channel type (optional)
165///
166/// If `rx` is not specified, it defaults to `NoReceiver`.
167///
168/// # Examples
169///
170/// Basic usage:
171/// ```
172/// #[rpc_requests(ComputeService)]
173/// enum ComputeProtocol {
174///     #[rpc(tx=oneshot::Sender<u128>)]
175///     Sqr(Sqr),
176///     #[rpc(tx=oneshot::Sender<i64>)]
177///     Sum(Sum),
178/// }
179/// ```
180///
181/// With a message enum:
182/// ```
183/// #[rpc_requests(ComputeService, message = ComputeMessage)]
184/// enum ComputeProtocol {
185///     #[rpc(tx=oneshot::Sender<u128>)]
186///     Sqr(Sqr),
187///     #[rpc(tx=oneshot::Sender<i64>)]
188///     Sum(Sum),
189/// }
190/// ```
191///
192/// With type aliases:
193/// ```
194/// #[rpc_requests(ComputeService, alias = "Msg")]
195/// enum ComputeProtocol {
196///     #[rpc(tx=oneshot::Sender<u128>)]
197///     Sqr(Sqr), // Generates type SqrMsg = WithChannels<Sqr, ComputeService>
198///     #[rpc(tx=oneshot::Sender<i64>)]
199///     Sum(Sum), // Generates type SumMsg = WithChannels<Sum, ComputeService>
200/// }
201/// ```
202#[proc_macro_attribute]
203pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
204    let mut input = parse_macro_input!(item as DeriveInput);
205    let args = parse_macro_input!(attr as MacroArgs);
206
207    let service_name = args.service_name;
208    let message_enum_name = args.message_enum_name;
209    let alias_suffix = args.alias_suffix;
210
211    let enum_name = &input.ident;
212    let input_span = input.span();
213
214    let data_enum = match &mut input.data {
215        Data::Enum(data_enum) => data_enum,
216        _ => return error_tokens(input.span(), "RpcRequests can only be applied to enums"),
217    };
218
219    // Collect trait implementations
220    let mut channel_impls = Vec::new();
221    // Types to check for uniqueness
222    let mut types = HashSet::new();
223    // All variant names and types
224    let mut all_variants = Vec::new();
225    // Variants with rpc attributes (for From implementations)
226    let mut variants_with_attr = Vec::new();
227
228    for variant in &mut data_enum.variants {
229        // Check field structure for every variant
230        let request_type = match &variant.fields {
231            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed[0].ty,
232            _ => {
233                return error_tokens(
234                    variant.span(),
235                    "Each variant must have exactly one unnamed field",
236                )
237            }
238        };
239        all_variants.push((variant.ident.clone(), request_type.clone()));
240
241        if !types.insert(request_type.to_token_stream().to_string()) {
242            return error_tokens(input_span, "Each variant must have a unique request type");
243        }
244
245        // Find and remove the rpc attribute
246        let mut rpc_attr = None;
247        let mut multiple_rpc_attrs = false;
248
249        variant.attrs.retain(|attr| {
250            if attr.path.is_ident(ATTR_NAME) {
251                if rpc_attr.is_some() {
252                    multiple_rpc_attrs = true;
253                    true // Keep this duplicate attribute
254                } else {
255                    rpc_attr = Some(attr.clone());
256                    false // Remove this attribute
257                }
258            } else {
259                true // Keep other attributes
260            }
261        });
262
263        // Check for multiple rpc attributes
264        if multiple_rpc_attrs {
265            return error_tokens(
266                variant.span(),
267                "Each variant can only have one rpc attribute",
268            );
269        }
270
271        // Process variants with rpc attributes
272        if let Some(attr) = rpc_attr {
273            variants_with_attr.push((variant.ident.clone(), request_type.clone()));
274
275            let args = match attr.parse_args::<NamedTypeArgs>() {
276                Ok(info) => info,
277                Err(e) => return e.to_compile_error().into(),
278            };
279
280            match generate_channels_impl(args, &service_name, request_type, attr.span()) {
281                Ok(impls) => channel_impls.push(impls),
282                Err(e) => return e.to_compile_error().into(),
283            }
284        }
285    }
286
287    // Generate From implementations for the original enum (only for variants with rpc attributes)
288    let original_from_impls = generate_case_from_impls(enum_name, &variants_with_attr);
289
290    // Generate type aliases if requested
291    let type_aliases = if let Some(suffix) = alias_suffix {
292        // Use all variants for type aliases, not just those with rpc attributes
293        generate_type_aliases(&all_variants, &service_name, &suffix)
294    } else {
295        quote! {}
296    };
297
298    // Generate the extended message enum if requested
299    let extended_enum_code = if let Some(message_enum_name) = message_enum_name {
300        let message_variants = all_variants
301            .iter()
302            .map(|(variant_name, inner_type)| {
303                quote! {
304                    #variant_name(::irpc::WithChannels<#inner_type, #service_name>)
305                }
306            })
307            .collect::<Vec<_>>();
308
309        // Extract variant names for the parent_span implementation
310        let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
311
312        // Create the message enum definition
313        let message_enum = quote! {
314            #[derive(Debug)]
315            pub enum #message_enum_name {
316                #(#message_variants),*
317            }
318        };
319
320        // Generate parent_span method
321        let parent_span_impl = generate_parent_span_impl(&message_enum_name, &variant_names);
322
323        // Generate From implementations for the message enum (only for variants with rpc attributes)
324        let message_from_impls = generate_message_enum_from_impls(
325            &message_enum_name,
326            &variants_with_attr,
327            &service_name,
328        );
329
330        quote! {
331            #message_enum
332            #parent_span_impl
333            #message_from_impls
334        }
335    } else {
336        // If no message_enum_name is provided, don't generate the extended enum
337        quote! {}
338    };
339
340    // Combine everything
341    let output = quote! {
342        #input
343
344        // Channel implementations
345        #(#channel_impls)*
346
347        // From implementations for the original enum
348        #original_from_impls
349
350        // Type aliases for WithChannels
351        #type_aliases
352
353        // Extended enum and its implementations
354        #extended_enum_code
355    };
356
357    output.into()
358}
359
360// Parse arguments for the macro
361struct MacroArgs {
362    service_name: Ident,
363    message_enum_name: Option<Ident>,
364    alias_suffix: Option<String>,
365}
366
367impl Parse for MacroArgs {
368    fn parse(input: ParseStream) -> syn::Result<Self> {
369        // First argument must be the service name (positional)
370        let service_name: Ident = input.parse()?;
371
372        // Initialize optional parameters
373        let mut message_enum_name = None;
374        let mut alias_suffix = None;
375
376        // Parse any additional named parameters
377        while input.peek(Token![,]) {
378            input.parse::<Token![,]>()?;
379            let param_name: Ident = input.parse()?;
380            input.parse::<Token![=]>()?;
381
382            match param_name.to_string().as_str() {
383                "message" => {
384                    message_enum_name = Some(input.parse()?);
385                }
386                "alias" => {
387                    let lit: LitStr = input.parse()?;
388                    alias_suffix = Some(lit.value());
389                }
390                _ => {
391                    return Err(syn::Error::new(
392                        param_name.span(),
393                        format!("Unknown parameter: {}", param_name),
394                    ));
395                }
396            }
397        }
398
399        Ok(MacroArgs {
400            service_name,
401            message_enum_name,
402            alias_suffix,
403        })
404    }
405}
406
407struct NamedTypeArgs {
408    types: BTreeMap<String, Type>,
409}
410
411impl NamedTypeArgs {
412    /// Get and remove a type from the map, failing if it doesn't exist
413    fn get(&mut self, key: &str, span: Span) -> syn::Result<Type> {
414        self.types
415            .remove(key)
416            .ok_or_else(|| syn::Error::new(span, format!("rpc requires a {key} type")))
417    }
418
419    /// Fail if there are any unknown arguments remaining
420    fn check_empty(&self, span: Span) -> syn::Result<()> {
421        if self.types.is_empty() {
422            Ok(())
423        } else {
424            Err(syn::Error::new(
425                span,
426                format!(
427                    "Unknown arguments provided: {:?}",
428                    self.types.keys().collect::<Vec<_>>()
429                ),
430            ))
431        }
432    }
433}
434
435/// Parse the rpc args as a comma separated list of name=type pairs
436impl Parse for NamedTypeArgs {
437    fn parse(input: ParseStream) -> syn::Result<Self> {
438        let mut types = BTreeMap::new();
439
440        loop {
441            if input.is_empty() {
442                break;
443            }
444
445            let key: Ident = input.parse()?;
446            let _: Token![=] = input.parse()?;
447            let value: Type = input.parse()?;
448
449            types.insert(key.to_string(), value);
450
451            if !input.peek(Token![,]) {
452                break;
453            }
454            let _: Token![,] = input.parse()?;
455        }
456
457        Ok(NamedTypeArgs { types })
458    }
459}