explicit_discriminant/
lib.rs1#![deny(warnings, unsafe_code)]
2
3use std::ops::{Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive};
4
5use proc_macro::TokenStream;
6use proc_macro_error::{abort, emit_error, emit_warning, proc_macro_error};
7use syn::{
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 Expr, ExprLit, ExprUnary, Lit, Pat, PatParen, RangeLimits, Result, Token, UnOp,
12};
13
14#[proc_macro_error]
37#[proc_macro_derive(ExplicitDiscriminant, attributes(pattern))]
38pub fn derive_explicit_discriminant(item: TokenStream) -> TokenStream {
39 let input = parse_macro_input!(item as syn::DeriveInput);
40 let syn::Data::Enum(data_enum) = input.data else {
41 abort!(input, "can only be derived on an enum")
42 };
43
44 let punctuated_patterns = match input
45 .attrs
46 .iter()
47 .filter(|a| a.path().is_ident("pattern"))
48 .map(|pat_attr| {
49 pat_attr.parse_args_with(Punctuated::<DisciminantPattern, Token![,]>::parse_terminated)
50 })
51 .collect::<Result<Vec<_>>>()
52 {
53 Ok(patterns) => patterns,
54 Err(err) => return err.into_compile_error().into(),
55 };
56 let patterns = punctuated_patterns
57 .iter()
58 .flat_map(|puncts| puncts.iter().map(|pat| pat.pat.clone()))
59 .collect::<Vec<_>>();
60
61 for variant in data_enum.variants {
62 if let Some((_, discriminant)) = variant.discriminant {
63 if !punctuated_patterns.is_empty()
64 && !patterns.iter().any(|pat| tok_matches(&discriminant, pat))
65 {
66 emit_error!(discriminant, "discriminant does not match any pattern")
67 }
68 } else {
69 emit_error!(variant, "no explicit discriminant")
70 }
71 }
72
73 TokenStream::new()
74}
75
76struct DisciminantPattern {
77 pat: Pat,
78}
79
80impl Parse for DisciminantPattern {
81 fn parse(input: ParseStream) -> Result<Self> {
82 Ok(Self {
83 pat: Pat::parse_multi(input)?,
84 })
85 }
86}
87
88fn tok_matches(expr: &Expr, pat: &Pat) -> bool {
90 let expr_int = expr_as_int(expr);
91
92 match pat {
93 Pat::Lit(exprlit) => expr_int == exprlit_as_int(exprlit),
94 Pat::Or(pator) => pator.cases.iter().any(|case| tok_matches(expr, case)),
95 Pat::Range(patrange) => {
96 let start = patrange.start.as_ref().map(|expr| expr_as_int(expr));
97 let end = patrange.end.as_ref().map(|expr| expr_as_int(expr));
98 match (start, patrange.limits, end) {
99 (Some(start), RangeLimits::Closed(_), Some(end)) => {
100 RangeInclusive::new(start, end).contains(&expr_int)
101 }
102 (Some(start), RangeLimits::HalfOpen(_), Some(end)) => {
103 Range { start, end }.contains(&expr_int)
104 }
105 (Some(start), RangeLimits::HalfOpen(_), None) => {
106 RangeFrom { start }.contains(&expr_int)
107 }
108 (None, RangeLimits::Closed(_), Some(end)) => {
109 RangeToInclusive { end }.contains(&expr_int)
110 }
111 (None, RangeLimits::HalfOpen(_), Some(end)) => RangeTo { end }.contains(&expr_int),
112 _ => abort!(patrange, "unsupported range type"),
113 }
114 }
115 Pat::Wild(_) => true,
116 Pat::Paren(PatParen { pat, .. }) => tok_matches(expr, pat),
117 _ => {
118 emit_warning!(
119 pat,
120 "Currently supported are: literals, ranges, or-patterns, parenthesizeds, and wilds"
121 );
122 abort!(pat, format!("pattern type not supported"));
123 }
124 }
125}
126
127fn expr_as_int(expr: &Expr) -> i128 {
128 match expr {
129 Expr::Lit(lit) => exprlit_as_int(lit),
130 Expr::Unary(ExprUnary {
131 op: UnOp::Neg(_),
132 expr,
133 ..
134 }) => -expr_as_int(expr),
135 _ => abort!(
136 expr,
137 "only literal expressions (optionally negated) are supported"
138 ),
139 }
140}
141
142fn exprlit_as_int(exprlit: &ExprLit) -> i128 {
143 match &exprlit.lit {
144 Lit::Int(litint) => litint
145 .base10_parse()
146 .unwrap_or_else(|_| abort!(litint, "could not parse token to i128")),
147 _ => abort!(exprlit.lit, "only integer literals are supported"),
148 }
149}