use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
Data, DeriveInput, Fields, Index, LitInt, LitStr, Meta, Path, Token, Type, TypePath,
};
#[proc_macro_derive(AFastSerialize, attributes(validate))]
pub fn derive_serialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let generics = &input.generics;
let mut generics_with_bounds = generics.clone();
for param in &mut generics_with_bounds.params {
if let syn::GenericParam::Type(ref mut ty) = *param {
ty.bounds
.push(syn::parse_quote!(::afastdata_core::AFastSerialize));
}
}
let (impl_generics, _, _) = generics_with_bounds.split_for_impl();
let (_, ty_generics, _) = generics.split_for_impl();
let expanded = match &input.data {
Data::Struct(data) => {
let serialize_body = generate_serialize_fields(&data.fields, quote!(self));
quote! {
impl #impl_generics ::afastdata_core::AFastSerialize for #name #ty_generics {
fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
#(#serialize_body)*
bytes
}
}
}
}
Data::Enum(data) => {
let mut arms = Vec::new();
for (i, variant) in data.variants.iter().enumerate() {
let variant_name = &variant.ident;
let tag = i as u32;
match &variant.fields {
Fields::Unit => {
arms.push(quote! {
#name::#variant_name => {
bytes.extend(#tag.to_le_bytes());
}
});
}
Fields::Unnamed(fields) => {
let field_names: Vec<_> = (0..fields.unnamed.len())
.map(|i| {
let ident =
syn::Ident::new(&format!("__f{}", i), variant_name.span());
ident
})
.collect();
let field_patterns = &field_names;
let mut serialize_fields = Vec::new();
for fname in &field_names {
serialize_fields.push(quote! {
bytes.extend(::afastdata_core::AFastSerialize::to_bytes(#fname));
});
}
arms.push(quote! {
#name::#variant_name(#(#field_patterns),*) => {
bytes.extend(#tag.to_le_bytes());
#(#serialize_fields)*
}
});
}
Fields::Named(fields) => {
let field_names: Vec<_> = fields
.named
.iter()
.map(|f| f.ident.as_ref().unwrap())
.collect();
let mut serialize_fields = Vec::new();
for fname in &field_names {
serialize_fields.push(quote! {
bytes.extend(::afastdata_core::AFastSerialize::to_bytes(#fname));
});
}
arms.push(quote! {
#name::#variant_name { #(#field_names),* } => {
bytes.extend(#tag.to_le_bytes());
#(#serialize_fields)*
}
});
}
}
}
quote! {
impl #impl_generics ::afastdata_core::AFastSerialize for #name #ty_generics {
fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
match self {
#(#arms)*
}
bytes
}
}
}
}
Data::Union(_) => panic!("AFastSerialize does not support unions"),
};
TokenStream::from(expanded)
}
#[proc_macro_derive(AFastDeserialize, attributes(validate))]
pub fn derive_deserialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let generics = &input.generics;
let mut generics_with_bounds = generics.clone();
for param in &mut generics_with_bounds.params {
if let syn::GenericParam::Type(ref mut ty) = *param {
ty.bounds
.push(syn::parse_quote!(::afastdata_core::AFastSerialize));
ty.bounds
.push(syn::parse_quote!(::afastdata_core::AFastDeserialize));
}
}
let (impl_generics, _, _) = generics_with_bounds.split_for_impl();
let (_, ty_generics, _) = generics.split_for_impl();
let expanded = match &input.data {
Data::Struct(data) => {
let (construct, field_desers) =
generate_deserialize_fields(&data.fields, &name, &ty_generics);
quote! {
impl #impl_generics ::afastdata_core::AFastDeserialize for #name #ty_generics {
fn from_bytes(data: &[u8]) -> Result<(Self, usize), ::afastdata_core::Error> {
let mut offset: usize = 0;
#(#field_desers)*
Ok((#construct, offset))
}
}
}
}
Data::Enum(data) => {
let mut arms = Vec::new();
for (i, variant) in data.variants.iter().enumerate() {
let variant_name = &variant.ident;
let tag = i as u32;
match &variant.fields {
Fields::Unit => {
arms.push(quote! {
#tag => {
Ok((#name::#variant_name, offset))
}
});
}
Fields::Unnamed(fields) => {
let mut field_desers = Vec::new();
let mut field_names = Vec::new();
for _ in &fields.unnamed {
let fname = syn::Ident::new(
&format!("__f{}", field_names.len()),
variant_name.span(),
);
field_desers.push(quote! {
let (__val, __new_offset) = ::afastdata_core::AFastDeserialize::from_bytes(&data[offset..])?;
let #fname = __val;
offset += __new_offset;
});
field_names.push(fname);
}
arms.push(quote! {
#tag => {
#(#field_desers)*
Ok((#name::#variant_name(#(#field_names),*), offset))
}
});
}
Fields::Named(fields) => {
let mut field_desers = Vec::new();
let mut field_names = Vec::new();
for f in &fields.named {
let fname = f.ident.as_ref().unwrap();
field_desers.push(quote! {
let (__val, __new_offset) = ::afastdata_core::AFastDeserialize::from_bytes(&data[offset..])?;
let #fname = __val;
offset += __new_offset;
});
field_names.push(fname);
}
arms.push(quote! {
#tag => {
#(#field_desers)*
Ok((#name::#variant_name { #(#field_names),* }, offset))
}
});
}
}
}
quote! {
impl #impl_generics ::afastdata_core::AFastDeserialize for #name #ty_generics {
fn from_bytes(data: &[u8]) -> Result<(Self, usize), ::afastdata_core::Error> {
let mut offset: usize = 0;
let (__tag_bytes, __new_offset) = <u32 as ::afastdata_core::AFastDeserialize>::from_bytes(&data[offset..])?;
offset += __new_offset;
match __tag_bytes {
#(#arms)*
v => Err(::afastdata_core::Error::deserialize(format!("Unknown variant tag: {} for {}", v, ::std::stringify!(#name)))),
}
}
}
}
}
Data::Union(_) => panic!("AFastDeserialize does not support unions"),
};
TokenStream::from(expanded)
}
fn generate_serialize_fields(
fields: &Fields,
self_prefix: proc_macro2::TokenStream,
) -> Vec<proc_macro2::TokenStream> {
match fields {
Fields::Named(named) => named
.named
.iter()
.map(|f| {
let fname = f.ident.as_ref().unwrap();
quote! {
bytes.extend(::afastdata_core::AFastSerialize::to_bytes(&#self_prefix.#fname));
}
})
.collect(),
Fields::Unnamed(unnamed) => unnamed
.unnamed
.iter()
.enumerate()
.map(|(i, _)| {
let idx = Index::from(i);
quote! {
bytes.extend(::afastdata_core::AFastSerialize::to_bytes(&#self_prefix.#idx));
}
})
.collect(),
Fields::Unit => vec![],
}
}
struct Range {
int: LitInt,
_comma1: Token![,],
code: LitInt,
_comma2: Token![,],
msg: LitStr,
}
impl Parse for Range {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Range {
int: input.parse()?,
_comma1: input.parse()?,
code: input.parse()?,
_comma2: input.parse()?,
msg: input.parse()?,
})
}
}
struct Length {
min: LitInt,
_comma1: Token![,],
max: LitInt,
_comma2: Token![,],
code: LitInt,
_comma3: Token![,],
msg: LitStr,
}
impl Parse for Length {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(Length {
min: input.parse()?,
_comma1: input.parse()?,
max: input.parse()?,
_comma2: input.parse()?,
code: input.parse()?,
_comma3: input.parse()?,
msg: input.parse()?,
})
}
}
fn generate_deserialize_fields(
fields: &Fields,
name: &syn::Ident,
ty_generics: &syn::TypeGenerics,
) -> (proc_macro2::TokenStream, Vec<proc_macro2::TokenStream>) {
let ty_params = ty_generics.as_turbofish();
match fields {
Fields::Named(named) => {
let mut desers = Vec::new();
let mut field_names = Vec::new();
for f in &named.named {
let fname = f.ident.as_ref().unwrap();
let ftype = &f.ty;
field_names.push(fname.clone());
let mut validates = Vec::new();
for attr in &f.attrs {
if attr.path().is_ident("validate") {
let nested = attr
.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
.unwrap();
for meta in nested {
match meta {
Meta::List(meta) => {
if meta.path.is_ident("gt") {
let inner = meta.parse_args::<Range>().unwrap();
let gt_value = inner.int.base10_parse::<i64>().unwrap();
let code = inner.code.base10_parse::<i64>().unwrap();
let err_msg = inner
.msg
.value()
.replace("${field}", &fname.to_string());
validates.push(quote! {
if #fname <= #gt_value {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
});
} else if meta.path.is_ident("gte") {
let inner = meta.parse_args::<Range>().unwrap();
let gt_value = inner.int.base10_parse::<i64>().unwrap();
let code = inner.code.base10_parse::<i64>().unwrap();
let err_msg = inner
.msg
.value()
.replace("${field}", &fname.to_string());
validates.push(quote! {
if #fname < #gt_value {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
});
} else if meta.path.is_ident("lt") {
let inner = meta.parse_args::<Range>().unwrap();
let lt_value = inner.int.base10_parse::<i64>().unwrap();
let code = inner.code.base10_parse::<i64>().unwrap();
let err_msg = inner
.msg
.value()
.replace("${field}", &fname.to_string());
validates.push(quote! {
if #fname >= #lt_value {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
});
} else if meta.path.is_ident("lte") {
let inner = meta.parse_args::<Range>().unwrap();
let lt_value = inner.int.base10_parse::<i64>().unwrap();
let code = inner.code.base10_parse::<i64>().unwrap();
let err_msg = inner
.msg
.value()
.replace("${field}", &fname.to_string());
validates.push(quote! {
if #fname > #lt_value {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
});
} else if meta.path.is_ident("len") {
let field_is_option = match ftype {
Type::Path(TypePath {
path: Path { segments, .. },
..
}) => {
segments.len() == 1 && segments[0].ident == "Option"
}
_ => false,
};
let inner = meta.parse_args::<Length>().unwrap();
let min_value = inner.min.base10_parse::<i64>().unwrap();
let max_value = inner.max.base10_parse::<i64>().unwrap();
let code = inner.code.base10_parse::<i64>().unwrap();
let err_msg = inner
.msg
.value()
.replace("${field}", &fname.to_string());
if min_value > max_value {
panic!("Invalid validation: min value {} is greater than max value {} for field {}", min_value, max_value, fname);
}
if min_value < 0 && max_value < 0 {
panic!("Invalid validation: both min and max values are negative for field {}", fname);
} else if min_value < 0 {
let max: usize = max_value.try_into().unwrap();
validates.push(quote! {
if #fname.len() > #max {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
});
} else if max_value < 0 {
let min: usize = min_value.try_into().unwrap();
validates.push(quote! {
if #fname.len() < #min {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
});
} else {
let min: usize = min_value.try_into().unwrap();
let max: usize = max_value.try_into().unwrap();
if field_is_option {
validates.push(quote! {
let length = match &#fname { Some(s) => {
let __length = s.len();
if __length < #min || __length > #max {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
},
None => {},
};
});
} else {
validates.push(quote! {
if #fname.len() > #max {
return Err(::afastdata_core::Error::validate(#code, #err_msg.to_string()));
}
});
}
}
} else if meta.path.is_ident("func") {
let inner = meta.parse_args::<LitStr>().unwrap();
let ident =
syn::parse_str::<syn::Ident>(&inner.value()).unwrap();
let field = fname.to_string();
validates.push(quote! {
match #ident(&#fname, #field) {
Ok(()) => {},
Err(e) => return Err(e.to_afastdata_error()),
}
});
}
}
_ => {}
}
}
}
}
desers.push(quote! {
let (__val, __new_offset) = ::afastdata_core::AFastDeserialize::from_bytes(&data[offset..])?;
let #fname: #ftype = __val;
#(#validates)*
offset += __new_offset;
});
}
let construct = quote! {
#name #ty_params { #(#field_names),* }
};
(construct, desers)
}
Fields::Unnamed(unnamed) => {
let mut desers = Vec::new();
let mut field_names = Vec::new();
for i in 0..unnamed.unnamed.len() {
let fname = syn::Ident::new(&format!("__f{}", i), name.span());
desers.push(quote! {
let (__val, __new_offset) = ::afastdata_core::AFastDeserialize::from_bytes(&data[offset..])?;
let #fname = __val;
offset += __new_offset;
});
field_names.push(fname);
}
let construct = quote! {
#name #ty_params ( #(#field_names),* )
};
(construct, desers)
}
Fields::Unit => {
let construct = quote! { #name #ty_params };
(construct, vec![])
}
}
}