bolt_proto_derive/
lib.rs

1#![warn(rust_2018_idioms)]
2
3use proc_macro::TokenStream;
4
5use syn::{AttributeArgs, Fields, Generics, Ident, ItemStruct, NestedMeta, WhereClause};
6
7use quote::{format_ident, quote};
8
9pub(crate) const MARKER_TINY_STRUCT: u8 = 0xB0;
10pub(crate) const MARKER_SMALL_STRUCT: u8 = 0xDC;
11pub(crate) const MARKER_MEDIUM_STRUCT: u8 = 0xDD;
12
13fn get_struct_info(
14    structure: ItemStruct,
15    args: AttributeArgs,
16) -> (Ident, Generics, Option<WhereClause>, Fields, u8, NestedMeta) {
17    let name = structure.ident;
18    let type_args = structure.generics;
19    let where_clause = type_args.where_clause.clone();
20    let fields = structure.fields;
21
22    let marker = match fields.len() {
23        0..=15 => MARKER_TINY_STRUCT | fields.len() as u8,
24        16..=255 => MARKER_SMALL_STRUCT,
25        256..=65535 => MARKER_MEDIUM_STRUCT,
26        _ => panic!("struct has too many fields"),
27    };
28
29    let signature = args.into_iter().next().expect("signature is required");
30
31    (name, type_args, where_clause, fields, marker, signature)
32}
33
34#[proc_macro_attribute]
35pub fn bolt_structure(attr_args: TokenStream, item: TokenStream) -> TokenStream {
36    let structure = syn::parse_macro_input!(item as ItemStruct);
37    let args = syn::parse_macro_input!(attr_args as AttributeArgs);
38    let (name, type_args, where_clause, fields, marker, signature) =
39        get_struct_info(structure.clone(), args);
40
41    let field_names: Vec<Ident> = fields.into_iter().map(|f| f.ident.unwrap()).collect();
42    let byte_var_names: Vec<Ident> = field_names
43        .iter()
44        .map(|name| format_ident!("{}_bytes", name))
45        .collect();
46
47    let byte_var_defs = byte_var_names.iter()
48        .zip(field_names.iter())
49        .map(|(var_name, field_name)| {
50            quote!(let #var_name = crate::Value::from(self.#field_name).serialize()?;)
51        });
52
53    let deserialize_var_defs = field_names.iter().map(|name| {
54        quote!(
55            let (#name, remaining) = crate::Value::deserialize(bytes)?;
56            bytes = remaining;
57        )
58    });
59
60    let deserialize_fields = field_names
61        .iter()
62        .map(|name| quote!(#name: #name.try_into()?,));
63
64    quote!(
65        #structure
66
67        impl #type_args crate::serialization::BoltValue for #name #type_args
68        #where_clause
69        {
70            fn marker(&self) -> crate::error::SerializeResult<u8> {
71                Ok(#marker)
72            }
73
74            fn serialize(self) -> crate::error::SerializeResult<::bytes::Bytes> {
75                use ::bytes::BufMut;
76                use crate::serialization::{BoltStructure, BoltValue};
77
78                let marker = self.marker()?;
79                let signature = self.signature();
80                #(#byte_var_defs)*
81
82                // Marker byte, signature byte, then the rest of the data
83                let mut result_bytes_mut = ::bytes::BytesMut::with_capacity(
84                    std::mem::size_of::<u8>() * 2 #(+ #byte_var_names.len())*
85                );
86                result_bytes_mut.put_u8(marker);
87                result_bytes_mut.put_u8(signature);
88                #(result_bytes_mut.put(#byte_var_names);)*
89                Ok(result_bytes_mut.freeze())
90            }
91
92            fn deserialize<B>(mut bytes: B) -> crate::error::DeserializeResult<(Self, B)>
93            where B: ::bytes::Buf + ::std::panic::UnwindSafe
94            {
95                #(#deserialize_var_defs)*
96                Ok((Self { #(#deserialize_fields)* }, bytes))
97            }
98        }
99
100        impl #type_args crate::serialization::BoltStructure for #name #type_args
101        #where_clause
102        {
103            fn signature(&self) -> u8 {
104                #signature
105            }
106        }
107    )
108    .into()
109}