use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse::Parser, parse2, punctuated::Punctuated, Fields, ItemStruct, LitInt, LitStr, Meta, Token,
};
pub fn expand(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream> {
let input: ItemStruct = parse2(item)?;
let metas: Punctuated<Meta, Token![,]> = Punctuated::<Meta, Token![,]>::parse_terminated
.parse2(attr.clone())
.unwrap_or_default();
let mut cu_hint: u32 = 0;
let mut allow_tail = false;
for m in &metas {
match m {
Meta::NameValue(nv) => {
if nv.path.is_ident("cu") {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(li),
..
}) = &nv.value
{
cu_hint = li.base10_parse::<u32>()?;
}
}
}
Meta::Path(p) if p.is_ident("tail") => {
allow_tail = true;
}
_ => {}
}
}
let name = input.ident.clone();
let fields = match &input.fields {
Fields::Named(n) => n.named.iter().collect::<Vec<_>>(),
_ => {
return Err(syn::Error::new_spanned(
&name,
"#[hopper::args] requires a named-field struct",
));
}
};
if fields.is_empty() {
return Err(syn::Error::new_spanned(
&name,
"#[hopper::args] requires at least one field",
));
}
let mut descriptor_entries = Vec::with_capacity(fields.len());
for f in &fields {
let fname = LitStr::new(
&f.ident.as_ref().unwrap().to_string(),
f.ident.as_ref().unwrap().span(),
);
let canonical = LitStr::new(
&canonical_ty_name(&f.ty),
f.ty.clone().into_token_stream_span(),
);
let ty = &f.ty;
descriptor_entries.push(quote! {
::hopper::hopper_schema::ArgDescriptor {
name: #fname,
canonical_type: #canonical,
size: ::core::mem::size_of::<#ty>() as u16,
}
});
}
let cu_lit = LitInt::new(&format!("{}u32", cu_hint), name.span());
let ty_list: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let option_field_idents: Vec<&syn::Ident> = fields
.iter()
.filter(|f| is_option_byte_type(&f.ty))
.filter_map(|f| f.ident.as_ref())
.collect();
let tag_validators: Vec<TokenStream> = option_field_idents
.iter()
.map(|ident| {
quote! {
r.#ident.validate_tag()?;
}
})
.collect();
let parse_with_tail_fn = if allow_tail {
quote! {
#[inline]
pub fn parse_with_tail(data: &[u8])
-> ::core::result::Result<
(&Self, &[u8]),
::hopper::hopper_schema::ArgParseError,
>
{
let head = Self::parse(data)?;
let tail = &data[Self::PACKED_SIZE..];
::core::result::Result::Ok((head, tail))
}
}
} else {
TokenStream::new()
};
let gen = quote! {
#input
impl #name {
pub const PACKED_SIZE: usize = 0 #( + ::core::mem::size_of::<#ty_list>() )*;
pub const CU_HINT: u32 = #cu_lit;
pub const ARG_DESCRIPTORS: &'static [::hopper::hopper_schema::ArgDescriptor] = &[
#( #descriptor_entries ),*
];
#[inline]
pub fn parse(data: &[u8]) -> ::core::result::Result<&Self, ::hopper::hopper_schema::ArgParseError> {
if data.len() < Self::PACKED_SIZE {
return ::core::result::Result::Err(
::hopper::hopper_schema::ArgParseError::TooShort {
required: Self::PACKED_SIZE as u16,
got: data.len() as u16,
}
);
}
let r = unsafe { &*(data.as_ptr() as *const Self) };
::core::result::Result::Ok(r)
}
#[inline]
pub fn validate_tags(&self)
-> ::core::result::Result<(), ::hopper::__runtime::ProgramError>
{
#( #tag_validators )*
::core::result::Result::Ok(())
}
#[inline]
pub fn parse_checked(data: &[u8])
-> ::core::result::Result<&Self, ::hopper::__runtime::ProgramError>
{
let r = Self::parse(data).map_err(|_| {
::hopper::__runtime::ProgramError::InvalidInstructionData
})?;
r.validate_tags()?;
::core::result::Result::Ok(r)
}
#parse_with_tail_fn
}
};
Ok(gen)
}
fn is_option_byte_type(ty: &syn::Type) -> bool {
if let syn::Type::Path(p) = ty {
if let Some(last) = p.path.segments.last() {
return last.ident == "OptionByte";
}
}
false
}
fn canonical_ty_name(ty: &syn::Type) -> String {
match ty {
syn::Type::Path(p) => p
.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_else(|| "unknown".to_string()),
syn::Type::Array(a) => {
let inner = canonical_ty_name(&a.elem);
format!("[{};{}]", inner, describe_array_len(&a.len))
}
_ => "unknown".to_string(),
}
}
fn describe_array_len(expr: &syn::Expr) -> String {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(li),
..
}) = expr
{
li.base10_digits().to_string()
} else {
"?".to_string()
}
}
trait IntoTokenStreamSpan {
fn into_token_stream_span(self) -> proc_macro2::Span;
}
impl IntoTokenStreamSpan for syn::Type {
fn into_token_stream_span(self) -> proc_macro2::Span {
match &self {
syn::Type::Path(p) => p
.path
.segments
.last()
.map(|s| s.ident.span())
.unwrap_or_else(proc_macro2::Span::call_site),
_ => proc_macro2::Span::call_site(),
}
}
}
#[cfg(test)]
mod args_tests {
use super::*;
use quote::quote;
fn expand_ok(attr: TokenStream, item: TokenStream) -> String {
expand(attr, item).expect("expand ok").to_string()
}
#[test]
fn plain_args_emit_parse_checked_and_validate_tags() {
let expanded = expand_ok(
quote!(),
quote! {
#[repr(C)]
pub struct Simple {
pub amount: u64,
}
},
);
assert!(expanded.contains("fn parse ("));
assert!(expanded.contains("fn parse_checked ("));
assert!(expanded.contains("fn validate_tags ("));
assert!(!expanded.contains("fn parse_with_tail ("));
}
#[test]
fn tail_flag_emits_parse_with_tail() {
let expanded = expand_ok(
quote!(tail),
quote! {
#[repr(C)]
pub struct WithTail {
pub amount: u64,
}
},
);
assert!(expanded.contains("fn parse_with_tail ("));
}
#[test]
fn cu_hint_is_recorded_on_the_impl() {
let expanded = expand_ok(
quote!(cu = 1200),
quote! {
#[repr(C)]
pub struct Costed {
pub amount: u64,
}
},
);
assert!(expanded.contains("CU_HINT"));
assert!(expanded.contains("1200u32"));
}
#[test]
fn option_byte_field_emits_tag_validator() {
let expanded = expand_ok(
quote!(),
quote! {
#[repr(C)]
pub struct WithOpt {
pub flag: OptionByte<u64>,
}
},
);
assert!(expanded.contains("validate_tags"));
assert!(
expanded.contains(". flag . validate_tag") || expanded.contains(".flag.validate_tag")
);
}
}