use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse::Parser, parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Expr, Fields, Lit,
Meta, Token,
};
#[proc_macro_derive(SecureSerialize, attributes(redact, secure_serialize))]
pub fn derive_secure_serialize(input: TokenStream) -> TokenStream {
let DeriveInput {
ident,
data,
generics,
attrs,
..
} = parse_macro_input!(input);
let (gen_debug, gen_display) = match extract_secure_serialize_options(&attrs) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let fields = match data {
Data::Struct(s) => match s.fields {
Fields::Named(f) => f.named,
_ => {
return syn::Error::new_spanned(
&ident,
"SecureSerialize only supports structs with named fields",
)
.to_compile_error()
.into();
}
},
_ => {
return syn::Error::new_spanned(&ident, "SecureSerialize only supports structs")
.to_compile_error()
.into();
}
};
let mut redacted_fields: Vec<(syn::Ident, String, String)> = Vec::new(); let mut redacted_custom_fields: Vec<(
syn::Ident,
String,
String,
proc_macro2::TokenStream,
syn::Type,
)> = Vec::new(); let mut custom_serialize_fields: Vec<(syn::Ident, proc_macro2::TokenStream, syn::Type)> =
Vec::new(); let mut normal_field_names: Vec<syn::Ident> = Vec::new();
for field in &fields {
let name = field.ident.as_ref().expect("named field");
let name_str = name.to_string();
let field_type = field.ty.clone();
let (is_redacted, redaction_string) = extract_redact_attribute(&field.attrs);
let custom_serialize_path = extract_serialize_with_attribute(&field.attrs);
match (is_redacted, &custom_serialize_path) {
(true, Some(path)) => redacted_custom_fields.push((
name.clone(),
name_str,
redaction_string,
path.clone(),
field_type,
)),
(true, None) => redacted_fields.push((name.clone(), name_str, redaction_string)),
(false, Some(path)) => {
custom_serialize_fields.push((name.clone(), path.clone(), field_type))
}
(false, None) => {
normal_field_names.push(name.clone());
}
}
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let redacted_field_names: Vec<String> =
redacted_fields.iter().map(|(_, n, _)| n.clone()).collect();
let redacted_field_idents: Vec<syn::Ident> =
redacted_fields.iter().map(|(i, _, _)| i.clone()).collect();
let redaction_strings: Vec<proc_macro2::TokenStream> = redacted_fields
.iter()
.map(|(_, _, r)| {
r.parse::<proc_macro2::TokenStream>()
.unwrap_or_else(|_| quote! { "<redacted>" })
})
.collect();
let redacted_custom_field_names: Vec<String> = redacted_custom_fields
.iter()
.map(|(_, n, _, _, _)| n.clone())
.collect();
let redacted_custom_strings: Vec<proc_macro2::TokenStream> = redacted_custom_fields
.iter()
.map(|(_, _, r, _, _)| {
r.parse::<proc_macro2::TokenStream>()
.unwrap_or_else(|_| quote! { "<redacted>" })
})
.collect();
let custom_serialize_idents: Vec<syn::Ident> = custom_serialize_fields
.iter()
.map(|(i, _, _)| i.clone())
.collect();
let generate_wrapper = |field_ident: &syn::Ident,
path: &proc_macro2::TokenStream,
field_type: &syn::Type| {
quote! {
{
struct _Wrapper<'a>(&'a #field_type);
impl<'a> ::serde::Serialize for _Wrapper<'a>
{
fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
#path(self.0, serializer)
}
}
_Wrapper(&self.#field_ident)
}
}
};
let custom_field_wrappers: Vec<proc_macro2::TokenStream> = custom_serialize_fields
.iter()
.map(|(ident, path, ty)| generate_wrapper(ident, path, ty))
.collect();
let redacted_custom_json_wrappers: Vec<proc_macro2::TokenStream> = redacted_custom_fields
.iter()
.map(|(ident, _, _, path, ty)| {
let wrapper = generate_wrapper(ident, path, ty);
quote! { ::serde_json::to_value(#wrapper)? }
})
.collect();
let custom_json_wrappers: Vec<proc_macro2::TokenStream> = custom_serialize_fields
.iter()
.map(|(ident, path, ty)| {
let wrapper = generate_wrapper(ident, path, ty);
quote! { ::serde_json::to_value(#wrapper)? }
})
.collect();
let debug_field_fragments: Vec<proc_macro2::TokenStream> = fields
.iter()
.map(|field| {
let name = field.ident.as_ref().expect("named field");
let name_literal = name.to_string();
let (is_redacted, redaction_string) = extract_redact_attribute(&field.attrs);
if is_redacted {
let redact_ts = redaction_string
.parse::<proc_macro2::TokenStream>()
.unwrap_or_else(|_| quote! { "<redacted>" });
quote! {
.field(#name_literal, &#redact_ts)
}
} else {
quote! {
.field(#name_literal, &self.#name)
}
}
})
.collect();
let debug_impl = if gen_debug {
quote! {
impl #impl_generics ::std::fmt::Debug for #ident #ty_generics #where_clause {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
f.debug_struct(stringify!(#ident))
#(#debug_field_fragments)*
.finish()
}
}
}
} else {
quote! {}
};
let display_impl = if gen_display {
quote! {
impl #impl_generics ::std::fmt::Display for #ident #ty_generics #where_clause {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
match ::serde_json::to_string(self) {
Ok(ref json) => f.write_str(json),
Err(e) => ::std::write!(
f,
concat!(stringify!(#ident), "(serialization error: {})"),
e
),
}
}
}
}
} else {
quote! {}
};
let expanded = quote! {
impl #impl_generics ::serde::Serialize for #ident #ty_generics #where_clause {
fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
use ::serde::ser::SerializeStruct;
let mut s = serializer.serialize_struct(
stringify!(#ident),
0usize
#(+ { let _ = stringify!(#redacted_field_names); 1usize })*
#(+ { let _ = stringify!(#redacted_custom_field_names); 1usize })*
#(+ { let _ = stringify!(#custom_serialize_idents); 1usize })*
#(+ { let _ = stringify!(#normal_field_names); 1usize })*
)?;
#(s.serialize_field(#redacted_field_names, #redaction_strings)?;)*
#(s.serialize_field(#redacted_custom_field_names, #redacted_custom_strings)?;)*
#(s.serialize_field(stringify!(#custom_serialize_idents), &#custom_field_wrappers)?;)*
#(s.serialize_field(stringify!(#normal_field_names), &self.#normal_field_names)?;)*
s.end()
}
}
impl #impl_generics ::secure_serialize::SecureSerialize for #ident #ty_generics #where_clause {
fn redacted_keys() -> &'static [&'static str] {
&[#(#redacted_field_names,)* #(#redacted_custom_field_names,)*]
}
fn to_json_unredacted(&self) -> ::std::result::Result<::serde_json::Value, ::serde_json::Error> {
use ::serde_json::Value as JsonValue;
let mut result = ::serde_json::Map::new();
#(result.insert(#redacted_field_names.to_string(), ::serde_json::to_value(&self.#redacted_field_idents)?);)*
#(result.insert(#redacted_custom_field_names.to_string(), #redacted_custom_json_wrappers);)*
#(result.insert(stringify!(#custom_serialize_idents).to_string(), #custom_json_wrappers);)*
#(result.insert(stringify!(#normal_field_names).to_string(), ::serde_json::to_value(&self.#normal_field_names)?);)*
Ok(JsonValue::Object(result))
}
}
#debug_impl
#display_impl
};
let tokens = expanded.into();
tokens
}
fn extract_secure_serialize_options(attrs: &[syn::Attribute]) -> Result<(bool, bool), syn::Error> {
let mut gen_debug = false;
let mut gen_display = false;
for attr in attrs {
if !attr.path().is_ident("secure_serialize") {
continue;
}
match &attr.meta {
Meta::Path(_) => {
return Err(syn::Error::new_spanned(
attr,
"expected #[secure_serialize(debug)], #[secure_serialize(display)], or both",
));
}
Meta::List(list) => {
if list.tokens.is_empty() {
return Err(syn::Error::new_spanned(
list,
"expected `debug` and/or `display` inside #[secure_serialize(...)]",
));
}
let metas = Punctuated::<Meta, Token![,]>::parse_terminated
.parse2(list.tokens.clone())?;
for meta in metas {
match meta {
Meta::Path(p) => {
if p.is_ident("debug") {
gen_debug = true;
} else if p.is_ident("display") {
gen_display = true;
} else {
return Err(syn::Error::new_spanned(
p,
"expected `debug` or `display`",
));
}
}
other => {
return Err(syn::Error::new_spanned(
other,
"expected `debug` or `display`",
));
}
}
}
}
Meta::NameValue(_) => {
return Err(syn::Error::new_spanned(
attr,
"invalid #[secure_serialize(...)] syntax",
));
}
}
}
Ok((gen_debug, gen_display))
}
fn extract_redact_attribute(attrs: &[syn::Attribute]) -> (bool, String) {
for attr in attrs {
if !attr.path().is_ident("redact") {
continue;
}
match &attr.meta {
syn::Meta::Path(_) => {
return (true, "\"<redacted>\"".to_string());
}
syn::Meta::List(list) => {
if let Ok(Meta::NameValue(nv)) = list.parse_args::<Meta>().and_then(|m| match m {
Meta::NameValue(nv) if nv.path.is_ident("with") => Ok(Meta::NameValue(nv)),
_ => Err(syn::Error::new_spanned(
&list,
"redact attribute expects: #[redact(with = \"string\")]",
)),
}) {
if let syn::Expr::Lit(expr_lit) = &nv.value {
if let syn::Lit::Str(lit_str) = &expr_lit.lit {
return (true, format!("\"{}\"", lit_str.value()));
}
}
}
}
_ => {}
}
}
(false, String::new())
}
fn extract_serialize_with_attribute(attrs: &[syn::Attribute]) -> Option<proc_macro2::TokenStream> {
for attr in attrs {
if !attr.path().is_ident("serde") {
continue;
}
let Meta::List(list) = &attr.meta else {
continue;
};
let metas = Punctuated::<Meta, Token![,]>::parse_terminated
.parse2(list.tokens.clone())
.ok()?;
for meta in metas {
let Meta::NameValue(name_value) = meta else {
continue;
};
if !name_value.path.is_ident("serialize_with") {
continue;
}
let Expr::Lit(expr_lit) = &name_value.value else {
continue;
};
let Lit::Str(value) = &expr_lit.lit else {
continue;
};
return value.parse::<proc_macro2::TokenStream>().ok();
}
}
None
}