#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(docsrs, allow(unused_attributes))]
#![doc = include_str!("../README.md")]
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse::Parse, punctuated::Punctuated, DeriveInput, Expr, LitStr, Token};
struct FmtArgs {
format_str: LitStr,
_comma: Option<Token![,]>,
args: Punctuated<Expr, Token![,]>,
}
impl Parse for FmtArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let format_str: LitStr = input.parse()?;
let _comma: Option<Token![,]> = input.parse()?;
let args: Punctuated<Expr, Token![,]> =
input.parse_terminated(Expr::parse, Token![,])?;
Ok(FmtArgs {
format_str,
_comma,
args,
})
}
}
fn parse_specific_attr(attrs: &[syn::Attribute], name: &str) -> Option<FmtArgs> {
for attr in attrs {
if attr.path().is_ident(name) {
return attr.parse_args().ok();
}
}
None
}
fn parse_display_attr(attrs: &[syn::Attribute]) -> Option<FmtArgs> {
parse_specific_attr(attrs, "fmt_display")
.or_else(|| parse_specific_attr(attrs, "fmt"))
}
fn parse_debug_attr(attrs: &[syn::Attribute]) -> Option<FmtArgs> {
parse_specific_attr(attrs, "fmt_debug")
.or_else(|| parse_specific_attr(attrs, "fmt"))
}
fn is_simple_field(expr: &Expr) -> bool {
matches!(expr, Expr::Path(expr_path) if expr_path.path.get_ident().is_some())
}
fn extract_field_from_format(spec: &str) -> Option<&str> {
let content = spec.strip_prefix('{')?.strip_suffix('}')?;
let field_part = content.split(':').next()?;
if field_part.is_empty() {
return None;
}
let first_char = field_part.chars().next()?;
if !first_char.is_ascii_alphabetic() && first_char != '_' {
return None;
}
if !field_part.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
return None;
}
Some(field_part)
}
fn transform_format_string(input: &str) -> (String, Vec<String>) {
let mut result = String::with_capacity(input.len());
let mut fields = Vec::new();
let mut chars = input.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '{' {
if chars.peek() == Some(&'{') {
chars.next(); result.push_str("{{");
continue;
}
let mut brace_content = String::new();
let mut found_closing = false;
for inner_ch in chars.by_ref() {
if inner_ch == '}' {
found_closing = true;
break;
}
brace_content.push(inner_ch);
}
if !found_closing {
result.push('{');
result.push_str(&brace_content);
continue;
}
if let Some(field_name) = extract_field_from_format(&format!("{{{}}}", brace_content)) {
fields.push(field_name.to_string());
let format_spec = &brace_content[field_name.len()..];
result.push('{');
result.push_str(format_spec);
result.push('}');
} else {
result.push('{');
result.push_str(&brace_content);
result.push('}');
}
} else if ch == '}' {
if chars.peek() == Some(&'}') {
chars.next(); result.push_str("}}");
} else {
result.push('}');
}
} else {
result.push(ch);
}
}
(result, fields)
}
fn generate_fmt_body(fmt_args: &FmtArgs) -> proc_macro2::TokenStream {
let original_format_str = fmt_args.format_str.value();
let (transformed_format, extracted_fields) = transform_format_string(&original_format_str);
let format_str = LitStr::new(&transformed_format, fmt_args.format_str.span());
let mut all_args: Vec<proc_macro2::TokenStream> = Vec::new();
for field in &extracted_fields {
let field_ident = syn::Ident::new(field, proc_macro2::Span::call_site());
all_args.push(quote! { self.#field_ident });
}
for arg in &fmt_args.args {
if is_simple_field(arg) {
all_args.push(quote! { self.#arg });
} else {
all_args.push(quote! { #arg });
}
}
quote! {
write!(f, #format_str #(, #all_args)*)
}
}
#[proc_macro_derive(DisplayAttr, attributes(fmt, fmt_display))]
pub fn derive_display(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fmt_args = match parse_display_attr(&input.attrs) {
Some(args) => args,
None => {
return syn::Error::new_spanned(
&input,
"DisplayAttr requires a #[fmt(...)] or #[fmt_display(...)] attribute",
)
.to_compile_error()
.into();
}
};
let fmt_body = generate_fmt_body(&fmt_args);
let expanded = quote! {
impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
#fmt_body
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(DebugAttr, attributes(fmt, fmt_debug))]
pub fn derive_debug(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fmt_args = match parse_debug_attr(&input.attrs) {
Some(args) => args,
None => {
return syn::Error::new_spanned(
&input,
"DebugAttr requires a #[fmt(...)] or #[fmt_debug(...)] attribute",
)
.to_compile_error()
.into();
}
};
let fmt_body = generate_fmt_body(&fmt_args);
let expanded = quote! {
impl #impl_generics std::fmt::Debug for #name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
#fmt_body
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(DisplayAsDebug)]
pub fn derive_display_as_debug(input: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let expanded = quote! {
impl #impl_generics std::fmt::Display for #name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(self, f)
}
}
};
TokenStream::from(expanded)
}