unit_enum/lib.rs
1#![doc = include_str!("lib.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Error, Expr, Fields, Type, Variant};
6
7/// Derives the `UnitEnum` trait for an enum.
8///
9/// This macro can be used on enums with unit variants (no fields) and optionally one "other" variant
10/// that can hold arbitrary discriminant values.
11///
12/// # Attributes
13/// - `#[repr(type)]`: Optional for regular enums, defaults to i32. Required when using an "other" variant.
14/// - `#[unit_enum(other)]`: Marks a variant as the catch-all for undefined discriminant values.
15/// The type of this variant must match the repr type.
16///
17/// # Requirements
18/// - The enum must contain only unit variants, except for one optional "other" variant
19/// - The "other" variant, if present, must:
20/// - Be marked with `#[unit_enum(other)]`
21/// - Have exactly one unnamed field matching the repr type
22/// - Be the only variant with the "other" attribute
23/// - Have a matching `#[repr(type)]` attribute
24///
25/// # Examples
26///
27/// Basic usage with unit variants (repr is optional):
28/// ```rust
29/// # use unit_enum::UnitEnum;
30/// #[derive(UnitEnum)]
31/// enum Example {
32/// A,
33/// B = 10,
34/// C,
35/// }
36/// ```
37///
38/// Usage with explicit repr:
39/// ```rust
40/// # use unit_enum::UnitEnum;
41/// #[derive(UnitEnum)]
42/// #[repr(u16)]
43/// enum Color {
44/// Red = 10,
45/// Green,
46/// Blue = 45654,
47/// }
48/// ```
49///
50/// Usage with an "other" variant (repr required):
51/// ```rust
52/// # use unit_enum::UnitEnum;
53/// #[derive(UnitEnum)]
54/// #[repr(u16)]
55/// enum Status {
56/// Active = 1,
57/// Inactive = 2,
58/// #[unit_enum(other)]
59/// Unknown(u16), // type must match repr
60/// }
61/// ```
62#[proc_macro_derive(UnitEnum, attributes(unit_enum))]
63pub fn unit_enum_derive(input: TokenStream) -> TokenStream {
64 let ast = parse_macro_input!(input as DeriveInput);
65
66 match validate_and_process(&ast) {
67 Ok((discriminant_type, unit_variants, other_variant)) => {
68 impl_unit_enum(&ast, &discriminant_type, &unit_variants, other_variant)
69 }
70 Err(e) => e.to_compile_error().into(),
71 }
72}
73
74struct ValidationResult<'a> {
75 unit_variants: Vec<&'a Variant>,
76 other_variant: Option<(&'a Variant, Type)>,
77}
78
79fn validate_and_process(ast: &DeriveInput) -> Result<(Type, Vec<&Variant>, Option<(&Variant, Type)>), Error> {
80 // Get discriminant type from #[repr] attribute
81 let discriminant_type = get_discriminant_type(ast)?;
82
83 let data_enum = match &ast.data {
84 Data::Enum(data_enum) => data_enum,
85 _ => return Err(Error::new_spanned(ast, "UnitEnum can only be derived for enums")),
86 };
87
88 let mut validation = ValidationResult {
89 unit_variants: Vec::new(),
90 other_variant: None,
91 };
92
93 // Validate each variant
94 for variant in &data_enum.variants {
95 match &variant.fields {
96 Fields::Unit => {
97 if has_unit_enum_attr(variant) {
98 return Err(Error::new_spanned(variant,
99 "Unit variants cannot have #[unit_enum] attributes"));
100 }
101 validation.unit_variants.push(variant);
102 }
103 Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
104 if has_unit_enum_other_attr(variant) {
105 if validation.other_variant.is_some() {
106 return Err(Error::new_spanned(variant,
107 "Multiple #[unit_enum(other)] variants found. Only one is allowed"));
108 }
109 validation.other_variant = Some((variant, fields.unnamed[0].ty.clone()));
110 } else {
111 return Err(Error::new_spanned(variant,
112 "Non-unit variant must be marked with #[unit_enum(other)] to be used as the catch-all variant"));
113 }
114 }
115 _ => return Err(Error::new_spanned(variant,
116 "Invalid variant. UnitEnum only supports unit variants and a single tuple variant marked with #[unit_enum(other)]")),
117 }
118 }
119
120 Ok((discriminant_type, validation.unit_variants, validation.other_variant))
121}
122
123fn get_discriminant_type(ast: &DeriveInput) -> Result<Type, Error> {
124 ast.attrs
125 .iter()
126 .find(|attr| attr.path().is_ident("repr"))
127 .map_or(Ok(syn::parse_quote!(i32)), |attr| {
128 attr.parse_args::<Type>()
129 .map_err(|_| Error::new_spanned(attr, "Invalid repr attribute"))
130 })
131}
132
133fn has_unit_enum_attr(variant: &Variant) -> bool {
134 variant.attrs.iter().any(|attr| attr.path().is_ident("unit_enum"))
135}
136
137fn has_unit_enum_other_attr(variant: &Variant) -> bool {
138 variant.attrs.iter().any(|attr| {
139 attr.path().is_ident("unit_enum")
140 && attr
141 .parse_nested_meta(|meta| {
142 if meta.path.is_ident("other") {
143 Ok(())
144 } else {
145 Err(meta.error("Invalid unit_enum attribute"))
146 }
147 })
148 .is_ok()
149 })
150}
151
152fn compute_discriminants(variants: &[&Variant]) -> Vec<Expr> {
153 let mut discriminants = Vec::with_capacity(variants.len());
154 let mut last_discriminant: Option<Expr> = None;
155
156 for variant in variants {
157 let discriminant = variant
158 .discriminant
159 .as_ref()
160 .map(|(_, expr)| expr.clone())
161 .or_else(|| last_discriminant.clone().map(|expr| syn::parse_quote! { #expr + 1 }))
162 .unwrap_or_else(|| syn::parse_quote! { 0 });
163
164 discriminants.push(discriminant.clone());
165 last_discriminant = Some(discriminant);
166 }
167
168 discriminants
169}
170
171fn impl_unit_enum(
172 ast: &DeriveInput, discriminant_type: &Type, unit_variants: &[&Variant], other_variant: Option<(&Variant, Type)>,
173) -> TokenStream {
174 let name = &ast.ident;
175 let num_variants = unit_variants.len();
176 let discriminants = compute_discriminants(unit_variants);
177
178 let name_impl = generate_name_impl(name, unit_variants, &other_variant);
179 let ordinal_impl = generate_ordinal_impl(name, unit_variants, &other_variant, num_variants);
180 let from_ordinal_impl = generate_from_ordinal_impl(name, unit_variants);
181 let discriminant_impl =
182 generate_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
183 let from_discriminant_impl =
184 generate_from_discriminant_impl(name, unit_variants, &other_variant, discriminant_type, &discriminants);
185 let values_impl = generate_values_impl(name, unit_variants, &discriminants, &other_variant);
186
187 quote! {
188 impl #name {
189 #name_impl
190
191 #ordinal_impl
192
193 #from_ordinal_impl
194
195 #discriminant_impl
196
197 #from_discriminant_impl
198
199 /// Returns the total number of unit variants in the enum (excluding the "other" variant if present).
200 ///
201 /// # Examples
202 ///
203 /// ```ignore
204 /// # use unit_enum::UnitEnum;
205 /// #[derive(UnitEnum)]
206 /// enum Example {
207 /// A,
208 /// B,
209 /// #[unit_enum(other)]
210 /// Other(i32),
211 /// }
212 ///
213 /// assert_eq!(Example::len(), 2);
214 /// ```
215 pub const fn len() -> usize {
216 #num_variants
217 }
218
219 #values_impl
220 }
221 }
222 .into()
223}
224
225fn generate_name_impl(
226 name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>,
227) -> proc_macro2::TokenStream {
228 let unit_match_arms = unit_variants.iter().map(|variant| {
229 let variant_name = &variant.ident;
230 quote! { #name::#variant_name => stringify!(#variant_name) }
231 });
232
233 let other_arm = other_variant.as_ref().map(|(variant, _)| {
234 let variant_name = &variant.ident;
235 quote! { #name::#variant_name(_) => stringify!(#variant_name) }
236 });
237
238 quote! {
239 /// Returns the name of the enum variant as a string.
240 ///
241 /// # Examples
242 ///
243 /// ```ignore
244 /// # use unit_enum::UnitEnum;
245 /// #[derive(UnitEnum)]
246 /// enum Example {
247 /// A,
248 /// B = 10,
249 /// C,
250 /// }
251 ///
252 /// assert_eq!(Example::A.name(), "A");
253 /// assert_eq!(Example::B.name(), "B");
254 /// assert_eq!(Example::C.name(), "C");
255 /// ```
256 pub const fn name(&self) -> &str {
257 match self {
258 #(#unit_match_arms,)*
259 #other_arm
260 }
261 }
262 }
263}
264
265fn generate_ordinal_impl(
266 name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, num_variants: usize,
267) -> proc_macro2::TokenStream {
268 let unit_match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
269 let variant_name = &variant.ident;
270 quote! { #name::#variant_name => #index }
271 });
272
273 let other_arm = other_variant.as_ref().map(|(variant, _)| {
274 let variant_name = &variant.ident;
275 quote! { #name::#variant_name(_) => #num_variants }
276 });
277
278 quote! {
279 /// Returns the zero-based ordinal of the enum variant.
280 ///
281 /// For enums with an "other" variant, it returns the position after all unit variants.
282 ///
283 /// # Examples
284 ///
285 /// ```ignore
286 /// # use unit_enum::UnitEnum;
287 /// #[derive(UnitEnum)]
288 /// enum Example {
289 /// A, // ordinal: 0
290 /// B = 10, // ordinal: 1
291 /// C, // ordinal: 2
292 /// }
293 ///
294 /// assert_eq!(Example::A.ordinal(), 0);
295 /// assert_eq!(Example::B.ordinal(), 1);
296 /// assert_eq!(Example::C.ordinal(), 2);
297 /// ```
298 pub const fn ordinal(&self) -> usize {
299 match self {
300 #(#unit_match_arms,)*
301 #other_arm
302 }
303 }
304 }
305}
306fn generate_from_ordinal_impl(name: &syn::Ident, unit_variants: &[&Variant]) -> proc_macro2::TokenStream {
307 let match_arms = unit_variants.iter().enumerate().map(|(index, variant)| {
308 let variant_name = &variant.ident;
309 quote! { #index => Some(#name::#variant_name) }
310 });
311
312 quote! {
313 /// Converts a zero-based ordinal to an enum variant, if possible.
314 ///
315 /// Returns `Some(variant)` if the ordinal corresponds to a unit variant,
316 /// or `None` if the ordinal is out of range or would correspond to the "other" variant.
317 ///
318 /// # Examples
319 ///
320 /// ```ignore
321 /// # use unit_enum::UnitEnum;
322 /// # #[derive(Debug, PartialEq)]
323 /// #[derive(UnitEnum)]
324 /// enum Example {
325 /// A,
326 /// B,
327 /// #[unit_enum(other)]
328 /// Other(i32),
329 /// }
330 ///
331 /// assert_eq!(Example::from_ordinal(0), Some(Example::A));
332 /// assert_eq!(Example::from_ordinal(2), None); // Other variant
333 /// assert_eq!(Example::from_ordinal(99), None); // Out of range
334 /// ```
335 pub const fn from_ordinal(ord: usize) -> Option<Self> {
336 match ord {
337 #(#match_arms,)*
338 _ => None
339 }
340 }
341 }
342}
343
344fn generate_discriminant_impl(
345 name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, discriminant_type: &Type,
346 discriminants: &[Expr],
347) -> proc_macro2::TokenStream {
348 let unit_match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
349 let variant_name = &variant.ident;
350 quote! { #name::#variant_name => #discriminant }
351 });
352
353 let other_arm = other_variant.as_ref().map(|(variant, _)| {
354 let variant_name = &variant.ident;
355 quote! { #name::#variant_name(val) => *val }
356 });
357
358 quote! {
359 /// Returns the discriminant value of the enum variant.
360 ///
361 /// For "other" variants, returns the contained value.
362 ///
363 /// # Examples
364 ///
365 /// ```ignore
366 /// # use unit_enum::UnitEnum;
367 /// #[derive(UnitEnum)]
368 /// enum Example {
369 /// A, // 0
370 /// B = 10, // 10
371 /// C, // 11
372 /// }
373 ///
374 /// assert_eq!(Example::A.discriminant(), 0);
375 /// assert_eq!(Example::B.discriminant(), 10);
376 /// assert_eq!(Example::C.discriminant(), 11);
377 /// ```
378 pub const fn discriminant(&self) -> #discriminant_type {
379 match self {
380 #(#unit_match_arms,)*
381 #other_arm
382 }
383 }
384 }
385}
386
387fn generate_from_discriminant_impl(
388 name: &syn::Ident, unit_variants: &[&Variant], other_variant: &Option<(&Variant, Type)>, discriminant_type: &Type,
389 discriminants: &[Expr],
390) -> proc_macro2::TokenStream {
391 if let Some((other_variant, _)) = other_variant {
392 let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
393 let variant_name = &variant.ident;
394 quote! { x if x == #discriminant => #name::#variant_name }
395 });
396
397 let other_name = &other_variant.ident;
398 quote! {
399 /// Converts a discriminant value to an enum variant.
400 ///
401 /// For enums with an "other" variant, this will always return a value,
402 /// using the "other" variant for undefined discriminants.
403 ///
404 /// # Examples
405 ///
406 /// ```ignore
407 /// # use unit_enum::UnitEnum;
408 /// #[derive(UnitEnum, PartialEq, Debug)]
409 /// #[repr(u8)]
410 /// enum Example {
411 /// A, // 0
412 /// B = 10, // 10
413 /// #[unit_enum(other)]
414 /// Other(u8),
415 /// }
416 ///
417 /// assert_eq!(Example::from_discriminant(0), Example::A);
418 /// assert_eq!(Example::from_discriminant(10), Example::B);
419 /// assert_eq!(Example::from_discriminant(42), Example::Other(42));
420 /// ```
421 pub const fn from_discriminant(discr: #discriminant_type) -> Self {
422 match discr {
423 #(#match_arms,)*
424 other => #name::#other_name(other)
425 }
426 }
427 }
428 } else {
429 let match_arms = unit_variants.iter().zip(discriminants).map(|(variant, discriminant)| {
430 let variant_name = &variant.ident;
431 quote! { x if x == #discriminant => Some(#name::#variant_name) }
432 });
433
434 quote! {
435 /// Converts a discriminant value to an enum variant, if possible.
436 ///
437 /// Returns `Some(variant)` if the discriminant corresponds to a defined variant,
438 /// or `None` if the discriminant is undefined.
439 ///
440 /// # Examples
441 ///
442 /// ```ignore
443 /// # use unit_enum::UnitEnum;
444 /// #[derive(UnitEnum, PartialEq, Debug)]
445 /// #[repr(u8)]
446 /// enum Example {
447 /// A, // 0
448 /// B = 10, // 10
449 /// C, // 11
450 /// }
451 ///
452 /// assert_eq!(Example::from_discriminant(0), Some(Example::A));
453 /// assert_eq!(Example::from_discriminant(10), Some(Example::B));
454 /// assert_eq!(Example::from_discriminant(42), None);
455 /// ```
456 pub const fn from_discriminant(discr: #discriminant_type) -> Option<Self> {
457 match discr {
458 #(#match_arms,)*
459 _ => None
460 }
461 }
462 }
463 }
464}
465
466fn generate_values_impl(
467 name: &syn::Ident, unit_variants: &[&Variant], discriminants: &[Expr], _other_variant: &Option<(&Variant, Type)>,
468) -> proc_macro2::TokenStream {
469 // Create a vector of variant expressions paired with their discriminants
470 let variant_exprs = unit_variants.iter().zip(discriminants).map(|(variant, _discriminant)| {
471 let variant_name = &variant.ident;
472 quote! {
473 #name::#variant_name // The variant
474 }
475 });
476
477 // Collect variants into a Vec to ensure consistent ordering
478 quote! {
479 /// Returns an iterator over all unit variants of the enum.
480 ///
481 /// Note: This does not include values from the "other" variant, if present.
482 ///
483 /// # Examples
484 ///
485 /// ```ignore
486 /// # use unit_enum::UnitEnum;
487 /// #[derive(UnitEnum, PartialEq, Debug)]
488 /// enum Example {
489 /// A,
490 /// B,
491 /// #[unit_enum(other)]
492 /// Other(i32),
493 /// }
494 ///
495 /// let values: Vec<_> = Example::values().collect();
496 /// assert_eq!(values, vec![Example::A, Example::B]);
497 /// ```
498 pub fn values() -> impl Iterator<Item = Self> {
499 vec![
500 #(#variant_exprs),*
501 ].into_iter()
502 }
503 }
504}