use crate::helper::{
combine_errors, extract_display_error_inner, has_snafu_keyword, looks_like_location_type,
};
use proc_macro2::{Span, TokenStream};
use std::collections::HashSet;
use syn::parse_quote;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{Attribute, Data, DeriveInput, Error, Field, Fields, GenericParam, Ident, Meta, Token};
pub(crate) fn process_suzu_attrs(
input: &mut DeriveInput,
crate_path: &TokenStream,
) -> Result<(), Error> {
process_non_field_attrs(&mut input.attrs)?;
let generic_type_params: HashSet<Ident> = input
.generics
.params
.iter()
.filter_map(|p| match p {
GenericParam::Type(tp) => Some(tp.ident.clone()),
_ => None,
})
.collect();
match &mut input.data {
Data::Struct(data_struct) => {
match &mut data_struct.fields {
Fields::Named(fields) => {
process_fields(&mut fields.named, crate_path, &generic_type_params)?
}
fields => reject_suzu_on_non_named_fields(fields)?,
}
Ok(())
}
Data::Enum(data_enum) => {
let mut errors = Vec::new();
for variant in &mut data_enum.variants {
if let Err(e) = process_non_field_attrs(&mut variant.attrs) {
errors.push(e);
}
match &mut variant.fields {
Fields::Named(fields) => {
if let Err(e) =
process_fields(&mut fields.named, crate_path, &generic_type_params)
{
errors.push(e);
}
}
fields => {
if let Err(e) = reject_suzu_on_non_named_fields(fields) {
errors.push(e);
}
}
}
}
combine_errors(errors)
}
Data::Union(_) => Err(Error::new(input.span(), "#[suzu] cannot be used on unions")),
}
}
fn reject_suzu_on_non_named_fields(fields: &Fields) -> Result<(), Error> {
let mut errors = Vec::new();
for field in fields {
for attr in &field.attrs {
if attr.path().is_ident("suzu") {
errors.push(Error::new(
attr.span(),
"#[suzu(...)] is not supported on unnamed fields; use named fields instead",
));
}
}
}
combine_errors(errors)
}
fn process_non_field_attrs(attrs: &mut Vec<Attribute>) -> Result<(), Error> {
let level = Level::NonField;
let mut new_attrs = Vec::new();
let mut errors = Vec::new();
for attr in attrs.drain(..) {
if !attr.path().is_ident("suzu") {
new_attrs.push(attr);
continue;
}
match process_single_suzu_attr(&attr, level) {
Ok(result) => {
if let Some(snafu_attr) = result.snafu_passthrough {
new_attrs.push(snafu_attr);
}
}
Err(e) => errors.push(e),
}
}
*attrs = new_attrs;
combine_errors(errors)
}
fn process_fields(
fields: &mut Punctuated<Field, Token![,]>,
crate_path: &TokenStream,
generic_type_params: &HashSet<Ident>,
) -> Result<(), Error> {
let mut errors = Vec::new();
let mut first_from_span: Option<Span> = None;
let mut first_location_span: Option<Span> = None;
for field in fields.iter_mut() {
let old_attrs = std::mem::take(&mut field.attrs);
let mut new_attrs = Vec::new();
let mut current_from_span: Option<Span> = None;
let mut current_location_span: Option<Span> = None;
for attr in old_attrs {
if !attr.path().is_ident("suzu") {
new_attrs.push(attr);
continue;
}
match process_single_suzu_attr(&attr, Level::Field) {
Ok(result) => {
if let Some(snafu_attr) = result.snafu_passthrough {
new_attrs.push(snafu_attr);
}
match result.effect {
SuzuEffect::From(keyword_span) => {
if let Some(first_span) = first_from_span {
let msg = if current_from_span.is_some() {
"duplicate #[suzu(from)] on the same field"
} else {
"multiple #[suzu(from)] fields; only one source field is allowed per struct/variant"
};
let mut err = Error::new(keyword_span, msg);
err.combine(Error::new(
first_span,
"first occurrence of #[suzu(from)] is here",
));
errors.push(err);
current_from_span = None;
} else {
first_from_span = Some(keyword_span);
current_from_span = Some(keyword_span);
}
}
SuzuEffect::Location(keyword_span) => {
if let Some(first_span) = first_location_span {
let msg = if current_location_span.is_some() {
"duplicate #[suzu(location)] on the same field"
} else {
"multiple #[suzu(location)] fields; only one is allowed per struct/variant"
};
let mut err = Error::new(keyword_span, msg);
err.combine(Error::new(
first_span,
"first occurrence of #[suzu(location)] is here",
));
errors.push(err);
current_location_span = None;
} else {
first_location_span = Some(keyword_span);
current_location_span = Some(keyword_span);
}
}
SuzuEffect::PassthroughOnly => {}
}
}
Err(e) => errors.push(e),
}
}
match (current_from_span, current_location_span) {
(Some(from_span), Some(loc_span)) => {
let mut err = Error::new(
from_span,
"`from` and `location` cannot be used on the same field",
);
err.combine(Error::new(loc_span, "`location` defined here"));
errors.push(err);
}
(Some(from_span), None) => match apply_from(
field,
&new_attrs,
crate_path,
from_span,
generic_type_params,
) {
Ok(snafu_source_attr) => new_attrs.push(snafu_source_attr),
Err(e) => errors.push(e),
},
(None, Some(_)) => {
if !looks_like_location_type(&field.ty) {
errors.push(Error::new(
field.ty.span(),
"#[suzu(location)] requires the field type to be `suzunari_error::Location`",
));
} else {
apply_location(&mut new_attrs);
}
}
(None, None) => {}
}
field.attrs = new_attrs;
}
combine_errors(errors)
}
#[derive(Clone, Copy)]
enum Level {
NonField,
Field,
}
enum SuzuEffect {
PassthroughOnly,
From(Span),
Location(Span),
}
struct SingleAttrResult {
snafu_passthrough: Option<Attribute>,
effect: SuzuEffect,
}
fn process_single_suzu_attr(attr: &Attribute, level: Level) -> Result<SingleAttrResult, Error> {
let Meta::List(meta_list) = &attr.meta else {
return Err(Error::new(
attr.span(),
"#[suzu] requires arguments, e.g., #[suzu(location)] or #[suzu(display(\"...\"))]",
));
};
let nested: Punctuated<Meta, Token![,]> =
meta_list.parse_args_with(Punctuated::parse_terminated)?;
if nested.is_empty() {
return Err(Error::new(
attr.span(),
"#[suzu()] requires arguments, e.g., #[suzu(location)] or #[suzu(display(\"...\"))]",
));
}
let mut effect = SuzuEffect::PassthroughOnly;
let mut passthrough_tokens: Vec<Meta> = Vec::new();
let mut has_source_in_passthrough = false;
for meta in &nested {
if meta.path().is_ident("from") {
if !matches!(meta, Meta::Path(_)) {
return Err(Error::new(
meta.span(),
"`from` does not accept arguments; use `#[suzu(from)]` as a bare keyword",
));
}
if matches!(level, Level::NonField) {
return Err(Error::new(meta.span(), "`from` can only be used on fields"));
}
if matches!(effect, SuzuEffect::Location(_)) {
return Err(Error::new(
meta.span(),
"`from` and `location` cannot be used on the same field",
));
}
effect = SuzuEffect::From(meta.span());
} else if meta.path().is_ident("location") {
if !matches!(meta, Meta::Path(_)) {
return Err(Error::new(
meta.span(),
"`location` does not accept arguments; use `#[suzu(location)]` as a bare keyword",
));
}
if matches!(level, Level::NonField) {
return Err(Error::new(
meta.span(),
"`location` can only be used on fields",
));
}
if matches!(effect, SuzuEffect::From(_)) {
return Err(Error::new(
meta.span(),
"`from` and `location` cannot be used on the same field",
));
}
effect = SuzuEffect::Location(meta.span());
} else {
if meta.path().is_ident("source") {
has_source_in_passthrough = true;
}
passthrough_tokens.push(meta.clone());
}
}
if matches!(effect, SuzuEffect::From(_)) && has_source_in_passthrough {
return Err(Error::new(
attr.span(),
"`from` conflicts with `source(...)`: `from` generates `source(from(...))` automatically",
));
}
let snafu_passthrough = if passthrough_tokens.is_empty() {
None
} else {
Some(parse_quote!(#[snafu(#(#passthrough_tokens),*)]))
};
Ok(SingleAttrResult {
snafu_passthrough,
effect,
})
}
fn apply_from(
field: &mut Field,
existing_attrs: &[Attribute],
crate_path: &TokenStream,
from_span: Span,
generic_type_params: &HashSet<Ident>,
) -> Result<Attribute, Error> {
if has_snafu_keyword(existing_attrs, "source") {
return Err(Error::new(
from_span,
"`from` conflicts with existing `#[snafu(source(...))]`",
));
}
if type_uses_generic_params(&field.ty, generic_type_params) {
return Err(Error::new(
from_span,
"`from` cannot be used on fields with generic type parameters; use a concrete type instead",
));
}
let original_type = match extract_display_error_inner(&field.ty) {
Some(inner) => inner.clone(),
None => {
let orig = field.ty.clone();
field.ty = parse_quote!(#crate_path::DisplayError<#orig>);
orig
}
};
Ok(parse_quote!(
#[snafu(source(from(#original_type, {
fn __wrap(__source: #original_type) -> #crate_path::DisplayError<#original_type> {
let __get_source: fn(&#original_type) -> Option<&(dyn ::core::error::Error + 'static)>
= #crate_path::__private::DisplayErrorSourceResolver(&__source).get_source_fn();
#crate_path::__private::display_error_with_get_source(__source, __get_source)
}
__wrap
})))]
))
}
fn type_uses_generic_params(ty: &syn::Type, params: &HashSet<Ident>) -> bool {
use syn::{GenericArgument, PathArguments, ReturnType, Type};
fn angle_bracketed_uses(
args: &syn::AngleBracketedGenericArguments,
params: &HashSet<Ident>,
) -> bool {
args.args.iter().any(|arg| match arg {
GenericArgument::Type(inner) => type_uses_generic_params(inner, params),
GenericArgument::AssocType(assoc) => {
type_uses_generic_params(&assoc.ty, params)
|| assoc
.generics
.as_ref()
.is_some_and(|g| angle_bracketed_uses(g, params))
}
_ => false,
})
}
fn path_args_uses(args: &PathArguments, params: &HashSet<Ident>) -> bool {
match args {
PathArguments::AngleBracketed(args) => angle_bracketed_uses(args, params),
PathArguments::Parenthesized(paren) => {
paren
.inputs
.iter()
.any(|t| type_uses_generic_params(t, params))
|| matches!(&paren.output, ReturnType::Type(_, t) if type_uses_generic_params(t, params))
}
PathArguments::None => false,
}
}
fn bounds_use(
bounds: &syn::punctuated::Punctuated<syn::TypeParamBound, Token![+]>,
params: &HashSet<Ident>,
) -> bool {
bounds.iter().any(|bound| match bound {
syn::TypeParamBound::Trait(trait_bound) => trait_bound
.path
.segments
.iter()
.any(|seg| path_args_uses(&seg.arguments, params)),
_ => false,
})
}
match ty {
Type::Path(type_path) => {
if let Some(qself) = &type_path.qself {
if type_uses_generic_params(&qself.ty, params) {
return true;
}
}
type_path
.path
.segments
.iter()
.any(|seg| params.contains(&seg.ident) || path_args_uses(&seg.arguments, params))
}
Type::Reference(type_ref) => type_uses_generic_params(&type_ref.elem, params),
Type::Tuple(type_tuple) => type_tuple
.elems
.iter()
.any(|t| type_uses_generic_params(t, params)),
Type::Array(type_array) => type_uses_generic_params(&type_array.elem, params),
Type::Slice(type_slice) => type_uses_generic_params(&type_slice.elem, params),
Type::Paren(type_paren) => type_uses_generic_params(&type_paren.elem, params),
Type::Group(type_group) => type_uses_generic_params(&type_group.elem, params),
Type::TraitObject(type_trait_object) => bounds_use(&type_trait_object.bounds, params),
Type::BareFn(type_bare_fn) => {
type_bare_fn
.inputs
.iter()
.any(|arg| type_uses_generic_params(&arg.ty, params))
|| matches!(&type_bare_fn.output, ReturnType::Type(_, t) if type_uses_generic_params(t, params))
}
Type::Ptr(type_ptr) => type_uses_generic_params(&type_ptr.elem, params),
Type::ImplTrait(type_impl_trait) => bounds_use(&type_impl_trait.bounds, params),
_ => false,
}
}
fn apply_location(attrs: &mut Vec<Attribute>) {
if !has_snafu_keyword(attrs, "implicit") {
attrs.push(parse_quote!(#[snafu(implicit)]));
}
let already_has_stack_location = attrs.iter().any(|a| a.path().is_ident("stack"));
if !already_has_stack_location {
attrs.push(parse_quote!(#[stack(location)]));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_apply_location_is_idempotent() {
let mut attrs: Vec<Attribute> = Vec::new();
apply_location(&mut attrs);
assert_eq!(
attrs.len(),
2,
"first call should add implicit + stack(location)"
);
apply_location(&mut attrs);
assert_eq!(
attrs.len(),
2,
"second call should not duplicate any attributes"
);
}
#[test]
fn test_apply_location_preserves_existing_implicit() {
let mut attrs: Vec<Attribute> = vec![parse_quote!(#[snafu(implicit)])];
apply_location(&mut attrs);
assert_eq!(attrs.len(), 2, "should add only #[stack(location)]");
let implicit_count = attrs.iter().filter(|a| a.path().is_ident("snafu")).count();
assert_eq!(implicit_count, 1, "should not duplicate #[snafu(implicit)]");
}
#[test]
fn test_apply_location_preserves_existing_stack_location() {
let mut attrs: Vec<Attribute> = vec![parse_quote!(#[stack(location)])];
apply_location(&mut attrs);
assert_eq!(attrs.len(), 2, "should add only #[snafu(implicit)]");
let stack_count = attrs.iter().filter(|a| a.path().is_ident("stack")).count();
assert_eq!(stack_count, 1, "should not duplicate #[stack(location)]");
}
}