recordkeeper_macros/
lib.rs

1use proc_macro2::TokenStream;
2
3use quote::{quote, ToTokens};
4use syn::punctuated::Punctuated;
5use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Expr, Field, Meta, Token};
6
7struct FieldVisitor<'ast> {
8    field: &'ast Field,
9    location: Option<TokenStream>,
10    assert_value: Option<TokenStream>,
11    assert_error: Option<TokenStream>,
12}
13
14impl<'ast> FieldVisitor<'ast> {
15    fn parser_tokens(&self) -> TokenStream {
16        let var_name = &self.field.ident;
17        let type_ident = &self.field.ty;
18
19        let loc_code = self.location.as_ref().map(|loc| {
20            quote! {
21                {
22                    #[cfg(debug_assertions)]
23                    {
24                        let current = __POS;
25                        if #loc < current {
26                            panic!("New location 0x{:x} is lower than current location 0x{:x} for field {}",
27                                #loc, current, stringify!(#var_name));
28                        }
29                    }
30                    __POS = #loc;
31                }
32            }
33        });
34
35        let assert_error = self.assert_error.clone().unwrap_or_else(|| {
36            quote! {
37                crate::error::SaveError::AssertionError(format!("(Actual) {:?} != (Expected) {:?}",
38                    ACTUAL, EXPECTED))
39            }
40        });
41
42        let assert_code = self.assert_value.as_ref().map(|assert_value| {
43            let field_type = self.field.ty.to_token_stream();
44            quote! {
45                let EXPECTED: #field_type = #assert_value;
46                // __OUT_PTR points to valid memory if the read succeeded
47                if EXPECTED != std::ptr::read(__OUT_PTR) {
48                    let ACTUAL = std::ptr::read(__OUT_PTR);
49                    return Err(#assert_error)
50                }
51            }
52        });
53
54        quote! {
55            #loc_code
56            {
57                let __OUT_PTR = addr_of_mut!((*__BUILDING). #var_name);
58                <#type_ident as crate::io::SaveBin>::read_into(&__IN_BYTES[__POS..], __OUT_PTR)?;
59                #assert_code
60            }
61            let __SIZE = <#type_ident as crate::io::SaveBin>::size();
62            __POS += __SIZE;
63        }
64    }
65
66    fn writer_tokens(&self) -> TokenStream {
67        let name = &self.field.ident;
68        let field_type = self.field.ty.to_token_stream();
69
70        let loc_code = self.location.as_ref().map(|loc| {
71            quote! {
72                __POS = #loc;
73            }
74        });
75
76        quote! {
77            #loc_code
78            let __TMP_BYTES = &mut __OUT_BYTES[__POS..];
79            self. #name .write(__TMP_BYTES)?;
80            __POS += <#field_type as crate::io::SaveBin>::size();
81        }
82    }
83
84    fn size_calc_tokens(&self) -> TokenStream {
85        let type_ident = &self.field.ty;
86        let field_name = self.field.ident.to_token_stream();
87
88        match &self.location {
89            Some(loc) => quote! {
90                #[cfg(debug_assertions)]
91                if #loc < current_loc {
92                    panic!("New location 0x{:x} is lower than current location 0x{:x} for field {}",
93                        #loc, current_loc, stringify!(#field_name));
94                }
95                let _size = <#type_ident as crate::io::SaveBin>::size();
96                size += _size + #loc - current_loc;
97                current_loc = #loc + _size;
98            },
99            None => quote! {
100                let _size = <#type_ident as crate::io::SaveBin>::size();
101                size += _size;
102                current_loc += _size;
103            },
104        }
105    }
106}
107
108#[proc_macro_derive(SaveBin, attributes(loc, assert, size))]
109pub fn derive_save_deserialize(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
110    let item = parse_macro_input!(item as DeriveInput);
111
112    let name = &item.ident;
113
114    let mut generics = item.generics.clone();
115    // add lifetime param for SaveBin, but only to impl generics
116    generics.params.insert(0, parse_quote!('__SRC));
117    let (impl_generics, _, _) = generics.split_for_impl();
118
119    let (_, ty_generics, where_clause) = item.generics.split_for_impl();
120
121    let item_struct = match item.data {
122        Data::Struct(str) => str,
123        _ => panic!("SaveBin can only be derived on structs"),
124    };
125
126    let expected_size = item
127        .attrs
128        .iter()
129        .find(|a| a.path().is_ident("size"))
130        .map(|a| match &a.meta {
131            Meta::List(l) => l.tokens.clone(),
132            _ => panic!("syntax: #[size(N)]"),
133        });
134
135    let field_visitors = item_struct
136        .fields
137        .iter()
138        .map(|f| {
139            let mut loc = None;
140            let mut assert = None;
141            let mut assert_error = None;
142
143            for attr in &f.attrs {
144                let path = attr.path();
145                let list = match &attr.meta {
146                    Meta::List(list) => list,
147                    _ => continue,
148                };
149                if path.is_ident("loc") {
150                    loc = Some(list.tokens.clone());
151                } else if path.is_ident("assert") {
152                    let parts: Punctuated<Expr, Token!(,)> = list.parse_args_with(Punctuated::parse_terminated)
153                    .expect(
154                        "syntax: #[assert(EXPECTED_VALUE)], or #[assert(EXPECTED, custom_error)",
155                    );
156                    let mut parts = parts.into_iter();
157                    assert = Some(parts.next().unwrap().into_token_stream());
158                    assert_error = parts.next().map(ToTokens::into_token_stream);
159                }
160            }
161
162            FieldVisitor {
163                field: f,
164                location: loc,
165                assert_value: assert,
166                assert_error,
167            }
168        })
169        .collect::<Vec<_>>();
170
171    let parsers = field_visitors
172        .iter()
173        .flat_map(|v| v.parser_tokens())
174        .collect::<TokenStream>();
175
176    let writers = field_visitors
177        .iter()
178        .flat_map(|v| v.writer_tokens())
179        .collect::<TokenStream>();
180
181    let size_calc = field_visitors
182        .iter()
183        .flat_map(|v| v.size_calc_tokens())
184        .collect::<TokenStream>();
185
186    let extra_size = expected_size.map(|size| {
187        quote! {
188            #[cfg(debug_assertions)]
189            if size > #size {
190                panic!("Struct {} too large, can't add padding. Expected max {} bytes, found {}.",
191                    stringify!(#name), #size, size);
192            }
193            size = #size;
194        }
195    });
196
197    let out = quote! {
198        impl #impl_generics crate::io::SaveBin<'__SRC> for #name #ty_generics #where_clause {
199            type ReadError = crate::error::SaveError;
200            type WriteError = crate::error::SaveError;
201
202            unsafe fn read_into(mut __IN_BYTES: &'__SRC [u8], __BUILDING: *mut Self) -> Result<(), Self::ReadError> {
203                use std::ptr::addr_of_mut;
204
205                // Set up relative positions for start of struct
206                if __IN_BYTES.len() < Self::size() {
207                    return Err(crate::error::SaveError::UnexpectedEof);
208                }
209
210                let mut __POS = 0;
211                #parsers
212                Ok(())
213            }
214
215            fn write(&self, mut __OUT_BYTES: &mut [u8]) -> Result<(), Self::WriteError> {
216                let mut __POS = 0;
217                #writers
218                Ok(())
219            }
220
221            fn size() -> usize { // TODO: const?
222                let mut current_loc: usize = 0;
223                let mut size: usize = 0;
224
225                #size_calc
226                #extra_size
227
228                size
229            }
230        }
231    };
232
233    out.into()
234}