irpc_derive/
lib.rs

1use std::collections::{BTreeMap, HashSet};
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    spanned::Spanned,
10    Data, DeriveInput, Fields, Ident, LitStr, Token, Type,
11};
12
13// Helper function for error reporting
14fn error_tokens(span: Span, message: &str) -> TokenStream {
15    syn::Error::new(span, message).to_compile_error().into()
16}
17
18/// The only attribute we care about
19const ATTR_NAME: &str = "rpc";
20/// the tx type name
21const TX_ATTR: &str = "tx";
22/// the rx type name
23const RX_ATTR: &str = "rx";
24/// Fully qualified path to the default rx type
25const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver";
26
27/// Generate parent span method for an enum
28fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> TokenStream2 {
29    quote! {
30        impl #enum_name {
31            /// Get the parent span of the message
32            pub fn parent_span(&self) -> tracing::Span {
33                let span = match self {
34                    #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),*
35                };
36                span.cloned().unwrap_or_else(|| ::tracing::Span::current())
37            }
38        }
39    }
40}
41
42fn generate_channels_impl(
43    mut args: NamedTypeArgs,
44    service_name: &Ident,
45    request_type: &Type,
46    attr_span: Span,
47) -> syn::Result<TokenStream2> {
48    // Try to get rx, default to NoReceiver if not present
49    // Use unwrap_or_else for a cleaner default
50    let rx = args.types.remove(RX_ATTR).unwrap_or_else(|| {
51        // We can safely unwrap here because this is a known valid type
52        syn::parse_str::<Type>(DEFAULT_RX_TYPE).expect("Failed to parse default rx type")
53    });
54    let tx = args.get(TX_ATTR, attr_span)?;
55
56    let res = quote! {
57        impl ::irpc::Channels<#service_name> for #request_type {
58            type Tx = #tx;
59            type Rx = #rx;
60        }
61    };
62
63    args.check_empty(attr_span)?;
64    Ok(res)
65}
66
67/// Generates From implementations for cases with rpc attributes
68fn generate_case_from_impls(
69    enum_name: &Ident,
70    variants_with_attr: &[(Ident, Type)],
71) -> TokenStream2 {
72    let mut impls = quote! {};
73
74    // Generate From implementations for each case that has an rpc attribute
75    for (variant_name, inner_type) in variants_with_attr {
76        let impl_tokens = quote! {
77            impl From<#inner_type> for #enum_name {
78                fn from(value: #inner_type) -> Self {
79                    #enum_name::#variant_name(value)
80                }
81            }
82        };
83
84        impls = quote! {
85            #impls
86            #impl_tokens
87        };
88    }
89
90    impls
91}
92
93/// Generate From implementations for message enum variants
94fn generate_message_enum_from_impls(
95    message_enum_name: &Ident,
96    variants_with_attr: &[(Ident, Type)],
97    service_name: &Ident,
98) -> TokenStream2 {
99    let mut impls = quote! {};
100
101    // Generate From<WithChannels<T, Service>> implementations for each case with an rpc attribute
102    for (variant_name, inner_type) in variants_with_attr {
103        let impl_tokens = quote! {
104            impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name {
105                fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self {
106                    #message_enum_name::#variant_name(value)
107                }
108            }
109        };
110
111        impls = quote! {
112            #impls
113            #impl_tokens
114        };
115    }
116
117    impls
118}
119
120/// Generate type aliases for WithChannels<T, Service>
121fn generate_type_aliases(
122    variants: &[(Ident, Type)],
123    service_name: &Ident,
124    suffix: &str,
125) -> TokenStream2 {
126    let mut aliases = quote! {};
127
128    for (variant_name, inner_type) in variants {
129        // Create a type name using the variant name + suffix
130        // For example: Sum + "Msg" = SumMsg
131        let type_name = format!("{}{}", variant_name, suffix);
132        let type_ident = Ident::new(&type_name, variant_name.span());
133
134        let alias = quote! {
135            /// Type alias for WithChannels<#inner_type, #service_name>
136            pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>;
137        };
138
139        aliases = quote! {
140            #aliases
141            #alias
142        };
143    }
144
145    aliases
146}
147
148/// Processes an RPC request enum and generates channel implementations.
149///
150/// This macro takes a protocol enum where each variant represents a different RPC request type
151/// and generates the necessary channel implementations for each request.
152///
153/// # Macro Arguments
154///
155/// * First positional argument (required): The service type that will handle these requests
156/// * `message` (optional): Generate an extended enum wrapping each type in `WithChannels<T, Service>`
157/// * `alias` (optional): Generate type aliases with the given suffix for each `WithChannels<T, Service>`
158///
159/// # Variant Attributes
160///
161/// Individual enum variants can be annotated with the `#[rpc(...)]` attribute to specify channel types:
162///
163/// * `#[rpc(tx=SomeType)]`: Specify the transmitter/sender channel type (required)
164/// * `#[rpc(tx=SomeType, rx=OtherType)]`: Also specify a receiver channel type (optional)
165///
166/// If `rx` is not specified, it defaults to `NoReceiver`.
167///
168/// # Examples
169///
170/// Basic usage:
171/// ```
172/// #[rpc_requests(ComputeService)]
173/// enum ComputeProtocol {
174///     #[rpc(tx=oneshot::Sender<u128>)]
175///     Sqr(Sqr),
176///     #[rpc(tx=oneshot::Sender<i64>)]
177///     Sum(Sum),
178/// }
179/// ```
180///
181/// With a message enum:
182/// ```
183/// #[rpc_requests(ComputeService, message = ComputeMessage)]
184/// enum ComputeProtocol {
185///     #[rpc(tx=oneshot::Sender<u128>)]
186///     Sqr(Sqr),
187///     #[rpc(tx=oneshot::Sender<i64>)]
188///     Sum(Sum),
189/// }
190/// ```
191///
192/// With type aliases:
193/// ```
194/// #[rpc_requests(ComputeService, alias = "Msg")]
195/// enum ComputeProtocol {
196///     #[rpc(tx=oneshot::Sender<u128>)]
197///     Sqr(Sqr), // Generates type SqrMsg = WithChannels<Sqr, ComputeService>
198///     #[rpc(tx=oneshot::Sender<i64>)]
199///     Sum(Sum), // Generates type SumMsg = WithChannels<Sum, ComputeService>
200/// }
201/// ```
202#[proc_macro_attribute]
203pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
204    let mut input = parse_macro_input!(item as DeriveInput);
205    let args = parse_macro_input!(attr as MacroArgs);
206
207    let service_name = args.service_name;
208    let message_enum_name = args.message_enum_name;
209    let alias_suffix = args.alias_suffix;
210
211    let enum_name = &input.ident;
212    let input_span = input.span();
213
214    let data_enum = match &mut input.data {
215        Data::Enum(data_enum) => data_enum,
216        _ => return error_tokens(input.span(), "RpcRequests can only be applied to enums"),
217    };
218
219    // Collect trait implementations
220    let mut channel_impls = Vec::new();
221    // Types to check for uniqueness
222    let mut types = HashSet::new();
223    // All variant names and types
224    let mut all_variants = Vec::new();
225    // Variants with rpc attributes (for From implementations)
226    let mut variants_with_attr = Vec::new();
227
228    for variant in &mut data_enum.variants {
229        // Check field structure for every variant
230        let request_type = match &variant.fields {
231            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed[0].ty,
232            _ => {
233                return error_tokens(
234                    variant.span(),
235                    "Each variant must have exactly one unnamed field",
236                )
237            }
238        };
239        all_variants.push((variant.ident.clone(), request_type.clone()));
240
241        if !types.insert(request_type.to_token_stream().to_string()) {
242            return error_tokens(input_span, "Each variant must have a unique request type");
243        }
244
245        // Find and remove the rpc attribute
246        let mut rpc_attr = None;
247        let mut multiple_rpc_attrs = false;
248
249        variant.attrs.retain(|attr| {
250            if attr.path.is_ident(ATTR_NAME) {
251                if rpc_attr.is_some() {
252                    multiple_rpc_attrs = true;
253                    true // Keep this duplicate attribute
254                } else {
255                    rpc_attr = Some(attr.clone());
256                    false // Remove this attribute
257                }
258            } else {
259                true // Keep other attributes
260            }
261        });
262
263        // Check for multiple rpc attributes
264        if multiple_rpc_attrs {
265            return error_tokens(
266                variant.span(),
267                "Each variant can only have one rpc attribute",
268            );
269        }
270
271        // Process variants with rpc attributes
272        if let Some(attr) = rpc_attr {
273            variants_with_attr.push((variant.ident.clone(), request_type.clone()));
274
275            let args = match attr.parse_args::<NamedTypeArgs>() {
276                Ok(info) => info,
277                Err(e) => return e.to_compile_error().into(),
278            };
279
280            match generate_channels_impl(args, &service_name, request_type, attr.span()) {
281                Ok(impls) => channel_impls.push(impls),
282                Err(e) => return e.to_compile_error().into(),
283            }
284        }
285    }
286
287    // Generate From implementations for the original enum (only for variants with rpc attributes)
288    let original_from_impls = generate_case_from_impls(enum_name, &variants_with_attr);
289
290    // Generate type aliases if requested
291    let type_aliases = if let Some(suffix) = alias_suffix {
292        // Use all variants for type aliases, not just those with rpc attributes
293        generate_type_aliases(&all_variants, &service_name, &suffix)
294    } else {
295        quote! {}
296    };
297
298    // Generate the extended message enum if requested
299    let extended_enum_code = if let Some(message_enum_name) = message_enum_name {
300        let message_variants = all_variants
301            .iter()
302            .map(|(variant_name, inner_type)| {
303                quote! {
304                    #[allow(missing_docs)]
305                    #variant_name(::irpc::WithChannels<#inner_type, #service_name>)
306                }
307            })
308            .collect::<Vec<_>>();
309
310        // Extract variant names for the parent_span implementation
311        let variant_names: Vec<&Ident> = all_variants.iter().map(|(name, _)| name).collect();
312
313        // Create the message enum definition
314        let message_enum = quote! {
315            #[allow(missing_docs)]
316            #[derive(Debug)]
317            pub enum #message_enum_name {
318                #(#message_variants),*
319            }
320        };
321
322        // Generate parent_span method
323        let parent_span_impl = generate_parent_span_impl(&message_enum_name, &variant_names);
324
325        // Generate From implementations for the message enum (only for variants with rpc attributes)
326        let message_from_impls = generate_message_enum_from_impls(
327            &message_enum_name,
328            &variants_with_attr,
329            &service_name,
330        );
331
332        quote! {
333            #message_enum
334            #parent_span_impl
335            #message_from_impls
336        }
337    } else {
338        // If no message_enum_name is provided, don't generate the extended enum
339        quote! {}
340    };
341
342    // Combine everything
343    let output = quote! {
344        #input
345
346        // Channel implementations
347        #(#channel_impls)*
348
349        // From implementations for the original enum
350        #original_from_impls
351
352        // Type aliases for WithChannels
353        #type_aliases
354
355        // Extended enum and its implementations
356        #extended_enum_code
357    };
358
359    output.into()
360}
361
362// Parse arguments for the macro
363struct MacroArgs {
364    service_name: Ident,
365    message_enum_name: Option<Ident>,
366    alias_suffix: Option<String>,
367}
368
369impl Parse for MacroArgs {
370    fn parse(input: ParseStream) -> syn::Result<Self> {
371        // First argument must be the service name (positional)
372        let service_name: Ident = input.parse()?;
373
374        // Initialize optional parameters
375        let mut message_enum_name = None;
376        let mut alias_suffix = None;
377
378        // Parse any additional named parameters
379        while input.peek(Token![,]) {
380            input.parse::<Token![,]>()?;
381            let param_name: Ident = input.parse()?;
382            input.parse::<Token![=]>()?;
383
384            match param_name.to_string().as_str() {
385                "message" => {
386                    message_enum_name = Some(input.parse()?);
387                }
388                "alias" => {
389                    let lit: LitStr = input.parse()?;
390                    alias_suffix = Some(lit.value());
391                }
392                _ => {
393                    return Err(syn::Error::new(
394                        param_name.span(),
395                        format!("Unknown parameter: {}", param_name),
396                    ));
397                }
398            }
399        }
400
401        Ok(MacroArgs {
402            service_name,
403            message_enum_name,
404            alias_suffix,
405        })
406    }
407}
408
409struct NamedTypeArgs {
410    types: BTreeMap<String, Type>,
411}
412
413impl NamedTypeArgs {
414    /// Get and remove a type from the map, failing if it doesn't exist
415    fn get(&mut self, key: &str, span: Span) -> syn::Result<Type> {
416        self.types
417            .remove(key)
418            .ok_or_else(|| syn::Error::new(span, format!("rpc requires a {key} type")))
419    }
420
421    /// Fail if there are any unknown arguments remaining
422    fn check_empty(&self, span: Span) -> syn::Result<()> {
423        if self.types.is_empty() {
424            Ok(())
425        } else {
426            Err(syn::Error::new(
427                span,
428                format!(
429                    "Unknown arguments provided: {:?}",
430                    self.types.keys().collect::<Vec<_>>()
431                ),
432            ))
433        }
434    }
435}
436
437/// Parse the rpc args as a comma separated list of name=type pairs
438impl Parse for NamedTypeArgs {
439    fn parse(input: ParseStream) -> syn::Result<Self> {
440        let mut types = BTreeMap::new();
441
442        loop {
443            if input.is_empty() {
444                break;
445            }
446
447            let key: Ident = input.parse()?;
448            let _: Token![=] = input.parse()?;
449            let value: Type = input.parse()?;
450
451            types.insert(key.to_string(), value);
452
453            if !input.peek(Token![,]) {
454                break;
455            }
456            let _: Token![,] = input.parse()?;
457        }
458
459        Ok(NamedTypeArgs { types })
460    }
461}