1#![recursion_limit = "256"]
2use darling::{FromDeriveInput, FromMeta};
7use proc_macro::TokenStream;
8use quote::quote;
9use std::convert::TryInto;
10use syn::{parse_macro_input, DataEnum, DataStruct, DeriveInput, Ident};
11
12const MAX_UNION_SELECTOR: u8 = 127;
15
16#[derive(Debug, FromDeriveInput)]
17#[darling(attributes(ssz))]
18struct StructOpts {
19 #[darling(default)]
20 enum_behaviour: Option<String>,
21}
22
23#[derive(Debug, Default, FromMeta)]
25struct FieldOpts {
26 #[darling(default)]
27 with: Option<Ident>,
28 #[darling(default)]
29 skip_serializing: bool,
30 #[darling(default)]
31 skip_deserializing: bool,
32}
33
34const ENUM_TRANSPARENT: &str = "transparent";
35const ENUM_UNION: &str = "union";
36const ENUM_VARIANTS: &[&str] = &[ENUM_TRANSPARENT, ENUM_UNION];
37const NO_ENUM_BEHAVIOUR_ERROR: &str = "enums require an \"enum_behaviour\" attribute, \
38 e.g., #[ssz(enum_behaviour = \"transparent\")]";
39
40enum EnumBehaviour {
41 Transparent,
42 Union,
43}
44
45impl EnumBehaviour {
46 pub fn new(s: Option<String>) -> Option<Self> {
47 s.map(|s| match s.as_ref() {
48 ENUM_TRANSPARENT => EnumBehaviour::Transparent,
49 ENUM_UNION => EnumBehaviour::Union,
50 other => panic!(
51 "{} is an invalid enum_behaviour, use either {:?}",
52 other, ENUM_VARIANTS
53 ),
54 })
55 }
56}
57
58fn parse_ssz_fields(struct_data: &syn::DataStruct) -> Vec<(&syn::Type, &syn::Ident, FieldOpts)> {
59 struct_data
60 .fields
61 .iter()
62 .map(|field| {
63 let ty = &field.ty;
64 let ident = match &field.ident {
65 Some(ref ident) => ident,
66 _ => panic!("ssz_derive only supports named struct fields."),
67 };
68
69 let field_opts_candidates = field
70 .attrs
71 .iter()
72 .filter(|attr| attr.path.get_ident().map_or(false, |ident| *ident == "ssz"))
73 .collect::<Vec<_>>();
74
75 if field_opts_candidates.len() > 1 {
76 panic!("more than one field-level \"ssz\" attribute provided")
77 }
78
79 let field_opts = field_opts_candidates
80 .first()
81 .map(|attr| {
82 let meta = attr.parse_meta().unwrap();
83 FieldOpts::from_meta(&meta).unwrap()
84 })
85 .unwrap_or_default();
86
87 (ty, ident, field_opts)
88 })
89 .collect()
90}
91
92#[proc_macro_derive(Encode, attributes(ssz))]
94pub fn ssz_encode_derive(input: TokenStream) -> TokenStream {
95 let item = parse_macro_input!(input as DeriveInput);
96 let opts = StructOpts::from_derive_input(&item).unwrap();
97 let enum_opt = EnumBehaviour::new(opts.enum_behaviour);
98
99 match &item.data {
100 syn::Data::Struct(s) => {
101 if enum_opt.is_some() {
102 panic!("enum_behaviour is invalid for structs");
103 }
104 ssz_encode_derive_struct(&item, s)
105 }
106 syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) {
107 EnumBehaviour::Transparent => ssz_encode_derive_enum_transparent(&item, s),
108 EnumBehaviour::Union => ssz_encode_derive_enum_union(&item, s),
109 },
110 _ => panic!("ssz_derive only supports structs and enums"),
111 }
112}
113
114fn ssz_encode_derive_struct(derive_input: &DeriveInput, struct_data: &DataStruct) -> TokenStream {
122 let name = &derive_input.ident;
123 let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
124
125 let field_is_ssz_fixed_len = &mut vec![];
126 let field_fixed_len = &mut vec![];
127 let field_ssz_bytes_len = &mut vec![];
128 let field_encoder_append = &mut vec![];
129
130 for (ty, ident, field_opts) in parse_ssz_fields(struct_data) {
131 if field_opts.skip_serializing {
132 continue;
133 }
134
135 if let Some(module) = field_opts.with {
136 let module = quote! { #module::encode };
137 field_is_ssz_fixed_len.push(quote! { #module::is_ssz_fixed_len() });
138 field_fixed_len.push(quote! { #module::ssz_fixed_len() });
139 field_ssz_bytes_len.push(quote! { #module::ssz_bytes_len(&self.#ident) });
140 field_encoder_append.push(quote! {
141 encoder.append_parameterized(
142 #module::is_ssz_fixed_len(),
143 |buf| #module::ssz_append(&self.#ident, buf)
144 )
145 });
146 } else {
147 field_is_ssz_fixed_len.push(quote! { <#ty as ssz::Encode>::is_ssz_fixed_len() });
148 field_fixed_len.push(quote! { <#ty as ssz::Encode>::ssz_fixed_len() });
149 field_ssz_bytes_len.push(quote! { self.#ident.ssz_bytes_len() });
150 field_encoder_append.push(quote! { encoder.append(&self.#ident) });
151 }
152 }
153
154 let output = quote! {
155 impl #impl_generics ssz::Encode for #name #ty_generics #where_clause {
156 fn is_ssz_fixed_len() -> bool {
157 #(
158 #field_is_ssz_fixed_len &&
159 )*
160 true
161 }
162
163 fn ssz_fixed_len() -> usize {
164 if <Self as ssz::Encode>::is_ssz_fixed_len() {
165 let mut len: usize = 0;
166 #(
167 len = len
168 .checked_add(#field_fixed_len)
169 .expect("encode ssz_fixed_len length overflow");
170 )*
171 len
172 } else {
173 ssz::BYTES_PER_LENGTH_OFFSET
174 }
175 }
176
177 fn ssz_bytes_len(&self) -> usize {
178 if <Self as ssz::Encode>::is_ssz_fixed_len() {
179 <Self as ssz::Encode>::ssz_fixed_len()
180 } else {
181 let mut len: usize = 0;
182 #(
183 if #field_is_ssz_fixed_len {
184 len = len
185 .checked_add(#field_fixed_len)
186 .expect("encode ssz_bytes_len length overflow");
187 } else {
188 len = len
189 .checked_add(ssz::BYTES_PER_LENGTH_OFFSET)
190 .expect("encode ssz_bytes_len length overflow for offset");
191 len = len
192 .checked_add(#field_ssz_bytes_len)
193 .expect("encode ssz_bytes_len length overflow for bytes");
194 }
195 )*
196
197 len
198 }
199 }
200
201 fn ssz_append(&self, buf: &mut Vec<u8>) {
202 let mut offset: usize = 0;
203 #(
204 offset = offset
205 .checked_add(#field_fixed_len)
206 .expect("encode ssz_append offset overflow");
207 )*
208
209 let mut encoder = ssz::SszEncoder::container(buf, offset);
210
211 #(
212 #field_encoder_append;
213 )*
214
215 encoder.finalize();
216 }
217 }
218 };
219 output.into()
220}
221
222fn ssz_encode_derive_enum_transparent(
240 derive_input: &DeriveInput,
241 enum_data: &DataEnum,
242) -> TokenStream {
243 let name = &derive_input.ident;
244 let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
245
246 let (patterns, assert_exprs): (Vec<_>, Vec<_>) = enum_data
247 .variants
248 .iter()
249 .map(|variant| {
250 let variant_name = &variant.ident;
251
252 if variant.fields.len() != 1 {
253 panic!("ssz::Encode can only be derived for enums with 1 field per variant");
254 }
255
256 let pattern = quote! {
257 #name::#variant_name(ref inner)
258 };
259
260 let ty = &(&variant.fields).into_iter().next().unwrap().ty;
261 let type_assert = quote! {
262 !<#ty as ssz::Encode>::is_ssz_fixed_len()
263 };
264 (pattern, type_assert)
265 })
266 .unzip();
267
268 let output = quote! {
269 impl #impl_generics ssz::Encode for #name #ty_generics #where_clause {
270 fn is_ssz_fixed_len() -> bool {
271 assert!(
272 #(
273 #assert_exprs &&
274 )* true,
275 "not all enum variants are variably-sized"
276 );
277 false
278 }
279
280 fn ssz_bytes_len(&self) -> usize {
281 match self {
282 #(
283 #patterns => inner.ssz_bytes_len(),
284 )*
285 }
286 }
287
288 fn ssz_append(&self, buf: &mut Vec<u8>) {
289 match self {
290 #(
291 #patterns => inner.ssz_append(buf),
292 )*
293 }
294 }
295 }
296 };
297 output.into()
298}
299
300fn ssz_encode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum) -> TokenStream {
310 let name = &derive_input.ident;
311 let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
312
313 let patterns: Vec<_> = enum_data
314 .variants
315 .iter()
316 .map(|variant| {
317 let variant_name = &variant.ident;
318
319 if variant.fields.len() != 1 {
320 panic!("ssz::Encode can only be derived for enums with 1 field per variant");
321 }
322
323 let pattern = quote! {
324 #name::#variant_name(ref inner)
325 };
326 pattern
327 })
328 .collect();
329
330 let union_selectors = compute_union_selectors(patterns.len());
331
332 let output = quote! {
333 impl #impl_generics ssz::Encode for #name #ty_generics #where_clause {
334 fn is_ssz_fixed_len() -> bool {
335 false
336 }
337
338 fn ssz_bytes_len(&self) -> usize {
339 match self {
340 #(
341 #patterns => inner
342 .ssz_bytes_len()
343 .checked_add(1)
344 .expect("encoded length must be less than usize::max_value"),
345 )*
346 }
347 }
348
349 fn ssz_append(&self, buf: &mut Vec<u8>) {
350 match self {
351 #(
352 #patterns => {
353 let union_selector: u8 = #union_selectors;
354 debug_assert!(union_selector <= ssz::MAX_UNION_SELECTOR);
355 buf.push(union_selector);
356 inner.ssz_append(buf)
357 },
358 )*
359 }
360 }
361 }
362 };
363 output.into()
364}
365
366#[proc_macro_derive(Decode, attributes(ssz))]
368pub fn ssz_decode_derive(input: TokenStream) -> TokenStream {
369 let item = parse_macro_input!(input as DeriveInput);
370 let opts = StructOpts::from_derive_input(&item).unwrap();
371 let enum_opt = EnumBehaviour::new(opts.enum_behaviour);
372
373 match &item.data {
374 syn::Data::Struct(s) => {
375 if enum_opt.is_some() {
376 panic!("enum_behaviour is invalid for structs");
377 }
378 ssz_decode_derive_struct(&item, s)
379 }
380 syn::Data::Enum(s) => match enum_opt.expect(NO_ENUM_BEHAVIOUR_ERROR) {
381 EnumBehaviour::Transparent => panic!(
382 "Decode cannot be derived for enum_behaviour \"{}\", only \"{}\" is valid.",
383 ENUM_TRANSPARENT, ENUM_UNION
384 ),
385 EnumBehaviour::Union => ssz_decode_derive_enum_union(&item, s),
386 },
387 _ => panic!("ssz_derive only supports structs and enums"),
388 }
389}
390
391fn ssz_decode_derive_struct(item: &DeriveInput, struct_data: &DataStruct) -> TokenStream {
401 let name = &item.ident;
402 let (impl_generics, ty_generics, where_clause) = &item.generics.split_for_impl();
403
404 let mut register_types = vec![];
405 let mut field_names = vec![];
406 let mut fixed_decodes = vec![];
407 let mut decodes = vec![];
408 let mut is_fixed_lens = vec![];
409 let mut fixed_lens = vec![];
410
411 for (ty, ident, field_opts) in parse_ssz_fields(struct_data) {
412 field_names.push(quote! {
413 #ident
414 });
415
416 if field_opts.skip_deserializing {
418 decodes.push(quote! {
419 let #ident = <_>::default();
420 });
421
422 fixed_decodes.push(quote! {
423 let #ident = <_>::default();
424 });
425
426 continue;
427 }
428
429 let is_ssz_fixed_len;
430 let ssz_fixed_len;
431 let from_ssz_bytes;
432 if let Some(module) = field_opts.with {
433 let module = quote! { #module::decode };
434
435 is_ssz_fixed_len = quote! { #module::is_ssz_fixed_len() };
436 ssz_fixed_len = quote! { #module::ssz_fixed_len() };
437 from_ssz_bytes = quote! { #module::from_ssz_bytes(slice) };
438
439 register_types.push(quote! {
440 builder.register_type_parameterized(#is_ssz_fixed_len, #ssz_fixed_len)?;
441 });
442 decodes.push(quote! {
443 let #ident = decoder.decode_next_with(|slice| #module::from_ssz_bytes(slice))?;
444 });
445 } else {
446 is_ssz_fixed_len = quote! { <#ty as ssz::Decode>::is_ssz_fixed_len() };
447 ssz_fixed_len = quote! { <#ty as ssz::Decode>::ssz_fixed_len() };
448 from_ssz_bytes = quote! { <#ty as ssz::Decode>::from_ssz_bytes(slice) };
449
450 register_types.push(quote! {
451 builder.register_type::<#ty>()?;
452 });
453 decodes.push(quote! {
454 let #ident = decoder.decode_next()?;
455 });
456 }
457
458 fixed_decodes.push(quote! {
459 let #ident = {
460 start = end;
461 end = end
462 .checked_add(#ssz_fixed_len)
463 .ok_or_else(|| ssz::DecodeError::OutOfBoundsByte {
464 i: usize::max_value()
465 })?;
466 let slice = bytes.get(start..end)
467 .ok_or_else(|| ssz::DecodeError::InvalidByteLength {
468 len: bytes.len(),
469 expected: end
470 })?;
471 #from_ssz_bytes?
472 };
473 });
474 is_fixed_lens.push(is_ssz_fixed_len);
475 fixed_lens.push(ssz_fixed_len);
476 }
477
478 let output = quote! {
479 impl #impl_generics ssz::Decode for #name #ty_generics #where_clause {
480 fn is_ssz_fixed_len() -> bool {
481 #(
482 #is_fixed_lens &&
483 )*
484 true
485 }
486
487 fn ssz_fixed_len() -> usize {
488 if <Self as ssz::Decode>::is_ssz_fixed_len() {
489 let mut len: usize = 0;
490 #(
491 len = len
492 .checked_add(#fixed_lens)
493 .expect("decode ssz_fixed_len overflow");
494 )*
495 len
496 } else {
497 ssz::BYTES_PER_LENGTH_OFFSET
498 }
499 }
500
501 fn from_ssz_bytes(bytes: &[u8]) -> std::result::Result<Self, ssz::DecodeError> {
502 if <Self as ssz::Decode>::is_ssz_fixed_len() {
503 if bytes.len() != <Self as ssz::Decode>::ssz_fixed_len() {
504 return Err(ssz::DecodeError::InvalidByteLength {
505 len: bytes.len(),
506 expected: <Self as ssz::Decode>::ssz_fixed_len(),
507 });
508 }
509
510 let mut start: usize = 0;
511 let mut end = start;
512
513 #(
514 #fixed_decodes
515 )*
516
517 Ok(Self {
518 #(
519 #field_names,
520 )*
521 })
522 } else {
523 let mut builder = ssz::SszDecoderBuilder::new(bytes);
524
525 #(
526 #register_types
527 )*
528
529 let mut decoder = builder.build()?;
530
531 #(
532 #decodes
533 )*
534
535
536 Ok(Self {
537 #(
538 #field_names,
539 )*
540 })
541 }
542 }
543 }
544 };
545 output.into()
546}
547
548fn ssz_decode_derive_enum_union(derive_input: &DeriveInput, enum_data: &DataEnum) -> TokenStream {
550 let name = &derive_input.ident;
551 let (impl_generics, ty_generics, where_clause) = &derive_input.generics.split_for_impl();
552
553 let (constructors, var_types): (Vec<_>, Vec<_>) = enum_data
554 .variants
555 .iter()
556 .map(|variant| {
557 let variant_name = &variant.ident;
558
559 if variant.fields.len() != 1 {
560 panic!("ssz::Encode can only be derived for enums with 1 field per variant");
561 }
562
563 let constructor = quote! {
564 #name::#variant_name
565 };
566
567 let ty = &(&variant.fields).into_iter().next().unwrap().ty;
568 (constructor, ty)
569 })
570 .unzip();
571
572 let union_selectors = compute_union_selectors(constructors.len());
573
574 let output = quote! {
575 impl #impl_generics ssz::Decode for #name #ty_generics #where_clause {
576 fn is_ssz_fixed_len() -> bool {
577 false
578 }
579
580 fn from_ssz_bytes(bytes: &[u8]) -> Result<Self, ssz::DecodeError> {
581 debug_assert_eq!(#MAX_UNION_SELECTOR, ssz::MAX_UNION_SELECTOR);
584
585 let (selector, body) = ssz::split_union_bytes(bytes)?;
586
587 match selector.into() {
588 #(
589 #union_selectors => {
590 <#var_types as ssz::Decode>::from_ssz_bytes(body).map(#constructors)
591 },
592 )*
593 other => Err(ssz::DecodeError::UnionSelectorInvalid(other))
594 }
595 }
596 }
597 };
598 output.into()
599}
600
601fn compute_union_selectors(num_variants: usize) -> Vec<u8> {
602 let union_selectors = (0..num_variants)
603 .map(|i| {
604 i.try_into()
605 .expect("union selector exceeds u8::max_value, union has too many variants")
606 })
607 .collect::<Vec<u8>>();
608
609 let highest_selector = union_selectors
610 .last()
611 .copied()
612 .expect("0-variant union is not permitted");
613
614 assert!(
615 highest_selector <= MAX_UNION_SELECTOR,
616 "union selector {} exceeds limit of {}, enum has too many variants",
617 highest_selector,
618 MAX_UNION_SELECTOR
619 );
620
621 union_selectors
622}