use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{ExprStruct, Result};
static PRINTERS: [(&'static str, fn(&mut ExprStruct) -> Result<TokenStream2>); 1] =
[("OVERLAPPED_ENTRY", construct_overlapped_entry)];
static ZEROABLE: [&'static str; 2] = ["OVERLAPPED", "OVERLAPPED_ENTRY"];
fn find_printer(
ident: &syn::Ident,
) -> Option<&'static fn(&mut ExprStruct) -> Result<TokenStream2>> {
PRINTERS
.iter()
.find_map(|(name, printer)| if ident == name { Some(printer) } else { None })
}
fn construct_overlapped_entry(fields: &mut ExprStruct) -> Result<TokenStream2> {
let internal = syn::parse2::<syn::FieldValue>(quote! {Internal:0})?;
fields.fields.push(internal);
Ok(fields.into_token_stream())
}
pub(crate) fn make_expanded(toks: TokenStream) -> Result<TokenStream2> {
let mut st = syn::parse::<ExprStruct>(toks)?;
let printer = st.path.get_ident().and_then(find_printer);
match printer {
Some(printer) => (*printer)(&mut st),
None => Ok(st.into_token_stream()),
}
}
pub(crate) fn make_expand_zeroed(toks: TokenStream) -> Result<TokenStream2> {
let type_array = syn::parse::<syn::TypeArray>(toks)?;
let len = match type_array.len {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(lit),
..
}) => Ok(lit),
_ => Err(syn::Error::new_spanned(
type_array.len.into_token_stream(),
"Array length expression must be const integer",
)),
}?;
let ty = match *type_array.elem {
syn::Type::Path(syn::TypePath { ref path, .. }) => match path.get_ident() {
Some(ident) if ZEROABLE.contains(&ident.to_string().as_str()) => Some(&type_array.elem),
_ => None,
},
_ => None,
}
.ok_or_else(|| {
syn::Error::new_spanned(
type_array.elem.as_ref().into_token_stream(),
"This type has not been audited as safe to be zeroed",
)
})?;
Ok(quote! {unsafe { std::mem::zeroed::<[#ty;#len]>()}})
}