#![warn(missing_docs)]
#[cfg(all(
test,
any(
target_arch = "x86_64",
target_arch = "x86",
target_arch = "aarch64",
target_arch = "loongarch64"
)
))]
mod tests;
use proc_macro::TokenStream;
use proc_macro_crate::{FoundCrate, crate_name};
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::{
Attribute, Data, DeriveInput, Fields, Ident, Index, LitStr, Meta, Type, parse_macro_input,
};
#[proc_macro_derive(RedoubtZero, attributes(fast_zeroize))]
pub fn derive_redoubt_zero(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
expand(input).unwrap_or_else(|e| e).into()
}
pub(crate) fn find_root_with_candidates(candidates: &[&'static str]) -> TokenStream2 {
for &candidate in candidates {
if let Some((crate_part, path_part)) = candidate.split_once("::") {
match crate_name(crate_part) {
Ok(FoundCrate::Itself) => {
let path: TokenStream2 = path_part.parse().unwrap_or_else(|_| quote!());
return quote!(crate::#path);
}
Ok(FoundCrate::Name(name)) => {
let crate_id = Ident::new(&name, Span::call_site());
let path: TokenStream2 = path_part.parse().unwrap_or_else(|_| quote!());
return quote!(#crate_id::#path);
}
Err(_) => continue,
}
} else {
match crate_name(candidate) {
Ok(FoundCrate::Itself) => return quote!(crate),
Ok(FoundCrate::Name(name)) => {
let id = Ident::new(&name, Span::call_site());
return quote!(#id);
}
Err(_) => continue,
}
}
}
let msg = "RedoubtZero: could not find redoubt-zero or redoubt-zero-core. Add redoubt-zero to Cargo.toml.";
let lit = LitStr::new(msg, Span::call_site());
quote! { compile_error!(#lit); }
}
pub(crate) fn is_zeroize_on_drop_sentinel_type(ty: &Type) -> bool {
matches!(
ty,
Type::Path(type_path)
if type_path.path.segments.last()
.map(|seg| seg.ident == "ZeroizeOnDropSentinel")
.unwrap_or(false)
)
}
pub(crate) fn is_mut_reference_type(ty: &Type) -> bool {
if let Type::Reference(r) = ty {
r.mutability.is_some()
} else {
false
}
}
pub(crate) fn is_immut_reference_type(ty: &Type) -> bool {
if let Type::Reference(r) = ty {
r.mutability.is_none()
} else {
false
}
}
fn has_fast_zeroize_skip(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| match &attr.meta {
Meta::List(meta_list) => {
meta_list.path.is_ident("fast_zeroize") && meta_list.tokens.to_string().contains("skip")
}
_ => false,
})
}
fn has_fast_zeroize_drop(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| match &attr.meta {
Meta::List(meta_list) => {
meta_list.path.is_ident("fast_zeroize") && meta_list.tokens.to_string().contains("drop")
}
_ => false,
})
}
struct SentinelState {
index: usize,
access: TokenStream2,
}
fn expand(input: DeriveInput) -> Result<TokenStream2, TokenStream2> {
let struct_name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let root = find_root_with_candidates(&["redoubt-zero-core", "redoubt-zero", "redoubt::zero"]);
let all_fields: Vec<(usize, &syn::Field)> = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(named) => named.named.iter().enumerate().collect(),
Fields::Unnamed(unnamed) => unnamed.unnamed.iter().enumerate().collect(),
Fields::Unit => vec![],
},
_ => {
return Err(syn::Error::new_spanned(
&input.ident,
"RedoubtZero can only be derived for structs (named or tuple).",
)
.to_compile_error());
}
};
let sentinel_ident = format_ident!("__sentinel");
let mut maybe_sentinel_state: Option<SentinelState> = None;
for (i, f) in &all_fields {
let is_sentinel = if let Some(ident) = &f.ident {
if *ident == sentinel_ident {
maybe_sentinel_state = Some(SentinelState {
index: *i,
access: quote! { self.#sentinel_ident },
});
true
} else {
false
}
} else {
if is_zeroize_on_drop_sentinel_type(&f.ty) {
let idx = Index::from(*i);
maybe_sentinel_state = Some(SentinelState {
index: *i,
access: quote! { self.#idx },
});
true
} else {
false
}
};
if is_sentinel {
break;
}
}
let sentinel_idx = maybe_sentinel_state.as_ref().map(|s| s.index);
for (i, f) in &all_fields {
if Some(*i) == sentinel_idx {
continue;
}
if is_immut_reference_type(&f.ty) && !has_fast_zeroize_skip(&f.attrs) {
let field_name = if let Some(ident) = &f.ident {
format!("field `{}`", ident)
} else {
format!("field at index {}", i)
};
return Err(syn::Error::new_spanned(
&f.ty,
format!(
"{} has type `&T` (immutable reference) which cannot be zeroized. \
Add `#[fast_zeroize(skip)]` to exclude it from zeroization.",
field_name
),
)
.to_compile_error());
}
}
let (immut_refs_without_sentinel, _): (Vec<TokenStream2>, Vec<TokenStream2>) = all_fields
.iter()
.filter(|(i, f)| Some(*i) != sentinel_idx && !has_fast_zeroize_skip(&f.attrs))
.map(|(i, f)| {
let is_mut_ref = is_mut_reference_type(&f.ty);
if let Some(ident) = &f.ident {
let immut_ref = if is_mut_ref {
quote! { self.#ident }
} else {
quote! { &self.#ident }
};
(immut_ref, quote! { &mut self.#ident })
} else {
let idx = Index::from(*i);
let immut_ref = if is_mut_ref {
quote! { self.#idx }
} else {
quote! { &self.#idx }
};
(immut_ref, quote! { &mut self.#idx })
}
})
.unzip();
let (_, mut_refs_with_sentinel): (Vec<TokenStream2>, Vec<TokenStream2>) = all_fields
.iter()
.filter(|(_, f)| !has_fast_zeroize_skip(&f.attrs))
.map(|(i, f)| {
let is_mut_ref = is_mut_reference_type(&f.ty);
if let Some(ident) = &f.ident {
let mut_ref = if is_mut_ref {
quote! { self.#ident }
} else {
quote! { &mut self.#ident }
};
(quote! { &self.#ident }, mut_ref)
} else {
let idx = Index::from(*i);
let mut_ref = if is_mut_ref {
quote! { self.#idx }
} else {
quote! { &mut self.#idx }
};
(quote! { &self.#idx }, mut_ref)
}
})
.unzip();
let len_without_sentinel = immut_refs_without_sentinel.len();
let len_without_sentinel_lit =
syn::LitInt::new(&len_without_sentinel.to_string(), Span::call_site());
let len_with_sentinel = mut_refs_with_sentinel.len();
let len_with_sentinel_lit = syn::LitInt::new(&len_with_sentinel.to_string(), Span::call_site());
let should_generate_drop = has_fast_zeroize_drop(&input.attrs);
let drop_impl = if should_generate_drop {
quote! {
impl #impl_generics Drop for #struct_name #ty_generics #where_clause {
fn drop(&mut self) {
#root::FastZeroizable::fast_zeroize(self);
}
}
}
} else {
quote! {}
};
let output = quote! {
impl #impl_generics #root::ZeroizeMetadata for #struct_name #ty_generics #where_clause {
const CAN_BE_BULK_ZEROIZED: bool = false;
}
impl #impl_generics #root::FastZeroizable for #struct_name #ty_generics #where_clause {
fn fast_zeroize(&mut self) {
let fields: [&mut dyn #root::FastZeroizable; #len_with_sentinel_lit] = [
#( #root::collections::to_fast_zeroizable_dyn_mut(#mut_refs_with_sentinel) ),*
];
#root::collections::zeroize_collection(&mut fields.into_iter())
}
}
impl #impl_generics #root::ZeroizationProbe for #struct_name #ty_generics #where_clause {
fn is_zeroized(&self) -> bool {
let fields: [&dyn #root::ZeroizationProbe; #len_without_sentinel_lit] = [
#( #root::collections::to_zeroization_probe_dyn_ref(#immut_refs_without_sentinel) ),*
];
#root::collections::collection_zeroed(&mut fields.into_iter())
}
}
#drop_impl
};
let assert_impl = if let Some(sentinel_state) = maybe_sentinel_state {
let sentinel_access = sentinel_state.access;
quote! {
impl #impl_generics #root::AssertZeroizeOnDrop for #struct_name #ty_generics #where_clause {
fn clone_sentinel(&self) -> #root::ZeroizeOnDropSentinel {
#sentinel_access.clone()
}
fn assert_zeroize_on_drop(self) {
#root::assert::assert_zeroize_on_drop(self);
}
}
}
} else {
quote! {}
};
let full_output = quote! {
#output
#assert_impl
};
Ok(full_output)
}