osu_rs_derive/
lib.rs

1extern crate proc_macro;
2use std::collections::HashMap;
3
4use proc_macro2::{Span, TokenStream};
5use quote::{quote, ToTokens};
6use syn::{Data, DeriveInput, Fields, Ident, LitStr, Meta};
7
8fn derive_beatmap_section2(input: TokenStream) -> TokenStream {
9    let input: DeriveInput = syn::parse2(input).unwrap();
10    struct FieldInfo {
11        kind: FieldKind,
12        aliases: Vec<TokenStream>,
13    }
14    enum FieldKind {
15        None,
16        FromI32,
17    }
18    let mut fields = HashMap::<Ident, FieldInfo>::new();
19    let name = input.ident;
20    let generics = input.generics;
21    let Data::Struct(data) = input.data else {
22        panic!("#[derive(BeatmapSection)] is only allowed on structs")
23    };
24    let Fields::Named(named) = data.fields else {
25        panic!("#[derive(BeatmapSection)] is only allowed on named fields")
26    };
27    let mut extra_handler = None;
28    for attr in input.attrs {
29        if attr.meta.path().to_token_stream().to_string() != "beatmap_section" {
30            continue;
31        }
32        match attr.meta {
33            Meta::List(x) => {
34                extra_handler = Some(x.tokens);
35            }
36            _ => panic!("#[derive(BeatmapSection)]: non-Meta::List"),
37        }
38    }
39    for field in named.named {
40        let mut kind = FieldKind::None;
41        let mut aliases = vec![];
42        for attr in field.attrs {
43            match attr.meta.path().to_token_stream().to_string().as_str() {
44                "from" => match attr.meta {
45                    Meta::List(x) => {
46                        let mut tokens = x.tokens.into_iter();
47                        let ty = tokens.next();
48                        if let Some(ty) = ty {
49                            let ty = ty.to_string();
50                            assert_eq!(ty, "i32");
51                            kind = FieldKind::FromI32;
52                            assert!(tokens.next().is_none());
53                        }
54                    }
55                    _ => panic!("#[derive(BeatmapSection)]: non-Meta::List"),
56                },
57                "alias" => match attr.meta {
58                    Meta::List(x) => {
59                        aliases.push(x.tokens);
60                    }
61                    _ => panic!("#[derive(BeatmapSection)]: non-Meta::List"),
62                },
63                "doc" => {}
64                x => panic!("{x}"),
65            }
66        }
67        assert!(fields
68            .insert(
69                field.ident.expect("field ident"),
70                FieldInfo {
71                    kind: kind,
72                    aliases
73                }
74            )
75            .is_none());
76    }
77    let mut match_fields = TokenStream::new();
78    let mut valid_fields = TokenStream::new();
79    for (name, info) in fields {
80        let name_camel = name
81            .to_string()
82            .split('_')
83            .map(|x| {
84                let mut first = true;
85                if matches!(x, "id" | "hp") {
86                    x.to_uppercase()
87                } else {
88                    x.chars()
89                        .map(|x| {
90                            if first {
91                                first = false;
92                                x.to_ascii_uppercase()
93                            } else {
94                                x
95                            }
96                        })
97                        .collect::<String>()
98                }
99            })
100            .collect::<Vec<_>>()
101            .join("");
102        let lit = syn::LitStr::new(&name_camel, proc_macro2::Span::call_site()).into_token_stream();
103        for lit in [lit].into_iter().chain(info.aliases) {
104            match info.kind {
105                FieldKind::None => {
106                    match_fields.extend(quote! {
107                        #lit => {
108                            self.#name = ParseField::parse_field(#lit, ctx, value)?;
109                            return Ok(None);
110                        }
111                    });
112                }
113                FieldKind::FromI32 => {
114                    match_fields.extend(quote! {
115                        #lit => {
116                            self.#name = {
117                                TryFrom::try_from(i32::parse_field(#lit, ctx, value)?)
118                                    .map_err(ParseError::curry(#lit, value.span()))?
119                            };
120                            return Ok(None);
121                        }
122                    });
123                }
124            }
125            valid_fields.extend(quote! { #lit, });
126        }
127    }
128    let name_str = LitStr::new(&(name.to_string() + " section"), Span::call_site());
129    let default_handler = quote! {
130        {
131            return Err(ParseError::curry(#name_str, line.span())(RecordParseError {
132                valid_fields: &[#valid_fields],
133            }));
134        }
135    };
136    let extra_handler = if let Some(handler) = extra_handler {
137        quote! {
138            {
139                if let Some(ret) = #handler(key) {
140                    return Ok(Some(ret));
141                } else #default_handler
142            }
143        }
144    } else {
145        default_handler
146    };
147    quote! {
148        impl<'a> BeatmapSection<'a> for #name #generics {
149            fn consume_line(
150                &mut self,
151                ctx: &Context,
152                line: impl StaticCow<'a>,
153            ) -> Result<Option<Section>, ParseError> {
154                if let Some((key, value)) = line.split_once(':') {
155                    let key = key.trim();
156                    let value = value.trim();
157                    match key.as_ref() {
158                        #match_fields
159                        _ => #extra_handler
160                    }
161                } else {
162                    Err(ParseError::curry(#name_str, line.span())(InvalidRecordField))
163                }
164            }
165        }
166    }
167}
168
169fn derive_beatmap_enum2(input: TokenStream) -> TokenStream {
170    let input: DeriveInput = syn::parse2(input).unwrap();
171    let name = input.ident;
172    let mut ignore_case = false;
173    let mut from_char = false;
174    for attr in input.attrs {
175        if !matches!(
176            attr.meta.path().to_token_stream().to_string().as_str(),
177            "beatmap_enum"
178        ) {
179            continue;
180        }
181        match attr.meta {
182            Meta::List(x) => {
183                let mut tokens = x.tokens.into_iter();
184                let attr = tokens.next().unwrap().to_string();
185                assert!(tokens.next().is_none());
186                match attr.as_str() {
187                    "ignore_case" => {
188                        ignore_case = true;
189                    }
190                    "from_char" => {
191                        from_char = true;
192                    }
193                    x => panic!("#[derive(BeatmapEnum)]: unexpected attr contents {x}"),
194                }
195            }
196            _ => panic!("#[derive(BeatmapEnum)]: non-Meta::List"),
197        }
198    }
199    let Data::Enum(data) = input.data else {
200        panic!("#[derive(BeatmapEnum)] is only allowed on enums")
201    };
202    let mut match_fields = TokenStream::new();
203    let mut int_match_fields = TokenStream::new();
204    let mut char_match_fields = TokenStream::new();
205    let mut reverse_char_match = TokenStream::new();
206    let mut reverse_match = TokenStream::new();
207    let mut valid_variants = TokenStream::new();
208    let mut valid_int_variants = TokenStream::new();
209    let mut valid_char_variants = TokenStream::new();
210    let name0 = &name;
211    for field in data.variants {
212        assert!(field.fields.is_empty());
213        let name = field.ident;
214        let mut name1 = name.to_string();
215        if ignore_case {
216            name1.make_ascii_lowercase();
217        }
218        let lit = syn::LitStr::new(&name1, proc_macro2::Span::call_site());
219        match_fields.extend(quote! {
220            #lit => Ok(Self::#name),
221        });
222        reverse_match.extend(quote! {
223            #name0::#name => #lit,
224        });
225        valid_variants.extend(quote! {
226            #lit,
227        });
228        let char_lit = syn::LitChar::new(
229            name1.chars().next().unwrap(),
230            proc_macro2::Span::call_site(),
231        );
232        char_match_fields.extend(quote! {
233            #char_lit => Ok(Self::#name),
234        });
235        reverse_char_match.extend(quote! {
236            #name0::#name => #char_lit,
237        });
238        valid_char_variants.extend(quote! {
239            #char_lit,
240        });
241        if let Some((_, discrim)) = field.discriminant {
242            int_match_fields.extend(quote! {
243                #discrim => Ok(Self::#name),
244            });
245            valid_int_variants.extend(quote! {
246                #discrim,
247            });
248            let str_discrim = discrim.into_token_stream().to_string();
249            let lit = syn::LitStr::new(&str_discrim, proc_macro2::Span::call_site());
250            match_fields.extend(quote! {
251                #lit => Ok(Self::#name),
252            });
253        }
254    }
255    let scrutinee = if ignore_case {
256        quote! {
257            s.to_lowercase().as_str()
258        }
259    } else {
260        quote! {
261            s
262        }
263    };
264    let mut extra = quote! {};
265    if from_char {
266        extra.extend(quote! {
267            impl std::convert::TryFrom<char> for #name {
268                type Error = CharEnumParseError;
269                fn try_from(x: char) -> Result<Self, Self::Error> {
270                    match x {
271                        #char_match_fields
272                        _ => Err(CharEnumParseError {
273                            variant: x,
274                            valid_variants: &[#valid_char_variants],
275                        }),
276                    }
277                }
278            }
279            impl From<#name> for char {
280                fn from(x: #name) -> Self {
281                    match x {
282                        #reverse_char_match
283                    }
284                }
285            }
286        });
287    }
288    quote! {
289        #extra
290        impl std::fmt::Display for #name {
291            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
292                f.write_str(match self {
293                    #reverse_match
294                })
295            }
296        }
297        impl std::convert::TryFrom<i32> for #name {
298            type Error = IntEnumParseError;
299            fn try_from(x: i32) -> Result<Self, Self::Error> {
300                match x {
301                    #int_match_fields
302                    _ => Err(IntEnumParseError {
303                        variant: x,
304                        valid_variants: &[#valid_int_variants],
305                    }),
306                }
307            }
308        }
309        impl std::str::FromStr for #name {
310            type Err = EnumParseError;
311            fn from_str(s: &str) -> Result<Self, Self::Err> {
312                match #scrutinee {
313                    #match_fields
314                    _ => Err(EnumParseError {
315                        valid_variants: &[#valid_variants],
316                    }),
317                }
318            }
319        }
320        impl<'a> ParseField<'a> for #name {
321            fn parse_field(
322                name: impl Into<Cow<'static, str>>,
323                _ctx: &Context,
324                line: impl StaticCow<'a>,
325            ) -> Result<Self, ParseError> {
326                line.as_ref().parse().map_err(ParseError::curry(name, line.span()))
327            }
328        }
329    }
330}
331
332#[proc_macro_derive(BeatmapSection, attributes(from, alias, beatmap_section))]
333pub fn derive_beatmap_section(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
334    derive_beatmap_section2(input.into()).into()
335}
336
337#[proc_macro_derive(BeatmapEnum, attributes(beatmap_enum))]
338pub fn derive_beatmap_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
339    derive_beatmap_enum2(input.into()).into()
340}