use darling::{ast, FromDeriveInput, FromField, FromVariant};
use proc_macro2::TokenStream;
use quote::quote;
mod macros;
use crate::macros::{deku_read::emit_deku_read, deku_write::emit_deku_write};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
#[derive(Debug)]
struct DekuData {
vis: syn::Visibility,
ident: syn::Ident,
generics: syn::Generics,
data: ast::Data<VariantData, FieldData>,
endian: Option<syn::LitStr>,
ctx: Option<syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>>,
id: Option<TokenStream>,
id_type: Option<syn::Ident>,
id_bits: Option<usize>,
}
impl DekuData {
fn from_receiver(receiver: DekuReceiver) -> Result<Self, TokenStream> {
DekuData::validate(&receiver)
.map_err(|(span, msg)| syn::Error::new(span, msg).to_compile_error())?;
let data = match receiver.data {
ast::Data::Struct(fields) => ast::Data::Struct(ast::Fields {
style: fields.style,
fields: fields
.fields
.into_iter()
.map(FieldData::from_receiver)
.collect::<Result<Vec<_>, _>>()?,
}),
ast::Data::Enum(variants) => ast::Data::Enum(
variants
.into_iter()
.map(VariantData::from_receiver)
.collect::<Result<Vec<_>, _>>()?,
),
};
let ctx = receiver
.ctx
.map(|s| s.parse_with(syn::punctuated::Punctuated::parse_terminated))
.transpose()
.map_err(|e| e.to_compile_error())?;
let id_bits = receiver.id_bytes.map(|b| b * 8).or(receiver.id_bits);
Ok(Self {
vis: receiver.vis,
ident: receiver.ident,
generics: receiver.generics,
data,
endian: receiver.endian,
ctx,
id: receiver.id,
id_type: receiver.id_type,
id_bits,
})
}
fn validate(receiver: &DekuReceiver) -> Result<(), (proc_macro2::Span, &str)> {
match receiver.data {
ast::Data::Struct(_) => {
if receiver.id_type.is_some() {
Err((receiver.id_type.span(), "`id_type` only supported on enum"))
} else if receiver.id.is_some() {
Err((receiver.id.span(), "`id` only supported on enum"))
} else if receiver.id_bytes.is_some() {
Err((
receiver.id_bytes.span(),
"`id_bytes` only supported on enum",
))
} else if receiver.id_bits.is_some() {
Err((receiver.id_bits.span(), "`id_bits` only supported on enum"))
} else {
Ok(())
}
}
ast::Data::Enum(_) => {
if receiver.id_type.is_none() && receiver.id.is_none() {
return Err((
receiver.ident.span(),
"`id_type` or `id` must be specified on enum",
));
}
if receiver.id_type.is_some() && receiver.id.is_some() {
return Err((
receiver.ident.span(),
"conflicting: both `id_type` and `id` specified on enum",
));
}
if receiver.id_bits.is_some() && receiver.id_bytes.is_some() {
return Err((
receiver.id_bits.span(),
"conflicting: both `id_bits` and `id_bytes` specified on enum",
));
}
Ok(())
}
}
}
fn emit_reader(&self) -> TokenStream {
match self.emit_reader_checked() {
Ok(tks) => tks,
Err(e) => e.to_compile_error(),
}
}
fn emit_writer(&self) -> TokenStream {
match self.emit_writer_checked() {
Ok(tks) => tks,
Err(e) => e.to_compile_error(),
}
}
fn emit_reader_checked(&self) -> Result<TokenStream, syn::Error> {
emit_deku_read(self)
}
fn emit_writer_checked(&self) -> Result<TokenStream, syn::Error> {
emit_deku_write(self)
}
}
#[derive(Debug)]
struct FieldData {
ident: Option<syn::Ident>,
ty: syn::Type,
endian: Option<syn::LitStr>,
bits: Option<usize>,
count: Option<TokenStream>,
map: Option<TokenStream>,
ctx: Option<Punctuated<syn::Expr, syn::token::Comma>>,
update: Option<TokenStream>,
reader: Option<TokenStream>,
writer: Option<TokenStream>,
skip: bool,
default: TokenStream,
cond: Option<TokenStream>,
}
impl FieldData {
fn from_receiver(receiver: DekuFieldReceiver) -> Result<Self, TokenStream> {
FieldData::validate(&receiver)
.map_err(|(span, msg)| syn::Error::new(span, msg).to_compile_error())?;
let bits = receiver.bytes.map(|b| b * 8).or(receiver.bits);
let default = receiver.default.unwrap_or(quote! { Default::default() });
let ctx = receiver
.ctx
.map(|s| s.parse_with(Punctuated::parse_terminated))
.transpose()
.map_err(|e| e.to_compile_error())?;
Ok(Self {
ident: receiver.ident,
ty: receiver.ty,
endian: receiver.endian,
bits,
count: receiver.count,
map: receiver.map,
ctx,
update: receiver.update,
reader: receiver.reader,
writer: receiver.writer,
skip: receiver.skip,
default,
cond: receiver.cond,
})
}
fn validate(receiver: &DekuFieldReceiver) -> Result<(), (proc_macro2::Span, &str)> {
if receiver.bits.is_some() && receiver.bytes.is_some() {
return Err((
receiver.bits.span(),
"conflicting: both `bits` and `bytes` specified on field",
));
}
if receiver.default.is_some() && (!receiver.skip && receiver.cond.is_none()) {
return Err((
receiver.default.span(),
"`default` attribute cannot be used here",
));
}
Ok(())
}
fn get_ident(&self, index: usize, prefix: bool) -> TokenStream {
let field_ident = gen_field_ident(self.ident.as_ref(), index, prefix);
quote! { #field_ident }
}
}
#[derive(Debug)]
struct VariantData {
ident: syn::Ident,
fields: ast::Fields<FieldData>,
reader: Option<TokenStream>,
writer: Option<TokenStream>,
id: Option<TokenStream>,
id_pat: Option<TokenStream>,
}
impl VariantData {
fn from_receiver(receiver: DekuVariantReceiver) -> Result<Self, TokenStream> {
VariantData::validate(&receiver)
.map_err(|(span, msg)| syn::Error::new(span, msg).to_compile_error())?;
let fields = ast::Fields {
style: receiver.fields.style,
fields: receiver
.fields
.fields
.into_iter()
.map(FieldData::from_receiver)
.collect::<Result<Vec<_>, _>>()?,
};
Ok(Self {
ident: receiver.ident,
fields,
reader: receiver.reader,
writer: receiver.writer,
id: receiver.id,
id_pat: receiver.id_pat,
})
}
fn validate(receiver: &DekuVariantReceiver) -> Result<(), (proc_macro2::Span, &str)> {
if receiver.id.is_some() && receiver.id_pat.is_some() {
return Err((
receiver.id.span(),
"conflicting: both `id` and `id_pat` specified on variant",
));
}
Ok(())
}
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(deku), supports(struct_any, enum_any))]
struct DekuReceiver {
vis: syn::Visibility,
ident: syn::Ident,
generics: syn::Generics,
data: ast::Data<DekuVariantReceiver, DekuFieldReceiver>,
#[darling(default)]
endian: Option<syn::LitStr>,
#[darling(default)]
ctx: Option<syn::LitStr>,
#[darling(default, map = "option_as_tokenstream")]
id: Option<TokenStream>,
#[darling(default)]
id_type: Option<syn::Ident>,
#[darling(default)]
id_bits: Option<usize>,
#[darling(default)]
id_bytes: Option<usize>,
}
fn option_as_tokenstream(input: Option<syn::LitStr>) -> Option<TokenStream> {
input.map(|v| {
v.parse::<TokenStream>()
.expect("could not parse token stream")
})
}
fn gen_field_ident<T: ToString>(ident: Option<T>, index: usize, prefix: bool) -> TokenStream {
let field_name = match ident {
Some(field_name) => field_name.to_string(),
None => {
let index = syn::Index::from(index);
let prefix = if prefix { "field_" } else { "" };
format!("{}{}", prefix, quote! { #index })
}
};
field_name.parse().unwrap()
}
#[derive(Debug, FromField)]
#[darling(attributes(deku))]
struct DekuFieldReceiver {
ident: Option<syn::Ident>,
ty: syn::Type,
#[darling(default)]
endian: Option<syn::LitStr>,
#[darling(default)]
bits: Option<usize>,
#[darling(default)]
bytes: Option<usize>,
#[darling(default, map = "option_as_tokenstream")]
count: Option<TokenStream>,
#[darling(default, map = "option_as_tokenstream")]
map: Option<TokenStream>,
#[darling(default)]
ctx: Option<syn::LitStr>,
#[darling(default, map = "option_as_tokenstream")]
update: Option<TokenStream>,
#[darling(default, map = "option_as_tokenstream")]
reader: Option<TokenStream>,
#[darling(default, map = "option_as_tokenstream")]
writer: Option<TokenStream>,
#[darling(default)]
skip: bool,
#[darling(default, map = "option_as_tokenstream")]
default: Option<TokenStream>,
#[darling(default, map = "option_as_tokenstream")]
cond: Option<TokenStream>,
}
#[derive(Debug, FromVariant)]
#[darling(attributes(deku))]
struct DekuVariantReceiver {
ident: syn::Ident,
fields: ast::Fields<DekuFieldReceiver>,
#[darling(default, map = "option_as_tokenstream")]
reader: Option<TokenStream>,
#[darling(default, map = "option_as_tokenstream")]
writer: Option<TokenStream>,
#[darling(default, map = "option_as_tokenstream")]
id: Option<TokenStream>,
#[darling(default, map = "option_as_tokenstream")]
id_pat: Option<TokenStream>,
}
#[proc_macro_derive(DekuRead, attributes(deku))]
pub fn proc_deku_read(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = match syn::parse(input) {
Ok(input) => input,
Err(err) => return err.to_compile_error().into(),
};
let receiver = match DekuReceiver::from_derive_input(&input) {
Ok(receiver) => receiver,
Err(err) => return err.write_errors().into(),
};
let data = match DekuData::from_receiver(receiver) {
Ok(data) => data,
Err(err) => return err.into(),
};
data.emit_reader().into()
}
#[proc_macro_derive(DekuWrite, attributes(deku))]
pub fn proc_deku_write(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = match syn::parse(input) {
Ok(input) => input,
Err(err) => return err.to_compile_error().into(),
};
let receiver = match DekuReceiver::from_derive_input(&input) {
Ok(receiver) => receiver,
Err(err) => return err.write_errors().into(),
};
let data = match DekuData::from_receiver(receiver) {
Ok(data) => data,
Err(err) => return err.into(),
};
data.emit_writer().into()
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use syn::parse_str;
#[rstest(input,
case::struct_empty(r#"struct Test {}"#),
case::struct_unnamed(r#"struct Test(u8, u8);"#),
case::struct_unnamed_attrs(r#"struct Test(#[deku(bits=4)] u8, u8);"#),
case::struct_all_attrs(r#"
struct Test {
#[deku(bits = 4)]
field_a: u8,
#[deku(bytes = 4)]
field_b: u64,
#[deku(endian = little)]
field_c: u32,
#[deku(endian = big)]
field_d: u32,
#[deku(skip, default = "5")]
field_e: u32,
}"#),
case::enum_empty(r#"#[deku(id_type = "u8")] enum Test {}"#),
case::enum_all(r#"
#[deku(id_type = "u8")]
enum Test {
#[deku(id = "1")]
A,
#[deku(id = "2")]
B(#[deku(bits = 4)] u8),
#[deku(id = "3")]
C { field_n: u8 },
}"#),
case::invalid_storage(r#"struct Test(#[deku(bits=9)] u8);"#),
case::invalid_endian(r#"struct Test(#[endian=big] u8);"#),
)]
fn test_macro(input: &str) {
let parsed = parse_str(input).unwrap();
let receiver = DekuReceiver::from_derive_input(&parsed).unwrap();
let data = DekuData::from_receiver(receiver).unwrap();
let res_reader = data.emit_reader_checked();
let res_writer = data.emit_writer_checked();
res_reader.unwrap();
res_writer.unwrap();
}
}