Skip to main content

errorstack_derive/
lib.rs

1use heck::ToSnakeCase;
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{Data, DeriveInput, Field, Fields, Ident};
6
7/// Derive macro for [`ErrorStack`].
8///
9/// Supports enums and structs with named fields. Note that the type must
10/// also implement [`Display`](std::fmt::Display) and
11/// [`Error`](std::error::Error). This can be accomplished manually or via
12/// [`thiserror`](https://crates.io/crates/thiserror).
13///
14/// This macro implements [`ErrorStack`] according to field names and
15/// attributes, and generates an ergonomic constructor for each struct or
16/// enum variant that captures caller location via `#[track_caller]` and
17/// composes naturally with [`Result::map_err`] for error chaining.
18///
19/// # Attributes
20///
21/// The following field attributes are available:
22///
23/// | Attribute         | Effect                                                                    | Auto-detected |
24/// |-------------------|---------------------------------------------------------------------------|---------------|
25/// | `#[source]`       | Marks a field as the error source.                                        | when field is named `source` |
26/// | `#[stack_source]` | Marks the field as both the error source and an [`ErrorStack`] implementor, enabling typed chain walking via `ErrorStack::stack_source`. Implies `#[source]`. | no |
27/// | `#[location]`     | Indicates the field stores a `&'static Location<'static>`, captured automatically at construction time. | no |
28///
29/// These attributes follow the same field conventions as
30/// [`thiserror`](https://crates.io/crates/thiserror), allowing
31/// both crates to be ergonomically used together.
32///
33/// # Stack sources
34///
35/// Any source field that implements [`ErrorStack`] should be annotated with
36/// `#[stack_source]` to preserve the typed error chain. The macro cannot
37/// inspect trait implementations, so without this annotation the source is
38/// treated as a plain [`std::error::Error`] and chain walking stops at that
39/// field.
40///
41/// # Optional sources
42///
43/// A source field may be wrapped in [`Option`] to represent errors that
44/// do not always have an underlying cause. When the macro detects an
45/// `Option<T>` source field, it generates two constructors instead of
46/// one:
47///
48/// | Constructor | Signature | Source value |
49/// |-------------|-----------|-------------|
50/// | `variant_name` / `new` | `(user_fields…) -> Self` | `None` |
51/// | `variant_name_with` / `new_with` | `(user_fields…) -> impl FnOnce(T) -> Self` | `Some(source)` |
52///
53/// # Error constructors
54///
55/// This macro also generates helper constructors for each struct or enum
56/// variant. Every constructor is marked `#[track_caller]`, so the
57/// call-site location is recorded without manual boilerplate. When a
58/// source field is present alongside user fields, the constructor returns
59/// `impl FnOnce(SourceTy) -> Self`, capturing the user fields and
60/// composing directly with [`Result::map_err`]. When the source is the
61/// only field, the constructor takes it as a parameter and returns `Self`
62/// directly.
63///
64/// Constructors are `pub(crate)` and named `new` for structs or
65/// `snake_cased_variant` for enum variants. Remaining fields
66/// become parameters, while `#[source]` and `#[location]` fields are filled
67/// automatically.
68///
69/// # Examples
70///
71/// The macro may be derived on enums and structs with named fields. This
72/// example shows both, with `thiserror` compatibility.
73///
74/// ```
75/// # use errorstack::ErrorStack;
76/// #[derive(thiserror::Error, ErrorStack, Debug)]
77/// pub enum AppError {
78///     #[error("io failed: {path}")]
79///     Io {
80///         path: String,
81///         source: std::io::Error,
82///         #[location]
83///         location: &'static std::panic::Location<'static>,
84///     },
85///
86///     #[error("inner failed")]
87///     Inner {
88///         #[stack_source]
89///         source: ConfigError,
90///         #[location]
91///         location: &'static std::panic::Location<'static>,
92///     },
93///
94///     #[error("not found: {id}")]
95///     NotFound {
96///         id: String,
97///         #[location]
98///         location: &'static std::panic::Location<'static>,
99///     },
100/// }
101///
102/// #[derive(thiserror::Error, ErrorStack, Debug)]
103/// #[error("config: {detail}")]
104/// pub struct ConfigError {
105///     detail: String,
106///     #[location]
107///     location: &'static std::panic::Location<'static>,
108/// }
109/// ```
110///
111/// The derive above generates the following constructors:
112///
113/// ```text
114/// impl AppError {
115///     // Source variants return a closure for use with map_err.
116///     pub(crate) fn io(path: String) -> impl FnOnce(io::Error) -> Self;
117///     pub(crate) fn inner(source: ConfigError) -> Self;
118///     // Sourceless variants return Self directly.
119///     pub(crate) fn not_found(id: String) -> Self;
120/// }
121///
122/// impl ConfigError {
123///     pub(crate) fn new(detail: String) -> Self;
124/// }
125/// ```
126///
127/// Source and location fields are handled automatically by these
128/// constructors, keeping call sites concise:
129///
130/// ```
131/// # use errorstack::ErrorStack;
132/// # #[derive(thiserror::Error, ErrorStack, Debug)]
133/// # pub enum AppError {
134/// #     #[error("io failed: {path}")]
135/// #     Io {
136/// #         path: String,
137/// #         source: std::io::Error,
138/// #         #[location]
139/// #         location: &'static std::panic::Location<'static>,
140/// #     },
141/// #     #[error("not found: {id}")]
142/// #     NotFound {
143/// #         id: String,
144/// #         #[location]
145/// #         location: &'static std::panic::Location<'static>,
146/// #     },
147/// # }
148/// # fn main() -> Result<(), AppError> {
149/// let _content = std::fs::read_to_string("Cargo.toml")
150///     .map_err(AppError::io("Cargo.toml".into()))?;
151///
152/// let _err = AppError::not_found("abc".into());
153/// # Ok(())
154/// # }
155/// ```
156#[proc_macro_derive(ErrorStack, attributes(source, stack_source, location))]
157pub fn derive_error_stack(input: TokenStream) -> TokenStream {
158    let input = syn::parse_macro_input!(input as DeriveInput);
159    match derive_impl(input) {
160        Ok(tokens) => tokens.into(),
161        Err(err) => err.to_compile_error().into(),
162    }
163}
164
165fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
166    let name = &input.ident;
167
168    match &input.data {
169        Data::Enum(data) => {
170            let mut constructor_methods = Vec::new();
171            let mut location_arms = Vec::new();
172            let mut stack_source_arms = Vec::new();
173
174            for variant in &data.variants {
175                let variant_name = &variant.ident;
176                let fields = match &variant.fields {
177                    Fields::Named(f) => f,
178                    Fields::Unnamed(_) => {
179                        return Err(syn::Error::new(
180                            variant_name.span(),
181                            format!(
182                                "ErrorStack derive requires named (struct) variants; \
183                                 found tuple variant `{variant_name}`"
184                            ),
185                        ));
186                    }
187                    Fields::Unit => {
188                        return Err(syn::Error::new(
189                            variant_name.span(),
190                            format!(
191                                "ErrorStack derive requires named (struct) variants; \
192                                 found unit variant `{variant_name}`"
193                            ),
194                        ));
195                    }
196                };
197
198                let parsed = parse_fields(&fields.named, variant_name)?;
199
200                constructor_methods.push(gen_constructor_enum(variant_name, &parsed));
201                location_arms.push(gen_location_arm_enum(variant_name, &parsed));
202                stack_source_arms.push(gen_stack_source_arm_enum(variant_name, &parsed));
203            }
204
205            Ok(quote! {
206                impl #name {
207                    #(#constructor_methods)*
208                }
209
210                impl ::errorstack::ErrorStack for #name {
211                    fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
212                        match self {
213                            #(#location_arms)*
214                        }
215                    }
216
217                    fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
218                        match self {
219                            #(#stack_source_arms)*
220                        }
221                    }
222                }
223            })
224        }
225
226        Data::Struct(data) => {
227            let fields = match &data.fields {
228                Fields::Named(f) => f,
229                _ => {
230                    return Err(syn::Error::new(
231                        name.span(),
232                        "ErrorStack derive requires named fields",
233                    ));
234                }
235            };
236
237            let parsed = parse_fields(&fields.named, name)?;
238            let constructor = gen_constructor_struct(name, &parsed);
239
240            let location_body = if let Some(loc) = &parsed.location {
241                let loc_ident = &loc.ident;
242                quote! { Some(self.#loc_ident) }
243            } else {
244                quote! { None }
245            };
246
247            let stack_source_body = if parsed.stack_source {
248                let src = parsed.source.as_ref().unwrap();
249                let src_ident = &src.ident;
250                if parsed.optional_source {
251                    quote! { self.#src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack) }
252                } else {
253                    quote! { Some(&self.#src_ident as &dyn ::errorstack::ErrorStack) }
254                }
255            } else {
256                quote! { None }
257            };
258
259            Ok(quote! {
260                impl #name {
261                    #constructor
262                }
263
264                impl ::errorstack::ErrorStack for #name {
265                    fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
266                        #location_body
267                    }
268
269                    fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
270                        #stack_source_body
271                    }
272                }
273            })
274        }
275
276        Data::Union(_) => Err(syn::Error::new(
277            name.span(),
278            "ErrorStack derive is not supported on unions",
279        )),
280    }
281}
282
283struct ParsedFields<'a> {
284    source: Option<&'a Field>,
285    location: Option<&'a Field>,
286    stack_source: bool,
287    optional_source: bool,
288    /// The inner type `T` when source is `Option<T>`.
289    inner_source_ty: Option<syn::Type>,
290    user_fields: Vec<&'a Field>,
291}
292
293fn attr(field: &Field, name: &str) -> bool {
294    field.attrs.iter().any(|a| a.path().is_ident(name))
295}
296
297/// If `ty` is `Option<T>`, returns the inner type `T`.
298fn extract_option_inner(ty: &syn::Type) -> Option<&syn::Type> {
299    let syn::Type::Path(type_path) = ty else {
300        return None;
301    };
302    let segment = type_path.path.segments.last()?;
303    if segment.ident != "Option" {
304        return None;
305    }
306    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
307        return None;
308    };
309    if args.args.len() != 1 {
310        return None;
311    }
312    let syn::GenericArgument::Type(inner) = args.args.first()? else {
313        return None;
314    };
315    Some(inner)
316}
317
318fn parse_fields<'a>(
319    fields: &'a syn::punctuated::Punctuated<Field, syn::Token![,]>,
320    context_name: &Ident,
321) -> syn::Result<ParsedFields<'a>> {
322    let mut source: Option<&Field> = None;
323    let mut location: Option<&Field> = None;
324    let mut stack_source = false;
325    let mut optional_source = false;
326    let mut inner_source_ty = None;
327    let mut user_fields = Vec::new();
328
329    for field in fields {
330        let ident = field.ident.as_ref().unwrap();
331        let source_by_name = ident == "source";
332        let source_by_attr = attr(field, "source");
333        let location_attr = attr(field, "location");
334        let stack_source_attr = attr(field, "stack_source");
335
336        if source_by_name || source_by_attr || stack_source_attr {
337            if source.is_some() {
338                return Err(syn::Error::new(
339                    ident.span(),
340                    format!("variant `{context_name}` has multiple source fields"),
341                ));
342            }
343            source = Some(field);
344            if stack_source_attr {
345                stack_source = true;
346            }
347            if let Some(inner) = extract_option_inner(&field.ty) {
348                optional_source = true;
349                inner_source_ty = Some(inner.clone());
350            }
351        } else if location_attr {
352            if location.is_some() {
353                return Err(syn::Error::new(
354                    ident.span(),
355                    format!("variant `{context_name}` has multiple location fields"),
356                ));
357            }
358            location = Some(field);
359        } else {
360            user_fields.push(field);
361        }
362    }
363
364    Ok(ParsedFields {
365        source,
366        location,
367        stack_source,
368        optional_source,
369        inner_source_ty,
370        user_fields,
371    })
372}
373
374/// Names and self-expression that vary between enum variants and structs.
375struct ConstructorCtx {
376    method_name: Ident,
377    with_method_name: Ident,
378    doc: String,
379    doc_with: String,
380    /// Token stream for constructing the type, e.g. `Self::Variant` or `Self`.
381    self_path: TokenStream2,
382}
383
384fn gen_constructor(ctx: &ConstructorCtx, parsed: &ParsedFields<'_>) -> TokenStream2 {
385    let ConstructorCtx {
386        method_name,
387        with_method_name,
388        doc,
389        doc_with,
390        self_path,
391    } = ctx;
392
393    let user_params: Vec<_> = parsed
394        .user_fields
395        .iter()
396        .map(|f| {
397            let ident = &f.ident;
398            let ty = &f.ty;
399            quote! { #ident: #ty }
400        })
401        .collect();
402
403    let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
404
405    let location_init = parsed.location.as_ref().map(|f| {
406        let ident = &f.ident;
407        quote! { #ident: location, }
408    });
409
410    let location_capture = parsed.location.as_ref().map(|_| {
411        quote! { let location = ::std::panic::Location::caller(); }
412    });
413
414    if let Some(src) = &parsed.source {
415        let src_ident = &src.ident;
416
417        if parsed.optional_source {
418            let inner_ty = parsed.inner_source_ty.as_ref().unwrap();
419            quote! {
420                #[doc = #doc]
421                #[track_caller]
422                pub(crate) fn #method_name(#(#user_params),*) -> Self {
423                    #location_capture
424                    #self_path {
425                        #src_ident: None,
426                        #(#user_field_names,)*
427                        #location_init
428                    }
429                }
430
431                #[doc = #doc_with]
432                #[track_caller]
433                pub(crate) fn #with_method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#inner_ty) -> Self {
434                    #location_capture
435                    move |#src_ident| #self_path {
436                        #src_ident: Some(#src_ident),
437                        #(#user_field_names,)*
438                        #location_init
439                    }
440                }
441            }
442        } else if parsed.user_fields.is_empty() {
443            let src_ty = &src.ty;
444            quote! {
445                #[doc = #doc]
446                #[track_caller]
447                pub(crate) fn #method_name(#src_ident: #src_ty) -> Self {
448                    #location_capture
449                    #self_path {
450                        #src_ident,
451                        #location_init
452                    }
453                }
454            }
455        } else {
456            let src_ty = &src.ty;
457            quote! {
458                #[doc = #doc]
459                #[track_caller]
460                pub(crate) fn #method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
461                    #location_capture
462                    move |#src_ident| #self_path {
463                        #src_ident,
464                        #(#user_field_names,)*
465                        #location_init
466                    }
467                }
468            }
469        }
470    } else {
471        quote! {
472            #[doc = #doc]
473            #[track_caller]
474            pub(crate) fn #method_name(#(#user_params),*) -> Self {
475                #location_capture
476                #self_path {
477                    #(#user_field_names,)*
478                    #location_init
479                }
480            }
481        }
482    }
483}
484
485fn gen_constructor_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
486    let snake = variant_name.to_string().to_snake_case();
487    let ctx = ConstructorCtx {
488        method_name: Ident::new(&snake, variant_name.span()),
489        with_method_name: Ident::new(&format!("{snake}_with"), variant_name.span()),
490        doc: format!("Constructs a [`{variant_name}`](Self::{variant_name}) error."),
491        doc_with: format!(
492            "Constructs a [`{variant_name}`](Self::{variant_name}) error with a source."
493        ),
494        self_path: quote! { Self::#variant_name },
495    };
496    gen_constructor(&ctx, parsed)
497}
498
499fn gen_constructor_struct(type_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
500    let ctx = ConstructorCtx {
501        method_name: Ident::new("new", type_name.span()),
502        with_method_name: Ident::new("new_with", type_name.span()),
503        doc: format!("Constructs a [`{type_name}`]."),
504        doc_with: format!("Constructs a [`{type_name}`] with a source."),
505        self_path: quote! { Self },
506    };
507    gen_constructor(&ctx, parsed)
508}
509
510fn gen_location_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
511    if let Some(loc) = &parsed.location {
512        let loc_ident = &loc.ident;
513        quote! {
514            Self::#variant_name { #loc_ident, .. } => Some(#loc_ident),
515        }
516    } else {
517        quote! {
518            Self::#variant_name { .. } => None,
519        }
520    }
521}
522
523fn gen_stack_source_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
524    if parsed.stack_source {
525        let src_ident = &parsed.source.unwrap().ident;
526        if parsed.optional_source {
527            quote! {
528                Self::#variant_name { #src_ident, .. } => #src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack),
529            }
530        } else {
531            quote! {
532                Self::#variant_name { #src_ident, .. } => Some(#src_ident as &dyn ::errorstack::ErrorStack),
533            }
534        }
535    } else {
536        quote! {
537            Self::#variant_name { .. } => None,
538        }
539    }
540}