enum_group_macros_impl/
lib.rs

1//! enum-group-macros-impl - Procedural macros for enum grouping
2//!
3//! This is the proc-macro companion crate for `enum-group-macros`.
4//! You should depend on `enum-group-macros` instead of this crate directly.
5//!
6//! See the `enum-group-macros` crate for documentation.
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{format_ident, quote};
11use syn::parse::{Parse, ParseStream};
12use syn::{braced, parse_macro_input, Attribute, Ident, Token, Type, Visibility};
13
14// =============================================================================
15// Custom Syntax Parser
16// =============================================================================
17
18/// Parsed representation of a single variant within a group
19#[derive(Debug)]
20struct ParsedVariant {
21  attrs: Vec<Attribute>,
22  name: Ident,
23  ty: Type,
24}
25
26/// Parsed representation of a group (e.g., `SupportMessage { ... }`)
27#[derive(Debug)]
28struct ParsedGroup {
29  name: Ident,
30  variants: Vec<ParsedVariant>,
31}
32
33/// Parsed input for `define_enum_group!`
34#[derive(Debug)]
35struct EnumGroupInput {
36  attrs: Vec<Attribute>,
37  vis: Visibility,
38  name: Ident,
39  groups: Vec<ParsedGroup>,
40}
41
42impl Parse for ParsedVariant {
43  fn parse(input: ParseStream) -> syn::Result<Self> {
44    let attrs = input.call(Attribute::parse_outer)?;
45    let name: Ident = input.parse()?;
46
47    // Parse (Type)
48    let content;
49    syn::parenthesized!(content in input);
50    let ty: Type = content.parse()?;
51
52    Ok(ParsedVariant { attrs, name, ty })
53  }
54}
55
56impl Parse for ParsedGroup {
57  fn parse(input: ParseStream) -> syn::Result<Self> {
58    let name: Ident = input.parse()?;
59
60    let content;
61    braced!(content in input);
62
63    let mut variants = Vec::new();
64    while !content.is_empty() {
65      variants.push(content.parse::<ParsedVariant>()?);
66      // Optional trailing comma
67      if content.peek(Token![,]) {
68        content.parse::<Token![,]>()?;
69      }
70    }
71
72    Ok(ParsedGroup { name, variants })
73  }
74}
75
76impl Parse for EnumGroupInput {
77  fn parse(input: ParseStream) -> syn::Result<Self> {
78    // Parse outer attributes (like #[derive(...)])
79    let attrs = input.call(Attribute::parse_outer)?;
80
81    // Parse visibility and enum keyword
82    let vis: Visibility = input.parse()?;
83    input.parse::<Token![enum]>()?;
84    let name: Ident = input.parse()?;
85
86    // Parse the groups inside braces
87    let content;
88    braced!(content in input);
89
90    let mut groups = Vec::new();
91    while !content.is_empty() {
92      groups.push(content.parse::<ParsedGroup>()?);
93      // Handle optional comma between groups
94      if content.peek(Token![,]) {
95        content.parse::<Token![,]>()?;
96      }
97    }
98
99    Ok(EnumGroupInput { attrs, vis, name, groups })
100  }
101}
102
103// =============================================================================
104// Code Generator
105// =============================================================================
106
107fn generate_enum_group(input: EnumGroupInput) -> TokenStream2 {
108  let EnumGroupInput { attrs, vis, name: wire_name, groups } = input;
109
110  let group_enum_name = format_ident!("{}Group", wire_name);
111
112  // Collect all variants for the flat wire enum
113  let mut all_variants = Vec::new();
114  let mut group_enum_variants = Vec::new();
115  let mut into_group_arms = Vec::new();
116
117  // Generate group enums and collect info
118  let group_enums: Vec<TokenStream2> = groups
119    .iter()
120    .map(|group| {
121      let group_name = &group.name;
122
123      // Variants for this group enum
124      let variants: Vec<TokenStream2> = group
125        .variants
126        .iter()
127        .map(|v| {
128          let v_attrs = &v.attrs;
129          let v_name = &v.name;
130          let v_ty = &v.ty;
131          quote! {
132              #(#v_attrs)*
133              #v_name(#v_ty)
134          }
135        })
136        .collect();
137
138      // Add to all_variants for wire enum
139      for v in &group.variants {
140        let v_attrs = &v.attrs;
141        let v_name = &v.name;
142        let v_ty = &v.ty;
143        all_variants.push(quote! {
144            #(#v_attrs)*
145            #v_name(#v_ty)
146        });
147
148        // Generate into_group arm
149        into_group_arms.push(quote! {
150            Self::#v_name(v) => #group_enum_name::#group_name(#group_name::#v_name(v))
151        });
152      }
153
154      // Add to group enum variants
155      group_enum_variants.push(quote! {
156          #group_name(#group_name)
157      });
158
159      // Generate the group enum
160      quote! {
161          #(#attrs)*
162          #vis enum #group_name {
163              #(#variants),*
164          }
165      }
166    })
167    .collect();
168
169  // Generate the flat wire enum
170  let wire_enum = quote! {
171      #(#attrs)*
172      #vis enum #wire_name {
173          #(#all_variants),*
174      }
175  };
176
177  // Generate the group dispatch enum
178  let group_dispatch_enum = quote! {
179      #[derive(Debug, Clone)]
180      #vis enum #group_enum_name {
181          #(#group_enum_variants),*
182      }
183  };
184
185  // Generate an inherent into_group method (doesn't require trait import)
186  let inherent_impl = quote! {
187      impl #wire_name {
188          /// Convert this enum into its grouped representation.
189          #vis fn into_group(self) -> #group_enum_name {
190              match self {
191                  #(#into_group_arms),*
192              }
193          }
194      }
195  };
196
197  // Generate the EnumGroup trait impl (for users who want trait-based access)
198  let trait_impl = quote! {
199      impl ::enum_group_macros::EnumGroup for #wire_name {
200          type Group = #group_enum_name;
201
202          fn into_group(self) -> Self::Group {
203              // Delegate to inherent method
204              #wire_name::into_group(self)
205          }
206      }
207  };
208
209  // Combine all generated code
210  quote! {
211      #(#group_enums)*
212
213      #wire_enum
214
215      #group_dispatch_enum
216
217      #inherent_impl
218
219      #trait_impl
220  }
221}
222
223// =============================================================================
224// Procedural Macro Entry Point
225// =============================================================================
226
227/// Defines a flat wire enum and multiple specialized categorical enums.
228///
229/// This macro generates:
230/// 1. A set of categorical enums, each containing a subset of variants.
231/// 2. A single flat "wire" enum containing all variants from all groups.
232/// 3. A `Group` enum for dispatch between groups.
233/// 4. An `EnumGroup` trait implementation for converting wire → group.
234///
235/// # Example
236///
237/// ```ignore
238/// use enum_group_macros::define_enum_group;
239/// use serde::{Deserialize, Serialize};
240///
241/// define_enum_group! {
242///     #[derive(Debug, Clone, Serialize, Deserialize)]
243///     #[serde(tag = "type", content = "payload")]
244///     pub enum WireMsg {
245///         Protocol {
246///             A(MsgA),
247///             B(MsgB),
248///         },
249///         Business {
250///             C(MsgC),
251///         }
252///     }
253/// }
254/// ```
255///
256/// This generates:
257/// - `enum Protocol { A(MsgA), B(MsgB) }` - categorical enum
258/// - `enum Business { C(MsgC) }` - categorical enum
259/// - `enum WireMsg { A(MsgA), B(MsgB), C(MsgC) }` - flat wire enum
260/// - `enum WireMsgGroup { Protocol(Protocol), Business(Business) }` - dispatch enum
261/// - `impl EnumGroup for WireMsg` - conversion trait
262#[proc_macro]
263pub fn define_enum_group(input: TokenStream) -> TokenStream {
264  let input = parse_macro_input!(input as EnumGroupInput);
265  generate_enum_group(input).into()
266}
267
268// =============================================================================
269// match_enum_group! Macro
270// =============================================================================
271
272/// Matches on a grouped enum using ergonomic syntax.
273///
274/// This macro allows you to match on the group level without manually calling
275/// `into_group()` or importing the `Group` enum.
276///
277/// # Example
278///
279/// ```ignore
280/// use enum_group_macros::match_enum_group;
281///
282/// match_enum_group!(msg, BrokerToCosignerMessage, {
283///     SupportMessage(s) => {
284///         // s is SupportMessage enum
285///         match s {
286///             SupportMessage::ReportResponse(r) => { /* ... */ }
287///             SupportMessage::HeartbeatResponse(r) => { /* ... */ }
288///         }
289///     },
290///     BusinessMessage(b) => handle_business(b),
291/// })
292/// ```
293#[proc_macro]
294pub fn match_enum_group(input: TokenStream) -> TokenStream {
295  let input2: TokenStream2 = input.into();
296
297  let result = parse_match_enum_group(input2);
298
299  match result {
300    Ok(tokens) => tokens.into(),
301    Err(e) => e.to_compile_error().into(),
302  }
303}
304
305/// Parsed match arm for match_enum_group!
306struct MatchArm {
307  group_name: Ident,
308  binding: proc_macro2::TokenStream,
309  body: TokenStream2,
310}
311
312fn parse_match_enum_group(input: TokenStream2) -> syn::Result<TokenStream2> {
313  use syn::parse::Parser;
314
315  let parser = |input: ParseStream| -> syn::Result<(syn::Expr, Ident, Vec<MatchArm>)> {
316    // Parse value expression
317    let val: syn::Expr = input.parse()?;
318    input.parse::<Token![,]>()?;
319
320    // Parse wire enum type (just the identifier)
321    let wire: Ident = input.parse()?;
322    input.parse::<Token![,]>()?;
323
324    // Parse arms block
325    let content;
326    braced!(content in input);
327
328    let mut arms = Vec::new();
329    while !content.is_empty() {
330      // Parse: GroupName(binding) => body
331      let group_name: Ident = content.parse()?;
332
333      let paren_content;
334      syn::parenthesized!(paren_content in content);
335      // Parse the binding pattern (can be complex like `s` or `_`)
336      let binding: proc_macro2::TokenStream = paren_content.parse()?;
337
338      content.parse::<Token![=>]>()?;
339
340      // Parse the body (could be a block or expression)
341      let body: syn::Expr = content.parse()?;
342
343      arms.push(MatchArm { group_name, binding, body: quote! { #body } });
344
345      // Optional trailing comma
346      if content.peek(Token![,]) {
347        content.parse::<Token![,]>()?;
348      }
349    }
350
351    Ok((val, wire, arms))
352  };
353
354  let (val, wire, arms) = parser.parse2(input)?;
355
356  // Generate match arms using the local type alias
357  let match_arms: Vec<TokenStream2> = arms
358    .iter()
359    .map(|arm| {
360      let group_name = &arm.group_name;
361      let binding = &arm.binding;
362      let body = &arm.body;
363
364      quote! {
365          __EnumGroup__::#group_name(#binding) => #body
366      }
367    })
368    .collect();
369
370  // Generate expansion with local type alias
371  // This avoids requiring users to import the Group type
372  Ok(quote! {
373      {
374          #[allow(non_camel_case_types)]
375          type __EnumGroup__ = <#wire as ::enum_group_macros::EnumGroup>::Group;
376
377          match <#wire as ::enum_group_macros::EnumGroup>::into_group(#val) {
378              #(#match_arms),*
379          }
380      }
381  })
382}