use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, spanned::Spanned, Fields, ItemEnum, Variant};
#[proc_macro_attribute]
pub fn traceable(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemEnum);
expand_traceable(input).unwrap_or_else(|e| e.to_compile_error()).into()
}
fn expand_traceable(mut item: ItemEnum) -> syn::Result<proc_macro2::TokenStream> {
let enum_ident = item.ident.clone();
let generics = item.generics.clone();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let mut from_impls = Vec::new();
let mut seen_from_sources: std::collections::HashMap<String, proc_macro2::Span> =
std::collections::HashMap::new();
for variant in &mut item.variants {
let from_info = extract_from_source(variant)?;
let Some(from_info) = from_info else {
continue;
};
let source_ty = from_info.source_ty.clone();
let source_ty_key = quote!(#source_ty).to_string();
if let Some(prev_span) = seen_from_sources.get(&source_ty_key) {
let mut err = syn::Error::new(
variant.span(),
format!(
"duplicate #[from] source type `{}`; this would create conflicting `From<{}>` impls",
source_ty_key, source_ty_key
),
);
err.combine(syn::Error::new(*prev_span, "previous #[from] source type seen here"));
return Err(err);
}
seen_from_sources.insert(source_ty_key, variant.span());
rewrite_from_variant(variant, &from_info)?;
let variant_ident = variant.ident.clone();
let source_field = from_info.source_field.clone();
let extra_fields = extra_default_inits(variant, &source_field)?;
let merge_origin = is_thistrace_origin(&source_ty);
let merge_bubbled = is_thistrace_bubbled(&source_ty);
let from_impl = if merge_origin {
quote! {
impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
#[track_caller]
fn from(source: #source_ty) -> Self {
let __loc = ::core::panic::Location::caller();
let __frame = ::thistrace::Frame::from_location(__loc);
let mut __trace = ::thistrace::HasTrace::trace(&source)
.cloned()
.unwrap_or_else(::thistrace::Trace::empty);
__trace.push(__frame);
#enum_ident::#variant_ident {
#source_field: source,
#(#extra_fields,)*
trace: __trace,
}
}
}
}
} else if merge_bubbled {
quote! {
impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
#[track_caller]
fn from(source: #source_ty) -> Self {
let __trace = ::thistrace::HasTrace::trace(&source)
.cloned()
.unwrap_or_else(::thistrace::Trace::empty);
#enum_ident::#variant_ident {
#source_field: source,
#(#extra_fields,)*
trace: __trace,
}
}
}
}
} else {
quote! {
impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
#[track_caller]
fn from(source: #source_ty) -> Self {
let __loc = ::core::panic::Location::caller();
let __frame = ::thistrace::Frame::from_location(__loc);
#enum_ident::#variant_ident {
#source_field: source,
#(#extra_fields,)*
trace: ::thistrace::Trace::from_frame(__frame),
}
}
}
}
};
from_impls.push(from_impl);
}
let match_arms = item.variants.iter().map(|v| {
let vident = &v.ident;
match &v.fields {
Fields::Named(named) => {
let has_trace = named.named.iter().any(|f| {
f.ident
.as_ref()
.is_some_and(|id| id == "trace")
});
if has_trace {
quote! { Self::#vident { trace, .. } => ::core::option::Option::Some(trace), }
} else {
quote! { Self::#vident { .. } => ::core::option::Option::None, }
}
}
Fields::Unnamed(_) => quote! { Self::#vident ( .. ) => ::core::option::Option::None, },
Fields::Unit => quote! { Self::#vident => ::core::option::Option::None, },
}
});
let has_trace_impl = quote! {
impl #impl_generics ::thistrace::HasTrace for #enum_ident #ty_generics #where_clause {
fn trace(&self) -> ::core::option::Option<&::thistrace::Trace> {
match self {
#(#match_arms)*
}
}
}
};
Ok(quote! {
#item
#(#from_impls)*
#has_trace_impl
})
}
struct FromInfo {
source_ty: syn::Type,
source_field: syn::Ident,
shape: FromShape,
tuple_ctx_tys: Vec<syn::Type>,
}
enum FromShape {
Tuple,
Struct,
}
fn extract_from_source(variant: &Variant) -> syn::Result<Option<FromInfo>> {
if let Fields::Unnamed(fields) = &variant.fields {
let from_indices: Vec<usize> = fields
.unnamed
.iter()
.enumerate()
.filter(|(_, f)| f.attrs.iter().any(|a| a.path().is_ident("from")))
.map(|(i, _)| i)
.collect();
if from_indices.len() > 1 {
return Err(syn::Error::new(
variant.span(),
"multiple #[from] fields in a single tuple variant are not supported",
));
}
if from_indices.len() == 1 {
let from_index = from_indices[0];
let from_field = &fields.unnamed[from_index];
let ctx_tys = fields
.unnamed
.iter()
.enumerate()
.filter(|(i, _)| *i != from_index)
.map(|(_, f)| f.ty.clone())
.collect::<Vec<_>>();
if !ctx_tys.is_empty() || from_field.attrs.iter().any(|a| a.path().is_ident("from")) {
return Ok(Some(FromInfo {
source_ty: from_field.ty.clone(),
source_field: format_ident!("source"),
shape: FromShape::Tuple,
tuple_ctx_tys: ctx_tys,
}));
}
}
}
if let Fields::Named(fields) = &variant.fields {
let from_fields: Vec<_> = fields
.named
.iter()
.filter(|f| f.attrs.iter().any(|a| a.path().is_ident("from")))
.collect();
if from_fields.len() > 1 {
return Err(syn::Error::new(
variant.span(),
"multiple #[from] fields in a single struct variant are not supported",
));
}
if from_fields.len() == 1 {
let field = from_fields[0];
let ident = field.ident.clone().ok_or_else(|| {
syn::Error::new(field.span(), "expected a named field for struct #[from] variant")
})?;
return Ok(Some(FromInfo {
source_ty: field.ty.clone(),
source_field: ident,
shape: FromShape::Struct,
tuple_ctx_tys: Vec::new(),
}));
}
}
Ok(None)
}
fn rewrite_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
match info.shape {
FromShape::Tuple => rewrite_tuple_from_variant(variant, &info.source_ty, &info.tuple_ctx_tys),
FromShape::Struct => rewrite_struct_from_variant(variant, info),
}
}
fn rewrite_tuple_from_variant(
variant: &mut Variant,
source_ty: &syn::Type,
ctx_tys: &[syn::Type],
) -> syn::Result<()> {
let variant_ident = variant.ident.clone();
match &variant.fields {
Fields::Unnamed(_) => {
let mut named = syn::punctuated::Punctuated::new();
named.push(syn::Field {
attrs: vec![syn::parse_quote!(#[source])],
vis: syn::Visibility::Inherited,
mutability: syn::FieldMutability::None,
ident: Some(format_ident!("source")),
colon_token: Some(Default::default()),
ty: source_ty.clone(),
});
for (i, ty) in ctx_tys.iter().enumerate() {
named.push(syn::Field {
attrs: vec![],
vis: syn::Visibility::Inherited,
mutability: syn::FieldMutability::None,
ident: Some(format_ident!("ctx{i}")),
colon_token: Some(Default::default()),
ty: ty.clone(),
});
}
named.push(syn::Field {
attrs: vec![],
vis: syn::Visibility::Inherited,
mutability: syn::FieldMutability::None,
ident: Some(format_ident!("trace")),
colon_token: Some(Default::default()),
ty: syn::parse_quote!(::thistrace::Trace),
});
variant.fields = Fields::Named(syn::FieldsNamed {
brace_token: Default::default(),
named,
});
Ok(())
}
_ => Err(syn::Error::new(
variant_ident.span(),
"only tuple variants can be rewritten for #[from]",
)),
}
}
fn rewrite_struct_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
let Fields::Named(fields) = &mut variant.fields else {
return Err(syn::Error::new(variant.span(), "expected struct variant"));
};
for field in fields.named.iter_mut() {
if field.ident.as_ref() == Some(&info.source_field) {
field.attrs.retain(|a| !a.path().is_ident("from"));
let has_source = field.attrs.iter().any(|a| a.path().is_ident("source"));
if !has_source {
field.attrs.push(syn::parse_quote!(#[source]));
}
}
}
let has_trace = fields
.named
.iter()
.any(|f| f.ident.as_ref().is_some_and(|id| id == "trace"));
if !has_trace {
fields.named.push(syn::Field {
attrs: vec![],
vis: syn::Visibility::Inherited,
mutability: syn::FieldMutability::None,
ident: Some(format_ident!("trace")),
colon_token: Some(Default::default()),
ty: syn::parse_quote!(::thistrace::Trace),
});
}
Ok(())
}
fn extra_default_inits(
variant: &Variant,
source_field: &syn::Ident,
) -> syn::Result<Vec<proc_macro2::TokenStream>> {
let mut inits = Vec::new();
let Fields::Named(fields) = &variant.fields else {
return Ok(inits);
};
for field in fields.named.iter() {
let Some(ident) = field.ident.as_ref() else {
continue;
};
if ident == source_field {
continue;
}
if ident == "trace" {
continue;
}
inits.push(quote! { #ident: ::core::default::Default::default() });
}
Ok(inits)
}
fn is_thistrace_origin(ty: &syn::Type) -> bool {
let syn::Type::Path(p) = ty else {
return false;
};
let Some(seg) = p.path.segments.last() else {
return false;
};
if seg.ident != "Origin" {
return false;
}
matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
}
fn is_thistrace_bubbled(ty: &syn::Type) -> bool {
let syn::Type::Path(p) = ty else {
return false;
};
let Some(seg) = p.path.segments.last() else {
return false;
};
if seg.ident != "Bubbled" {
return false;
}
matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
}