Skip to main content

errorstack_derive/
lib.rs

1use heck::ToSnakeCase;
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{Data, DeriveInput, Field, Fields, Ident};
6
7/// Derive macro for [`ErrorStack`].
8///
9/// Supports enums and structs with named fields. Note that the type must
10/// also implement [`Display`](std::fmt::Display) and
11/// [`Error`](std::error::Error). This can be accomplished manually or via
12/// [`thiserror`](https://crates.io/crates/thiserror).
13///
14/// This macro implements [`ErrorStack`] according to field names and
15/// attributes, and generates an ergonomic constructor for each struct or
16/// enum variant that captures caller location via `#[track_caller]` and
17/// composes naturally with [`Result::map_err`] for error chaining.
18///
19/// # Attributes
20///
21/// The following field attributes are available:
22///
23/// | Attribute         | Effect                                                                    | Auto-detected |
24/// |-------------------|---------------------------------------------------------------------------|---------------|
25/// | `#[source]`       | Marks a field as the error source.                                        | when field is named `source` |
26/// | `#[stack_source]` | Marks the field as both the error source and an [`ErrorStack`] implementor, enabling typed chain walking via `ErrorStack::stack_source`. Implies `#[source]`. | no |
27/// | `#[location]`     | Indicates the field stores a `&'static Location<'static>`, captured automatically at construction time. | no |
28///
29/// These attributes follow the same field conventions as
30/// [`thiserror`](https://crates.io/crates/thiserror), allowing
31/// both crates to be ergonomically used together.
32///
33/// # Stack sources
34///
35/// Any source field that implements [`ErrorStack`] should be annotated with
36/// `#[stack_source]` to preserve the typed error chain. The macro cannot
37/// inspect trait implementations, so without this annotation the source is
38/// treated as a plain [`std::error::Error`] and chain walking stops at that
39/// field.
40///
41/// # Optional sources
42///
43/// A source field may be wrapped in [`Option`] to represent errors that
44/// do not always have an underlying cause. When the macro detects an
45/// `Option<T>` source field, it generates two constructors instead of
46/// one:
47///
48/// | Constructor | Signature | Source value |
49/// |-------------|-----------|-------------|
50/// | `variant_name` / `new` | `(user_fields…) -> Self` | `None` |
51/// | `variant_name_with` / `new_with` | `(user_fields…) -> impl FnOnce(T) -> Self` | `Some(source)` |
52///
53/// # Error constructors
54///
55/// This macro also generates helper constructors for each struct or enum
56/// variant. Every constructor is marked `#[track_caller]`, so the
57/// call-site location is recorded without manual boilerplate. When a
58/// source field is present the constructor returns
59/// `impl FnOnce(SourceTy) -> Self`, so it can be passed directly to
60/// [`Result::map_err`] without an intermediate closure.
61///
62/// Constructors are `pub(crate)` and named `new` for structs or
63/// `snake_cased_variant` for enum variants. Remaining fields
64/// become parameters, while `#[source]` and `#[location]` fields are filled
65/// automatically.
66///
67/// # Examples
68///
69/// The macro may be derived on enums and structs with named fields. This
70/// example shows both, with `thiserror` compatibility.
71///
72/// ```
73/// # use errorstack::ErrorStack;
74/// #[derive(thiserror::Error, ErrorStack, Debug)]
75/// pub enum AppError {
76///     #[error("io failed: {path}")]
77///     Io {
78///         path: String,
79///         source: std::io::Error,
80///         #[location]
81///         location: &'static std::panic::Location<'static>,
82///     },
83///
84///     #[error("inner failed")]
85///     Inner {
86///         #[stack_source]
87///         source: ConfigError,
88///         #[location]
89///         location: &'static std::panic::Location<'static>,
90///     },
91///
92///     #[error("not found: {id}")]
93///     NotFound {
94///         id: String,
95///         #[location]
96///         location: &'static std::panic::Location<'static>,
97///     },
98/// }
99///
100/// #[derive(thiserror::Error, ErrorStack, Debug)]
101/// #[error("config: {detail}")]
102/// pub struct ConfigError {
103///     detail: String,
104///     #[location]
105///     location: &'static std::panic::Location<'static>,
106/// }
107/// ```
108///
109/// The derive above generates the following constructors:
110///
111/// ```text
112/// impl AppError {
113///     // Source variants return a closure for use with map_err.
114///     pub(crate) fn io(path: String) -> impl FnOnce(io::Error) -> Self;
115///     pub(crate) fn inner() -> impl FnOnce(ConfigError) -> Self;
116///     // Sourceless variants return Self directly.
117///     pub(crate) fn not_found(id: String) -> Self;
118/// }
119///
120/// impl ConfigError {
121///     pub(crate) fn new(detail: String) -> Self;
122/// }
123/// ```
124///
125/// Source and location fields are handled automatically by these
126/// constructors, keeping call sites concise:
127///
128/// ```
129/// # use errorstack::ErrorStack;
130/// # #[derive(thiserror::Error, ErrorStack, Debug)]
131/// # pub enum AppError {
132/// #     #[error("io failed: {path}")]
133/// #     Io {
134/// #         path: String,
135/// #         source: std::io::Error,
136/// #         #[location]
137/// #         location: &'static std::panic::Location<'static>,
138/// #     },
139/// #     #[error("not found: {id}")]
140/// #     NotFound {
141/// #         id: String,
142/// #         #[location]
143/// #         location: &'static std::panic::Location<'static>,
144/// #     },
145/// # }
146/// # fn main() -> Result<(), AppError> {
147/// let _content = std::fs::read_to_string("Cargo.toml")
148///     .map_err(AppError::io("Cargo.toml".into()))?;
149///
150/// let _err = AppError::not_found("abc".into());
151/// # Ok(())
152/// # }
153/// ```
154#[proc_macro_derive(ErrorStack, attributes(source, stack_source, location))]
155pub fn derive_error_stack(input: TokenStream) -> TokenStream {
156    let input = syn::parse_macro_input!(input as DeriveInput);
157    match derive_impl(input) {
158        Ok(tokens) => tokens.into(),
159        Err(err) => err.to_compile_error().into(),
160    }
161}
162
163fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
164    let name = &input.ident;
165
166    match &input.data {
167        Data::Enum(data) => {
168            let mut constructor_methods = Vec::new();
169            let mut location_arms = Vec::new();
170            let mut stack_source_arms = Vec::new();
171
172            for variant in &data.variants {
173                let variant_name = &variant.ident;
174                let fields = match &variant.fields {
175                    Fields::Named(f) => f,
176                    Fields::Unnamed(_) => {
177                        return Err(syn::Error::new(
178                            variant_name.span(),
179                            format!(
180                                "ErrorStack derive requires named (struct) variants; \
181                                 found tuple variant `{variant_name}`"
182                            ),
183                        ));
184                    }
185                    Fields::Unit => {
186                        return Err(syn::Error::new(
187                            variant_name.span(),
188                            format!(
189                                "ErrorStack derive requires named (struct) variants; \
190                                 found unit variant `{variant_name}`"
191                            ),
192                        ));
193                    }
194                };
195
196                let parsed = parse_fields(&fields.named, variant_name)?;
197
198                constructor_methods.push(gen_constructor_enum(variant_name, &parsed));
199                location_arms.push(gen_location_arm_enum(variant_name, &parsed));
200                stack_source_arms.push(gen_stack_source_arm_enum(variant_name, &parsed));
201            }
202
203            Ok(quote! {
204                impl #name {
205                    #(#constructor_methods)*
206                }
207
208                impl ::errorstack::ErrorStack for #name {
209                    fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
210                        match self {
211                            #(#location_arms)*
212                        }
213                    }
214
215                    fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
216                        match self {
217                            #(#stack_source_arms)*
218                        }
219                    }
220                }
221            })
222        }
223
224        Data::Struct(data) => {
225            let fields = match &data.fields {
226                Fields::Named(f) => f,
227                _ => {
228                    return Err(syn::Error::new(
229                        name.span(),
230                        "ErrorStack derive requires named fields",
231                    ));
232                }
233            };
234
235            let parsed = parse_fields(&fields.named, name)?;
236            let constructor = gen_constructor_struct(name, &parsed);
237
238            let location_body = if let Some(loc) = &parsed.location {
239                let loc_ident = &loc.ident;
240                quote! { Some(self.#loc_ident) }
241            } else {
242                quote! { None }
243            };
244
245            let stack_source_body = if parsed.stack_source {
246                let src = parsed.source.as_ref().unwrap();
247                let src_ident = &src.ident;
248                if parsed.optional_source {
249                    quote! { self.#src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack) }
250                } else {
251                    quote! { Some(&self.#src_ident as &dyn ::errorstack::ErrorStack) }
252                }
253            } else {
254                quote! { None }
255            };
256
257            Ok(quote! {
258                impl #name {
259                    #constructor
260                }
261
262                impl ::errorstack::ErrorStack for #name {
263                    fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
264                        #location_body
265                    }
266
267                    fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
268                        #stack_source_body
269                    }
270                }
271            })
272        }
273
274        Data::Union(_) => Err(syn::Error::new(
275            name.span(),
276            "ErrorStack derive is not supported on unions",
277        )),
278    }
279}
280
281struct ParsedFields<'a> {
282    source: Option<&'a Field>,
283    location: Option<&'a Field>,
284    stack_source: bool,
285    optional_source: bool,
286    /// The inner type `T` when source is `Option<T>`.
287    inner_source_ty: Option<syn::Type>,
288    user_fields: Vec<&'a Field>,
289}
290
291fn attr(field: &Field, name: &str) -> bool {
292    field.attrs.iter().any(|a| a.path().is_ident(name))
293}
294
295/// If `ty` is `Option<T>`, returns the inner type `T`.
296fn extract_option_inner(ty: &syn::Type) -> Option<&syn::Type> {
297    let syn::Type::Path(type_path) = ty else {
298        return None;
299    };
300    let segment = type_path.path.segments.last()?;
301    if segment.ident != "Option" {
302        return None;
303    }
304    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
305        return None;
306    };
307    if args.args.len() != 1 {
308        return None;
309    }
310    let syn::GenericArgument::Type(inner) = args.args.first()? else {
311        return None;
312    };
313    Some(inner)
314}
315
316fn parse_fields<'a>(
317    fields: &'a syn::punctuated::Punctuated<Field, syn::Token![,]>,
318    context_name: &Ident,
319) -> syn::Result<ParsedFields<'a>> {
320    let mut source: Option<&Field> = None;
321    let mut location: Option<&Field> = None;
322    let mut stack_source = false;
323    let mut optional_source = false;
324    let mut inner_source_ty = None;
325    let mut user_fields = Vec::new();
326
327    for field in fields {
328        let ident = field.ident.as_ref().unwrap();
329        let source_by_name = ident == "source";
330        let source_by_attr = attr(field, "source");
331        let location_attr = attr(field, "location");
332        let stack_source_attr = attr(field, "stack_source");
333
334        if source_by_name || source_by_attr || stack_source_attr {
335            if source.is_some() {
336                return Err(syn::Error::new(
337                    ident.span(),
338                    format!("variant `{context_name}` has multiple source fields"),
339                ));
340            }
341            source = Some(field);
342            if stack_source_attr {
343                stack_source = true;
344            }
345            if let Some(inner) = extract_option_inner(&field.ty) {
346                optional_source = true;
347                inner_source_ty = Some(inner.clone());
348            }
349        } else if location_attr {
350            if location.is_some() {
351                return Err(syn::Error::new(
352                    ident.span(),
353                    format!("variant `{context_name}` has multiple location fields"),
354                ));
355            }
356            location = Some(field);
357        } else {
358            user_fields.push(field);
359        }
360    }
361
362    Ok(ParsedFields {
363        source,
364        location,
365        stack_source,
366        optional_source,
367        inner_source_ty,
368        user_fields,
369    })
370}
371
372/// Names and self-expression that vary between enum variants and structs.
373struct ConstructorCtx {
374    method_name: Ident,
375    with_method_name: Ident,
376    doc: String,
377    doc_with: String,
378    /// Token stream for constructing the type, e.g. `Self::Variant` or `Self`.
379    self_path: TokenStream2,
380}
381
382fn gen_constructor(ctx: &ConstructorCtx, parsed: &ParsedFields<'_>) -> TokenStream2 {
383    let ConstructorCtx {
384        method_name,
385        with_method_name,
386        doc,
387        doc_with,
388        self_path,
389    } = ctx;
390
391    let user_params: Vec<_> = parsed
392        .user_fields
393        .iter()
394        .map(|f| {
395            let ident = &f.ident;
396            let ty = &f.ty;
397            quote! { #ident: #ty }
398        })
399        .collect();
400
401    let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
402
403    let location_init = parsed.location.as_ref().map(|f| {
404        let ident = &f.ident;
405        quote! { #ident: location, }
406    });
407
408    let location_capture = parsed.location.as_ref().map(|_| {
409        quote! { let location = ::std::panic::Location::caller(); }
410    });
411
412    if let Some(src) = &parsed.source {
413        let src_ident = &src.ident;
414
415        if parsed.optional_source {
416            let inner_ty = parsed.inner_source_ty.as_ref().unwrap();
417            quote! {
418                #[doc = #doc]
419                #[track_caller]
420                pub(crate) fn #method_name(#(#user_params),*) -> Self {
421                    #location_capture
422                    #self_path {
423                        #src_ident: None,
424                        #(#user_field_names,)*
425                        #location_init
426                    }
427                }
428
429                #[doc = #doc_with]
430                #[track_caller]
431                pub(crate) fn #with_method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#inner_ty) -> Self {
432                    #location_capture
433                    move |#src_ident| #self_path {
434                        #src_ident: Some(#src_ident),
435                        #(#user_field_names,)*
436                        #location_init
437                    }
438                }
439            }
440        } else {
441            let src_ty = &src.ty;
442            quote! {
443                #[doc = #doc]
444                #[track_caller]
445                pub(crate) fn #method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
446                    #location_capture
447                    move |#src_ident| #self_path {
448                        #src_ident,
449                        #(#user_field_names,)*
450                        #location_init
451                    }
452                }
453            }
454        }
455    } else {
456        quote! {
457            #[doc = #doc]
458            #[track_caller]
459            pub(crate) fn #method_name(#(#user_params),*) -> Self {
460                #location_capture
461                #self_path {
462                    #(#user_field_names,)*
463                    #location_init
464                }
465            }
466        }
467    }
468}
469
470fn gen_constructor_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
471    let snake = variant_name.to_string().to_snake_case();
472    let ctx = ConstructorCtx {
473        method_name: Ident::new(&snake, variant_name.span()),
474        with_method_name: Ident::new(&format!("{snake}_with"), variant_name.span()),
475        doc: format!("Constructs a [`{variant_name}`](Self::{variant_name}) error."),
476        doc_with: format!(
477            "Constructs a [`{variant_name}`](Self::{variant_name}) error with a source."
478        ),
479        self_path: quote! { Self::#variant_name },
480    };
481    gen_constructor(&ctx, parsed)
482}
483
484fn gen_constructor_struct(type_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
485    let ctx = ConstructorCtx {
486        method_name: Ident::new("new", type_name.span()),
487        with_method_name: Ident::new("new_with", type_name.span()),
488        doc: format!("Constructs a [`{type_name}`]."),
489        doc_with: format!("Constructs a [`{type_name}`] with a source."),
490        self_path: quote! { Self },
491    };
492    gen_constructor(&ctx, parsed)
493}
494
495fn gen_location_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
496    if let Some(loc) = &parsed.location {
497        let loc_ident = &loc.ident;
498        quote! {
499            Self::#variant_name { #loc_ident, .. } => Some(#loc_ident),
500        }
501    } else {
502        quote! {
503            Self::#variant_name { .. } => None,
504        }
505    }
506}
507
508fn gen_stack_source_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
509    if parsed.stack_source {
510        let src_ident = &parsed.source.unwrap().ident;
511        if parsed.optional_source {
512            quote! {
513                Self::#variant_name { #src_ident, .. } => #src_ident.as_ref().map(|s| s as &dyn ::errorstack::ErrorStack),
514            }
515        } else {
516            quote! {
517                Self::#variant_name { #src_ident, .. } => Some(#src_ident as &dyn ::errorstack::ErrorStack),
518            }
519        }
520    } else {
521        quote! {
522            Self::#variant_name { .. } => None,
523        }
524    }
525}