1use proc_macro2::{Span, TokenStream};
9use quote::{quote, quote_spanned};
10use syn::{
11 parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam, Generics, Index,
12};
13
14#[proc_macro_derive(TryFromForSizedBytes, attributes(ErrorType))]
22pub fn try_from_for_sized_bytes(source: proc_macro::TokenStream) -> proc_macro::TokenStream {
23 let ast: DeriveInput = syn::parse(source).expect("Incorrect macro input");
24 let name = &ast.ident;
25
26 let error_type = get_type_from_attrs(&ast.attrs, "ErrorType").unwrap();
27
28 let generics = add_basic_bound(ast.generics);
29 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
30
31 let gen = quote! {
32 impl #impl_generics ::std::convert::TryFrom<&[u8]> for #name #ty_generics #where_clause {
33 type Error = #error_type;
34
35 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
36 let expected_len = <<Self as ::generic_bytes::SizedBytes>::Len as generic_array::typenum::Unsigned>::to_usize();
37 if bytes.len() != expected_len {
38 return Err(#error_type::default());
39 }
40 let arr = GenericArray::from_slice(bytes);
41 <Self as ::generic_bytes::SizedBytes>::from_arr(arr).map_err(|_| #error_type::default())
42 }
43 }
44 };
45 gen.into()
46}
47
48fn get_type_from_attrs(attrs: &[syn::Attribute], attr_name: &str) -> syn::Result<syn::Type> {
49 attrs
50 .iter()
51 .find(|attr| attr.path.is_ident(attr_name))
52 .map_or_else(
53 || {
54 Err(syn::Error::new(
55 proc_macro2::Span::call_site(),
56 format!("Could not find attribute {}", attr_name),
57 ))
58 },
59 |attr| match attr.parse_meta()? {
60 syn::Meta::NameValue(meta) => {
61 if let syn::Lit::Str(lit) = &meta.lit {
62 Ok(lit.clone())
63 } else {
64 Err(syn::Error::new_spanned(
65 meta,
66 &format!("Could not parse {} attribute", attr_name)[..],
67 ))
68 }
69 }
70 bad => Err(syn::Error::new_spanned(
71 bad,
72 &format!("Could not parse {} attribute", attr_name)[..],
73 )),
74 },
75 )
76 .and_then(|str| str.parse())
77}
78
79fn add_basic_bound(mut generics: Generics) -> Generics {
81 for param in &mut generics.params {
82 if let GenericParam::Type(ref mut type_param) = *param {
83 type_param
84 .bounds
85 .push(parse_quote!(::generic_bytes::SizedBytes));
86 }
87 }
88 generics
89}
90
91fn add_trait_bounds(
97 generics: &mut Generics,
98 data: &syn::Data,
99 bound: syn::Path,
100) -> Result<(), syn::Error> {
101 if generics.params.is_empty() {
102 return Ok(());
103 }
104
105 let types = collect_types(data)?;
106 if !types.is_empty() {
107 let where_clause = generics.make_where_clause();
108
109 types
110 .into_iter()
111 .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #bound)));
112 bounds_sum(data, where_clause)?;
113 }
114
115 Ok(())
116}
117
118fn collect_types(data: &syn::Data) -> Result<Vec<syn::Type>, syn::Error> {
119 use syn::*;
120
121 let types = match *data {
122 Data::Struct(ref data) => match &data.fields {
123 Fields::Named(FieldsNamed { named: fields, .. })
124 | Fields::Unnamed(FieldsUnnamed {
125 unnamed: fields, ..
126 }) => fields.iter().map(|f| f.ty.clone()).collect(),
127
128 Fields::Unit => Vec::new(),
129 },
130
131 Data::Enum(ref data) => data
132 .variants
133 .iter()
134 .flat_map(|variant| match &variant.fields {
135 Fields::Named(FieldsNamed { named: fields, .. })
136 | Fields::Unnamed(FieldsUnnamed {
137 unnamed: fields, ..
138 }) => fields.iter().map(|f| f.ty.clone()).collect(),
139
140 Fields::Unit => Vec::new(),
141 })
142 .collect(),
143
144 Data::Union(_) => {
145 return Err(Error::new(
146 Span::call_site(),
147 "Union types are not supported.",
148 ))
149 }
150 };
151
152 Ok(types)
153}
154
155fn extract_size_type_from_generic_array(ty: &syn::Type) -> Option<&syn::Type> {
156 fn path_is_generic_array(path: &syn::Path) -> Option<&syn::GenericArgument> {
157 path.segments.iter().find_map(|pt| {
158 if pt.ident == "GenericArray" {
159 match &pt.arguments {
161 syn::PathArguments::AngleBracketed(params) if params.args.len() == 2 => {
162 params.args.last()
163 }
164 _ => None,
165 }
166 } else {
167 None
168 }
169 })
170 }
171
172 match ty {
173 syn::Type::Path(typepath)
174 if typepath.qself.is_none()
175 && typepath
176 .path
177 .segments
178 .iter()
179 .any(|pt| pt.ident == "GenericArray") =>
180 {
181 let type_param = path_is_generic_array(&typepath.path);
183 if let Some(syn::GenericArgument::Type(ty)) = type_param {
185 Some(ty)
186 } else {
187 None
188 }
189 }
190 _ => None,
191 }
192}
193
194fn bounds_sum(data: &Data, where_clause: &mut syn::WhereClause) -> Result<(), syn::Error> {
195 match *data {
196 Data::Struct(ref data) => {
197 match data.fields {
198 Fields::Named(ref fields) => {
199 let mut quote = None;
200 for f in fields.named.iter() {
201 let ty = &f.ty;
202 let res =
203 if let Some(unsigned_ty) = extract_size_type_from_generic_array(ty) {
204 quote_spanned! {f.span()=>
205 #unsigned_ty
206 }
207 } else {
208 quote_spanned! {f.span()=>
209 <#ty as ::generic_bytes::SizedBytes>::Len
210 }
211 };
212 if let Some(ih) = quote {
213 quote = Some(quote! {
214 ::generic_array::typenum::Sum<#ih, #res>
215 });
216 where_clause
217 .predicates
218 .push(parse_quote!(#ih: ::core::ops::Add<#res>));
219 where_clause
220 .predicates
221 .push(parse_quote!(::generic_array::typenum::Sum<#ih, #res> : ::generic_array::ArrayLength<u8> + ::core::ops::Sub<#ih, Output = #res>));
222 where_clause
223 .predicates
224 .push(parse_quote!(::generic_array::typenum::Diff<::generic_array::typenum::Sum<#ih, #res>, #ih> : ::generic_array::ArrayLength<u8>));
225 } else {
226 quote = Some(res);
227 }
228 }
229 Ok(())
230 }
231 Fields::Unnamed(ref fields) => {
232 let mut quote = None;
233 for f in fields.unnamed.iter() {
234 let ty = &f.ty;
235 let res =
236 if let Some(unsigned_ty) = extract_size_type_from_generic_array(ty) {
237 quote_spanned! {f.span()=>
238 #unsigned_ty
239 }
240 } else {
241 quote_spanned! {f.span()=>
242 <#ty as ::generic_bytes::SizedBytes>::Len
243 }
244 };
245 if let Some(ih) = quote {
246 quote = Some(quote! {
247 ::generic_array::typenum::Sum<#ih, #res>
248 });
249 where_clause
250 .predicates
251 .push(parse_quote!(#ih : ::core::ops::Add<#res>));
252 where_clause
253 .predicates
254 .push(parse_quote!(::generic_array::typenum::Sum<#ih, #res> : ::generic_array::ArrayLength<u8> + ::core::ops::Sub<#ih, Output = #res>));
255 where_clause
256 .predicates
257 .push(parse_quote!(::generic_array::typenum::Diff<::generic_array::typenum::Sum<#ih, #res>, #ih> : ::generic_array::ArrayLength<u8>));
258 } else {
259 quote = Some(res);
260 }
261 }
262 Ok(())
263 }
264 Fields::Unit => {
265 unimplemented!()
267 }
268 }
269 }
270 Data::Enum(_) | Data::Union(_) => unimplemented!(),
271 }
272}
273
274fn sum(data: &Data) -> TokenStream {
276 match *data {
277 Data::Struct(ref data) => {
278 match data.fields {
279 Fields::Named(ref fields) => {
280 let mut quote = None;
281 for f in fields.named.iter() {
282 let ty = &f.ty;
283 let res = quote_spanned! {f.span()=>
284 <#ty as ::generic_bytes::SizedBytes>::Len
285 };
286 if let Some(ih) = quote {
287 quote = Some(quote! {
288 ::generic_array::typenum::Sum<#ih, #res>
289 });
290 } else {
291 quote = Some(res);
292 }
293 }
294 quote! {
295 #quote
296 }
297 }
298 Fields::Unnamed(ref fields) => {
299 let mut quote = None;
300 for f in fields.unnamed.iter() {
301 let ty = &f.ty;
302 let res = quote_spanned! {f.span()=>
303 <#ty as ::generic_bytes::SizedBytes>::Len
304 };
305 if let Some(ih) = quote {
306 quote = Some(quote! {
307 ::generic_array::typenum::Sum<#ih, #res>
308 });
309 } else {
310 quote = Some(res);
311 }
312 }
313 quote! {
314 #quote
315 }
316 }
317 Fields::Unit => {
318 unimplemented!()
320 }
321 }
322 }
323 Data::Enum(_) | Data::Union(_) => unimplemented!(),
324 }
325}
326
327fn byte_concatenation(data: &Data) -> TokenStream {
329 match *data {
330 Data::Struct(ref data) => {
331 match data.fields {
332 Fields::Named(ref fields) => {
333 let mut quote = None;
334 for f in fields.named.iter() {
335 let name = &f.ident;
336 let res = quote_spanned! {f.span()=>
337 ::generic_bytes::SizedBytes::to_arr(&self.#name)
338 };
339 if let Some(ih) = quote {
340 quote = Some(quote! {
341 ::generic_array::sequence::Concat::concat(#ih, #res)
342 });
343 } else {
344 quote = Some(res);
345 }
346 }
347 quote! {
348 #quote
349 }
350 }
351 Fields::Unnamed(ref fields) => {
352 let mut quote = None;
353 for (i, f) in fields.unnamed.iter().enumerate() {
354 let index = Index::from(i);
355 let res = quote_spanned! {f.span()=>
356 ::generic_bytes::SizedBytes::to_arr(&self.#index)
357 };
358 if let Some(ih) = quote {
359 quote = Some(quote! {
360 ::generic_array::sequence::Concat::concat(#ih, #res)
361 });
362 } else {
363 quote = Some(res);
364 }
365 }
366 quote! {
367 #quote
368 }
369 }
370 Fields::Unit => {
371 quote!(0)
373 }
374 }
375 }
376 Data::Enum(_) | Data::Union(_) => unimplemented!(),
377 }
378}
379
380fn byte_splitting(constr: &proc_macro2::Ident, data: &Data) -> TokenStream {
382 match *data {
383 Data::Struct(ref data) => {
384 match data.fields {
385 Fields::Named(ref fields) => {
386 let l = fields.named.len();
387 let setup: TokenStream = fields
388 .named
389 .iter().enumerate()
390 .map(|(i, f)| {
391 let name = &f.ident;
392 let ty = &f.ty;
393
394 if i < (l-1) {
395 quote_spanned! {f.span()=>
396 let (head, _tail): (&GenericArray<u8, <#ty as ::generic_bytes::SizedBytes>::Len>, &GenericArray<u8, _>) =
397 generic_array::sequence::Split::split(_tail);
398 let #name: #ty = ::generic_bytes::SizedBytes::from_arr(head)?;
399 }
400 } else {
401 quote_spanned!{f.span() =>
402 let #name: #ty = ::generic_bytes::SizedBytes::from_arr(_tail)?;
403 }
404 }
405 })
406 .collect();
407
408 let conclude: TokenStream = fields
409 .named
410 .iter()
411 .map(|f| {
412 let name = &f.ident;
413 quote_spanned! {f.span()=>
414 #name,
415 }
416 })
417 .collect();
418 quote! {
419 let _tail = arr;
420 #setup
421 Ok(#constr {
422 #conclude
423 })
424 }
425 }
426 Fields::Unnamed(ref fields) => {
427 let l = fields.unnamed.len();
428 let setup: TokenStream = fields
429 .unnamed
430 .iter()
431 .enumerate()
432 .map(|(i, f)| {
433 let ty = &f.ty;
434 if i < (l-1) {
435 let field_name = format!("f_{}", i);
436 let fname = syn::Ident::new(&field_name, f.span());
437 quote_spanned! {f.span()=>
438 let (head, _tail) = generic_array::sequence::Split::split(_tail);
439 let #fname: #ty = ::generic_bytes::SizedBytes::from_arr(head)?;
440 }
441 } else {
442 let field_name = format!("f_{}", i);
443 let fname = syn::Ident::new(&field_name, f.span());
444 quote_spanned! {f.span()=>
445 let #fname: #ty = ::generic_bytes::SizedBytes::from_arr(_tail)?;
446 }
447 }
448 })
449 .collect();
450
451 let conclude: TokenStream = fields
452 .unnamed
453 .iter()
454 .enumerate()
455 .map(|(i, f)| {
456 let field_name = format!("f_{}", i);
457 let fname = syn::Ident::new(&field_name, f.span());
458 quote_spanned! {f.span()=>
459 #fname,
460 }
461 })
462 .collect();
463 quote! (
464 let _tail = arr;
465 #setup
466 Ok(#constr (
467 #conclude
468 ))
469 )
470 }
471 Fields::Unit => {
472 quote!(0)
474 }
475 }
476 }
477 Data::Enum(_) | Data::Union(_) => unimplemented!(),
478 }
479}
480
481#[proc_macro_derive(SizedBytes)]
482pub fn derive_sized_bytes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
483 let mut input: DeriveInput = match syn::parse(input) {
484 Ok(input) => input,
485 Err(e) => return e.to_compile_error().into(),
486 };
487 let name = &input.ident;
488
489 if let Err(e) = add_trait_bounds(
491 &mut input.generics,
492 &input.data,
493 parse_quote!(::generic_bytes::SizedBytes),
494 ) {
495 return e.to_compile_error().into();
496 };
497
498 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
499
500 let types_sum = sum(&input.data);
502
503 let to_arr_impl = byte_concatenation(&input.data);
505
506 let from_arr_impl = byte_splitting(name, &input.data);
508
509 let res = quote! (
510 impl #impl_generics ::generic_bytes::SizedBytes for #name #ty_generics #where_clause {
512
513 type Len = #types_sum;
514
515 fn to_arr(&self) -> GenericArray<u8, Self::Len> {
516 #to_arr_impl
517 }
518
519 fn from_arr(arr: &GenericArray<u8, Self::Len>) -> Result<Self, ::generic_bytes::TryFromSizedBytesError> {
520 #from_arr_impl
521 }
522 }
523 );
524 res.into()
525}