bitsong_macros/
lib.rs

1#![allow(unused_imports)]
2use std::collections::HashMap;
3
4use proc_macro::TokenStream;
5use quote::ToTokens;
6use syn::{parse::{self, Parse}, punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, DataEnum, DataStruct, Error, Field, Fields, Ident, ItemEnum, ItemStruct, Meta, Token, Variant, Visibility};
7
8use syn::{parse::Parser, parse_macro_input, DeriveInput};
9use quote::quote;
10use proc_macro2::{Delimiter, Span, TokenStream as TokenStream2, TokenTree};
11use anyhow::bail;
12
13// Note, this won't work in downstream crates.
14#[proc_macro_derive(SpiError)]
15pub fn derive_spi_error(item: TokenStream) -> TokenStream {
16    let item: DeriveInput = parse_macro_input!(item);
17    let name = item.ident;
18    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
19
20    quote! {
21        impl #impl_generics embedded_hal_async::spi::ErrorType for #name #ty_generics #where_clause {
22            type Error = crate::spi::SpiError;
23        }
24    }.into()
25}
26
27// TODO: maybe do something about this code duplication...
28
29#[proc_macro_derive(SongSize, attributes(song))]
30pub fn derive_song_size(tok: TokenStream) -> TokenStream {
31    let tok1 = tok.clone();
32    let item: DeriveInput = parse_macro_input!(tok1);
33
34    match item.data {
35        syn::Data::Struct(ref s) => derive_song_size_struct(&item, s).unwrap().into(),
36        syn::Data::Enum(_) => {
37            let enum_song: EnumSong = parse_macro_input!(tok);
38            derive_song_size_enum(&enum_song).unwrap().into()
39        },
40        _ => todo!()
41    }
42}
43
44#[proc_macro_derive(ToSong, attributes(song))]
45pub fn derive_to_song(tok: TokenStream) -> TokenStream {
46    let tok1 = tok.clone();
47    let item: DeriveInput = parse_macro_input!(tok1);
48
49    match item.data {
50        syn::Data::Struct(ref s) => derive_to_song_struct(&item, &s).unwrap().into(),
51        syn::Data::Enum(_) => {
52            let enum_song: EnumSong = parse_macro_input!(tok);
53            derive_to_song_enum(&enum_song).unwrap().into()
54        },
55        _ => todo!()
56    }
57}
58
59#[proc_macro_derive(FromSong, attributes(song))]
60pub fn derive_from_song(tok: TokenStream) -> TokenStream {
61    let tok1 = tok.clone();
62    let item: DeriveInput = parse_macro_input!(tok1);
63
64    match item.data {
65        syn::Data::Struct(ref s) => derive_from_song_struct(&item, s).unwrap().into(),
66        syn::Data::Enum(_) => {
67            let enum_song: EnumSong = parse_macro_input!(tok);
68            derive_from_song_enum(&enum_song).unwrap().into()
69        },
70        _ => todo!()
71    }
72}
73
74enum EnumSongDisc {
75    Enum {
76        disc_name: Ident,
77        disc_type: Ident,
78    },
79    Repr {
80        ty: Ident
81    }
82}
83
84struct EnumSong {
85    item: ItemEnum,
86    disc: EnumSongDisc
87}
88
89impl Parse for EnumSong {
90    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
91        let item: ItemEnum = input.parse()?;
92
93        let attrs: Vec<TokenStream2> = item.attrs.iter().filter_map(|a| {
94            if a.path().get_ident()?.to_string() != "song" {
95                return None;
96            }
97    
98            let Meta::List(ref list) = a.meta else {
99                return None
100            };
101    
102            Some(list.tokens.clone())
103        }).collect();
104
105        let repr = item.attrs.iter()
106            .filter(|attr| attr.path().to_token_stream().to_string() == "repr")
107            .filter_map(|attr| attr.parse_args::<Ident>().ok())
108            .next();
109
110        if let Some(repr) = repr {
111            if attrs.len() != 0 {
112                return Err(Error::new(
113                    attrs.get(0).map(|i| i.span()).unwrap_or_else(|| item.span()),
114                    "Expected no attrs"
115                ));                
116            }
117
118            return Ok(EnumSong {
119                item,
120                disc: EnumSongDisc::Repr {
121                    ty: repr
122                }
123            })
124        }
125
126        if attrs.len() != 1 {
127            return Err(Error::new(
128                attrs.get(0).map(|i| i.span()).unwrap_or_else(|| item.span()),
129                "Expected 1 attr -- discriminant"
130            ));
131        }
132
133        let mut attr = attrs[0].clone().into_iter();
134
135        let disc_ident: Ident = syn::parse2(attr.next().unwrap().to_token_stream())?;
136        if disc_ident.to_string() != "discriminant" {
137            return Err(Error::new_spanned(disc_ident, "Expected 'discriminant'"));
138        }
139
140        let group: proc_macro2::Group = syn::parse2(attr.next().unwrap().to_token_stream())?;
141        if group.delimiter() != Delimiter::Parenthesis {
142            return Err(Error::new(group.delim_span().span(), "Expected parens"));
143        }
144
145        let mut iter = group.stream().into_token_stream().into_iter();
146        let disc_name: Ident = syn::parse2(iter.next().unwrap().into_token_stream())?;
147        let _: Token![=] = syn::parse2(iter.next().unwrap().into_token_stream())?;
148        let disc_type: Ident = syn::parse2(iter.next().unwrap().into_token_stream())?;
149
150        /*if disc_type.to_string() != "u8" {
151            return Err(Error::new(disc_type.span(), "Only u8 supported rn"));
152        }*/
153
154        Ok(EnumSong {
155            item,
156            disc: EnumSongDisc::Enum {
157                disc_name,
158                disc_type,
159            }
160        })
161    }
162}
163
164fn derive_song_size_struct(item: &DeriveInput, data: &DataStruct) -> Result<TokenStream2, anyhow::Error> {
165    let mut size_out = vec![];
166
167    let mut has_song_size = quote! {()};
168
169    for field in &data.fields {
170        let Field { ident, ty, .. } = field;
171
172        size_out.push(quote! {
173            self.#ident.song_size()
174        });
175
176        has_song_size = quote! { (#has_song_size, ConstSongSizeImplFromConstSongSize<#ty>) }
177    }
178
179    let ident = &item.ident;
180    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
181    
182    Ok(quote! {
183        #[automatically_derived]
184        impl #impl_generics HasSongSize for #ident #ty_generics #where_clause {
185            type Size = #has_song_size;
186        }
187
188        #[automatically_derived]
189        impl #impl_generics SongSize for #ident #ty_generics #where_clause {
190            fn song_size(&self) -> usize {
191                0 #( + #size_out)*
192            }
193        }
194    })
195}
196
197fn derive_to_song_struct(item: &DeriveInput, data: &DataStruct) -> Result<TokenStream2, anyhow::Error> {
198    let mut fields_out = vec![];
199
200    for field in &data.fields {
201        let ident = &field.ident;
202
203        fields_out.push(quote! {
204            self.#ident.to_song(&mut buf[i..])?;
205            i += self.#ident.song_size();
206        });
207    }
208
209    let ident = &item.ident;
210    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
211
212    Ok(quote! {
213        impl #impl_generics ToSong for #ident #ty_generics #where_clause {
214            fn to_song(&self, buf: &mut [u8]) -> Result<(), ToSongError> {
215                let size = self.song_size();
216                if buf.len() >= size {
217                    let mut i = 0;
218                    #(#fields_out)*
219                    Ok(())
220                } else {
221                    Err(ToSongError::BufferOverflow)
222                }
223            }
224        }
225    })
226}
227
228fn derive_from_song_struct(item: &DeriveInput, data: &DataStruct) -> Result<TokenStream2, anyhow::Error> {
229    let mut from_song_out = vec![];
230    
231    for field in &data.fields {
232        let Field { ident, ty, .. } = field;
233
234        from_song_out.push(quote! {
235            #ident: {
236                let value = <#ty as FromSong>::from_song(&buf[i..])?;
237                i += value.song_size();
238                value
239            }
240        });
241    }
242
243    let ident = &item.ident;
244    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
245
246    Ok(quote! {
247        impl #impl_generics FromSong for #ident #ty_generics #where_clause {
248            fn from_song(buf: &[u8]) -> Result<Self, FromSongError> {
249                let mut i = 0;
250                Ok(
251                    #ident {
252                        #(#from_song_out,)*
253                    }
254                )
255            }
256        }
257    })
258}
259
260fn derive_song_size_enum(enum_song: &EnumSong) -> syn::Result<TokenStream2> {
261    let ident = &enum_song.item.ident;
262    
263    match &enum_song.disc {
264        EnumSongDisc::Repr { .. } => {
265            Ok(quote! {
266                #[automatically_derived]
267                impl HasSongSize for #ident {
268                    type Size = ConstSongSizeValue<{ core::mem::size_of::<Self>() }>;
269                }
270
271                #[automatically_derived]
272                impl SongSize for #ident {
273                    fn song_size(&self) -> usize {
274                        core::mem::size_of::<Self>()
275                    }
276                }
277            })
278        },
279        EnumSongDisc::Enum { disc_name, disc_type } => {
280            let mut out = vec![];
281            let mut disc_out = vec![];
282            let mut song_disc_out = vec![];
283
284            for var in &enum_song.item.variants {
285                let disc_var_ident = &var.ident;
286
287                match &var.discriminant {
288                    Some((eq, val)) => {
289                        disc_out.push(quote! {
290                            #disc_var_ident #eq #val
291                        });
292                    }
293                    None => {
294                        disc_out.push(quote! {
295                            #disc_var_ident
296                        });
297                    }
298                }
299
300                let mut idents = vec![];
301
302                let mut i = 0;
303                for field in &var.fields {
304                    idents.push(
305                        field.ident.clone().unwrap_or_else(
306                            || Ident::new(&format!("t{}", i).to_string(), field.span())
307                        )
308                    );
309                    i += 1;
310                }
311
312                let mut fields_out = vec![];
313
314                for ident in &idents {
315                    fields_out.push(quote! {
316                        i += #ident.song_size();
317                    });
318                }
319
320                let ident = &var.ident;
321
322                let destructure = match &var.fields {
323                    Fields::Unit => quote!(),
324                    Fields::Named(_) => quote!({ #(#idents),* }),
325                    Fields::Unnamed(_) => quote!(( #(#idents),* ))
326                };
327
328                out.push(quote! {
329                    Self::#ident #destructure => {
330                        let mut i = core::mem::size_of::<#disc_type>();
331                        #(#fields_out)*
332                        i
333                    }
334                });
335
336                song_disc_out.push(quote! {
337                    Self::#ident #destructure => #disc_name::#disc_var_ident
338                });
339            }
340
341            let vis = &enum_song.item.vis;
342
343            Ok(quote! {
344                #[derive(Clone, Copy, PartialEq, Eq, Debug, SongSize, ToSong, FromSong)]
345                #[repr(#disc_type)]
346                #vis enum #disc_name {
347                    #(#disc_out),*
348                }
349
350                #[automatically_derived]
351                impl SongDiscriminant for #ident {
352                    type Discriminant = #disc_name;
353
354                    fn song_discriminant(&self) -> Self::Discriminant {
355                        match self {
356                            #(#song_disc_out),*
357                        }
358                    }
359                }
360
361                #[automatically_derived]
362                impl SongSize for #ident {
363                    fn song_size(&self) -> usize {
364                        match self {
365                            #(#out),*
366                        }
367                    }
368                }
369            })
370        }
371    }
372}
373
374fn derive_to_song_enum(enum_song: &EnumSong) -> syn::Result<TokenStream2> {
375    let ident = &enum_song.item.ident;
376
377    match &enum_song.disc {
378        EnumSongDisc::Repr { ty } => {
379            Ok(quote! {
380                #[automatically_derived]
381                impl ToSong for #ident {
382                    fn to_song(&self, buf: &mut [u8]) -> Result<(), ToSongError> {
383                        (*self as #ty).to_song(buf)
384                    }
385                }
386            })
387        },
388        EnumSongDisc::Enum { disc_name, disc_type } => {
389            let mut out = vec![];
390
391            for var in &enum_song.item.variants {
392                let mut idents = vec![];
393
394                let mut i = 0;
395                for field in &var.fields {
396                    // println!("Debug 100: {:?}", field);
397                    idents.push(
398                        field.ident.clone().unwrap_or_else(
399                            || Ident::new(&format!("t{}", i).to_string(), field.span())
400                        )
401                    );
402                    i += 1;
403                }
404
405                let mut fields_out = vec![];
406
407                for ident in &idents {
408                    fields_out.push(quote! {
409                        #ident.to_song(&mut buf[i..])?;
410                        i += #ident.song_size();
411                    });
412                }
413
414                let ident = &var.ident;
415
416                let destructure = match &var.fields {
417                    Fields::Unit => quote!(),
418                    Fields::Named(_) => quote!({ #(#idents),* }),
419                    Fields::Unnamed(_) => quote!(( #(#idents),* ))
420                };
421
422                out.push(quote! {
423                    Self::#ident #destructure => {
424                        (#disc_name::#ident as #disc_type).to_song(buf)?;
425                        let mut i = core::mem::size_of::<#disc_type>();
426                        #(#fields_out)*
427                        Ok(())
428                    }
429                });
430            }
431            
432            Ok(quote! {
433                impl ToSong for #ident {
434                    fn to_song(&self, buf: &mut [u8]) -> Result<(), ToSongError> {
435                        match self {
436                            #(#out),*
437                        }
438                    }
439                }
440            })
441        }
442    }
443}
444
445fn derive_from_song_enum(enum_song: &EnumSong) -> syn::Result<TokenStream2> {
446    let ident = &enum_song.item.ident;
447    
448    match &enum_song.disc {
449        EnumSongDisc::Repr { ty } => {
450            let var = enum_song.item.variants.iter().map(|var| &var.ident);
451
452            Ok(quote! {
453                impl FromSong for #ident {
454                    fn from_song(buf: &[u8]) -> Result<Self, FromSongError> {
455                        if buf.len() < core::mem::size_of::<#ty>() {
456                            return Err(FromSongError::BufferOverflow);
457                        }
458
459                        match buf[0] {
460                            #(disc if disc == #ident::#var as #ty => Ok(#ident::#var),)*
461                            _ => Err(FromSongError::InvalidPacketId)
462                        }
463                    }
464                }
465            })
466        }
467        EnumSongDisc::Enum { disc_name, disc_type } => {
468            let mut out = vec![];
469            for var in &enum_song.item.variants {
470
471                let mut fields_out = vec![];
472                let mut idents = vec![];
473                
474                let mut i = 0;
475                for field in &var.fields {
476                    let ident = field.ident.clone().unwrap_or_else(
477                        || Ident::new(&format!("t{}", i).to_string(), field.span()));
478                    idents.push(ident.clone());
479                    
480                    let typ = field.ty.clone();
481                    
482                    fields_out.push(quote! {
483                        let #ident = {
484                            // avoid turbofish problem
485                            type Ty = #typ;
486                            Ty::from_song(&buf[i..])?
487                        };
488
489                        i += #ident.song_size();
490                    });
491                    i += 1;
492                }   
493
494                let var_ident = &var.ident;
495
496
497                match &var.fields {
498                    Fields::Unit => out.push(quote! {
499                            val if val == #disc_name::#var_ident as #disc_type => Ok(#ident::#var_ident)
500                        }),
501                    Fields::Unnamed(_) => out.push(quote! { 
502                        val if val == #disc_name::#var_ident as #disc_type => {
503                            let mut i = 1;
504                            #(#fields_out)*
505                            Ok(#ident::#var_ident(#(#idents,)*))
506                        }}),
507                    Fields::Named(_) => todo!()
508                }
509            }
510            
511            Ok(quote! {
512                impl FromSong for #ident {
513                    fn from_song(buf: &[u8]) -> Result<Self, FromSongError> {
514                        let Some(&disc) = buf.get(0) else {
515                            return Err(FromSongError::BufferOverflow)
516                        };
517
518                        match disc {
519                            #(#out,)*
520                            _ => Err(FromSongError::InvalidPacketId)
521                        }
522                    }
523                }
524            })
525        }
526    }
527}