Skip to main content

errorstack_derive/
lib.rs

1use heck::ToSnakeCase;
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{Data, DeriveInput, Field, Fields, Ident};
6
7/// Derive macro for [`ErrorStack`].
8///
9/// Supports enums and structs with named fields. Note that the type must
10/// also implement [`Display`](std::fmt::Display) and
11/// [`Error`](std::error::Error). This can be accomplished manually or via
12/// [`thiserror`](https://crates.io/crates/thiserror).
13///
14/// This macro implements [`ErrorStack`] according to field names and
15/// attributes, and generates an ergonomic constructor for each struct or
16/// enum variant that captures caller location via `#[track_caller]` and
17/// composes naturally with [`Result::map_err`] for error chaining.
18///
19/// # Attributes
20///
21/// The following field attributes are available:
22///
23/// | Attribute         | Effect                                                                    | Auto-detected |
24/// |-------------------|---------------------------------------------------------------------------|---------------|
25/// | `#[source]`       | Marks a field as the error source.                                        | when field is named `source` |
26/// | `#[stack_source]` | Marks the field as both the error source and an [`ErrorStack`] implementor, enabling typed chain walking via `ErrorStack::stack_source`. Implies `#[source]`. | no |
27/// | `#[location]`     | Indicates the field stores a `&'static Location<'static>`, captured automatically at construction time. | no |
28///
29/// These attributes follow the same field conventions as
30/// [`thiserror`](https://crates.io/crates/thiserror), allowing
31/// both crates to be ergonomically used together.
32///
33/// # Stack sources
34///
35/// Any source field that implements [`ErrorStack`] should be annotated with
36/// `#[stack_source]` to preserve the typed error chain. The macro cannot
37/// inspect trait implementations, so without this annotation the source is
38/// treated as a plain [`std::error::Error`] and chain walking stops at that
39/// field.
40///
41/// # Error constructors
42///
43/// This macro also generates helper constructors for each struct or enum
44/// variant. Every constructor is marked `#[track_caller]`, so the
45/// call-site location is recorded without manual boilerplate. When a
46/// source field is present the constructor returns
47/// `impl FnOnce(SourceTy) -> Self`, so it can be passed directly to
48/// [`Result::map_err`] without an intermediate closure.
49///
50/// Constructors are `pub(crate)` and named `new` for structs or
51/// `snake_cased_variant` for enum variants. Remaining fields
52/// become parameters, while `#[source]` and `#[location]` fields are filled
53/// automatically.
54///
55/// # Examples
56///
57/// The macro may be derived on enums and structs with named fields. This
58/// example shows both, with `thiserror` compatibility.
59///
60/// ```
61/// # use errorstack::ErrorStack;
62/// #[derive(thiserror::Error, ErrorStack, Debug)]
63/// pub enum AppError {
64///     #[error("io failed: {path}")]
65///     Io {
66///         path: String,
67///         source: std::io::Error,
68///         #[location]
69///         location: &'static std::panic::Location<'static>,
70///     },
71///
72///     #[error("inner failed")]
73///     Inner {
74///         #[stack_source]
75///         source: ConfigError,
76///         #[location]
77///         location: &'static std::panic::Location<'static>,
78///     },
79///
80///     #[error("not found: {id}")]
81///     NotFound {
82///         id: String,
83///         #[location]
84///         location: &'static std::panic::Location<'static>,
85///     },
86/// }
87///
88/// #[derive(thiserror::Error, ErrorStack, Debug)]
89/// #[error("config: {detail}")]
90/// pub struct ConfigError {
91///     detail: String,
92///     #[location]
93///     location: &'static std::panic::Location<'static>,
94/// }
95/// ```
96///
97/// The derive above generates the following constructors:
98///
99/// ```text
100/// // AppError: one constructor per variant
101/// impl AppError {
102///     // Source variants return a closure for use with map_err.
103///     pub(crate) fn io(path: String) -> impl FnOnce(io::Error) -> Self;
104///     pub(crate) fn inner() -> impl FnOnce(ConfigError) -> Self;
105///     // Sourceless variants return Self directly.
106///     pub(crate) fn not_found(id: String) -> Self;
107/// }
108///
109/// // ConfigError: struct constructor is named `new`
110/// impl ConfigError {
111///     pub(crate) fn new(detail: String) -> Self;
112/// }
113/// ```
114///
115/// Source and location fields are handled automatically by these
116/// constructors, keeping call sites concise:
117///
118/// ```
119/// # use errorstack::ErrorStack;
120/// # #[derive(thiserror::Error, ErrorStack, Debug)]
121/// # pub enum AppError {
122/// #     #[error("io failed: {path}")]
123/// #     Io {
124/// #         path: String,
125/// #         source: std::io::Error,
126/// #         #[location]
127/// #         location: &'static std::panic::Location<'static>,
128/// #     },
129/// #     #[error("not found: {id}")]
130/// #     NotFound {
131/// #         id: String,
132/// #         #[location]
133/// #         location: &'static std::panic::Location<'static>,
134/// #     },
135/// # }
136/// # fn main() -> Result<(), AppError> {
137/// let _content = std::fs::read_to_string("Cargo.toml")
138///     .map_err(AppError::io("Cargo.toml".into()))?;
139///
140/// let _err = AppError::not_found("abc".into());
141/// # Ok(())
142/// # }
143/// ```
144#[proc_macro_derive(ErrorStack, attributes(source, stack_source, location))]
145pub fn derive_error_stack(input: TokenStream) -> TokenStream {
146    let input = syn::parse_macro_input!(input as DeriveInput);
147    match derive_impl(input) {
148        Ok(tokens) => tokens.into(),
149        Err(err) => err.to_compile_error().into(),
150    }
151}
152
153fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
154    let name = &input.ident;
155
156    match &input.data {
157        Data::Enum(data) => {
158            let mut constructor_methods = Vec::new();
159            let mut location_arms = Vec::new();
160            let mut stack_source_arms = Vec::new();
161
162            for variant in &data.variants {
163                let variant_name = &variant.ident;
164                let fields = match &variant.fields {
165                    Fields::Named(f) => f,
166                    Fields::Unnamed(_) => {
167                        return Err(syn::Error::new(
168                            variant_name.span(),
169                            format!(
170                                "ErrorStack derive requires named (struct) variants; \
171                                 found tuple variant `{variant_name}`"
172                            ),
173                        ));
174                    }
175                    Fields::Unit => {
176                        return Err(syn::Error::new(
177                            variant_name.span(),
178                            format!(
179                                "ErrorStack derive requires named (struct) variants; \
180                                 found unit variant `{variant_name}`"
181                            ),
182                        ));
183                    }
184                };
185
186                let parsed = parse_fields(&fields.named, variant_name)?;
187
188                constructor_methods.push(gen_constructor_enum(variant_name, &parsed));
189                location_arms.push(gen_location_arm_enum(variant_name, &parsed));
190                stack_source_arms.push(gen_stack_source_arm_enum(variant_name, &parsed));
191            }
192
193            Ok(quote! {
194                impl #name {
195                    #(#constructor_methods)*
196                }
197
198                impl ::errorstack::ErrorStack for #name {
199                    fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
200                        match self {
201                            #(#location_arms)*
202                        }
203                    }
204
205                    fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
206                        match self {
207                            #(#stack_source_arms)*
208                        }
209                    }
210                }
211            })
212        }
213
214        Data::Struct(data) => {
215            let fields = match &data.fields {
216                Fields::Named(f) => f,
217                _ => {
218                    return Err(syn::Error::new(
219                        name.span(),
220                        "ErrorStack derive requires named fields",
221                    ));
222                }
223            };
224
225            let parsed = parse_fields(&fields.named, name)?;
226            let constructor = gen_constructor_struct(name, &parsed);
227
228            let location_body = if let Some(loc) = &parsed.location {
229                let loc_ident = &loc.ident;
230                quote! { Some(self.#loc_ident) }
231            } else {
232                quote! { None }
233            };
234
235            let stack_source_body = if parsed.stack_source {
236                let src = parsed.source.as_ref().unwrap();
237                let src_ident = &src.ident;
238                quote! { Some(&self.#src_ident as &dyn ::errorstack::ErrorStack) }
239            } else {
240                quote! { None }
241            };
242
243            Ok(quote! {
244                impl #name {
245                    #constructor
246                }
247
248                impl ::errorstack::ErrorStack for #name {
249                    fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
250                        #location_body
251                    }
252
253                    fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
254                        #stack_source_body
255                    }
256                }
257            })
258        }
259
260        Data::Union(_) => Err(syn::Error::new(
261            name.span(),
262            "ErrorStack derive is not supported on unions",
263        )),
264    }
265}
266
267struct ParsedFields<'a> {
268    source: Option<&'a Field>,
269    location: Option<&'a Field>,
270    stack_source: bool,
271    user_fields: Vec<&'a Field>,
272}
273
274fn attr(field: &Field, name: &str) -> bool {
275    field.attrs.iter().any(|a| a.path().is_ident(name))
276}
277
278fn parse_fields<'a>(
279    fields: &'a syn::punctuated::Punctuated<Field, syn::Token![,]>,
280    context_name: &Ident,
281) -> syn::Result<ParsedFields<'a>> {
282    let mut source: Option<&Field> = None;
283    let mut location: Option<&Field> = None;
284    let mut stack_source = false;
285    let mut user_fields = Vec::new();
286
287    for field in fields {
288        let ident = field.ident.as_ref().unwrap();
289        let source_by_name = ident == "source";
290        let source_by_attr = attr(field, "source");
291        let location_attr = attr(field, "location");
292        let stack_source_attr = attr(field, "stack_source");
293
294        if source_by_name || source_by_attr || stack_source_attr {
295            if source.is_some() {
296                return Err(syn::Error::new(
297                    ident.span(),
298                    format!("variant `{context_name}` has multiple source fields"),
299                ));
300            }
301            source = Some(field);
302            if stack_source_attr {
303                stack_source = true;
304            }
305        } else if location_attr {
306            if location.is_some() {
307                return Err(syn::Error::new(
308                    ident.span(),
309                    format!("variant `{context_name}` has multiple location fields"),
310                ));
311            }
312            location = Some(field);
313        } else {
314            user_fields.push(field);
315        }
316    }
317
318    Ok(ParsedFields {
319        source,
320        location,
321        stack_source,
322        user_fields,
323    })
324}
325
326fn gen_constructor_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
327    let method_name = Ident::new(
328        &variant_name.to_string().to_snake_case(),
329        variant_name.span(),
330    );
331
332    let user_params: Vec<_> = parsed
333        .user_fields
334        .iter()
335        .map(|f| {
336            let ident = &f.ident;
337            let ty = &f.ty;
338            quote! { #ident: #ty }
339        })
340        .collect();
341
342    let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
343
344    let location_init = parsed.location.as_ref().map(|f| {
345        let ident = &f.ident;
346        quote! { #ident: location, }
347    });
348
349    let location_capture = parsed.location.as_ref().map(|_| {
350        quote! { let location = ::std::panic::Location::caller(); }
351    });
352
353    let doc = format!("Constructs a [`{variant_name}`](Self::{variant_name}) error.");
354
355    if let Some(src) = &parsed.source {
356        let src_ident = &src.ident;
357        let src_ty = &src.ty;
358        quote! {
359            #[doc = #doc]
360            #[track_caller]
361            pub(crate) fn #method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
362                #location_capture
363                move |#src_ident| Self::#variant_name {
364                    #src_ident,
365                    #(#user_field_names,)*
366                    #location_init
367                }
368            }
369        }
370    } else {
371        quote! {
372            #[doc = #doc]
373            #[track_caller]
374            pub(crate) fn #method_name(#(#user_params),*) -> Self {
375                #location_capture
376                Self::#variant_name {
377                    #(#user_field_names,)*
378                    #location_init
379                }
380            }
381        }
382    }
383}
384
385fn gen_constructor_struct(type_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
386    let user_params: Vec<_> = parsed
387        .user_fields
388        .iter()
389        .map(|f| {
390            let ident = &f.ident;
391            let ty = &f.ty;
392            quote! { #ident: #ty }
393        })
394        .collect();
395
396    let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
397
398    let location_init = parsed.location.as_ref().map(|f| {
399        let ident = &f.ident;
400        quote! { #ident: location, }
401    });
402
403    let location_capture = parsed.location.as_ref().map(|_| {
404        quote! { let location = ::std::panic::Location::caller(); }
405    });
406
407    let doc = format!("Constructs a [`{type_name}`].");
408
409    if let Some(src) = &parsed.source {
410        let src_ident = &src.ident;
411        let src_ty = &src.ty;
412        quote! {
413            #[doc = #doc]
414            #[track_caller]
415            pub(crate) fn new(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
416                #location_capture
417                move |#src_ident| Self {
418                    #src_ident,
419                    #(#user_field_names,)*
420                    #location_init
421                }
422            }
423        }
424    } else {
425        quote! {
426            #[doc = #doc]
427            #[track_caller]
428            pub(crate) fn new(#(#user_params),*) -> Self {
429                #location_capture
430                Self {
431                    #(#user_field_names,)*
432                    #location_init
433                }
434            }
435        }
436    }
437}
438
439fn gen_location_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
440    if let Some(loc) = &parsed.location {
441        let loc_ident = &loc.ident;
442        quote! {
443            Self::#variant_name { #loc_ident, .. } => Some(#loc_ident),
444        }
445    } else {
446        quote! {
447            Self::#variant_name { .. } => None,
448        }
449    }
450}
451
452fn gen_stack_source_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
453    if parsed.stack_source {
454        let src_ident = &parsed.source.unwrap().ident;
455        quote! {
456            Self::#variant_name { #src_ident, .. } => Some(#src_ident as &dyn ::errorstack::ErrorStack),
457        }
458    } else {
459        quote! {
460            Self::#variant_name { .. } => None,
461        }
462    }
463}