errorstack_derive/lib.rs
1use heck::ToSnakeCase;
2use proc_macro::TokenStream;
3use proc_macro2::TokenStream as TokenStream2;
4use quote::quote;
5use syn::{Data, DeriveInput, Field, Fields, Ident};
6
7/// Derive macro for [`ErrorStack`].
8///
9/// Supports enums and structs with named fields. Note that the type must
10/// also implement [`Display`](std::fmt::Display) and
11/// [`Error`](std::error::Error). This can be accomplished manually or via
12/// [`thiserror`](https://crates.io/crates/thiserror).
13///
14/// This macro implements [`ErrorStack`] according to field names and
15/// attributes, and generates an ergonomic constructor for each struct or
16/// enum variant that captures caller location via `#[track_caller]` and
17/// composes naturally with [`Result::map_err`] for error chaining.
18///
19/// # Attributes
20///
21/// The following field attributes are available:
22///
23/// | Attribute | Effect | Auto-detected |
24/// |-------------------|---------------------------------------------------------------------------|---------------|
25/// | `#[source]` | Marks a field as the error source. | when field is named `source` |
26/// | `#[stack_source]` | Marks the field as both the error source and an [`ErrorStack`] implementor, enabling typed chain walking via `ErrorStack::stack_source`. Implies `#[source]`. | no |
27/// | `#[location]` | Indicates the field stores a `&'static Location<'static>`, captured automatically at construction time. | no |
28///
29/// These attributes follow the same field conventions as
30/// [`thiserror`](https://crates.io/crates/thiserror), allowing
31/// both crates to be ergonomically used together.
32///
33/// # Stack sources
34///
35/// Any source field that implements [`ErrorStack`] should be annotated with
36/// `#[stack_source]` to preserve the typed error chain. The macro cannot
37/// inspect trait implementations, so without this annotation the source is
38/// treated as a plain [`std::error::Error`] and chain walking stops at that
39/// field.
40///
41/// # Error constructors
42///
43/// This macro also generates helper constructors for each struct or enum
44/// variant. Every constructor is marked `#[track_caller]`, so the
45/// call-site location is recorded without manual boilerplate. When a
46/// source field is present the constructor returns
47/// `impl FnOnce(SourceTy) -> Self`, so it can be passed directly to
48/// [`Result::map_err`] without an intermediate closure.
49///
50/// Constructors are `pub(crate)` and named `new` for structs or
51/// `snake_cased_variant` for enum variants. Remaining fields
52/// become parameters, while `#[source]` and `#[location]` fields are filled
53/// automatically.
54///
55/// # Examples
56///
57/// The macro may be derived on enums and structs with named fields. This
58/// example shows both, with `thiserror` compatibility.
59///
60/// ```
61/// # use errorstack::ErrorStack;
62/// #[derive(thiserror::Error, ErrorStack, Debug)]
63/// pub enum AppError {
64/// #[error("io failed: {path}")]
65/// Io {
66/// path: String,
67/// source: std::io::Error,
68/// #[location]
69/// location: &'static std::panic::Location<'static>,
70/// },
71///
72/// #[error("inner failed")]
73/// Inner {
74/// #[stack_source]
75/// source: ConfigError,
76/// #[location]
77/// location: &'static std::panic::Location<'static>,
78/// },
79///
80/// #[error("not found: {id}")]
81/// NotFound {
82/// id: String,
83/// #[location]
84/// location: &'static std::panic::Location<'static>,
85/// },
86/// }
87///
88/// #[derive(thiserror::Error, ErrorStack, Debug)]
89/// #[error("config: {detail}")]
90/// pub struct ConfigError {
91/// detail: String,
92/// #[location]
93/// location: &'static std::panic::Location<'static>,
94/// }
95/// ```
96///
97/// The derive above generates the following constructors:
98///
99/// ```text
100/// // AppError: one constructor per variant
101/// impl AppError {
102/// // Source variants return a closure for use with map_err.
103/// pub(crate) fn io(path: String) -> impl FnOnce(io::Error) -> Self;
104/// pub(crate) fn inner() -> impl FnOnce(ConfigError) -> Self;
105/// // Sourceless variants return Self directly.
106/// pub(crate) fn not_found(id: String) -> Self;
107/// }
108///
109/// // ConfigError: struct constructor is named `new`
110/// impl ConfigError {
111/// pub(crate) fn new(detail: String) -> Self;
112/// }
113/// ```
114///
115/// Source and location fields are handled automatically by these
116/// constructors, keeping call sites concise:
117///
118/// ```
119/// # use errorstack::ErrorStack;
120/// # #[derive(thiserror::Error, ErrorStack, Debug)]
121/// # pub enum AppError {
122/// # #[error("io failed: {path}")]
123/// # Io {
124/// # path: String,
125/// # source: std::io::Error,
126/// # #[location]
127/// # location: &'static std::panic::Location<'static>,
128/// # },
129/// # #[error("not found: {id}")]
130/// # NotFound {
131/// # id: String,
132/// # #[location]
133/// # location: &'static std::panic::Location<'static>,
134/// # },
135/// # }
136/// # fn main() -> Result<(), AppError> {
137/// let _content = std::fs::read_to_string("Cargo.toml")
138/// .map_err(AppError::io("Cargo.toml".into()))?;
139///
140/// let _err = AppError::not_found("abc".into());
141/// # Ok(())
142/// # }
143/// ```
144#[proc_macro_derive(ErrorStack, attributes(source, stack_source, location))]
145pub fn derive_error_stack(input: TokenStream) -> TokenStream {
146 let input = syn::parse_macro_input!(input as DeriveInput);
147 match derive_impl(input) {
148 Ok(tokens) => tokens.into(),
149 Err(err) => err.to_compile_error().into(),
150 }
151}
152
153fn derive_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
154 let name = &input.ident;
155
156 match &input.data {
157 Data::Enum(data) => {
158 let mut constructor_methods = Vec::new();
159 let mut location_arms = Vec::new();
160 let mut stack_source_arms = Vec::new();
161
162 for variant in &data.variants {
163 let variant_name = &variant.ident;
164 let fields = match &variant.fields {
165 Fields::Named(f) => f,
166 Fields::Unnamed(_) => {
167 return Err(syn::Error::new(
168 variant_name.span(),
169 format!(
170 "ErrorStack derive requires named (struct) variants; \
171 found tuple variant `{variant_name}`"
172 ),
173 ));
174 }
175 Fields::Unit => {
176 return Err(syn::Error::new(
177 variant_name.span(),
178 format!(
179 "ErrorStack derive requires named (struct) variants; \
180 found unit variant `{variant_name}`"
181 ),
182 ));
183 }
184 };
185
186 let parsed = parse_fields(&fields.named, variant_name)?;
187
188 constructor_methods.push(gen_constructor_enum(variant_name, &parsed));
189 location_arms.push(gen_location_arm_enum(variant_name, &parsed));
190 stack_source_arms.push(gen_stack_source_arm_enum(variant_name, &parsed));
191 }
192
193 Ok(quote! {
194 impl #name {
195 #(#constructor_methods)*
196 }
197
198 impl ::errorstack::ErrorStack for #name {
199 fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
200 match self {
201 #(#location_arms)*
202 }
203 }
204
205 fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
206 match self {
207 #(#stack_source_arms)*
208 }
209 }
210 }
211 })
212 }
213
214 Data::Struct(data) => {
215 let fields = match &data.fields {
216 Fields::Named(f) => f,
217 _ => {
218 return Err(syn::Error::new(
219 name.span(),
220 "ErrorStack derive requires named fields",
221 ));
222 }
223 };
224
225 let parsed = parse_fields(&fields.named, name)?;
226 let constructor = gen_constructor_struct(name, &parsed);
227
228 let location_body = if let Some(loc) = &parsed.location {
229 let loc_ident = &loc.ident;
230 quote! { Some(self.#loc_ident) }
231 } else {
232 quote! { None }
233 };
234
235 let stack_source_body = if parsed.stack_source {
236 let src = parsed.source.as_ref().unwrap();
237 let src_ident = &src.ident;
238 quote! { Some(&self.#src_ident as &dyn ::errorstack::ErrorStack) }
239 } else {
240 quote! { None }
241 };
242
243 Ok(quote! {
244 impl #name {
245 #constructor
246 }
247
248 impl ::errorstack::ErrorStack for #name {
249 fn location(&self) -> Option<&'static ::std::panic::Location<'static>> {
250 #location_body
251 }
252
253 fn stack_source(&self) -> Option<&dyn ::errorstack::ErrorStack> {
254 #stack_source_body
255 }
256 }
257 })
258 }
259
260 Data::Union(_) => Err(syn::Error::new(
261 name.span(),
262 "ErrorStack derive is not supported on unions",
263 )),
264 }
265}
266
267struct ParsedFields<'a> {
268 source: Option<&'a Field>,
269 location: Option<&'a Field>,
270 stack_source: bool,
271 user_fields: Vec<&'a Field>,
272}
273
274fn attr(field: &Field, name: &str) -> bool {
275 field.attrs.iter().any(|a| a.path().is_ident(name))
276}
277
278fn parse_fields<'a>(
279 fields: &'a syn::punctuated::Punctuated<Field, syn::Token![,]>,
280 context_name: &Ident,
281) -> syn::Result<ParsedFields<'a>> {
282 let mut source: Option<&Field> = None;
283 let mut location: Option<&Field> = None;
284 let mut stack_source = false;
285 let mut user_fields = Vec::new();
286
287 for field in fields {
288 let ident = field.ident.as_ref().unwrap();
289 let source_by_name = ident == "source";
290 let source_by_attr = attr(field, "source");
291 let location_attr = attr(field, "location");
292 let stack_source_attr = attr(field, "stack_source");
293
294 if source_by_name || source_by_attr || stack_source_attr {
295 if source.is_some() {
296 return Err(syn::Error::new(
297 ident.span(),
298 format!("variant `{context_name}` has multiple source fields"),
299 ));
300 }
301 source = Some(field);
302 if stack_source_attr {
303 stack_source = true;
304 }
305 } else if location_attr {
306 if location.is_some() {
307 return Err(syn::Error::new(
308 ident.span(),
309 format!("variant `{context_name}` has multiple location fields"),
310 ));
311 }
312 location = Some(field);
313 } else {
314 user_fields.push(field);
315 }
316 }
317
318 Ok(ParsedFields {
319 source,
320 location,
321 stack_source,
322 user_fields,
323 })
324}
325
326fn gen_constructor_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
327 let method_name = Ident::new(
328 &variant_name.to_string().to_snake_case(),
329 variant_name.span(),
330 );
331
332 let user_params: Vec<_> = parsed
333 .user_fields
334 .iter()
335 .map(|f| {
336 let ident = &f.ident;
337 let ty = &f.ty;
338 quote! { #ident: #ty }
339 })
340 .collect();
341
342 let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
343
344 let location_init = parsed.location.as_ref().map(|f| {
345 let ident = &f.ident;
346 quote! { #ident: location, }
347 });
348
349 let location_capture = parsed.location.as_ref().map(|_| {
350 quote! { let location = ::std::panic::Location::caller(); }
351 });
352
353 let doc = format!("Constructs a [`{variant_name}`](Self::{variant_name}) error.");
354
355 if let Some(src) = &parsed.source {
356 let src_ident = &src.ident;
357 let src_ty = &src.ty;
358 quote! {
359 #[doc = #doc]
360 #[track_caller]
361 pub(crate) fn #method_name(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
362 #location_capture
363 move |#src_ident| Self::#variant_name {
364 #src_ident,
365 #(#user_field_names,)*
366 #location_init
367 }
368 }
369 }
370 } else {
371 quote! {
372 #[doc = #doc]
373 #[track_caller]
374 pub(crate) fn #method_name(#(#user_params),*) -> Self {
375 #location_capture
376 Self::#variant_name {
377 #(#user_field_names,)*
378 #location_init
379 }
380 }
381 }
382 }
383}
384
385fn gen_constructor_struct(type_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
386 let user_params: Vec<_> = parsed
387 .user_fields
388 .iter()
389 .map(|f| {
390 let ident = &f.ident;
391 let ty = &f.ty;
392 quote! { #ident: #ty }
393 })
394 .collect();
395
396 let user_field_names: Vec<_> = parsed.user_fields.iter().map(|f| &f.ident).collect();
397
398 let location_init = parsed.location.as_ref().map(|f| {
399 let ident = &f.ident;
400 quote! { #ident: location, }
401 });
402
403 let location_capture = parsed.location.as_ref().map(|_| {
404 quote! { let location = ::std::panic::Location::caller(); }
405 });
406
407 let doc = format!("Constructs a [`{type_name}`].");
408
409 if let Some(src) = &parsed.source {
410 let src_ident = &src.ident;
411 let src_ty = &src.ty;
412 quote! {
413 #[doc = #doc]
414 #[track_caller]
415 pub(crate) fn new(#(#user_params),*) -> impl ::std::ops::FnOnce(#src_ty) -> Self {
416 #location_capture
417 move |#src_ident| Self {
418 #src_ident,
419 #(#user_field_names,)*
420 #location_init
421 }
422 }
423 }
424 } else {
425 quote! {
426 #[doc = #doc]
427 #[track_caller]
428 pub(crate) fn new(#(#user_params),*) -> Self {
429 #location_capture
430 Self {
431 #(#user_field_names,)*
432 #location_init
433 }
434 }
435 }
436 }
437}
438
439fn gen_location_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
440 if let Some(loc) = &parsed.location {
441 let loc_ident = &loc.ident;
442 quote! {
443 Self::#variant_name { #loc_ident, .. } => Some(#loc_ident),
444 }
445 } else {
446 quote! {
447 Self::#variant_name { .. } => None,
448 }
449 }
450}
451
452fn gen_stack_source_arm_enum(variant_name: &Ident, parsed: &ParsedFields<'_>) -> TokenStream2 {
453 if parsed.stack_source {
454 let src_ident = &parsed.source.unwrap().ident;
455 quote! {
456 Self::#variant_name { #src_ident, .. } => Some(#src_ident as &dyn ::errorstack::ErrorStack),
457 }
458 } else {
459 quote! {
460 Self::#variant_name { .. } => None,
461 }
462 }
463}