const_table/
lib.rs

1//! This crate provides an attribute macro to associate struct-type constants with enum variants.
2//!
3//! ## Syntax
4//!
5//! Place `#[const_table]` on an enum with at least two variants, where
6//!
7//! * the first has named fields and defines the type of the associated constants, and
8//! * all following have discriminant expressions of that type:
9//!
10//! ```
11//! use const_table::const_table;
12//!
13//! #[const_table]
14//! pub enum Planet {
15//!     PlanetInfo {
16//!         pub mass: f32,
17//!         pub radius: f32,
18//!     },
19//!
20//!     Mercury = PlanetInfo { mass: 3.303e+23, radius: 2.4397e6 },
21//!     Venus = PlanetInfo { mass: 4.869e+24, radius: 6.0518e6 },
22//!     Earth = PlanetInfo { mass: 5.976e+24, radius: 6.37814e6 },
23//!     Mars = PlanetInfo { mass: 6.421e+23, radius: 3.3972e6 },
24//!     Jupiter = PlanetInfo { mass: 1.9e+27, radius: 7.1492e7 },
25//!     Saturn = PlanetInfo { mass: 5.688e+26, radius: 6.0268e7 },
26//!     Uranus = PlanetInfo { mass: 8.686e+25, radius: 2.5559e7 },
27//!     Neptune = PlanetInfo { mass: 1.024e+26, radius: 2.4746e7 },
28//! }
29//! ```
30//!
31//! This expands to the following:
32//!
33//! ```
34//! #[repr(u32)]
35//! #[derive(core::marker::Copy, core::clone::Clone, core::fmt::Debug, core::hash::Hash, core::cmp::PartialEq, core::cmp::Eq)]
36//! pub enum Planet {
37//!     Mercury,
38//!     Venus,
39//!     Earth,
40//!     Mars,
41//!     Jupiter,
42//!     Saturn,
43//!     Uranus,
44//!     Neptune,
45//! }
46//!
47//! pub struct PlanetInfo {
48//!     pub mass: f32,
49//!     pub radius: f32,
50//! }
51//!
52//! impl Planet {
53//!     const COUNT: usize = 8;
54//!     pub fn iter() -> impl core::iter::DoubleEndedIterator<Item = Self> {
55//!         // transmuting here is fine because... (see try_from)
56//!         (0..Self::COUNT).map(|i| unsafe { core::mem::transmute(i as u32) })
57//!     }
58//! }
59//!
60//! impl core::ops::Deref for Planet {
61//!     type Target = PlanetInfo;
62//!     fn deref(&self) -> &Self::Target {
63//!         use Planet::*;
64//!         const TABLE: [PlanetInfo; 8] = [
65//!             PlanetInfo { mass: 3.303e+23, radius: 2.4397e6 },
66//!             PlanetInfo { mass: 4.869e+24, radius: 6.0518e6 },
67//!             PlanetInfo { mass: 5.976e+24, radius: 6.37814e6 },
68//!             PlanetInfo { mass: 6.421e+23, radius: 3.3972e6 },
69//!             PlanetInfo { mass: 1.9e+27, radius: 7.1492e7 },
70//!             PlanetInfo { mass: 5.688e+26, radius: 6.0268e7 },
71//!             PlanetInfo { mass: 8.686e+25, radius: 2.5559e7 },
72//!             PlanetInfo { mass: 1.024e+26, radius: 2.4746e7 },
73//!         ];
74//!
75//!         &TABLE[*self as usize]
76//!     }
77//! }
78//!
79//! impl core::convert::TryFrom<u32> for Planet {
80//!     type Error = u32;
81//!     fn try_from(i: u32) -> Result<Self, Self::Error> {
82//!         if (i as usize) < Self::COUNT {
83//!             // transmuting here is fine because all values in range are valid, since
84//!             // discriminants are assigned linearly starting at 0.
85//!             Ok(unsafe { core::mem::transmute(i) })
86//!         } else {
87//!             Err(i)
88//!         }
89//!     }
90//! }
91//! ```
92//!
93//! Note the automatically inserted `repr` and `derive` attributes. You may place a different `repr` attribute as normal,
94//! although only `u8`, `u16`, `u32` and `u64` are supported; an implementation of `TryFrom<T>` is provided, where `T` is
95//! the chosen `repr` type. You may also `derive` additional traits on the enum.
96//!
97//! Any attributes placed on the first variant will be placed on the corresponding struct in the expanded code.
98//!
99//! Also, note that the macro places the discriminant expressions inside a scope that imports all variants of your enum.
100//! This makes it convenient to make the values refer to each other, e.g. in a graph-like structure.
101//!
102//! Because the macro implements `Deref` for your enum, you can access fields of the target type like `Planet::Earth.mass`.
103//!
104//! Finally, `Planet::iter()` gives a `DoubleEndedIterator` over all variants in declaration order, and `Planet::COUNT` is
105//! the total number of variants.
106
107extern crate quote;
108extern crate syn;
109
110use proc_macro::TokenStream;
111use proc_macro2::Span;
112
113use quote::quote;
114use syn::parse::Error;
115use syn::punctuated::Punctuated;
116use syn::spanned::Spanned;
117use syn::{parse_macro_input, Expr, Ident, ItemEnum, ItemStruct, Variant};
118
119#[proc_macro_attribute]
120pub fn const_table(_attr: TokenStream, item: TokenStream) -> TokenStream {
121    let mut errors = proc_macro2::TokenStream::new();
122
123    let input_item = parse_macro_input!(item as syn::Item);
124    let input_item = if let syn::Item::Enum(e) = input_item {
125        e
126    } else {
127        let span = input_item.span();
128        let message = "the const_table attribute may only be applied to enums";
129        return Error::new(span, message).to_compile_error().into();
130    };
131
132    if !input_item.generics.params.is_empty() {
133        let span = input_item.generics.params.span();
134        let message = "a const_table enum cannot be generic";
135        errors.extend(Error::new(span, message).to_compile_error());
136    }
137
138    let (enum_attrs, repr_type) = {
139        let mut attrs = Vec::with_capacity(input_item.attrs.len());
140        let mut repr = None;
141
142        for attr in input_item.attrs {
143            if attr.path.is_ident("derive") {
144                let mut conflict_found = false;
145                if let Ok(syn::Meta::List(derive_attr)) = attr.parse_meta() {
146                    for arg in &derive_attr.nested {
147                        if let syn::NestedMeta::Meta(syn::Meta::Path(p)) = arg {
148                            if p.is_ident("Copy") || p.is_ident("Clone") ||
149                                p.is_ident("Debug") || p.is_ident("Hash") ||
150                                p.is_ident("PartialEq") || p.is_ident("Eq")
151                            {
152                                let span = p.span();
153                                let message = format!("the {} trait is already implemented by the const_table macro", p.get_ident().unwrap());
154                                errors.extend(Error::new(span, message).to_compile_error());
155                                conflict_found = true;
156                            }
157                        }
158                    }
159                }
160
161                if conflict_found {
162                    continue;
163                }
164            }
165
166            if attr.path.is_ident("repr") {
167                let ident: Ident = attr.parse_args().unwrap();
168                if ident != "u8" && ident != "u16" && ident != "u32" && ident != "u64" {
169                    let span = attr.tokens.span();
170                    let message = "unsupported repr hint for a const_table enum: expected one of u8, u16, u32 or u64 (default is u32)";
171                    errors.extend(Error::new(span, message).to_compile_error());
172                    continue;
173                }
174
175                repr = Some(ident);
176            } else {
177                attrs.push(attr);
178            }
179        }
180
181        (attrs, repr.unwrap_or_else(|| Ident::new("u32", Span::call_site())))
182    };
183
184    let mut input_variants = input_item.variants.iter();
185    let first_variant = input_variants.next();
186
187    let (variants, value_exprs): (Punctuated<Variant, syn::token::Comma>, Vec<Expr>) = input_variants.map(|variant| {
188        if !variant.fields.is_empty() {
189            let span = variant.fields.span();
190            let message = "in a const_table enum, only the first variant should have fields";
191            errors.extend(Error::new(span, message).to_compile_error());
192        }
193
194        if let Some((_, expr)) = &variant.discriminant {
195            let v = Variant {
196                discriminant: None,
197                fields: syn::Fields::Unit,
198                ..(*variant).clone()
199            };
200
201            (v, expr.clone())
202        } else {
203            let span = variant.span();
204            let message = "in a const_table enum, all but the first variant should have a discriminant expression";
205            errors.extend(Error::new(span, message).to_compile_error());
206
207            let empty_expr = Expr::Tuple(syn::ExprTuple {
208                attrs: Vec::new(), paren_token: syn::token::Paren { span: variant.ident.span() }, elems: Punctuated::new()
209            });
210
211            (variant.clone(), empty_expr)
212        }
213    }).unzip();
214
215    if variants.is_empty() {
216        let span = input_item.brace_token.span;
217        let message = "a const_table enum needs at least one variant with a discriminant expression";
218        errors.extend(Error::new(span, message).to_compile_error());
219        return errors.into();
220    }
221
222    let struct_decl = if let Some(v) = first_variant {
223        use syn::Fields::Named;
224        if let Named(fields) = &v.fields {
225            ItemStruct {
226                attrs: v.attrs.clone(),
227                vis: input_item.vis.clone(),
228                struct_token: syn::token::Struct {
229                    span: Span::call_site(),
230                },
231                ident: v.ident.clone(),
232                generics: Default::default(),
233                fields: Named((*fields).clone()),
234                semi_token: None,
235            }
236        } else {
237            let span = v.span();
238            let message = "the first variant of a const_table enum should have named fields to specify the table layout";
239            errors.extend(Error::new(span, message).to_compile_error());
240            return errors.into();
241        }
242    } else {
243        let span = input_item.brace_token.span;
244        let message = "a const_table enum needs at least one variant with named fields to specify the table layout";
245        errors.extend(Error::new(span, message).to_compile_error());
246        return errors.into();
247    };
248    let struct_name = &struct_decl.ident;
249
250    let table_size = variants.len();
251    let enum_decl = ItemEnum {
252        attrs: enum_attrs,
253        variants,
254        ..input_item
255    };
256    let enum_name = &enum_decl.ident;
257
258    let expanded = quote! {
259        #errors
260
261        #[repr(#repr_type)]
262        #[derive(core::marker::Copy, core::clone::Clone, core::fmt::Debug, core::hash::Hash, core::cmp::PartialEq, core::cmp::Eq)]
263        #enum_decl
264
265        #struct_decl
266
267        impl #enum_name {
268            pub const COUNT: usize = #table_size;
269            pub fn iter() -> impl core::iter::DoubleEndedIterator<Item = Self> {
270                // transmuting here is fine because... (see try_from)
271                (0..Self::COUNT).map(|i| unsafe { core::mem::transmute(i as #repr_type) })
272            }
273        }
274
275        impl core::ops::Deref for #enum_name {
276            type Target = #struct_name;
277            fn deref(&self) -> &Self::Target {
278                use #enum_name::*;
279                const TABLE: [#struct_name; #table_size] = [ #(#value_exprs),* ];
280                &TABLE[*self as usize]
281            }
282        }
283
284        impl core::convert::TryFrom<#repr_type> for #enum_name {
285            type Error = #repr_type;
286            fn try_from(i: #repr_type) -> core::result::Result<Self, #repr_type> {
287                if (i as usize) < Self::COUNT {
288                    // transmuting here is fine because all values in range are valid, since
289                    // discriminants are assigned linearly starting at 0.
290                    core::result::Result::Ok(unsafe { core::mem::transmute(i) })
291                } else {
292                    core::result::Result::Err(i)
293                }
294            }
295        }
296    };
297    expanded.into()
298}