beamcode_derive/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::spanned::Spanned;
4use syn::{parse_macro_input, Data, DeriveInput, Fields};
5
6#[proc_macro_derive(Opcode, attributes(opcode))]
7pub fn derive_opcode_trait(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9    let name = input.ident;
10    let code = &input.attrs.last().expect("missing `#[opcode(N)]`").tokens;
11    let expanded = quote! {
12        impl crate::instruction::Opcode for #name {
13            const CODE: u8 = #code;
14        }
15    };
16    proc_macro::TokenStream::from(expanded)
17}
18
19#[proc_macro_derive(Decode)]
20pub fn derive_decode_trait(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
21    let input = parse_macro_input!(input as DeriveInput);
22    let name = input.ident;
23    let decode = generate_decode_fun_body(&input.data);
24    let expanded = quote! {
25        impl crate::Decode for #name {
26            fn decode_with_tag<R: std::io::Read>(reader: &mut R, tag: u8) -> Result<Self, crate::DecodeError> {
27                #decode
28            }
29        }
30    };
31    proc_macro::TokenStream::from(expanded)
32}
33
34fn generate_decode_fun_body(data: &Data) -> TokenStream {
35    match *data {
36        Data::Enum(ref data) => {
37            let arms = data.variants.iter().map(|variant| {
38                let name = &variant.ident;
39                let op =
40                    if let Fields::Unnamed(fields) = &variant.fields {
41                        assert_eq!(fields.unnamed.len(), 1);
42                        &fields.unnamed.iter().next().expect("unreachable").ty
43                    } else {
44                        unimplemented!()
45                    };
46                quote_spanned! { variant.span() => #op::CODE => crate::Decode::decode_with_tag(reader, tag).map(Self::#name), }
47            });
48            quote! {
49                match tag {
50                    #(#arms)*
51                    opcode => Err(crate::DecodeError::UnknownOpcode{ opcode })
52                }
53            }
54        }
55        Data::Struct(ref data) => match data.fields {
56            Fields::Named(ref fields) => {
57                let decode = fields.named.iter().map(|f| {
58                    let name = &f.ident;
59                    quote_spanned! { f.span() => #name: crate::Decode::decode(reader)? }
60                });
61                quote! {
62                    if tag != Self::CODE {
63                        return Err(crate::DecodeError::UnknownOpcode{ opcode: tag });
64                    }
65                    Ok(Self{
66                        #(#decode ,)*
67                    })
68                }
69            }
70            _ => unimplemented!(),
71        },
72        _ => unimplemented!(),
73    }
74}
75
76#[proc_macro_derive(Encode)]
77pub fn derive_encode_trait(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
78    let input = parse_macro_input!(input as DeriveInput);
79    let name = input.ident;
80    let encode = generate_encode_fun_body(&input.data);
81    let expanded = quote! {
82        impl crate::Encode for #name {
83            fn encode<W: std::io::Write>(&self, writer: &mut W) -> Result<(), crate::EncodeError> {
84                #encode
85            }
86        }
87    };
88    proc_macro::TokenStream::from(expanded)
89}
90
91fn generate_encode_fun_body(data: &Data) -> TokenStream {
92    match *data {
93        Data::Enum(ref data) => {
94            let arms = data.variants.iter().map(|variant| {
95                let name = &variant.ident;
96                if let Fields::Unnamed(fields) = &variant.fields {
97                    assert_eq!(fields.unnamed.len(), 1);
98                } else {
99                    unimplemented!();
100                }
101                quote_spanned! { variant.span() => Self::#name(x) => x.encode(writer), }
102            });
103            quote! {
104                match self {
105                    #(#arms)*
106                }
107            }
108        }
109        Data::Struct(ref data) => match data.fields {
110            Fields::Named(ref fields) => {
111                let encode = fields.named.iter().map(|f| {
112                    let name = &f.ident;
113                    quote_spanned! { f.span() => self.#name.encode(writer)? }
114                });
115                quote! {
116                    writer.write_all(&[Self::CODE])?;
117                    #(#encode ;)*;
118                    Ok(())
119                }
120            }
121            _ => unimplemented!(),
122        },
123        _ => unimplemented!(),
124    }
125}