thistrace-macros 0.1.0

Proc-macros for the thistrace crate
Documentation
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, spanned::Spanned, Fields, ItemEnum, Variant};

#[proc_macro_attribute]
pub fn traceable(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let input = parse_macro_input!(item as ItemEnum);
    expand_traceable(input).unwrap_or_else(|e| e.to_compile_error()).into()
}

fn expand_traceable(mut item: ItemEnum) -> syn::Result<proc_macro2::TokenStream> {
    let enum_ident = item.ident.clone();
    let generics = item.generics.clone();
    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

    let mut from_impls = Vec::new();

    let mut seen_from_sources: std::collections::HashMap<String, proc_macro2::Span> =
        std::collections::HashMap::new();

    for variant in &mut item.variants {
        let from_info = extract_from_source(variant)?;
        let Some(from_info) = from_info else {
            continue;
        };
        // Reserved for future: variant-level trace merging.

        let source_ty = from_info.source_ty.clone();
        let source_ty_key = quote!(#source_ty).to_string();
        if let Some(prev_span) = seen_from_sources.get(&source_ty_key) {
            let mut err = syn::Error::new(
                variant.span(),
                format!(
                    "duplicate #[from] source type `{}`; this would create conflicting `From<{}>` impls",
                    source_ty_key, source_ty_key
                ),
            );
            err.combine(syn::Error::new(*prev_span, "previous #[from] source type seen here"));
            return Err(err);
        }
        seen_from_sources.insert(source_ty_key, variant.span());

        rewrite_from_variant(variant, &from_info)?;

        let variant_ident = variant.ident.clone();
        let source_field = from_info.source_field.clone();
        let extra_fields = extra_default_inits(variant, &source_field)?;
        let merge_origin = is_thistrace_origin(&source_ty);
        let merge_bubbled = is_thistrace_bubbled(&source_ty);
        let from_impl = if merge_origin {
            quote! {
                impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
                    #[track_caller]
                    fn from(source: #source_ty) -> Self {
                        let __loc = ::core::panic::Location::caller();
                        let __frame = ::thistrace::Frame::from_location(__loc);
                        let mut __trace = ::thistrace::HasTrace::trace(&source)
                            .cloned()
                            .unwrap_or_else(::thistrace::Trace::empty);
                        __trace.push(__frame);

                        #enum_ident::#variant_ident {
                            #source_field: source,
                            #(#extra_fields,)*
                            trace: __trace,
                        }
                    }
                }
            }
        } else if merge_bubbled {
            quote! {
                impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
                    #[track_caller]
                    fn from(source: #source_ty) -> Self {
                        let __trace = ::thistrace::HasTrace::trace(&source)
                            .cloned()
                            .unwrap_or_else(::thistrace::Trace::empty);

                        #enum_ident::#variant_ident {
                            #source_field: source,
                            #(#extra_fields,)*
                            trace: __trace,
                        }
                    }
                }
            }
        } else {
            quote! {
                impl #impl_generics ::core::convert::From<#source_ty> for #enum_ident #ty_generics #where_clause {
                    #[track_caller]
                    fn from(source: #source_ty) -> Self {
                        let __loc = ::core::panic::Location::caller();
                        let __frame = ::thistrace::Frame::from_location(__loc);
                        #enum_ident::#variant_ident {
                            #source_field: source,
                            #(#extra_fields,)*
                            trace: ::thistrace::Trace::from_frame(__frame),
                        }
                    }
                }
            }
        };
        from_impls.push(from_impl);
    }

    // Generate HasTrace impl that returns the variant trace if present.
    let match_arms = item.variants.iter().map(|v| {
        let vident = &v.ident;
        match &v.fields {
            Fields::Named(named) => {
                let has_trace = named.named.iter().any(|f| {
                    f.ident
                        .as_ref()
                        .is_some_and(|id| id == "trace")
                });
                if has_trace {
                    quote! { Self::#vident { trace, .. } => ::core::option::Option::Some(trace), }
                } else {
                    quote! { Self::#vident { .. } => ::core::option::Option::None, }
                }
            }
            Fields::Unnamed(_) => quote! { Self::#vident ( .. ) => ::core::option::Option::None, },
            Fields::Unit => quote! { Self::#vident => ::core::option::Option::None, },
        }
    });

    let has_trace_impl = quote! {
        impl #impl_generics ::thistrace::HasTrace for #enum_ident #ty_generics #where_clause {
            fn trace(&self) -> ::core::option::Option<&::thistrace::Trace> {
                match self {
                    #(#match_arms)*
                }
            }
        }
    };

    Ok(quote! {
        #item
        #(#from_impls)*
        #has_trace_impl
    })
}

struct FromInfo {
    source_ty: syn::Type,
    source_field: syn::Ident,
    shape: FromShape,
    tuple_ctx_tys: Vec<syn::Type>,
}

enum FromShape {
    Tuple,
    Struct,
}

fn extract_from_source(variant: &Variant) -> syn::Result<Option<FromInfo>> {
    // tuple form: Foo(#[from] io::Error, Ctx0, Ctx1, ...)
    if let Fields::Unnamed(fields) = &variant.fields {
        let from_indices: Vec<usize> = fields
            .unnamed
            .iter()
            .enumerate()
            .filter(|(_, f)| f.attrs.iter().any(|a| a.path().is_ident("from")))
            .map(|(i, _)| i)
            .collect();
        if from_indices.len() > 1 {
            return Err(syn::Error::new(
                variant.span(),
                "multiple #[from] fields in a single tuple variant are not supported",
            ));
        }
        if from_indices.len() == 1 {
            let from_index = from_indices[0];
            let from_field = &fields.unnamed[from_index];
            let ctx_tys = fields
                .unnamed
                .iter()
                .enumerate()
                .filter(|(i, _)| *i != from_index)
                .map(|(_, f)| f.ty.clone())
                .collect::<Vec<_>>();
            if !ctx_tys.is_empty() || from_field.attrs.iter().any(|a| a.path().is_ident("from")) {
                return Ok(Some(FromInfo {
                    source_ty: from_field.ty.clone(),
                    source_field: format_ident!("source"),
                    shape: FromShape::Tuple,
                    tuple_ctx_tys: ctx_tys,
                }));
            }
        }
    }

    // struct form: Foo { #[from] source: io::Error }
    if let Fields::Named(fields) = &variant.fields {
        let from_fields: Vec<_> = fields
            .named
            .iter()
            .filter(|f| f.attrs.iter().any(|a| a.path().is_ident("from")))
            .collect();
        if from_fields.len() > 1 {
            return Err(syn::Error::new(
                variant.span(),
                "multiple #[from] fields in a single struct variant are not supported",
            ));
        }
        if from_fields.len() == 1 {
            let field = from_fields[0];
            let ident = field.ident.clone().ok_or_else(|| {
                syn::Error::new(field.span(), "expected a named field for struct #[from] variant")
            })?;
            return Ok(Some(FromInfo {
                source_ty: field.ty.clone(),
                source_field: ident,
                shape: FromShape::Struct,
                tuple_ctx_tys: Vec::new(),
            }));
        }
    }

    Ok(None)
}

fn rewrite_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
    match info.shape {
        FromShape::Tuple => rewrite_tuple_from_variant(variant, &info.source_ty, &info.tuple_ctx_tys),
        FromShape::Struct => rewrite_struct_from_variant(variant, info),
    }
}

fn rewrite_tuple_from_variant(
    variant: &mut Variant,
    source_ty: &syn::Type,
    ctx_tys: &[syn::Type],
) -> syn::Result<()> {
    let variant_ident = variant.ident.clone();
    match &variant.fields {
        Fields::Unnamed(_) => {
            let mut named = syn::punctuated::Punctuated::new();
            named.push(syn::Field {
                attrs: vec![syn::parse_quote!(#[source])],
                vis: syn::Visibility::Inherited,
                mutability: syn::FieldMutability::None,
                ident: Some(format_ident!("source")),
                colon_token: Some(Default::default()),
                ty: source_ty.clone(),
            });

            for (i, ty) in ctx_tys.iter().enumerate() {
                named.push(syn::Field {
                    attrs: vec![],
                    vis: syn::Visibility::Inherited,
                    mutability: syn::FieldMutability::None,
                    ident: Some(format_ident!("ctx{i}")),
                    colon_token: Some(Default::default()),
                    ty: ty.clone(),
                });
            }

            named.push(syn::Field {
                attrs: vec![],
                vis: syn::Visibility::Inherited,
                mutability: syn::FieldMutability::None,
                ident: Some(format_ident!("trace")),
                colon_token: Some(Default::default()),
                ty: syn::parse_quote!(::thistrace::Trace),
            });

            variant.fields = Fields::Named(syn::FieldsNamed {
                brace_token: Default::default(),
                named,
            });
            Ok(())
        }
        _ => Err(syn::Error::new(
            variant_ident.span(),
            "only tuple variants can be rewritten for #[from]",
        )),
    }
}

fn rewrite_struct_from_variant(variant: &mut Variant, info: &FromInfo) -> syn::Result<()> {
    let Fields::Named(fields) = &mut variant.fields else {
        return Err(syn::Error::new(variant.span(), "expected struct variant"));
    };

    // Remove #[from] from the field to avoid thiserror generating a conflicting From impl.
    for field in fields.named.iter_mut() {
        if field.ident.as_ref() == Some(&info.source_field) {
            field.attrs.retain(|a| !a.path().is_ident("from"));
            // Ensure #[source] so thiserror's source() chain works.
            let has_source = field.attrs.iter().any(|a| a.path().is_ident("source"));
            if !has_source {
                field.attrs.push(syn::parse_quote!(#[source]));
            }
        }
    }

    let has_trace = fields
        .named
        .iter()
        .any(|f| f.ident.as_ref().is_some_and(|id| id == "trace"));
    if !has_trace {
        fields.named.push(syn::Field {
            attrs: vec![],
            vis: syn::Visibility::Inherited,
            mutability: syn::FieldMutability::None,
            ident: Some(format_ident!("trace")),
            colon_token: Some(Default::default()),
            ty: syn::parse_quote!(::thistrace::Trace),
        });
    }

    Ok(())
}

fn extra_default_inits(
    variant: &Variant,
    source_field: &syn::Ident,
) -> syn::Result<Vec<proc_macro2::TokenStream>> {
    let mut inits = Vec::new();
    let Fields::Named(fields) = &variant.fields else {
        return Ok(inits);
    };

    for field in fields.named.iter() {
        let Some(ident) = field.ident.as_ref() else {
            continue;
        };
        if ident == source_field {
            continue;
        }
        if ident == "trace" {
            continue;
        }
        inits.push(quote! { #ident: ::core::default::Default::default() });
    }

    Ok(inits)
}

fn is_thistrace_origin(ty: &syn::Type) -> bool {
    let syn::Type::Path(p) = ty else {
        return false;
    };
    let Some(seg) = p.path.segments.last() else {
        return false;
    };
    if seg.ident != "Origin" {
        return false;
    }
    // If it is `Origin<T>` we treat it as our wrapper.
    matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
}

fn is_thistrace_bubbled(ty: &syn::Type) -> bool {
    let syn::Type::Path(p) = ty else {
        return false;
    };
    let Some(seg) = p.path.segments.last() else {
        return false;
    };
    if seg.ident != "Bubbled" {
        return false;
    }
    matches!(seg.arguments, syn::PathArguments::AngleBracketed(_))
}