#![deny(missing_docs)]
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::parse::Result;
use syn::spanned::Spanned;
use syn::{
parse_macro_input, Data, DataEnum, DeriveInput, Fields, GenericArgument, Pat, PathArguments,
Type, TypePath,
};
use proc_macro2::Span;
use syn::LitInt;
mod attributes;
use attributes::{Attributes, Endianness};
macro_rules! unwrap {
($expression:expr) => {
match $expression {
Ok(a) => a,
Err(e) => return e.to_compile_error().into(),
}
};
($expression:expr, $span:expr, $message:literal) => {
match $expression {
Some(a) => a,
None => {
return syn::Error::new($span.span(), $message)
.to_compile_error()
.into()
}
}
};
}
fn primitive_type(ty: &Ident) -> bool {
[
"f32", "f64", "i8", "i16", "i32", "i64", "i128", "u8", "u16", "u32", "u64", "u128",
]
.iter()
.any(|i| ty == i)
}
fn primitive_size(ty: &Ident) -> LitInt {
[
("f32", 4),
("f64", 8),
("i8", 1),
("i16", 2),
("i32", 4),
("i64", 8),
("i128", 16),
("u8", 1),
("u16", 2),
("u32", 4),
("u64", 8),
("u128", 16),
]
.iter()
.find_map(|(i, j)| {
if ty == i {
Some(LitInt::new(&j.to_string(), Span::call_site()))
} else {
None
}
})
.unwrap()
}
fn primitive_function(endianness: Endianness) -> (Ident, Ident) {
let en = match endianness {
Endianness::Big => "be",
Endianness::Little => "le",
Endianness::Native => "ne",
};
(
Ident::new(&format!("from_{}_bytes", en), Span::call_site()),
Ident::new(&format!("to_{}_bytes", en), Span::call_site()),
)
}
fn syn_error<S: Spanned, T>(span: &S, message: &str) -> Result<T> {
Err(syn::Error::new(span.span(), message))
}
#[proc_macro_derive(Plod, attributes(plod))]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let attributes = unwrap!(Attributes::parse(&input.attrs));
let plod_impl = unwrap!(plod_impl(&input, &attributes));
let name = input.ident;
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let type_params = input.generics.type_params();
let ctx_ty = attributes.context_type;
let expanded = quote! {
#[automatically_derived]
impl <#(#type_params),*> plod::Plod for #name #ty_generics #where_clause {
type Context= #ctx_ty;
#plod_impl
}
};
proc_macro::TokenStream::from(expanded)
}
fn plod_impl(input: &DeriveInput, attributes: &Attributes) -> Result<TokenStream> {
let self_name = &input.ident;
let (size_impl, read_impl, write_impl) = match &input.data {
Data::Struct(data) => {
let (size_code, read_code, write_code, field_list) = generate_for_fields(
&data.fields,
Some("e! { self. }),
&input.ident,
&attributes,
)?;
(
size_code,
quote! {
#read_code
Ok(#self_name #field_list)
},
quote! {
#write_code
Ok(())
},
)
}
Data::Enum(data) => enum_impl(self_name, data, attributes)?,
Data::Union(u) => {
return Err(syn::Error::new(
u.union_token.span(),
"Union types are not supported by plod",
))
}
};
Ok(quote! {
fn size_at_rest(&self) -> usize {
#size_impl
}
fn read_from<R: std::io::Read>(from: &mut R, ctx: &Self::Context) -> plod::Result<Self> {
#read_impl
}
fn write_to<W: std::io::Write>(&self, to: &mut W, ctx: &Self::Context) -> plod::Result<()> {
#write_impl
}
})
}
fn enum_impl(
self_name: &Ident,
data: &DataEnum,
attributes: &Attributes,
) -> Result<(TokenStream, TokenStream, TokenStream)> {
let mut size_impl = TokenStream::new();
let mut read_impl = TokenStream::new();
let mut write_impl = TokenStream::new();
let tag_type = match &attributes.tag_type {
Some(t) => t,
None => return syn_error(self_name, "#[plod(tag_type(<type>)] is mandatory for enum"),
};
if !primitive_type(tag_type) {
return syn_error(
&tag_type,
"#[plod(tag_type(<type>)] tag only works with primitive types",
);
}
let tag_size = primitive_size(tag_type);
let (from_method, to_method) = primitive_function(attributes.endianness);
let mut default_done = false;
for variant in data.variants.iter() {
let ident = &variant.ident;
let variant_attributes = attributes.extend(&variant.attrs)?;
let tag_value = &variant_attributes.tag;
if variant_attributes.skip {
let error_token = quote! { #self_name::#ident };
let error_str = error_token.to_string();
let fields_token = if let Fields::Unit = variant.fields {
TokenStream::new()
} else {
quote! { (..) }
};
size_impl.extend(quote! {
#self_name::#ident #fields_token => 0,
});
write_impl.extend(quote! {
#self_name::#ident #fields_token => {
return Err(std::io::Error::other(format!("Variant {} cannot be written because it is plod(skipped)", #error_str)));
}
});
continue;
}
if default_done {
return syn_error(
&variant.ident,
"The variant without #[plod(tag(<value>))] must come last",
);
}
let (size_code, read_code, write_code, field_list) =
generate_for_fields(&variant.fields, None, &variant.ident, &variant_attributes)?;
match &tag_value {
Some(value) => read_impl.extend(quote! {
#value => {
#read_code
Ok(#self_name::#ident #field_list)
}
}),
None => {
read_impl.extend(quote! {
_ => {
#read_code
Ok(#self_name::#ident #field_list)
}
});
default_done = true;
}
}
let add_tag = if variant_attributes.keep_tag {
TokenStream::new()
} else {
let tag_pattern = match &variant_attributes.tag {
Some(t) => t,
None => {
return syn_error(ident, "#[plod(tag(<value>))] is mandatory without keep_tag")
}
};
let tag_value = match tag_pattern {
Pat::Lit(expr) => expr,
_ => {
return syn_error(tag_type, "#[plod(keep_tag)] is mandatory with tag patterns")
}
};
quote! {
let buffer: [u8; #tag_size] = (#tag_value as #tag_type).#to_method();
to.write_all(&buffer)?;
}
};
write_impl.extend(quote! {
#self_name::#ident #field_list => {
#add_tag
#write_code
}
});
size_impl.extend(quote! {
#self_name::#ident #field_list => #size_code,
});
}
size_impl = quote! {
match self {
#size_impl
}
};
let read_tag = quote! {
let mut buffer: [u8; #tag_size] = [0; #tag_size];
from.read_exact(&mut buffer)?;
let discriminant = #tag_type::#from_method(buffer);
};
if default_done {
read_impl = quote! {
#read_tag
match discriminant {
#read_impl
}
};
} else {
read_impl = quote! {
#read_tag
match discriminant {
#read_impl
_ => return Err(std::io::Error::other(format!("Tag value {} not found", discriminant))),
}
};
}
write_impl = quote! {
match self {
#write_impl
}
Ok(())
};
Ok((size_impl, read_impl, write_impl))
}
fn generate_for_fields(
fields: &Fields,
field_prefix: Option<&TokenStream>,
ident: &Ident,
attributes: &Attributes,
) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream)> {
let mut size_code = TokenStream::new();
let mut read_code = TokenStream::new();
let mut write_code = TokenStream::new();
let mut field_list = TokenStream::new();
let mut context_val = quote! { ctx };
let mut prefixed_context_val = quote! { ctx };
if let Some((ty, value)) = &attributes.magic {
let (from_method, to_method) = primitive_function(attributes.endianness);
if !primitive_type(ty) {
return syn_error(ty, "magic only works with primitive types");
}
let ty_size = primitive_size(ty);
size_code.extend(quote! {
#ty_size +
});
read_code.extend(quote! {
let mut buffer: [u8; #ty_size] = [0; #ty_size];
from.read_exact(&mut buffer)?;
let magic = #ty::#from_method(buffer);
if magic != #value {
return Err(std::io::Error::other(format!("Magic value {} expected, found {}", #value, magic)));
}
});
write_code.extend(quote! {
let buffer: [u8; #ty_size] = (#value as #ty).#to_method();
to.write_all(&buffer)?;
});
}
match fields {
Fields::Named(fields) => {
let mut i = 0;
for field in fields.named.iter() {
let field_attributes = attributes.extend(&field.attrs)?;
let field_ident = field.ident.as_ref().unwrap();
let (prefixed_field_ref, prefixed_field_dotted) = match field_prefix {
None => (quote! { #field_ident }, quote! { #field_ident .}),
Some(prefix) => (
quote! { (& #prefix #field_ident) },
quote! { #prefix #field_ident . },
),
};
generate_for_item(
&field_ident,
&field.ty,
&prefixed_field_ref,
&prefixed_field_dotted,
i == 0 && attributes.keep_tag,
&field_attributes,
&mut size_code,
&mut read_code,
&mut write_code,
&context_val,
&prefixed_context_val,
)?;
if field_attributes.is_context {
context_val = quote! { (&#field_ident) };
prefixed_context_val = prefixed_field_ref;
}
field_list.extend(quote! {
#field_ident,
});
i += 1;
}
field_list = quote! { { #field_list } };
}
Fields::Unnamed(fields) => {
for (i, field) in fields.unnamed.iter().enumerate() {
let field_attributes = attributes.extend(&field.attrs)?;
let field_ident = Ident::new(&format!("field_{}", i), field.span());
let (prefixed_field_ref, prefixed_field_dotted) = match field_prefix {
None => (quote! { #field_ident }, quote! { #field_ident .}),
Some(prefix) => {
let i = syn::Index::from(i);
(quote! { ( & #prefix #i ) }, quote! { #prefix #i . })
}
};
generate_for_item(
&field_ident,
&field.ty,
&prefixed_field_ref,
&prefixed_field_dotted,
i == 0 && attributes.keep_tag,
&field_attributes,
&mut size_code,
&mut read_code,
&mut write_code,
&context_val,
&prefixed_context_val,
)?;
if field_attributes.is_context {
context_val = quote! { (&#field_ident) };
prefixed_context_val = quote! { #prefixed_field_ref };
}
field_list.extend(quote! {
#field_ident,
});
}
field_list = quote! { (#field_list) };
}
Fields::Unit => {
if attributes.keep_tag {
return syn_error(ident, "Cannot keep tag on unit variant");
}
}
};
if attributes.keep_tag {
size_code.extend(quote! { 0 });
} else {
match &attributes.tag_type {
None => size_code.extend(quote! { 0 }),
Some(ty) => {
let ty_size = primitive_size(ty);
size_code.extend(quote! { #ty_size });
}
}
}
Ok((size_code, read_code, write_code, field_list))
}
fn generate_for_item(
field_ident: &Ident,
field_type: &Type,
prefixed_field_ref: &TokenStream,
prefixed_field_dotted: &TokenStream,
is_tag: bool,
attributes: &Attributes,
size_code: &mut TokenStream,
read_code: &mut TokenStream,
write_code: &mut TokenStream,
context_val: &TokenStream,
prefixed_context_val: &TokenStream,
) -> Result<()> {
if attributes.skip {
read_code.extend(quote! {
let #field_ident = <#field_type as std::default::Default>::default();
});
return Ok(());
}
match field_type {
Type::Path(type_path) => {
let mut is_vec = false;
let mut is_primitive = false;
if let Some(id) = type_path.path.segments.first() {
is_vec = id.ident == "Vec";
is_primitive = primitive_type(&id.ident);
};
if is_vec {
generate_for_vec(
type_path,
field_ident,
prefixed_field_dotted,
attributes,
size_code,
read_code,
write_code,
context_val,
prefixed_context_val,
)?;
} else if is_primitive {
let ty = type_path.path.get_ident().unwrap();
let ty_size = primitive_size(ty);
let (from_method, to_method) = primitive_function(attributes.endianness);
size_code.extend(quote! {
#ty_size +
});
if is_tag {
if let Some(diff) = &attributes.keep_diff {
read_code.extend(quote! {
let #field_ident = discriminant as #ty - #diff;
});
} else {
read_code.extend(quote! {
let #field_ident = discriminant as #ty;
});
}
} else {
read_code.extend(quote! {
let mut buffer: [u8; #ty_size] = [0; #ty_size];
from.read_exact(&mut buffer)?;
let #field_ident = #ty::#from_method(buffer);
});
}
let diff = if is_tag && attributes.keep_diff.is_some() {
let diff = attributes.keep_diff.as_ref().unwrap();
quote! { + #diff }
} else {
TokenStream::new()
};
write_code.extend(quote! {
let buffer: [u8; #ty_size] = (#prefixed_field_ref #diff). #to_method();
to.write_all(&buffer)?;
});
} else {
size_code.extend(quote! {
<#type_path as plod::Plod>::size_at_rest(#prefixed_field_ref) +
});
read_code.extend(quote! {
let #field_ident = <#type_path as plod::Plod>::read_from(from, #context_val.into())?;
});
write_code.extend(quote! {
<#type_path as plod::Plod>::write_to(#prefixed_field_ref, to, #prefixed_context_val.into())?;
});
}
}
Type::Tuple(t) => {
let mut field_list = TokenStream::new();
for (i, field_ty) in t.elems.iter().enumerate() {
let field_ident = Ident::new(&format!("infield_{}", i), field_ty.span());
let (prefixed_field_ref, prefixed_field_dotted) = {
let i = syn::Index::from(i);
(
quote! { ( & #prefixed_field_dotted #i ) },
quote! { #prefixed_field_dotted #i . },
)
};
generate_for_item(
&field_ident,
field_ty,
&prefixed_field_ref,
&prefixed_field_dotted,
false,
attributes,
size_code,
read_code,
write_code,
context_val,
prefixed_context_val,
)?;
field_list.extend(quote! {
#field_ident,
});
}
read_code.extend(quote! {
let #field_ident = (#field_list);
});
}
Type::Array(t) => {
let n = &t.len;
let ty_ = &t.elem;
let mut item_size_code = TokenStream::new();
let mut item_read_code = TokenStream::new();
let mut item_write_code = TokenStream::new();
let item_name = Ident::new("item", field_ident.span());
generate_for_item(
&item_name,
ty_,
"e! { #item_name },
"e! { #item_name . },
false,
attributes,
&mut item_size_code,
&mut item_read_code,
&mut item_write_code,
context_val,
prefixed_context_val,
)?;
size_code.extend(quote! {
#prefixed_field_dotted iter().fold(0, |n, item| n + #item_size_code 0) +
});
read_code.extend(quote! {
let mut vec = Vec::new();
for _ in 0..#n {
#item_read_code
vec.push(item);
}
let #field_ident: #t = vec.try_into().unwrap();
});
write_code.extend(quote! {
for item in #prefixed_field_dotted iter() {
#item_write_code
}
});
}
_ => {
return syn_error(field_ident, "Unsupported type for Plod");
}
}
Ok(())
}
fn generate_for_vec(
type_path: &TypePath,
field_ident: &Ident,
prefixed_field_dotted: &TokenStream,
attributes: &Attributes,
size_code: &mut TokenStream,
read_code: &mut TokenStream,
write_code: &mut TokenStream,
context_val: &TokenStream,
prefixed_context_val: &TokenStream,
) -> Result<()> {
let size_ty = match &attributes.size_type {
Some(ty) => ty,
None => {
return syn_error(
type_path,
"#[plod(size_type(<value>))] is mandatory for Vec<type>",
);
}
};
if !primitive_type(size_ty) {
return syn_error(size_ty, "vec length magic only works with primitive types");
}
let ty_size = primitive_size(size_ty);
let (from_method, to_method) = primitive_function(attributes.endianness);
let vec_generic = match &type_path.path.segments.first().unwrap().arguments {
PathArguments::AngleBracketed(pa) => {
if pa.args.len() != 1 {
return syn_error(
type_path,
"Plod only support regular Vec<Type>: unknown type Vec<X,Y,...>",
);
}
match pa.args.first().unwrap() {
GenericArgument::Type(t) => t,
_ => {
return syn_error(
type_path,
"Plod only support regular Vec<Type>: unknown Vec<...>",
)
}
}
}
_ => {
return syn_error(
type_path,
"Plod only support regular Vec<Type>: unknown Vec...",
);
}
};
let mut item_size_code = TokenStream::new();
let mut item_read_code = TokenStream::new();
let mut item_write_code = TokenStream::new();
let item_name = Ident::new("item", field_ident.span());
generate_for_item(
&item_name,
vec_generic,
"e! { #item_name },
"e! { #item_name . },
false,
attributes,
&mut item_size_code,
&mut item_read_code,
&mut item_write_code,
context_val,
prefixed_context_val,
)?;
size_code.extend(quote! {
#ty_size + #prefixed_field_dotted iter().fold(0, |n, item| n + #item_size_code 0) +
});
let (plus_one, minus_one) = if attributes.size_is_next {
(quote! { + 1 }, quote! { - 1 })
} else {
(quote! {}, quote! {})
};
read_code.extend(quote! {
let mut #field_ident = Vec::new();
let mut buffer: [u8; #ty_size] = [0; #ty_size];
from.read_exact(&mut buffer)?;
let mut size = #size_ty::#from_method(buffer) as usize #minus_one;
});
if attributes.byte_sized {
read_code.extend(quote! {
while size > 0 {
#item_read_code
size -= #item_size_code 0;
#field_ident.push(item);
}
});
write_code.extend(quote! {
let size = #prefixed_field_dotted iter().fold(0, |n, item| n + #item_size_code 0);
let buffer: [u8; #ty_size] = (size as #size_ty #plus_one).#to_method();
to.write_all(&buffer)?;
});
} else {
read_code.extend(quote! {
for _ in 0..size {
#item_read_code
#field_ident.push(item);
}
});
write_code.extend(quote! {
let size = #prefixed_field_dotted len();
let buffer: [u8; #ty_size] = (size as #size_ty #plus_one).#to_method();
to.write_all(&buffer)?;
});
}
write_code.extend(quote! {
for item in #prefixed_field_dotted iter() {
#item_write_code
}
});
Ok(())
}