err_as_you_go/
lib.rs

1//! Generate `enum` error types inline.
2//!
3//! If you want:
4//! - to easy throw errors inline like with [anyhow]
5//! - to make your error types handleable in a nice enum like [thiserror]
6//!
7//! then this is the crate for you!
8//!
9//! ```
10//! use err_as_you_go::err_as_you_go;
11//!
12//! #[err_as_you_go]
13//! fn shave_yaks(
14//!     num_yaks: usize,
15//!     empty_buckets: usize,
16//!     num_razors: usize,
17//! ) -> Result<(), ShaveYaksError> {
18//!     if num_razors == 0 {
19//!         return Err(err!(NotEnoughRazors));
20//!     }
21//!     if num_yaks > empty_buckets {
22//!         return Err(err!(NotEnoughBuckets {
23//!             got: usize = empty_buckets,
24//!             required: usize = num_yaks,
25//!         }));
26//!     }
27//!     Ok(())
28//! }
29//! ```
30//! Under the hood, a struct like this is generated:
31//! ```
32//! enum ShaveYaksError { // name and visibility are taken from function return type and visibility
33//!     NotEnoughRazors,
34//!     NotEnoughBuckets {
35//!         got: usize,
36//!         required: usize,
37//!     }
38//! }
39//! ```
40//!
41//! Importantly, you can derive on the generated struct, _and_ passthrough attributes, allowing you to use crates like [thiserror].
42//! ```
43//! # use err_as_you_go::err_as_you_go;
44//!
45//! #[err_as_you_go(derive(Debug, thiserror::Error))]
46//! fn shave_yaks(
47//!     num_yaks: usize,
48//!     empty_buckets: usize,
49//!     num_razors: usize,
50//! ) -> Result<(), ShaveYaksError> {
51//!     if num_razors == 0 {
52//!         return Err(err!(
53//!             #[error("not enough razors!")]
54//!             NotEnoughRazors
55//!         ));
56//!     }
57//!     if num_yaks > empty_buckets {
58//!         return Err(err!(
59//!             #[error("not enough buckets - needed {required}")]
60//!             NotEnoughBuckets {
61//!                 got: usize = empty_buckets,
62//!                 required: usize = num_yaks,
63//!             }
64//!         ));
65//!     }
66//!     Ok(())
67//! }
68//! ```
69//!
70//! Which generates the following:
71//! ```
72//! #[derive(Debug, thiserror::Error)]
73//! enum ShaveYaksError {
74//!     #[error("not enough razors!")]
75//!     NotEnoughRazors,
76//!     #[error("not enough buckets - needed {required}")]
77//!     NotEnoughBuckets {
78//!         got: usize,
79//!         required: usize,
80//!     }
81//! }
82//! ```
83//! And `err!` macro invocations are replaced with struct instantiations - no matter where they are in the function body!
84//!
85//! If you need to reuse the same variant within a function, just use the normal construction syntax:
86//! ```
87//! # use err_as_you_go::err_as_you_go;
88//! # use std::io;
89//! # fn fallible_op() -> Result<(), io::Error> { todo!() }
90//! #[err_as_you_go]
91//! fn foo() -> Result<(), FooError> {
92//!     fallible_op().map_err(|e| err!(IoError(io::Error = e)));
93//!     Err(FooError::IoError(todo!()))
94//! }
95//! ```
96//!
97//! [anyhow]: https://docs.rs/anyhow
98//! [thiserror]: https://docs.rs/thiserror
99
100use config::Config;
101use data::VariantWithValue;
102use log::debug;
103use proc_macro2::{Ident, Span, TokenStream};
104use proc_macro_error::{emit_error, proc_macro_error};
105use quote::{quote, ToTokens};
106use syn::{
107    parse2, parse_macro_input, visit_mut::VisitMut, AngleBracketedGenericArguments,
108    GenericArgument, ItemFn, Path, PathArguments, PathSegment, ReturnType, TypePath,
109};
110
111mod config;
112mod data;
113
114/// See [module documentation](index.html) for general usage.
115///
116/// # `err!` construction
117/// Instances of `err!` will be parsed like so:
118/// ```
119/// # #[err_as_you_go::err_as_you_go]
120/// # fn foo() -> Result<(), FooError> {
121/// err!(Unity);                        // A unit enum variant
122/// err!(Tuply(usize = 1, char = 'a')); // A tuple enum variant
123/// err!(Structy {                      // A struct enum variant
124///         u: usize = 1,
125///         c: char = 'a',
126/// });
127/// # Ok(())
128/// # }
129/// ```
130/// # Arguments
131/// `derive` arguments are passed through to the generated struct.
132/// ```
133/// # use err_as_you_go::err_as_you_go;
134/// #[err_as_you_go(derive(Debug, Clone, Copy))]
135/// # fn foo() -> Result<(), FooError> { Ok(()) }
136/// ```
137///
138/// `attributes` arguments are passed through to the top of the generated struct
139/// ```
140/// # use err_as_you_go::err_as_you_go;
141/// #[err_as_you_go(attributes(
142///     #[must_use = "maybe you missed something!"]
143///     #[repr(u8)]
144/// ))]
145/// # fn foo() -> Result<(), FooError> { Ok(()) }
146/// ```
147/// `visibility` can be used to override the generated struct's visibility.
148/// ```
149/// # use err_as_you_go::err_as_you_go;
150/// #[err_as_you_go(visibility(pub))]
151/// # fn foo() -> Result<(), FooError> { Ok(()) }
152/// ```
153#[proc_macro_attribute]
154#[proc_macro_error]
155pub fn err_as_you_go(
156    attr: proc_macro::TokenStream,
157    item: proc_macro::TokenStream,
158) -> proc_macro::TokenStream {
159    pretty_env_logger::try_init_custom_env("RUST_LOG_ERR_AS_YOU_GO").ok();
160
161    debug!("attr={attr:?}");
162
163    //////////////////////
164    // Parse our inputs //
165    //////////////////////
166    let config = parse_macro_input!(attr as Config);
167    let mut item = parse_macro_input!(item as ItemFn);
168
169    debug!("config={config:?}");
170
171    let Some(error_name) = get_struct_name_from_return_type(&item.sig.output) else {
172        emit_error!(
173            item.sig,
174            "unsupported return type - function must return a `Result<_, SomeConcreteErr>`"
175        );
176        return quote!(#item).into();
177    };
178    let error_vis = config.visibility.unwrap_or_else(|| item.vis.clone());
179
180    let mut visitor = ErrAsYouGoVisitor::new(error_name.clone());
181    visitor.visit_item_fn_mut(&mut item);
182
183    for (src, e) in visitor.collection_errors {
184        emit_error!(src, "{}", e)
185    }
186
187    let variants = visitor.variants;
188    let derives = match config.derives {
189        Some(derives) => quote!(#[derive(
190            #(#derives),*
191        )]),
192        None => quote!(),
193    };
194
195    quote! {
196        #derives
197        #error_vis enum #error_name {
198            #(#variants),*
199        }
200
201        #item
202    }
203    .into()
204}
205
206fn get_struct_name_from_return_type(return_type: &ReturnType) -> Option<Ident> {
207    if let ReturnType::Type(_, ty) = return_type {
208        if let syn::Type::Path(TypePath {
209            qself: None,
210            path: Path { ref segments, .. },
211        }) = **ty
212        {
213            if let Some(PathSegment {
214                ident,
215                arguments:
216                    PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }),
217            }) = segments.last()
218            {
219                if ident == "Result" && args.len() == 2 {
220                    if let Some(GenericArgument::Type(syn::Type::Path(TypePath {
221                        qself: None,
222                        path:
223                            Path {
224                                segments,
225                                leading_colon: None,
226                            },
227                    }))) = args.into_iter().nth(1)
228                    {
229                        if segments.len() == 1 {
230                            let PathSegment { ident, arguments } = &segments[0];
231                            if arguments.is_empty() {
232                                return Some(ident.clone());
233                            }
234                        }
235                    }
236                }
237            }
238        }
239    }
240    None
241}
242
243/// Implementation detail
244// Allows use to swap the macro in-place in our visitor.
245#[doc(hidden)]
246#[proc_macro]
247#[proc_macro_error]
248pub fn __nothing(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
249    input
250}
251
252struct ErrAsYouGoVisitor {
253    error_name: Ident,
254    variants: Vec<syn::Variant>,
255    collection_errors: Vec<(TokenStream, syn::Error)>,
256}
257
258impl ErrAsYouGoVisitor {
259    fn new(error_name: Ident) -> Self {
260        Self {
261            error_name,
262            variants: Vec::new(),
263            collection_errors: Vec::new(),
264        }
265    }
266}
267
268impl syn::visit_mut::VisitMut for ErrAsYouGoVisitor {
269    fn visit_macro_mut(&mut self, i: &mut syn::Macro) {
270        if i.path.is_ident("err") {
271            match parse2::<VariantWithValue>(i.tokens.clone()) {
272                Ok(variant_with_value) => {
273                    self.variants
274                        .push(variant_with_value.clone().into_syn_variant());
275                    i.path = path(["err_as_you_go", "__nothing"]);
276                    i.tokens = variant_with_value
277                        .into_syn_expr_with_prefix(Path::from(self.error_name.clone()))
278                        .into_token_stream();
279                }
280                Err(e) => self.collection_errors.push((i.tokens.clone(), e)),
281            }
282        }
283    }
284}
285
286fn path<'a>(segments: impl IntoIterator<Item = &'a str>) -> Path {
287    syn::Path {
288        leading_colon: None,
289        segments: segments
290            .into_iter()
291            .map(|segment| PathSegment::from(ident(segment)))
292            .collect(),
293    }
294}
295
296fn ident(s: &str) -> Ident {
297    Ident::new(s, Span::call_site())
298}
299
300#[cfg(test)]
301mod test_utils {
302
303    pub fn test_parse<T>(tokens: proc_macro2::TokenStream, expected: T)
304    where
305        T: syn::parse::Parse + PartialEq + std::fmt::Debug,
306    {
307        let actual = syn::parse2::<T>(tokens).expect("couldn't parse tokens");
308        pretty_assertions::assert_eq!(expected, actual);
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    #[test]
316    fn trybuild() {
317        let t = trybuild::TestCases::new();
318        t.pass("trybuild/pass/**/*.rs");
319        t.compile_fail("trybuild/fail/**/*.rs")
320    }
321
322    #[test]
323    fn readme() {
324        let expected = std::process::Command::new("cargo")
325            .arg("readme")
326            .output()
327            .expect("couldn't run `cargo readme`");
328        let expected = String::from_utf8_lossy(&expected.stdout);
329        let actual = std::fs::read("README.md").expect("couldn't read README.md");
330        let actual = String::from_utf8_lossy(&actual);
331        pretty_assertions::assert_eq!(expected, actual);
332    }
333
334    #[test]
335    fn get_result_name() {
336        let ident = get_struct_name_from_return_type(
337            &syn::parse2(quote!(-> Result<T, SomeConcreteErr>)).unwrap(),
338        )
339        .unwrap();
340        assert_eq!(ident, "SomeConcreteErr");
341
342        let ident = get_struct_name_from_return_type(
343            &syn::parse2(quote!(-> ::std::result::Result<T, SomeConcreteErr>)).unwrap(),
344        )
345        .unwrap();
346        assert_eq!(ident, "SomeConcreteErr");
347    }
348}