recordkeeper_macros/
lib.rs1use proc_macro2::TokenStream;
2
3use quote::{quote, ToTokens};
4use syn::punctuated::Punctuated;
5use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Expr, Field, Meta, Token};
6
7struct FieldVisitor<'ast> {
8 field: &'ast Field,
9 location: Option<TokenStream>,
10 assert_value: Option<TokenStream>,
11 assert_error: Option<TokenStream>,
12}
13
14impl<'ast> FieldVisitor<'ast> {
15 fn parser_tokens(&self) -> TokenStream {
16 let var_name = &self.field.ident;
17 let type_ident = &self.field.ty;
18
19 let loc_code = self.location.as_ref().map(|loc| {
20 quote! {
21 {
22 #[cfg(debug_assertions)]
23 {
24 let current = __POS;
25 if #loc < current {
26 panic!("New location 0x{:x} is lower than current location 0x{:x} for field {}",
27 #loc, current, stringify!(#var_name));
28 }
29 }
30 __POS = #loc;
31 }
32 }
33 });
34
35 let assert_error = self.assert_error.clone().unwrap_or_else(|| {
36 quote! {
37 crate::error::SaveError::AssertionError(format!("(Actual) {:?} != (Expected) {:?}",
38 ACTUAL, EXPECTED))
39 }
40 });
41
42 let assert_code = self.assert_value.as_ref().map(|assert_value| {
43 let field_type = self.field.ty.to_token_stream();
44 quote! {
45 let EXPECTED: #field_type = #assert_value;
46 if EXPECTED != std::ptr::read(__OUT_PTR) {
48 let ACTUAL = std::ptr::read(__OUT_PTR);
49 return Err(#assert_error)
50 }
51 }
52 });
53
54 quote! {
55 #loc_code
56 {
57 let __OUT_PTR = addr_of_mut!((*__BUILDING). #var_name);
58 <#type_ident as crate::io::SaveBin>::read_into(&__IN_BYTES[__POS..], __OUT_PTR)?;
59 #assert_code
60 }
61 let __SIZE = <#type_ident as crate::io::SaveBin>::size();
62 __POS += __SIZE;
63 }
64 }
65
66 fn writer_tokens(&self) -> TokenStream {
67 let name = &self.field.ident;
68 let field_type = self.field.ty.to_token_stream();
69
70 let loc_code = self.location.as_ref().map(|loc| {
71 quote! {
72 __POS = #loc;
73 }
74 });
75
76 quote! {
77 #loc_code
78 let __TMP_BYTES = &mut __OUT_BYTES[__POS..];
79 self. #name .write(__TMP_BYTES)?;
80 __POS += <#field_type as crate::io::SaveBin>::size();
81 }
82 }
83
84 fn size_calc_tokens(&self) -> TokenStream {
85 let type_ident = &self.field.ty;
86 let field_name = self.field.ident.to_token_stream();
87
88 match &self.location {
89 Some(loc) => quote! {
90 #[cfg(debug_assertions)]
91 if #loc < current_loc {
92 panic!("New location 0x{:x} is lower than current location 0x{:x} for field {}",
93 #loc, current_loc, stringify!(#field_name));
94 }
95 let _size = <#type_ident as crate::io::SaveBin>::size();
96 size += _size + #loc - current_loc;
97 current_loc = #loc + _size;
98 },
99 None => quote! {
100 let _size = <#type_ident as crate::io::SaveBin>::size();
101 size += _size;
102 current_loc += _size;
103 },
104 }
105 }
106}
107
108#[proc_macro_derive(SaveBin, attributes(loc, assert, size))]
109pub fn derive_save_deserialize(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
110 let item = parse_macro_input!(item as DeriveInput);
111
112 let name = &item.ident;
113
114 let mut generics = item.generics.clone();
115 generics.params.insert(0, parse_quote!('__SRC));
117 let (impl_generics, _, _) = generics.split_for_impl();
118
119 let (_, ty_generics, where_clause) = item.generics.split_for_impl();
120
121 let item_struct = match item.data {
122 Data::Struct(str) => str,
123 _ => panic!("SaveBin can only be derived on structs"),
124 };
125
126 let expected_size = item
127 .attrs
128 .iter()
129 .find(|a| a.path().is_ident("size"))
130 .map(|a| match &a.meta {
131 Meta::List(l) => l.tokens.clone(),
132 _ => panic!("syntax: #[size(N)]"),
133 });
134
135 let field_visitors = item_struct
136 .fields
137 .iter()
138 .map(|f| {
139 let mut loc = None;
140 let mut assert = None;
141 let mut assert_error = None;
142
143 for attr in &f.attrs {
144 let path = attr.path();
145 let list = match &attr.meta {
146 Meta::List(list) => list,
147 _ => continue,
148 };
149 if path.is_ident("loc") {
150 loc = Some(list.tokens.clone());
151 } else if path.is_ident("assert") {
152 let parts: Punctuated<Expr, Token!(,)> = list.parse_args_with(Punctuated::parse_terminated)
153 .expect(
154 "syntax: #[assert(EXPECTED_VALUE)], or #[assert(EXPECTED, custom_error)",
155 );
156 let mut parts = parts.into_iter();
157 assert = Some(parts.next().unwrap().into_token_stream());
158 assert_error = parts.next().map(ToTokens::into_token_stream);
159 }
160 }
161
162 FieldVisitor {
163 field: f,
164 location: loc,
165 assert_value: assert,
166 assert_error,
167 }
168 })
169 .collect::<Vec<_>>();
170
171 let parsers = field_visitors
172 .iter()
173 .flat_map(|v| v.parser_tokens())
174 .collect::<TokenStream>();
175
176 let writers = field_visitors
177 .iter()
178 .flat_map(|v| v.writer_tokens())
179 .collect::<TokenStream>();
180
181 let size_calc = field_visitors
182 .iter()
183 .flat_map(|v| v.size_calc_tokens())
184 .collect::<TokenStream>();
185
186 let extra_size = expected_size.map(|size| {
187 quote! {
188 #[cfg(debug_assertions)]
189 if size > #size {
190 panic!("Struct {} too large, can't add padding. Expected max {} bytes, found {}.",
191 stringify!(#name), #size, size);
192 }
193 size = #size;
194 }
195 });
196
197 let out = quote! {
198 impl #impl_generics crate::io::SaveBin<'__SRC> for #name #ty_generics #where_clause {
199 type ReadError = crate::error::SaveError;
200 type WriteError = crate::error::SaveError;
201
202 unsafe fn read_into(mut __IN_BYTES: &'__SRC [u8], __BUILDING: *mut Self) -> Result<(), Self::ReadError> {
203 use std::ptr::addr_of_mut;
204
205 if __IN_BYTES.len() < Self::size() {
207 return Err(crate::error::SaveError::UnexpectedEof);
208 }
209
210 let mut __POS = 0;
211 #parsers
212 Ok(())
213 }
214
215 fn write(&self, mut __OUT_BYTES: &mut [u8]) -> Result<(), Self::WriteError> {
216 let mut __POS = 0;
217 #writers
218 Ok(())
219 }
220
221 fn size() -> usize { let mut current_loc: usize = 0;
223 let mut size: usize = 0;
224
225 #size_calc
226 #extra_size
227
228 size
229 }
230 }
231 };
232
233 out.into()
234}