derive_error_kind/
lib.rs

1//! # derive-error-kind
2//!
3//! A Rust procedural macro for implementing the ErrorKind pattern that simplifies error classification and handling in complex applications.
4//!
5//! ## Motivation
6//!
7//! The ErrorKind pattern is a common technique in Rust for separating:
8//! - The **kind** of an error (represented by a simple enum)
9//! - The **details** of the error (contained in the error structure)
10//!
11//! This allows developers to handle errors more granularly without losing context.
12//!
13//! Rust's standard library uses this pattern in `std::io::ErrorKind`, and many other libraries have adopted it due to its flexibility. However, manually implementing this pattern can be repetitive and error-prone, especially in applications with multiple nested error types.
14//!
15//! This crate solves this problem by providing a derive macro that automates the implementation of the ErrorKind pattern.
16//!
17//! ## Overview
18//!
19//! The `ErrorKind` macro allows you to associate error types with a specific kind from an enum. This creates a clean and consistent way to categorize errors in your application, enabling more precise error handling.
20//!
21//! Key features:
22//! - Automatically implements a `.kind()` method that returns a categorized error type
23//! - Supports nested error types via the `transparent` attribute
24//! - Works with unit variants, named fields, and tuple variants
25//! - Enables transparent error propagation through error hierarchies
26//!
27//! ## Basic Usage
28//!
29//! First, define an enum for your error kinds:
30//!
31//! ```rust
32//! #[derive(Copy, Clone, Debug, Eq, PartialEq)]
33//! pub enum ErrorKind {
34//!     NotFound,
35//!     InvalidInput,
36//!     InternalError,
37//! }
38//! ```
39//!
40//! Then, use the `ErrorKind` derive macro on your error enums:
41//!
42//! ```rust
43//! use derive_error_kind::ErrorKind;
44//!
45//! #[derive(Debug, ErrorKind)]
46//! #[error_kind(ErrorKind)]
47//! pub enum MyError {
48//!     #[error_kind(ErrorKind, NotFound)]
49//!     ResourceNotFound,
50//!
51//!     #[error_kind(ErrorKind, InvalidInput)]
52//!     BadRequest { details: String },
53//!
54//!     #[error_kind(ErrorKind, InternalError)]
55//!     ServerError(String),
56//! }
57//!
58//! // Now you can use the .kind() method
59//! let error = MyError::ResourceNotFound;
60//! assert_eq!(error.kind(), ErrorKind::NotFound);
61//! ```
62//!
63//! ## Attribute Reference
64//!
65//! - `#[error_kind(KindEnum)]`: Top-level attribute that specifies which enum to use for error kinds
66//! - `#[error_kind(KindEnum, Variant)]`: Variant-level attribute that specifies which variant of the kind enum to return
67//! - `#[error_kind(transparent)]`: Variant-level attribute for nested errors, indicating that the inner error's kind should be used
68//!
69//! ## Requirements
70//!
71//! - The macro can only be applied to enums
72//! - Each variant must have an `error_kind` attribute
73//! - The kind enum must be in scope and accessible
74
75use proc_macro::TokenStream;
76use quote::quote;
77use syn::{
78    parse_macro_input, punctuated::Punctuated, DeriveInput, Meta, MetaList, NestedMeta, Path,
79};
80
81/// Create a kind method for struct
82/// # Examples
83/// ```
84/// use derive_error_kind::ErrorKind;
85///#[derive(Copy, Clone, Debug, Eq, PartialEq)]
86/// enum ErrorType {
87///     A,
88///     B,
89///     C,
90/// }
91///
92/// #[derive(ErrorKind)]
93/// #[error_kind(ErrorType)]
94/// enum CacheError {
95///     #[error_kind(ErrorType, A)]
96///     Poisoned,
97///
98///     #[error_kind(ErrorType, B)]
99///     Missing,
100/// }
101///
102/// #[derive(ErrorKind)]
103/// #[error_kind(ErrorType)]
104/// enum ServiceError {
105///     #[error_kind(transparent)]
106///     Cache(CacheError),
107///
108///     #[error_kind(ErrorType, C)]
109///     Db,
110/// }
111///
112/// assert_eq!(ServiceError::Cache(CacheError::Missing).kind(), ErrorType::B);
113/// assert_eq!(ServiceError::Db.kind(), ErrorType::C);
114/// ```
115#[proc_macro_derive(ErrorKind, attributes(error_kind))]
116pub fn error_kind(input: TokenStream) -> TokenStream {
117    error_kind_macro(input)
118}
119
120fn error_kind_macro(input: TokenStream) -> TokenStream {
121    let input = parse_macro_input!(input as DeriveInput);
122    let kind_ty = get_kind_ty(&input);
123
124    let name = input.ident;
125    let variants = if let syn::Data::Enum(data) = input.data {
126        data.variants
127    } else {
128        panic!("ImplKind just can be used in enums");
129    };
130
131    let mut kind_variants = Vec::new();
132
133    for variant in variants.clone() {
134        let ident = variant.ident;
135        if let Some(attr) = variant
136            .attrs
137            .into_iter()
138            .find(|attr| attr.path.is_ident("error_kind"))
139        {
140            if let Ok(syn::Meta::List(meta)) = attr.parse_meta() {
141                if meta.nested.len() == 2 {
142                    if let (
143                        syn::NestedMeta::Meta(syn::Meta::Path(enum_ty)),
144                        syn::NestedMeta::Meta(syn::Meta::Path(variant)),
145                    ) = (&meta.nested[0], &meta.nested[1])
146                    {
147                        kind_variants.push((ident, enum_ty.clone(), Some(variant.clone())));
148                    } else {
149                        panic!("Invalid value for error_kind");
150                    }
151                } else if meta.nested.len() == 1 {
152                    for sub_meta in meta.nested {
153                        if let NestedMeta::Meta(Meta::Path(path)) = sub_meta {
154                            if path.is_ident("transparent") {
155                                kind_variants.push((ident.clone(), kind_ty.clone(), None));
156                            }
157                        } else {
158                            panic!("Invalid value for #[error_kind]");
159                        }
160                    }
161                } else {
162                    panic!("error_kind must have one two arguments");
163                }
164            } else {
165                panic!("Error parsing meta");
166            }
167        } else {
168            panic!("Enum variants must have the attribute `error_kind`");
169        }
170    }
171
172    let kind_enum = kind_variants
173        .first()
174        .expect("No variants in Enum")
175        .1
176        .clone();
177    let match_arms = kind_variants.into_iter().map(|(ident, enum_ty, variant)| {
178        let fields = &variants.iter().find(|v| v.ident == ident).unwrap().fields;
179        match fields {
180            syn::Fields::Unit => {
181                quote! {
182                    Self::#ident => #enum_ty::#variant,
183                }
184            }
185            syn::Fields::Named(_) => {
186                quote! {
187                    Self::#ident{..} => #enum_ty::#variant,
188                }
189            }
190            syn::Fields::Unnamed(_) => match variant {
191                Some(v) => quote! {
192                    Self::#ident(..) => #enum_ty::#v,
193                },
194                None => quote! {
195                    Self::#ident(inner) => inner.kind(),
196                },
197            },
198        }
199    });
200
201    let expanded = quote! {
202        impl #name {
203            pub fn kind(&self) -> #kind_enum {
204                match self {
205                    #(#match_arms)*
206                }
207            }
208        }
209    };
210
211    TokenStream::from(expanded)
212}
213
214fn get_kind_ty(input: &DeriveInput) -> Path {
215    let metas = find_attribute(input, "error_kind")
216        .expect("#[derive(ErrorKind)] requires error_kind attribute");
217    if let Some(&NestedMeta::Meta(Meta::Path(ref path))) = metas.iter().next() {
218        path.to_owned()
219    } else {
220        panic!("#[error_kind(KIND_IDENT)] attribute requires and identifier");
221    }
222}
223
224/// Get an attribute from the input.
225/// 
226/// Adapted from https://crates.io/crates/enum-kinds
227fn find_attribute(
228    definition: &DeriveInput,
229    name: &str,
230) -> Option<Punctuated<NestedMeta, syn::token::Comma>> {
231    for attr in definition.attrs.iter() {
232        match attr.parse_meta() {
233            Ok(Meta::List(MetaList {
234                ref path,
235                ref nested,
236                ..
237            })) if path.is_ident(name) => return Some(nested.clone()),
238            _ => continue,
239        }
240    }
241    None
242}