finte_derive/
lib.rs

1use std::borrow::Cow;
2
3use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
4
5#[proc_macro_derive(IntEnum)]
6pub fn derive_int_enum(input: TokenStream) -> TokenStream {
7    let input = match parse_input(input) {
8        Ok(input) => input,
9        Err(err) => return err.to_compile_error(),
10    };
11
12    match generate_impl(&input) {
13        Ok(impls) => impls,
14        Err(err) => err.to_compile_error(),
15    }
16}
17
18fn parse_input(input: TokenStream) -> Result<DeriveInput, Error> {
19    let mut tt_iter = input.into_iter().peekable();
20
21    let mut repr = None;
22    loop {
23        match tt_iter.next().expect("`#` in `#[repr(_)]`") {
24            TokenTree::Punct(punct) => {
25                assert_eq!(punct.as_char(), '#');
26            }
27            TokenTree::Ident(ident) => match ident.to_string().as_str() {
28                "pub" => {
29                    if let Some(TokenTree::Group(_)) = tt_iter.peek() {
30                        tt_iter.next().expect("vis path");
31                    }
32                    continue;
33                }
34                "enum" => {
35                    break;
36                }
37                _ => {
38                    return Err(Error::new(
39                        ident.span(),
40                        "unsupported type: try using an enum instead",
41                    ));
42                }
43            },
44            _ => panic!("expect outer meta"),
45        }
46
47        let stream = match tt_iter.next().expect("`repr` in #[repr(_)]`") {
48            TokenTree::Group(group) => {
49                assert_eq!(group.delimiter(), Delimiter::Bracket);
50                group.stream()
51            }
52            _ => panic!("expect attr in #[repr(_)]"),
53        };
54
55        let mut tt_iter = stream.into_iter();
56        let repr_group_stream = match tt_iter.next().expect("attr") {
57            TokenTree::Ident(ident) => {
58                if ident.to_string() == "repr" {
59                    match tt_iter.next().expect("attr list") {
60                        TokenTree::Group(group) => {
61                            assert_eq!(group.delimiter(), Delimiter::Parenthesis);
62                            group.stream()
63                        }
64                        _ => {
65                            panic!("repr attr child should be group");
66                        }
67                    }
68                } else {
69                    continue;
70                }
71            }
72            _ => continue,
73        };
74
75        let mut tt_iter = repr_group_stream.into_iter();
76        match tt_iter.next().expect("repr type") {
77            TokenTree::Ident(ident) => {
78                if let Some(_) = repr.replace(ident) {
79                    panic!("more than repr found");
80                }
81            }
82            _ => {
83                panic!("repr child shuld be ident");
84            }
85        }
86        assert!(tt_iter.next().is_none(), "repr should have only child");
87    }
88
89    let repr = match repr {
90        Some(repr) => repr,
91        None => {
92            return Err(Error::new(
93                Span::call_site(),
94                "no #[repr(_)] found: try adding one to specify the type for `IntEnum::Int`",
95            ));
96        }
97    };
98
99    let name = match tt_iter.next().expect("enum name") {
100        TokenTree::Ident(ident) => ident,
101        tt => return Err(Error::new(tt.span(), "expect enum name")),
102    };
103
104    let enum_item_tt = tt_iter.next().expect("enum definition body");
105    let enum_item_group = match enum_item_tt {
106        TokenTree::Group(group) => {
107            assert_eq!(group.delimiter(), Delimiter::Brace);
108            group.stream()
109        }
110        _ => panic!("enum items should reside in a group"),
111    };
112
113    let variants = {
114        let mut variants = Vec::new();
115        let mut tt_iter = enum_item_group.into_iter();
116        loop {
117            let variant_name = match tt_iter.by_ref().find_map(|tt| match tt {
118                TokenTree::Ident(ident) => Some(ident),
119                _ => None,
120            }) {
121                Some(ident) => ident,
122                None => break,
123            };
124
125            match tt_iter.next().expect("`=`") {
126                TokenTree::Punct(punct) => match punct.as_char() {
127                    '=' => (),
128                    _ => return Err(Error::new(punct.span(), "expect discriminant")),
129                },
130                tt => return Err(Error::new(tt.span(), "expect Punct(_) after variant name")),
131            }
132
133            let variant_value = match tt_iter.next().expect("variant value") {
134                TokenTree::Literal(literal) => literal,
135                tt => return Err(Error::new(tt.span(), "expect discriminant value after `=`")),
136            };
137
138            variants.push((variant_name, variant_value));
139        }
140        variants
141    };
142
143    if variants.is_empty() {
144        Err(Error::new(Span::call_site(), "no variants"))
145    } else {
146        Ok(DeriveInput {
147            int: repr,
148            name,
149            variants,
150        })
151    }
152}
153
154fn generate_impl(input: &DeriveInput) -> Result<TokenStream, Error> {
155    use std::fmt::Write;
156
157    let mut ts = String::new();
158
159    writeln!(ts, "impl finte::IntEnum for {} {{", input.name).unwrap();
160
161    writeln!(ts, "    type Int = {};", input.int).unwrap();
162
163    {
164        writeln!(
165            ts,
166            "    fn try_from_int(value: Self::Int) -> std::result::Result::<Self, finte::TryFromIntError<Self>> {{"
167        )
168        .unwrap();
169        writeln!(ts, "        match value {{").unwrap();
170        for (name, value) in input.variants.iter() {
171            writeln!(ts, "            {} => Ok(Self::{}),", value, name).unwrap();
172        }
173        writeln!(
174            ts,
175            "            _ => Err(finte::TryFromIntError::new(value)),"
176        )
177        .unwrap();
178        writeln!(ts, "        }}").unwrap();
179        writeln!(ts, "    }}").unwrap();
180    }
181
182    {
183        writeln!(ts, "    fn int_value(&self) -> Self::Int {{").unwrap();
184        writeln!(ts, "        match self {{").unwrap();
185        for (name, value) in input.variants.iter() {
186            writeln!(ts, "            Self::{} => {},", name, value).unwrap();
187        }
188        writeln!(ts, "        }}").unwrap();
189        writeln!(ts, "    }}").unwrap();
190    }
191
192    writeln!(ts, "}}").unwrap();
193
194    Ok(ts.parse().unwrap())
195}
196
197#[derive(Debug)]
198struct DeriveInput {
199    int: Ident,
200    name: Ident,
201    variants: Vec<(Ident, Literal)>,
202}
203
204struct Error {
205    span: Span,
206    message: Cow<'static, str>,
207}
208
209impl Error {
210    fn new(span: Span, message: impl Into<Cow<'static, str>>) -> Self {
211        Self {
212            span,
213            message: message.into(),
214        }
215    }
216
217    fn to_compile_error(&self) -> TokenStream {
218        fn with_span(span: Span, tt: impl Into<TokenTree>) -> TokenTree {
219            let mut tt = tt.into();
220            tt.set_span(span);
221            tt
222        }
223
224        let mut stream: Vec<TokenTree> = Vec::new();
225
226        stream.push(with_span(self.span, Ident::new("compile_error", self.span)));
227        stream.push(with_span(self.span, Punct::new('!', Spacing::Alone)));
228        stream.push(with_span(
229            self.span,
230            with_span(
231                self.span,
232                Group::new(
233                    Delimiter::Parenthesis,
234                    TokenTree::from(Literal::string(&self.message)).into(),
235                ),
236            ),
237        ));
238        stream.push(with_span(self.span, Punct::new(';', Spacing::Alone)));
239
240        stream.into_iter().collect()
241    }
242}