disc_derive/
lib.rs

1#![feature(proc_macro_diagnostic)]
2
3use proc_macro::{Diagnostic, Level, TokenStream};
4use proc_macro2::Ident;
5use proc_macro2::TokenStream as TokenStream2;
6use proc_macro_crate::{crate_name, FoundCrate};
7use quote::{format_ident, quote};
8use syn::{parse_macro_input, Data, DeriveInput, DataEnum, Generics, spanned::Spanned};
9
10#[proc_macro_attribute]
11pub fn disc(args: TokenStream, input: TokenStream) -> TokenStream {
12    if !args.is_empty() {
13        Diagnostic::new(Level::Error, "No arguments are expected. This will be changed in the future so that other types such as u32 can be used.").emit();
14        return TokenStream::new();
15    }
16
17    let input = parse_macro_input!(input as DeriveInput);
18    
19    match input.data {
20        Data::Enum(ref data) => {
21            let implementation = generate_implementation(data, &input.ident, &input.generics);
22            return TokenStream::from(quote! {
23                #[repr(u8)]
24                #input
25
26                #implementation
27            });
28        },
29        Data::Struct(..) => Diagnostic::new(Level::Error, "Incompatible data type.")
30            .span_error(input.ident.span().unwrap(), "`struct` has no discriminant.")
31            .help("Use `enum` instead.")
32            .emit(),
33        Data::Union(..) => Diagnostic::new(Level::Error, "Incompatible data type.")
34            .span_error(input.ident.span().unwrap(), "`union` has no discriminant.")
35            .help("Use `enum` instead.")
36            .emit(),
37    }
38    TokenStream::new() 
39}
40
41fn verify_fields(data: &DataEnum) -> bool {
42    for variant in data.variants.iter() {
43        if !variant.fields.is_empty() {
44            Diagnostic::new(Level::Error, "A disc enumeration cannot have fields.").span_error(variant.fields.span().unwrap(), "Here").emit();
45            return false;
46        }
47    }
48    true
49}
50
51fn generate_implementation(data: &DataEnum, name: &Ident, generics: &Generics) -> TokenStream2 {
52    if !verify_fields(data) {
53        return TokenStream2::new();
54    }
55
56    let found_crate = crate_name("disc").expect("Couldn't find the crate `disc`.");
57
58    let from_discriminant_ty = match found_crate {
59        FoundCrate::Itself => quote!(crate::FromDiscriminant),
60        FoundCrate::Name(name) => {
61            let ident = format_ident!("{}", name);
62            quote!(#ident::FromDiscriminant)
63        }
64    };
65
66    // TODO: Add a implementation for all discriminants when `auto` is false.
67    let auto = !data.variants.iter().any(|variant| variant.discriminant.is_some());
68    let n = data.variants.len();
69
70    if n > (u8::MAX as usize) {
71        Diagnostic::new(Level::Error, "Cannot have more than `u8::MAX` (255) variants.").emit()
72    }
73    let n = n as u8;
74    
75    let body = if auto {
76        quote! {
77            if d >= #n {
78                return None;
79            }
80            Some(unsafe { ::core::mem::transmute(d) })
81        }
82    } else {
83        todo!()
84    };
85    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
86
87    let tokens = quote! {
88        impl #impl_generics #from_discriminant_ty<u8> for #name #ty_generics #where_clause {
89            fn from_discriminant(d: u8) -> Option<Self> {
90                #body
91            }
92        }
93    };
94    tokens
95}