midi_toolkit_rs_derive/
lib.rs

1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
4use proc_macro_error::{abort_call_site, proc_macro_error, ResultExt};
5use quote::{quote, ToTokens};
6use syn::{self, ext::IdentExt, Attribute, DataEnum, DataStruct, DeriveInput, Fields, Variant};
7
8fn has_attr(attrs: &[Attribute], name: &str) -> bool {
9    attrs.iter().any(|a| match a.path.get_ident() {
10        None => false,
11        Some(ident) => ident.unraw().to_string().eq(name),
12    })
13}
14
15fn find_attr_fields<'a>(fields: &'a Fields, name: &str) -> Option<&'a Ident> {
16    let fields = fields
17        .iter()
18        .filter(|f| has_attr(&f.attrs, name))
19        .collect::<Vec<_>>();
20    match fields.len() {
21        0 => None,
22        1 => fields[0].ident.as_ref(),
23        _ => abort_call_site!(format!("Multiple fields found with attribute #[{name}]")),
24    }
25}
26
27#[proc_macro_derive(MIDIEvent, attributes(key, channel, playback))]
28#[proc_macro_error]
29pub fn midi_event(input: TokenStream) -> TokenStream {
30    let ast: DeriveInput = syn::parse(input).expect_or_abort("Couldn't parse for MIDIEvent");
31
32    let name = &ast.ident;
33    let generics = &ast.generics;
34    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
35
36    // Is it a struct?
37    if let syn::Data::Struct(DataStruct { ref fields, .. }) = ast.data {
38        let key_field = find_attr_fields(fields, "key");
39        let channel_field = find_attr_fields(fields, "channel");
40
41        let playback_event = has_attr(&ast.attrs, "playback");
42
43        if key_field.is_some() && channel_field.is_none() {
44            abort_call_site!(
45                "Key events must also have a channel (use #[channel] along with #[key])!"
46            );
47        }
48
49        let mut generated_impl = Vec::new();
50        let mut generated_trait_impl = Vec::new();
51        let mut generated_traits = Vec::new();
52
53        if playback_event {
54            generated_impl.push(quote! {
55                #[inline(always)]
56                pub fn as_u32(&self) -> u32 {
57                    PlaybackEvent::as_u32(self)
58                }
59            });
60
61            generated_trait_impl.push(quote! {
62                #[inline(always)]
63                fn as_u32(&self) -> Option<u32> {
64                    Some(PlaybackEvent::as_u32(self))
65                }
66            });
67        } else {
68            generated_trait_impl.push(quote! {
69                #[inline(always)]
70                fn as_u32(&self) -> Option<u32> {
71                    None
72                }
73            });
74        }
75
76        match key_field {
77            None => {
78                generated_trait_impl.push(quote! {
79                    #[inline(always)]
80                    fn key(&self) -> Option<u8> {
81                        None
82                    }
83
84                    #[inline(always)]
85                    fn key_mut(&mut self) -> Option<&mut u8> {
86                        None
87                    }
88                });
89            }
90            Some(ident) => {
91                generated_impl.push(quote! {
92                    #[inline(always)]
93                    pub fn key(&self) -> u8 {
94                        self.#ident
95                    }
96
97                    #[inline(always)]
98                    pub fn key_mut(&mut self) -> &mut u8 {
99                        &mut self.#ident
100                    }
101                });
102
103                generated_trait_impl.push(quote! {
104                    #[inline(always)]
105                    fn key(&self) -> Option<u8> {
106                        Some(self.#ident)
107                    }
108
109                    #[inline(always)]
110                    fn key_mut(&mut self) -> Option<&mut u8> {
111                        Some(&mut self.#ident)
112                    }
113                });
114
115                generated_traits.push(quote! {
116                    impl #impl_generics KeyEvent for #name #ty_generics #where_clause {
117                        #[inline(always)]
118                        fn key(&self) -> u8 {
119                            self.#ident
120                        }
121
122                        #[inline(always)]
123                        fn key_mut(&mut self) -> &mut u8 {
124                            &mut self.#ident
125                        }
126                    }
127                });
128            }
129        }
130
131        match channel_field {
132            None => {
133                generated_trait_impl.push(quote! {
134                    #[inline(always)]
135                    fn channel(&self) -> Option<u8> {
136                        None
137                    }
138
139                    #[inline(always)]
140                    fn channel_mut(&mut self) -> Option<&mut u8> {
141                        None
142                    }
143                });
144            }
145            Some(ident) => {
146                generated_impl.push(quote! {
147                    #[inline(always)]
148                    pub fn channel(&self) -> u8 {
149                        self.#ident
150                    }
151
152                    #[inline(always)]
153                    pub fn channel_mut(&mut self) -> &mut u8 {
154                        &mut self.#ident
155                    }
156                });
157
158                generated_trait_impl.push(quote! {
159                    #[inline(always)]
160                    fn channel(&self) -> Option<u8> {
161                        Some(self.#ident)
162                    }
163
164                    #[inline(always)]
165                    fn channel_mut(&mut self) -> Option<&mut u8> {
166                        Some(&mut self.#ident)
167                    }
168                });
169
170                generated_traits.push(quote! {
171                    impl #impl_generics ChannelEvent for #name #ty_generics #where_clause {
172                        #[inline(always)]
173                        fn channel(&self) -> u8 {
174                            self.#ident
175                        }
176
177                        #[inline(always)]
178                        fn channel_mut(&mut self) -> &mut u8 {
179                            &mut self.#ident
180                        }
181                    }
182                });
183            }
184        }
185
186        let gen = quote! {
187            #(#generated_traits)*
188
189            impl MIDIEvent for #name #where_clause {
190                #(#generated_trait_impl)*
191            }
192
193            impl #impl_generics #name #ty_generics #where_clause {
194                #(#generated_impl)*
195            }
196        };
197
198        gen.into()
199    } else {
200        // Nope. This is an Enum.
201        abort_call_site!("#[derive(MIDIEvent)] is only defined for structs, not for enums!");
202    }
203}
204
205fn event_enum_from_struct(name: &Ident) -> Ident {
206    let event_name = name.unraw().to_string();
207    let event_name = &event_name[..event_name.len() - 5];
208    Ident::new(event_name, name.span())
209}
210
211fn event_struct_from_enum(name: &Ident) -> Ident {
212    let event_name = name.unraw().to_string();
213    let event_name = event_name + "Event";
214    Ident::new(&event_name[..], name.span())
215}
216
217#[proc_macro_derive(NewEvent)]
218#[proc_macro_error]
219pub fn create_new_event(input: TokenStream) -> TokenStream {
220    let ast: DeriveInput = syn::parse(input).expect_or_abort("Couldn't parse for NewEvent");
221
222    let name = &ast.ident;
223    let generics = &ast.generics;
224    let (impl_generics, _ty_generics, where_clause) = generics.split_for_impl();
225
226    // Is it a struct?
227    if let syn::Data::Struct(DataStruct { ref fields, .. }) = ast.data {
228        let mut new_args = Vec::new();
229        let mut assign = Vec::new();
230
231        let event_ident = event_enum_from_struct(name);
232        let snake_case = name.unraw().to_string()[..].to_case(Case::Snake);
233        let new_ident = Ident::new(&format!("new_{snake_case}")[..], Span::call_site());
234        let new_delta_ident = Ident::new(&format!("new_delta_{snake_case}")[..], Span::call_site());
235
236        let doc_str = &format!("Creates a new `{name}`.");
237        let doc_str2 = &format!(
238            "Creates a new [`{name}`](crate::events::{name}) wrapped in [`Event::{ident}`](crate::events::Event::{ident}).",
239            ident = event_ident.unraw(),
240        );
241        let doc_str2_delta = &format!(
242            "Creates a new [`{name}`](crate::events::{name}) wrapped in [`Event::{ident}`](crate::events::Event::{ident}).",
243            ident = event_ident.unraw(),
244        );
245
246        for field in fields.iter() {
247            match &field.ident {
248                None => {}
249                Some(ident) => {
250                    let ty = &field.ty;
251                    new_args.push(quote! {#ident: #ty,});
252                    assign.push(quote! {#ident,});
253                }
254            }
255        }
256
257        let gen = quote! {
258            impl #impl_generics #name #where_clause {
259                #[doc=#doc_str]
260                #[inline(always)]
261                pub fn new(#(#new_args)*) -> Self {
262                    Self {
263                        #(#assign)*
264                    }
265                }
266            }
267
268            impl Event {
269                #[doc=#doc_str2]
270                #[inline(always)]
271                pub fn #new_ident(#(#new_args)*) -> Event {
272                    (#name :: new(#(#assign)*)).as_event()
273                }
274
275                #[doc=#doc_str2_delta]
276                #[inline(always)]
277                pub fn #new_delta_ident<D: MIDINum>(delta: D, #(#new_args)*) -> Delta<D, Event> {
278                    Delta::new(delta, (#name :: new(#(#assign)*)).as_event())
279                }
280            }
281        };
282
283        gen.into()
284    } else {
285        // Nope. This is an Enum.
286        abort_call_site!("#[derive(MIDIEvent)] is only defined for structs, not for enums!");
287    }
288}
289
290#[proc_macro_derive(EventImpl, attributes(channel, key, playback, tempo))]
291#[proc_macro_error]
292pub fn event_impl(input: TokenStream) -> TokenStream {
293    let ast: DeriveInput = syn::parse(input).expect_or_abort("Couldn't parse for EventImpl");
294
295    let name = &ast.ident;
296    let generics = &ast.generics;
297    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
298
299    // Is it a struct?
300    if let syn::Data::Enum(DataEnum { variants, .. }) = ast.data {
301        fn has_attr(v: &Variant, name: &str) -> bool {
302            v.attrs.iter().any(|a| match a.path.get_ident() {
303                None => false,
304                Some(ident) => ident.unraw().to_string().eq(name),
305            })
306        }
307        fn is_key(v: &Variant) -> bool {
308            has_attr(v, "key")
309        }
310        fn is_channel(v: &Variant) -> bool {
311            has_attr(v, "channel")
312        }
313        fn is_playback(v: &Variant) -> bool {
314            has_attr(v, "playback")
315        }
316
317        fn match_all(lines: Vec<TokenStream2>) -> TokenStream2 {
318            quote! {
319                match self {
320                    #(#lines)*
321                }
322            }
323        }
324
325        fn make_match_line(ident: &Ident, res: TokenStream2) -> TokenStream2 {
326            quote! {
327                Event::#ident(event) => #res,
328            }
329        }
330
331        fn is_boxed(variant: &Variant) -> bool {
332            let field = variant.fields.iter().next().unwrap();
333            let mut tokens = TokenStream2::new();
334            field.ty.to_tokens(&mut tokens);
335            tokens.to_string().starts_with("Box <")
336        }
337
338        trait Mapper {
339            fn wrap(&self, tokens: TokenStream2) -> TokenStream2;
340        }
341
342        struct WrapSome;
343        impl Mapper for WrapSome {
344            fn wrap(&self, tokens: TokenStream2) -> TokenStream2 {
345                quote! { Some(#tokens) }
346            }
347        }
348
349        struct DontWrap;
350        impl Mapper for DontWrap {
351            fn wrap(&self, _: TokenStream2) -> TokenStream2 {
352                quote! { None }
353            }
354        }
355
356        struct Mappers<'a> {
357            wrap_key: &'a dyn Mapper,
358            wrap_channel: &'a dyn Mapper,
359            wrap_playback: &'a dyn Mapper,
360        }
361
362        fn create_match<T: Fn(&Variant, Mappers) -> TokenStream2>(
363            variants: &[&Variant],
364            map: T,
365        ) -> TokenStream2 {
366            let wrap_some: Box<dyn Mapper> = Box::new(WrapSome);
367            let dont_wrap: Box<dyn Mapper> = Box::new(DontWrap);
368
369            match_all(
370                variants
371                    .iter()
372                    .map(|v| {
373                        map(
374                            v,
375                            Mappers {
376                                wrap_key: if is_key(v) { &*wrap_some } else { &*dont_wrap },
377                                wrap_channel: if is_channel(v) {
378                                    &*wrap_some
379                                } else {
380                                    &*dont_wrap
381                                },
382                                wrap_playback: if is_playback(v) {
383                                    &*wrap_some
384                                } else {
385                                    &*dont_wrap
386                                },
387                            },
388                        )
389                    })
390                    .collect::<Vec<_>>(),
391            )
392        }
393
394        let variants = variants.iter().collect::<Vec<_>>();
395
396        let clone_match = create_match(&variants, |v, _mappers| {
397            let ident = &v.ident;
398            if is_boxed(v) {
399                make_match_line(
400                    ident,
401                    quote! { Event::#ident(Box::new(event.as_ref().clone())) },
402                )
403            } else {
404                make_match_line(ident, quote! { Event::#ident(event.clone()) })
405            }
406        });
407
408        macro_rules! make_map {
409            ($res:expr) => {
410                create_match(&variants, |v, _| make_match_line(&v.ident, $res))
411            };
412            (key, $res:expr) => {
413                create_match(&variants, |v, Mappers { wrap_key, .. }| {
414                    make_match_line(&v.ident, wrap_key.wrap($res))
415                })
416            };
417            (channel, $res:expr) => {
418                create_match(&variants, |v, Mappers { wrap_channel, .. }| {
419                    make_match_line(&v.ident, wrap_channel.wrap($res))
420                })
421            };
422            (playback, $res:expr) => {
423                create_match(&variants, |v, Mappers { wrap_playback, .. }| {
424                    make_match_line(&v.ident, wrap_playback.wrap($res))
425                })
426            };
427        }
428
429        let key = make_map!(key, quote! { event.key() });
430        let key_mut = make_map!(key, quote! { event.key_mut() });
431        let channel = make_map!(channel, quote! { event.channel() });
432        let channel_mut = make_map!(channel, quote! { event.channel_mut() });
433        let as_u32 = make_map!(playback, quote! { event.as_u32() });
434
435        let serialize_event = make_map!(quote! { event.serialize_event(buf) });
436
437        let mut event_wrap_impl = Vec::new();
438        for variant in variants.iter() {
439            let ident = &variant.ident;
440            let struct_ident = event_struct_from_enum(ident);
441            let doc_str = &format!(
442                "Wraps the `{}` in a `Event::{}`.",
443                struct_ident.unraw(),
444                ident.unraw()
445            );
446            if is_boxed(variant) {
447                event_wrap_impl.push(quote! {
448                    impl #impl_generics #struct_ident #ty_generics {
449                        #[doc=#doc_str]
450                        #[inline(always)]
451                        pub fn as_event(self) -> #name #ty_generics {
452                            #name::#ident(Box::new(self))
453                        }
454                    }
455                });
456            } else {
457                event_wrap_impl.push(quote! {
458                    impl #impl_generics #struct_ident #ty_generics {
459                        #[doc=#doc_str]
460                        #[inline(always)]
461                        pub fn as_event(self) -> #name #ty_generics {
462                            #name::#ident(self)
463                        }
464                    }
465                });
466            }
467        }
468
469        let gen = quote! {
470            impl #impl_generics Clone for #name #ty_generics #where_clause {
471                #[inline(always)]
472                fn clone(&self) -> #name #ty_generics {
473                    #clone_match
474                }
475            }
476
477            impl#impl_generics MIDIEvent #ty_generics for #name #ty_generics #where_clause {
478                #[inline(always)]
479                fn key(&self) -> Option<u8> {
480                    #key
481                }
482
483                #[inline(always)]
484                fn key_mut(&mut self) -> Option<&mut u8> {
485                    #key_mut
486                }
487
488                #[inline(always)]
489                fn channel(&self) -> Option<u8> {
490                    #channel
491                }
492
493                #[inline(always)]
494                fn channel_mut(&mut self) -> Option<&mut u8> {
495                    #channel_mut
496                }
497
498                #[inline(always)]
499                fn as_u32(&self) -> Option<u32> {
500                    #as_u32
501                }
502            }
503
504            #(#event_wrap_impl)*
505
506            impl SerializeEvent for Event {
507                #[inline(always)]
508                fn serialize_event<T: Write>(&self, buf: &mut T) -> Result<usize, MIDIWriteError> {
509                    #serialize_event
510                }
511            }
512        };
513
514        gen.into()
515    } else {
516        // Nope. This is an Enum.
517        abort_call_site!("#[derive(MIDIEvent)] is only defined for structs, not for enums!");
518    }
519}