explicit_discriminant/
lib.rs

1#![deny(warnings, unsafe_code)]
2
3use std::ops::{Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive};
4
5use proc_macro::TokenStream;
6use proc_macro_error::{abort, emit_error, emit_warning, proc_macro_error};
7use syn::{
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    Expr, ExprLit, ExprUnary, Lit, Pat, PatParen, RangeLimits, Result, Token, UnOp,
12};
13
14/// Require that all variants of an enum have an explicit discriminant defined.
15///
16/// Example:
17/// ```
18/// # use explicit_discriminant::ExplicitDiscriminant;
19/// // This works
20/// #[derive(ExplicitDiscriminant)]
21/// pub enum MyEnum {
22///     One = 1,
23///     Two = 2,
24///     Three = 3,
25/// }
26/// ```
27/// ```compule_fail
28/// // But this won't compile
29/// #[derive(ExplicitDiscriminant)]
30/// pub enum MyOtherEnum {
31///     One = 1,
32///     Two,
33///     Three = 3,
34/// }
35/// ```
36#[proc_macro_error]
37#[proc_macro_derive(ExplicitDiscriminant, attributes(pattern))]
38pub fn derive_explicit_discriminant(item: TokenStream) -> TokenStream {
39    let input = parse_macro_input!(item as syn::DeriveInput);
40    let syn::Data::Enum(data_enum) = input.data else {
41        abort!(input, "can only be derived on an enum")
42    };
43
44    let punctuated_patterns = match input
45        .attrs
46        .iter()
47        .filter(|a| a.path().is_ident("pattern"))
48        .map(|pat_attr| {
49            pat_attr.parse_args_with(Punctuated::<DisciminantPattern, Token![,]>::parse_terminated)
50        })
51        .collect::<Result<Vec<_>>>()
52    {
53        Ok(patterns) => patterns,
54        Err(err) => return err.into_compile_error().into(),
55    };
56    let patterns = punctuated_patterns
57        .iter()
58        .flat_map(|puncts| puncts.iter().map(|pat| pat.pat.clone()))
59        .collect::<Vec<_>>();
60
61    for variant in data_enum.variants {
62        if let Some((_, discriminant)) = variant.discriminant {
63            if !punctuated_patterns.is_empty()
64                && !patterns.iter().any(|pat| tok_matches(&discriminant, pat))
65            {
66                emit_error!(discriminant, "discriminant does not match any pattern")
67            }
68        } else {
69            emit_error!(variant, "no explicit discriminant")
70        }
71    }
72
73    TokenStream::new()
74}
75
76struct DisciminantPattern {
77    pat: Pat,
78}
79
80impl Parse for DisciminantPattern {
81    fn parse(input: ParseStream) -> Result<Self> {
82        Ok(Self {
83            pat: Pat::parse_multi(input)?,
84        })
85    }
86}
87
88/// Recreating part of the behavior for `matches!`, but for syn Tokens
89fn tok_matches(expr: &Expr, pat: &Pat) -> bool {
90    let expr_int = expr_as_int(expr);
91
92    match pat {
93        Pat::Lit(exprlit) => expr_int == exprlit_as_int(exprlit),
94        Pat::Or(pator) => pator.cases.iter().any(|case| tok_matches(expr, case)),
95        Pat::Range(patrange) => {
96            let start = patrange.start.as_ref().map(|expr| expr_as_int(expr));
97            let end = patrange.end.as_ref().map(|expr| expr_as_int(expr));
98            match (start, patrange.limits, end) {
99                (Some(start), RangeLimits::Closed(_), Some(end)) => {
100                    RangeInclusive::new(start, end).contains(&expr_int)
101                }
102                (Some(start), RangeLimits::HalfOpen(_), Some(end)) => {
103                    Range { start, end }.contains(&expr_int)
104                }
105                (Some(start), RangeLimits::HalfOpen(_), None) => {
106                    RangeFrom { start }.contains(&expr_int)
107                }
108                (None, RangeLimits::Closed(_), Some(end)) => {
109                    RangeToInclusive { end }.contains(&expr_int)
110                }
111                (None, RangeLimits::HalfOpen(_), Some(end)) => RangeTo { end }.contains(&expr_int),
112                _ => abort!(patrange, "unsupported range type"),
113            }
114        }
115        Pat::Wild(_) => true,
116        Pat::Paren(PatParen { pat, .. }) => tok_matches(expr, pat),
117        _ => {
118            emit_warning!(
119                pat,
120                "Currently supported are: literals, ranges, or-patterns, parenthesizeds, and wilds"
121            );
122            abort!(pat, format!("pattern type not supported"));
123        }
124    }
125}
126
127fn expr_as_int(expr: &Expr) -> i128 {
128    match expr {
129        Expr::Lit(lit) => exprlit_as_int(lit),
130        Expr::Unary(ExprUnary {
131            op: UnOp::Neg(_),
132            expr,
133            ..
134        }) => -expr_as_int(expr),
135        _ => abort!(
136            expr,
137            "only literal expressions (optionally negated) are supported"
138        ),
139    }
140}
141
142fn exprlit_as_int(exprlit: &ExprLit) -> i128 {
143    match &exprlit.lit {
144        Lit::Int(litint) => litint
145            .base10_parse()
146            .unwrap_or_else(|_| abort!(litint, "could not parse token to i128")),
147        _ => abort!(exprlit.lit, "only integer literals are supported"),
148    }
149}