1use std::borrow::Cow;
2
3use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
4
5#[proc_macro_derive(IntEnum)]
6pub fn derive_int_enum(input: TokenStream) -> TokenStream {
7 let input = match parse_input(input) {
8 Ok(input) => input,
9 Err(err) => return err.to_compile_error(),
10 };
11
12 match generate_impl(&input) {
13 Ok(impls) => impls,
14 Err(err) => err.to_compile_error(),
15 }
16}
17
18fn parse_input(input: TokenStream) -> Result<DeriveInput, Error> {
19 let mut tt_iter = input.into_iter().peekable();
20
21 let mut repr = None;
22 loop {
23 match tt_iter.next().expect("`#` in `#[repr(_)]`") {
24 TokenTree::Punct(punct) => {
25 assert_eq!(punct.as_char(), '#');
26 }
27 TokenTree::Ident(ident) => match ident.to_string().as_str() {
28 "pub" => {
29 if let Some(TokenTree::Group(_)) = tt_iter.peek() {
30 tt_iter.next().expect("vis path");
31 }
32 continue;
33 }
34 "enum" => {
35 break;
36 }
37 _ => {
38 return Err(Error::new(
39 ident.span(),
40 "unsupported type: try using an enum instead",
41 ));
42 }
43 },
44 _ => panic!("expect outer meta"),
45 }
46
47 let stream = match tt_iter.next().expect("`repr` in #[repr(_)]`") {
48 TokenTree::Group(group) => {
49 assert_eq!(group.delimiter(), Delimiter::Bracket);
50 group.stream()
51 }
52 _ => panic!("expect attr in #[repr(_)]"),
53 };
54
55 let mut tt_iter = stream.into_iter();
56 let repr_group_stream = match tt_iter.next().expect("attr") {
57 TokenTree::Ident(ident) => {
58 if ident.to_string() == "repr" {
59 match tt_iter.next().expect("attr list") {
60 TokenTree::Group(group) => {
61 assert_eq!(group.delimiter(), Delimiter::Parenthesis);
62 group.stream()
63 }
64 _ => {
65 panic!("repr attr child should be group");
66 }
67 }
68 } else {
69 continue;
70 }
71 }
72 _ => continue,
73 };
74
75 let mut tt_iter = repr_group_stream.into_iter();
76 match tt_iter.next().expect("repr type") {
77 TokenTree::Ident(ident) => {
78 if let Some(_) = repr.replace(ident) {
79 panic!("more than repr found");
80 }
81 }
82 _ => {
83 panic!("repr child shuld be ident");
84 }
85 }
86 assert!(tt_iter.next().is_none(), "repr should have only child");
87 }
88
89 let repr = match repr {
90 Some(repr) => repr,
91 None => {
92 return Err(Error::new(
93 Span::call_site(),
94 "no #[repr(_)] found: try adding one to specify the type for `IntEnum::Int`",
95 ));
96 }
97 };
98
99 let name = match tt_iter.next().expect("enum name") {
100 TokenTree::Ident(ident) => ident,
101 tt => return Err(Error::new(tt.span(), "expect enum name")),
102 };
103
104 let enum_item_tt = tt_iter.next().expect("enum definition body");
105 let enum_item_group = match enum_item_tt {
106 TokenTree::Group(group) => {
107 assert_eq!(group.delimiter(), Delimiter::Brace);
108 group.stream()
109 }
110 _ => panic!("enum items should reside in a group"),
111 };
112
113 let variants = {
114 let mut variants = Vec::new();
115 let mut tt_iter = enum_item_group.into_iter();
116 loop {
117 let variant_name = match tt_iter.by_ref().find_map(|tt| match tt {
118 TokenTree::Ident(ident) => Some(ident),
119 _ => None,
120 }) {
121 Some(ident) => ident,
122 None => break,
123 };
124
125 match tt_iter.next().expect("`=`") {
126 TokenTree::Punct(punct) => match punct.as_char() {
127 '=' => (),
128 _ => return Err(Error::new(punct.span(), "expect discriminant")),
129 },
130 tt => return Err(Error::new(tt.span(), "expect Punct(_) after variant name")),
131 }
132
133 let variant_value = match tt_iter.next().expect("variant value") {
134 TokenTree::Literal(literal) => literal,
135 tt => return Err(Error::new(tt.span(), "expect discriminant value after `=`")),
136 };
137
138 variants.push((variant_name, variant_value));
139 }
140 variants
141 };
142
143 if variants.is_empty() {
144 Err(Error::new(Span::call_site(), "no variants"))
145 } else {
146 Ok(DeriveInput {
147 int: repr,
148 name,
149 variants,
150 })
151 }
152}
153
154fn generate_impl(input: &DeriveInput) -> Result<TokenStream, Error> {
155 use std::fmt::Write;
156
157 let mut ts = String::new();
158
159 writeln!(ts, "impl finte::IntEnum for {} {{", input.name).unwrap();
160
161 writeln!(ts, " type Int = {};", input.int).unwrap();
162
163 {
164 writeln!(
165 ts,
166 " fn try_from_int(value: Self::Int) -> std::result::Result::<Self, finte::TryFromIntError<Self>> {{"
167 )
168 .unwrap();
169 writeln!(ts, " match value {{").unwrap();
170 for (name, value) in input.variants.iter() {
171 writeln!(ts, " {} => Ok(Self::{}),", value, name).unwrap();
172 }
173 writeln!(
174 ts,
175 " _ => Err(finte::TryFromIntError::new(value)),"
176 )
177 .unwrap();
178 writeln!(ts, " }}").unwrap();
179 writeln!(ts, " }}").unwrap();
180 }
181
182 {
183 writeln!(ts, " fn int_value(&self) -> Self::Int {{").unwrap();
184 writeln!(ts, " match self {{").unwrap();
185 for (name, value) in input.variants.iter() {
186 writeln!(ts, " Self::{} => {},", name, value).unwrap();
187 }
188 writeln!(ts, " }}").unwrap();
189 writeln!(ts, " }}").unwrap();
190 }
191
192 writeln!(ts, "}}").unwrap();
193
194 Ok(ts.parse().unwrap())
195}
196
197#[derive(Debug)]
198struct DeriveInput {
199 int: Ident,
200 name: Ident,
201 variants: Vec<(Ident, Literal)>,
202}
203
204struct Error {
205 span: Span,
206 message: Cow<'static, str>,
207}
208
209impl Error {
210 fn new(span: Span, message: impl Into<Cow<'static, str>>) -> Self {
211 Self {
212 span,
213 message: message.into(),
214 }
215 }
216
217 fn to_compile_error(&self) -> TokenStream {
218 fn with_span(span: Span, tt: impl Into<TokenTree>) -> TokenTree {
219 let mut tt = tt.into();
220 tt.set_span(span);
221 tt
222 }
223
224 let mut stream: Vec<TokenTree> = Vec::new();
225
226 stream.push(with_span(self.span, Ident::new("compile_error", self.span)));
227 stream.push(with_span(self.span, Punct::new('!', Spacing::Alone)));
228 stream.push(with_span(
229 self.span,
230 with_span(
231 self.span,
232 Group::new(
233 Delimiter::Parenthesis,
234 TokenTree::from(Literal::string(&self.message)).into(),
235 ),
236 ),
237 ));
238 stream.push(with_span(self.span, Punct::new(';', Spacing::Alone)));
239
240 stream.into_iter().collect()
241 }
242}