open_enum_derive/
lib.rs

1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15extern crate proc_macro;
16
17mod config;
18mod discriminant;
19mod repr;
20
21use config::Config;
22
23use discriminant::Discriminant;
24use proc_macro2::{Span, TokenStream};
25use quote::{format_ident, quote, ToTokens};
26use repr::Repr;
27use std::collections::HashSet;
28use syn::Attribute;
29use syn::{
30    parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Ident, ItemEnum, Visibility,
31};
32
33/// Sets the span for every token tree in the token stream
34fn set_token_stream_span(tokens: TokenStream, span: Span) -> TokenStream {
35    tokens
36        .into_iter()
37        .map(|mut tt| {
38            tt.set_span(span);
39            tt
40        })
41        .collect()
42}
43
44/// Checks that there are no duplicate discriminant values. If all variants are literals, return an `Err` so we can have
45/// more clear error messages. Otherwise, emit a static check that ensures no duplicates.
46fn check_no_alias<'a>(
47    enum_: &ItemEnum,
48    variants: impl Iterator<Item = (&'a Ident, &'a Discriminant, Span)> + Clone,
49) -> syn::Result<TokenStream> {
50    // If they're all literals, we can give better error messages by checking at proc macro time.
51    let mut values: HashSet<i128> = HashSet::new();
52    for (_, variant, span) in variants {
53        if let &Discriminant::Literal(value) = variant {
54            if !values.insert(value) {
55                return Err(Error::new(
56                    span,
57                    format!("discriminant value `{value}` assigned more than once"),
58                ));
59            }
60        } else {
61            let mut checking_enum = syn::ItemEnum {
62                ident: format_ident!("_Check{}", enum_.ident),
63                vis: Visibility::Inherited,
64                ..enum_.clone()
65            };
66            checking_enum.attrs.retain(|attr| {
67                matches!(
68                    attr.path().to_token_stream().to_string().as_str(),
69                    "repr" | "allow" | "warn" | "deny" | "forbid"
70                )
71            });
72            return Ok(quote!(
73                #[allow(dead_code)]
74                #checking_enum
75            ));
76        }
77    }
78    Ok(TokenStream::default())
79}
80
81fn emit_debug_impl<'a>(
82    ident: &Ident,
83    variants: impl Iterator<Item = &'a Ident> + Clone,
84    attrs: impl Iterator<Item = &'a Vec<Attribute>> + Clone,
85) -> TokenStream {
86    let attrs = attrs.map(|attrs| {
87        // Only allow "#[cfg(...)]" attributes
88        let iter = attrs.iter().filter(|attr| attr.path().is_ident("cfg"));
89        quote!(#(#iter)*)
90    });
91    quote!(impl ::core::fmt::Debug for #ident {
92        fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
93            #![allow(unreachable_patterns)]
94            let s = match *self {
95                #( #attrs Self::#variants => stringify!(#variants), )*
96                _ => {
97                    return fmt.debug_tuple(stringify!(#ident)).field(&self.0).finish();
98                }
99            };
100            fmt.pad(s)
101        }
102    })
103}
104
105fn path_matches_prelude_derive(
106    got_path: &syn::Path,
107    expected_path_after_std: &[&'static str],
108) -> bool {
109    let &[a, b] = expected_path_after_std else {
110        unimplemented!("checking against stdlib paths with != 2 parts");
111    };
112    let segments: Vec<&syn::PathSegment> = got_path.segments.iter().collect();
113    if segments
114        .iter()
115        .any(|segment| !matches!(segment.arguments, syn::PathArguments::None))
116    {
117        return false;
118    }
119    match &segments[..] {
120        // `core::fmt::Debug` or `some_crate::module::Name`
121        [maybe_core_or_std, maybe_a, maybe_b] => {
122            (maybe_core_or_std.ident == "core" || maybe_core_or_std.ident == "std")
123                && maybe_a.ident == a
124                && maybe_b.ident == b
125        }
126        // `fmt::Debug` or `module::Name`
127        [maybe_a, maybe_b] => {
128            maybe_a.ident == a && maybe_b.ident == b && got_path.leading_colon.is_none()
129        }
130        // `Debug` or `Name``
131        [maybe_b] => maybe_b.ident == b && got_path.leading_colon.is_none(),
132        _ => false,
133    }
134}
135
136fn open_enum_impl(
137    enum_: ItemEnum,
138    Config {
139        allow_alias,
140        repr_visibility,
141    }: Config,
142) -> Result<TokenStream, Error> {
143    // Does the enum define a `#[repr()]`?
144    let mut struct_attrs: Vec<TokenStream> = Vec::with_capacity(enum_.attrs.len() + 5);
145    struct_attrs.push(quote!(#[allow(clippy::exhaustive_structs)]));
146
147    if !enum_.generics.params.is_empty() {
148        return Err(Error::new(enum_.generics.span(), "enum cannot be generic"));
149    }
150    let mut variants = Vec::with_capacity(enum_.variants.len());
151    let mut last_field = Discriminant::Literal(-1);
152    for variant in &enum_.variants {
153        if !matches!(variant.fields, syn::Fields::Unit) {
154            return Err(Error::new(variant.span(), "enum cannot contain fields"));
155        }
156
157        let (value, value_span) = if let Some((_, discriminant)) = &variant.discriminant {
158            let span = discriminant.span();
159            (Discriminant::new(discriminant.clone())?, span)
160        } else {
161            last_field = last_field
162                .next_value()
163                .ok_or_else(|| Error::new(variant.span(), "enum discriminant overflowed"))?;
164            (last_field.clone(), variant.ident.span())
165        };
166        last_field = value.clone();
167        variants.push((&variant.ident, value, value_span, &variant.attrs))
168    }
169
170    let mut impl_attrs: Vec<TokenStream> = vec![quote!(#[allow(non_upper_case_globals)])];
171    let mut explicit_repr: Option<Repr> = None;
172
173    // To make `match` seamless, derive(PartialEq, Eq) if they aren't already.
174    let mut extra_derives = vec![quote!(::core::cmp::PartialEq), quote!(::core::cmp::Eq)];
175
176    let mut make_custom_debug_impl = false;
177    for attr in &enum_.attrs {
178        let mut include_in_struct = true;
179        // Turns out `is_ident` does a `to_string` every time
180        match attr.path().to_token_stream().to_string().as_str() {
181            "derive" => {
182                if let Ok(derive_paths) =
183                    attr.parse_args_with(Punctuated::<syn::Path, syn::Token![,]>::parse_terminated)
184                {
185                    for derive in &derive_paths {
186                        // These derives are treated specially
187                        const PARTIAL_EQ_PATH: &[&str] = &["cmp", "PartialEq"];
188                        const EQ_PATH: &[&str] = &["cmp", "Eq"];
189                        const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
190
191                        if path_matches_prelude_derive(derive, PARTIAL_EQ_PATH)
192                            || path_matches_prelude_derive(derive, EQ_PATH)
193                        {
194                            // This derive is always included, exclude it.
195                            continue;
196                        }
197                        if path_matches_prelude_derive(derive, DEBUG_PATH) && !allow_alias {
198                            make_custom_debug_impl = true;
199                            // Don't include this derive since we're generating a special one.
200                            continue;
201                        }
202                        extra_derives.push(derive.to_token_stream());
203                    }
204                    include_in_struct = false;
205                }
206            }
207            // Copy linting attribute to the impl.
208            "allow" | "warn" | "deny" | "forbid" => impl_attrs.push(attr.to_token_stream()),
209            "repr" => {
210                assert!(explicit_repr.is_none(), "duplicate explicit repr");
211                explicit_repr = Some(attr.parse_args()?);
212                include_in_struct = false;
213            }
214            "non_exhaustive" => {
215                // technically it's exhaustive if the enum covers the full integer range
216                return Err(Error::new(attr.path().span(), "`non_exhaustive` cannot be applied to an open enum; it is already non-exhaustive"));
217            }
218            _ => {}
219        }
220        if include_in_struct {
221            struct_attrs.push(attr.to_token_stream());
222        }
223    }
224
225    // The proper repr to type-check against
226    let typecheck_repr: Repr = explicit_repr.unwrap_or(Repr::Isize);
227
228    // The actual representation of the value.
229    let inner_repr = match explicit_repr {
230        Some(explicit_repr) => {
231            // If there is an explicit repr, emit #[repr(transparent)].
232            struct_attrs.push(quote!(#[repr(transparent)]));
233            explicit_repr
234        }
235        None => {
236            // If there isn't an explicit repr, determine an appropriate sized integer that will fit.
237            // Interpret all discriminant expressions as isize.
238            repr::autodetect_inner_repr(variants.iter().map(|v| &v.1))
239        }
240    };
241
242    if !extra_derives.is_empty() {
243        struct_attrs.push(quote!(#[derive(#(#extra_derives),*)]));
244    }
245
246    let alias_check = if allow_alias {
247        TokenStream::default()
248    } else {
249        check_no_alias(&enum_, variants.iter().map(|(i, v, s, _)| (*i, v, *s)))?
250    };
251
252    let syn::ItemEnum { ident, vis, .. } = enum_;
253
254    let debug_impl = if make_custom_debug_impl {
255        emit_debug_impl(
256            &ident,
257            variants.iter().map(|(i, _, _, _)| *i),
258            variants.iter().map(|(_, _, _, a)| *a),
259        )
260    } else {
261        TokenStream::default()
262    };
263
264    let fields = variants
265        .into_iter()
266        .map(|(name, value, value_span, attrs)| {
267            let mut value = value.into_token_stream();
268            value = set_token_stream_span(value, value_span);
269            let inner = if typecheck_repr == inner_repr {
270                value
271            } else {
272                quote!(::core::convert::identity::<#typecheck_repr>(#value) as #inner_repr)
273            };
274            quote!(
275                #(#attrs)*
276                pub const #name: #ident = #ident(#inner);
277            )
278        });
279
280    Ok(quote! {
281        #(#struct_attrs)*
282        #vis struct #ident(#repr_visibility #inner_repr);
283
284        #(#impl_attrs)*
285        impl #ident {
286            #(
287                #fields
288            )*
289        }
290        #debug_impl
291        #alias_check
292    })
293}
294
295#[proc_macro_attribute]
296pub fn open_enum(
297    attrs: proc_macro::TokenStream,
298    input: proc_macro::TokenStream,
299) -> proc_macro::TokenStream {
300    let enum_ = parse_macro_input!(input as syn::ItemEnum);
301    let config = parse_macro_input!(attrs as Config);
302    open_enum_impl(enum_, config)
303        .unwrap_or_else(Error::into_compile_error)
304        .into()
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_path_matches_stdlib_derive() {
313        const DEBUG_PATH: &[&str] = &["fmt", "Debug"];
314
315        for success_case in [
316            "::core::fmt::Debug",
317            "::std::fmt::Debug",
318            "core::fmt::Debug",
319            "std::fmt::Debug",
320            "fmt::Debug",
321            "Debug",
322        ] {
323            assert!(
324                path_matches_prelude_derive(&syn::parse_str(success_case).unwrap(), DEBUG_PATH),
325                "{success_case}"
326            );
327        }
328
329        for fail_case in [
330            "::fmt::Debug",
331            "::Debug",
332            "zerocopy::AsBytes",
333            "::zerocopy::AsBytes",
334            "PartialEq",
335            "core::cmp::Eq",
336        ] {
337            assert!(
338                !path_matches_prelude_derive(&syn::parse_str(fail_case).unwrap(), DEBUG_PATH),
339                "{fail_case}"
340            );
341        }
342    }
343}