Skip to main content

unit_enum/
lib.rs

1#![doc = include_str!("lib.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Error, Expr, Fields, Type, Variant};
6
7/// Derives the `UnitEnum` trait for an enum.
8///
9/// This macro can be used on enums with unit variants (no fields) and optionally one "other" variant
10/// that can hold arbitrary discriminant values.
11///
12/// # Attributes
13/// - `#[repr(type)]`: Optional for regular enums, defaults to i32. Required when using an "other" variant.
14/// - `#[unit_enum(other)]`: Marks a variant as the catch-all for undefined discriminant values.
15///   The type of this variant must match the repr type.
16///
17/// # Requirements
18/// - The enum must contain only unit variants, except for one optional "other" variant
19/// - The "other" variant, if present, must:
20///   - Be marked with `#[unit_enum(other)]`
21///   - Have exactly one unnamed field matching the repr type
22///   - Be the only variant with the "other" attribute
23///   - Have a matching `#[repr(type)]` attribute
24///
25/// # Examples
26///
27/// Basic usage with unit variants (repr is optional):
28/// ```rust
29/// # use unit_enum::UnitEnum;
30/// #[derive(UnitEnum)]
31/// enum Example {
32///     A,
33///     B = 10,
34///     C,
35/// }
36/// ```
37///
38/// Usage with explicit repr:
39/// ```rust
40/// # use unit_enum::UnitEnum;
41/// #[derive(UnitEnum)]
42/// #[repr(u16)]
43/// enum Color {
44///     Red = 10,
45///     Green,
46///     Blue = 45654,
47/// }
48/// ```
49///
50/// Usage with an "other" variant (repr required):
51/// ```rust
52/// # use unit_enum::UnitEnum;
53/// #[derive(UnitEnum)]
54/// #[repr(u16)]
55/// enum Status {
56///     Active = 1,
57///     Inactive = 2,
58///     #[unit_enum(other)]
59///     Unknown(u16),  // type must match repr
60/// }
61/// ```
62#[proc_macro_derive(UnitEnum, attributes(unit_enum))]
63pub fn unit_enum_derive(input: TokenStream) -> TokenStream {
64    let ast = parse_macro_input!(input as DeriveInput);
65
66    match validate_and_process(&ast) {
67        Ok((discriminant_type, unit_variants, other_variant)) => {
68            impl_unit_enum(&ast, &discriminant_type, &unit_variants, other_variant)
69        }
70        Err(e) => e.to_compile_error().into(),
71    }
72}
73
74struct ValidationResult<'a> {
75    unit_variants: Vec<&'a Variant>,
76    other_variant: Option<(&'a Variant, Type)>,
77}
78
79fn validate_and_process(ast: &DeriveInput) -> Result<(Type, Vec<&Variant>, Option<(&Variant, Type)>), Error> {
80    // Get discriminant type from #[repr] attribute
81    let discriminant_type = get_discriminant_type(ast)?;
82
83    let data_enum = match &ast.data {
84        Data::Enum(data_enum) => data_enum,
85        _ => return Err(Error::new_spanned(ast, "UnitEnum can only be derived for enums")),
86    };
87
88    let mut validation = ValidationResult {
89        unit_variants: Vec::new(),
90        other_variant: None,
91    };
92
93    // Validate each variant
94    for variant in &data_enum.variants {
95        match &variant.fields {
96            Fields::Unit => {
97                if has_unit_enum_attr(variant) {
98                    return Err(Error::new_spanned(variant,
99                                                  "Unit variants cannot have #[unit_enum] attributes"));
100                }
101                validation.unit_variants.push(variant);
102            }
103            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
104                if has_unit_enum_other_attr(variant) {
105                    if validation.other_variant.is_some() {
106                        return Err(Error::new_spanned(variant,
107                                                      "Multiple #[unit_enum(other)] variants found. Only one is allowed"));
108                    }
109                    validation.other_variant = Some((variant, fields.unnamed[0].ty.clone()));
110                } else {
111                    return Err(Error::new_spanned(variant,
112                                                  "Non-unit variant must be marked with #[unit_enum(other)] to be used as the catch-all variant"));
113                }
114            }
115            _ => return Err(Error::new_spanned(variant,
116                                               "Invalid variant. UnitEnum only supports unit variants and a single tuple variant marked with #[unit_enum(other)]")),
117        }
118    }
119
120    Ok((discriminant_type, validation.unit_variants, validation.other_variant))
121}
122
123fn get_discriminant_type(ast: &DeriveInput) -> Result<Type, Error> {
124    ast.attrs
125        .iter()
126        .find(|attr| attr.path().is_ident("repr"))
127        .map_or(Ok(syn::parse_quote!(i32)), |attr| {
128            attr.parse_args::<Type>()
129                .map_err(|_| Error::new_spanned(attr, "Invalid repr attribute"))
130        })
131}
132
133fn has_unit_enum_attr(variant: &Variant) -> bool {
134    variant.attrs.iter().any(|attr| attr.path().is_ident("unit_enum"))
135}
136
137fn has_unit_enum_other_attr(variant: &Variant) -> bool {
138    variant.attrs.iter().any(|attr| {
139        attr.path().is_ident("unit_enum")
140            && attr
141                .parse_nested_meta(|meta| {
142                    if meta.path.is_ident("other") {
143                        Ok(())
144                    } else {
145                        Err(meta.error("Invalid unit_enum attribute"))
146                    }
147                })
148                .is_ok()
149    })
150}
151
152fn compute_discriminants(variants: &[&Variant]) -> Vec<Expr> {
153    let mut discriminants = Vec::with_capacity(variants.len());
154    let mut last_discriminant: Option<Expr> = None;
155
156    for variant in variants {
157        let discriminant = variant
158            .discriminant
159            .as_ref()
160            .map(|(_, expr)| expr.clone())
161            .or_else(|| last_discriminant.clone().map(|expr| syn::parse_quote! { #expr + 1 }))
162            .unwrap_or_else(|| syn::parse_quote! { 0 });
163
164        discriminants.push(discriminant.clone());
165        last_discriminant = Some(discriminant);
166    }
167
168    discriminants
169}
170
171fn impl_unit_enum(
172    ast: &DeriveInput, discriminant_type: &Type, unit_variants: &[&Variant], other_variant: Option<(&Variant, Type)>,
173) -> TokenStream {
174    let name = &ast.ident;
175    let num_variants = unit_variants.len();
176    let discriminants = compute_discriminants(unit_variants);
177
178    let name_impl = generate_name_impl(name, unit_variants, &other_variant);
179    let ordinal_impl = generate_ordinal_impl(name, unit_variants, &other_variant, num_variants);
180    let from_ordinal_impl = generate_from_ordinal_impl(name, unit_variants);
181    let discriminant_impl =
182        generate_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
183    let from_discriminant_impl =
184        generate_from_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
185    let values_impl = generate_values_impl(name, unit_variants, &discriminants, &other_variant);
186
187    quote! {
188        impl #name {
189            #name_impl
190
191            #ordinal_impl
192
193            #from_ordinal_impl
194
195            #discriminant_impl
196
197            #from_discriminant_impl
198
199            /// Returns the total number of unit variants in the enum (excluding the "other" variant if present).
200            ///
201            /// # Examples
202            ///
203            /// ```ignore
204            /// # use unit_enum::UnitEnum;
205            /// #[derive(UnitEnum)]
206            /// enum Example {
207            ///     A,
208            ///     B,
209            ///     #[unit_enum(other)]
210            ///     Other(i32),
211            /// }
212            ///
213            /// assert_eq!(Example::len(), 2);
214            /// ```
215            pub const fn len() -> usize {
216                #num_variants
217            }
218
219            #values_impl
220        }
221    }
222    .into()
223}
224
225fn generate_name_impl(
226    name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>,
227) -> proc_macro2::TokenStream {
228    let unit_match_arms = unit_variants.iter().map(|variant| {
229        let variant_name = &variant.ident;
230        quote! { #name::#variant_name => stringify!(#variant_name) }
231    });
232
233    let other_arm = other_variant.as_ref().map(|(variant, _)| {
234        let variant_name = &variant.ident;
235        quote! { #name::#variant_name(_) => stringify!(#variant_name) }
236    });
237
238    quote! {
239        /// Returns the name of the enum variant as a string.
240        ///
241        /// # Examples
242        ///
243        /// ```ignore
244        /// # use unit_enum::UnitEnum;
245        /// #[derive(UnitEnum)]
246        /// enum Example {
247        ///     A,
248        ///     B = 10,
249        ///     C,
250        /// }
251        ///
252        /// assert_eq!(Example::A.name(), "A");
253        /// assert_eq!(Example::B.name(), "B");
254        /// assert_eq!(Example::C.name(), "C");
255        /// ```
256        pub const fn name(&self) -> &str {
257            match self {
258                #(#unit_match_arms,)*
259                #other_arm
260            }
261        }
262    }
263}
264
265fn generate_ordinal_impl(
266    name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, num_variants: usize,
267) -> proc_macro2::TokenStream {
268    let unit_match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
269        let variant_name = &variant.ident;
270        quote! { #name::#variant_name => #index }
271    });
272
273    let other_arm = other_variant.as_ref().map(|(variant, _)| {
274        let variant_name = &variant.ident;
275        quote! { #name::#variant_name(_) => #num_variants }
276    });
277
278    quote! {
279        /// Returns the zero-based ordinal of the enum variant.
280        ///
281        /// For enums with an "other" variant, it returns the position after all unit variants.
282        ///
283        /// # Examples
284        ///
285        /// ```ignore
286        /// # use unit_enum::UnitEnum;
287        /// #[derive(UnitEnum)]
288        /// enum Example {
289        ///     A,      // ordinal: 0
290        ///     B = 10, // ordinal: 1
291        ///     C,      // ordinal: 2
292        /// }
293        ///
294        /// assert_eq!(Example::A.ordinal(), 0);
295        /// assert_eq!(Example::B.ordinal(), 1);
296        /// assert_eq!(Example::C.ordinal(), 2);
297        /// ```
298        pub const fn ordinal(&self) -> usize {
299            match self {
300                #(#unit_match_arms,)*
301                #other_arm
302            }
303        }
304    }
305}
306fn generate_from_ordinal_impl(name: &syn::Ident, unit_variants: &[&Variant]) -> proc_macro2::TokenStream {
307    let match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
308        let variant_name = &variant.ident;
309        quote! { #index => Some(#name::#variant_name) }
310    });
311
312    quote! {
313        /// Converts a zero-based ordinal to an enum variant, if possible.
314        ///
315        /// Returns `Some(variant)` if the ordinal corresponds to a unit variant,
316        /// or `None` if the ordinal is out of range or would correspond to the "other" variant.
317        ///
318        /// # Examples
319        ///
320        /// ```ignore
321        /// # use unit_enum::UnitEnum;
322        /// # #[derive(Debug, PartialEq)]
323        /// #[derive(UnitEnum)]
324        /// enum Example {
325        ///     A,
326        ///     B,
327        ///     #[unit_enum(other)]
328        ///     Other(i32),
329        /// }
330        ///
331        /// assert_eq!(Example::from_ordinal(0), Some(Example::A));
332        /// assert_eq!(Example::from_ordinal(2), None); // Other variant
333        /// assert_eq!(Example::from_ordinal(99), None); // Out of range
334        /// ```
335        pub const fn from_ordinal(ord: usize) -> Option<Self> {
336            match ord {
337                #(#match_arms,)*
338                _ => None
339            }
340        }
341    }
342}
343
344fn generate_discriminant_impl(
345    name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, discriminant_type: &Type,
346    discriminants: &[Expr],
347) -> proc_macro2::TokenStream {
348    let unit_match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
349        let variant_name = &variant.ident;
350        quote! { #name::#variant_name => #discriminant }
351    });
352
353    let other_arm = other_variant.as_ref().map(|(variant, _)| {
354        let variant_name = &variant.ident;
355        quote! { #name::#variant_name(val) => *val }
356    });
357
358    quote! {
359        /// Returns the discriminant value of the enum variant.
360        ///
361        /// For "other" variants, returns the contained value.
362        ///
363        /// # Examples
364        ///
365        /// ```ignore
366        /// # use unit_enum::UnitEnum;
367        /// #[derive(UnitEnum)]
368        /// enum Example {
369        ///     A,      // 0
370        ///     B = 10, // 10
371        ///     C,      // 11
372        /// }
373        ///
374        /// assert_eq!(Example::A.discriminant(), 0);
375        /// assert_eq!(Example::B.discriminant(), 10);
376        /// assert_eq!(Example::C.discriminant(), 11);
377        /// ```
378         pub const fn discriminant(&self) -> #discriminant_type {
379            match self {
380                #(#unit_match_arms,)*
381                #other_arm
382            }
383        }
384    }
385}
386
387fn generate_from_discriminant_impl(
388    name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, discriminant_type: &Type,
389    discriminants: &[Expr],
390) -> proc_macro2::TokenStream {
391    if let Some((other_variant, _)) = other_variant {
392        let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
393            let variant_name = &variant.ident;
394            quote! { x if x == #discriminant => #name::#variant_name }
395        });
396
397        let other_name = &other_variant.ident;
398        quote! {
399            /// Converts a discriminant value to an enum variant.
400            ///
401            /// For enums with an "other" variant, this will always return a value,
402            /// using the "other" variant for undefined discriminants.
403            ///
404            /// # Examples
405            ///
406            /// ```ignore
407            /// # use unit_enum::UnitEnum;
408            /// #[derive(UnitEnum, PartialEq, Debug)]
409            /// #[repr(u8)]
410            /// enum Example {
411            ///     A,      // 0
412            ///     B = 10, // 10
413            ///     #[unit_enum(other)]
414            ///     Other(u8),
415            /// }
416            ///
417            /// assert_eq!(Example::from_discriminant(0), Example::A);
418            /// assert_eq!(Example::from_discriminant(10), Example::B);
419            /// assert_eq!(Example::from_discriminant(42), Example::Other(42));
420            /// ```
421            pub const fn from_discriminant(discr: #discriminant_type) -> Self {
422                match discr {
423                    #(#match_arms,)*
424                    other => #name::#other_name(other)
425                }
426            }
427        }
428    } else {
429        let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
430            let variant_name = &variant.ident;
431            quote! { x if x == #discriminant => Some(#name::#variant_name) }
432        });
433
434        quote! {
435            /// Converts a discriminant value to an enum variant, if possible.
436            ///
437            /// Returns `Some(variant)` if the discriminant corresponds to a defined variant,
438            /// or `None` if the discriminant is undefined.
439            ///
440            /// # Examples
441            ///
442            /// ```ignore
443            /// # use unit_enum::UnitEnum;
444            /// #[derive(UnitEnum, PartialEq, Debug)]
445            /// #[repr(u8)]
446            /// enum Example {
447            ///     A,      // 0
448            ///     B = 10, // 10
449            ///     C,      // 11
450            /// }
451            ///
452            /// assert_eq!(Example::from_discriminant(0), Some(Example::A));
453            /// assert_eq!(Example::from_discriminant(10), Some(Example::B));
454            /// assert_eq!(Example::from_discriminant(42), None);
455            /// ```
456            pub const fn from_discriminant(discr: #discriminant_type) -> Option<Self> {
457                match discr {
458                    #(#match_arms,)*
459                    _ => None
460                }
461            }
462        }
463    }
464}
465
466fn generate_values_impl(
467    name: &syn::Ident, unit_variants: &[&Variant], discriminants: &[Expr], _other_variant: &Option<(&Variant, Type)>,
468) -> proc_macro2::TokenStream {
469    // Create a vector of variant expressions paired with their discriminants
470    let variant_exprs = unit_variants.iter().zip(discriminants).map(|(variant, _discriminant)| {
471        let variant_name = &variant.ident;
472        quote! {
473            #name::#variant_name // The variant
474        }
475    });
476
477    // Collect variants into a Vec to ensure consistent ordering
478    quote! {
479        /// Returns an iterator over all unit variants of the enum.
480        ///
481        /// Note: This does not include values from the "other" variant, if present.
482        ///
483        /// # Examples
484        ///
485        /// ```ignore
486        /// # use unit_enum::UnitEnum;
487        /// #[derive(UnitEnum, PartialEq, Debug)]
488        /// enum Example {
489        ///     A,
490        ///     B,
491        ///     #[unit_enum(other)]
492        ///     Other(i32),
493        /// }
494        ///
495        /// let values: Vec<_> = Example::values().collect();
496        /// assert_eq!(values, vec![Example::A, Example::B]);
497        /// ```
498        pub fn values() -> impl Iterator<Item = Self> {
499            vec![
500                #(#variant_exprs),*
501            ].into_iter()
502        }
503    }
504}