compound_error/
lib.rs

1extern crate proc_macro;
2
3mod util;
4
5use std::collections::HashMap;
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::parse_macro_input;
10use syn::Data;
11use syn::DeriveInput;
12use syn::Fields;
13use syn::Ident;
14use syn::Meta;
15use syn::NestedMeta;
16use syn::Path;
17use syn::Type;
18use util::attr_args;
19use util::error;
20use util::flag;
21
22macro_rules! try_compile {
23	($what:expr, | $err:ident | $ret:expr) => {{
24		match $what {
25			Err($err) => return $ret,
26			Ok(ok) => ok,
27		}
28	}};
29}
30
31macro_rules! flag {
32	($args:expr, $arg:expr) => {
33		try_compile!(flag($args, $arg), |path| {
34			error(path, &format!("'{}' attribute takes no arguments!", $arg))
35		})
36	};
37}
38
39#[derive(Debug, Clone, Eq, PartialEq, Hash)]
40enum PathOrLit {
41	Path(syn::Path),
42	Lit(syn::TypePath),
43}
44
45impl PathOrLit {
46	fn path(&self) -> syn::Path {
47		let mut path = {
48			match self {
49				Self::Path(path) => path,
50				Self::Lit(type_path) => &type_path.path,
51			}
52			.clone()
53		};
54
55		path.segments.last_mut().unwrap().arguments = syn::PathArguments::None;
56		path
57	}
58}
59
60impl quote::ToTokens for PathOrLit {
61	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
62		match self {
63			Self::Path(path) => path.to_tokens(tokens),
64			Self::Lit(ty) => ty.to_tokens(tokens),
65		}
66	}
67}
68
69/// Implement `CompoundError` functionality for the target type.
70///
71/// If the target is an enum, `From` is implemented for each variant.
72/// Additionally, variants can be annotated with
73/// `#[compound_error( inline_from(X) )]`
74/// to specify that an "inlining from `X`" should be implemented. In addition
75/// to the `From` impls, by default also `std::error::Error` and
76/// `std::fmt::Display` are implemented for the target type. If the target type
77/// is a struct, no `From` impls, but only `std::error::Error` and
78/// `std::fmt::Display` impls are generated.
79///
80/// The generation of the `Error` and `Display` impls can by suppressed by
81/// specifying `#[compound_error( skip_error )]` and
82/// `#[compound_error( skip_display )]` on the target type.
83///
84/// If the target type is an enum, all variants must take exactly one argument.
85/// By default, this argument must implement `std::error::Error`. This can be
86/// circumvented by either specifying the `skip_error` attribute on the target
87/// type or by specifying the `no_source` attribute on the respective variant.
88/// `no_source` causes `None` to be returned by the implementation of
89/// `std::error::Error::source()` on the target type for the respective enum
90/// variant.
91///
92/// # Attributes
93///
94/// Attributes are specified in the following form:
95///
96/// ```text
97/// #[compound_error( attr1, attr2, attr3, ... )]
98/// #[compound_error( attr4, attr5, ... )]
99/// <ELEMENT>
100/// ```
101///
102/// `<ELEMENT>` can be the target type or an enum variant. The following
103/// attributes are available:
104///
105/// On the target type:
106/// * `title = "<title>"`: Set the title of this error to `"<title>"`. This is
107///   relevant for the automatic `Display` implementation on the target type.
108/// * `description = "<description>"`: Set the description of this error to
109///   `"<description>"`. This is relevant for the automatic `Display`
110///   implementation on the target type.
111/// * `skip_display`: Skip the automatic implementation of `std::fmt::Display`
112///   on the target type.
113/// * `skip_error`: Skip the automatic implementation of `std::error::Error` on
114///   the target type.
115/// * `transparent`: forward the source and Display methods through to all
116///   underlying errors without adding an additional message.
117///
118/// On each enum variant:
119/// * `inline_from(A,B,C,...)`: Inline the Errors `A`, `B`, `C`, ... in the
120///   target type.
121/// * `no_source`: Return `None` from `<Self as std::error::Error>::source()`
122///   for this enum variant. This lifts the requirement that `std::error::Error`
123///   is implemented for the argument of this variant.
124/// * `convert_source(fn)`: Applies `fn` to the error of this variant before
125///   returing it from `<Self as std::error::Error>::source()`
126/// * `transparent`: forward the source and Display methods through to the
127///   argument of this variant without adding an additional message.
128///
129#[proc_macro_derive(CompoundError, attributes(compound_error))]
130pub fn derive_compound_error(input: TokenStream) -> TokenStream {
131	let input = parse_macro_input!(input as DeriveInput);
132	let original_input = input.clone();
133	let ident = input.ident.clone();
134	let generics = input.generics;
135	let (generics_impl, generics_type, generics_where) = generics.split_for_impl();
136
137	let mut toplevel_args = try_compile!(
138		attr_args(
139			&input.attrs,
140			"compound_error",
141			&[
142				"title",
143				"description",
144				"skip_display",
145				"skip_error",
146				"transparent"
147			]
148		),
149		|err| err.explain()
150	);
151
152	let title_attr = toplevel_args.remove(&"title");
153	let title = {
154		if let Some(attr) = title_attr {
155			if attr.values.len() != 1 {
156				return error(&attr.path, "'title' takes exactly one string argument!");
157			}
158			match &attr.values[0] {
159				NestedMeta::Lit(syn::Lit::Str(lit)) => lit.value(),
160				_ => return error(&attr.path, "'title' argument must be a string!"),
161			}
162		} else {
163			ident.to_string()
164		}
165	};
166
167	let description_attr = toplevel_args.remove(&"description");
168	let description = {
169		if let Some(attr) = description_attr {
170			if attr.values.len() != 1 {
171				return error(
172					&attr.path,
173					"'description' takes exactly one string argument!",
174				);
175			}
176			match &attr.values[0] {
177				NestedMeta::Lit(syn::Lit::Str(lit)) => {
178					format!(" ({})", lit.value())
179				},
180				_ => return error(&attr.path, "'description' argument must be a string!"),
181			}
182		} else {
183			"".into()
184		}
185	};
186
187	let skip_display = flag!(&toplevel_args, &"skip_display");
188	let skip_error = flag!(&toplevel_args, &"skip_error");
189
190	#[allow(unused_assignments)]
191	let mut err_source = proc_macro2::TokenStream::new();
192	let mut from_enums: HashMap<PathOrLit, Vec<Ident>> = HashMap::new();
193	let mut from_structs: Vec<(Path, Ident)> = Vec::new();
194
195	#[allow(unused_assignments)]
196	let mut display = proc_macro2::TokenStream::new();
197
198	match input.data {
199		Data::Enum(data) => {
200			let transparent_enum = flag!(&toplevel_args, &"transparent");
201
202			let mut err_sources = proc_macro2::TokenStream::new();
203
204			let mut display_cases = Vec::new();
205
206			for variant in data.variants {
207				let variant_ident = variant.ident;
208				let variant_ident_str = variant_ident.to_string();
209				let field = {
210					match variant.fields {
211						Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
212							fields.unnamed[0].clone()
213						},
214						_ => {
215							return error(
216								&original_input,
217								&format!(
218									"Variant '{}' must specify exactly one unnamed field!",
219									variant_ident
220								),
221							)
222						},
223					}
224				};
225
226				let primitive_type_path = {
227					if let Type::Path(ty) = field.ty {
228						ty.path
229					} else {
230						return error(
231							&original_input,
232							&format!(
233								"Variant '{}' must specify exactly one unnamed field referencing \
234								 a type!",
235								variant_ident
236							),
237						);
238					}
239				};
240
241				let mut args = {
242					match attr_args(
243						&variant.attrs,
244						"compound_error",
245						&[
246							"inline_from",
247							"skip_single_from",
248							"no_source",
249							"convert_source",
250							"transparent",
251						],
252					) {
253						Err(err) => return err.explain(),
254						Ok(ok) => ok,
255					}
256				};
257
258				if let Some(from_attr) = args.remove(&"inline_from") {
259					for nested in from_attr.values {
260						match nested {
261							NestedMeta::Meta(Meta::Path(path)) => {
262								from_enums
263									.entry(PathOrLit::Path(path))
264									.or_default()
265									.push(variant_ident.clone());
266							},
267							NestedMeta::Lit(syn::Lit::Str(lit)) => {
268								let parsed_ty = {
269									match lit.parse() {
270										Err(_) => {
271											return error(
272												&from_attr.path,
273												"'inline_from' attribute must be a list of types!",
274											)
275										},
276										Ok(ok) => ok,
277									}
278								};
279								from_enums
280									.entry(PathOrLit::Lit(parsed_ty))
281									.or_default()
282									.push(variant_ident.clone());
283							},
284							_ => {
285								return error(
286									&from_attr.path,
287									"'inline_from' attribute must be a list of types!",
288								)
289							},
290						}
291					}
292				}
293
294				let skip_single_from = flag!(&args, &"skip_single_from");
295				let transparent = flag!(&args, &"transparent") || transparent_enum;
296
297				// If it's not a pure generic variant, implement from
298				if !skip_single_from
299					&& !generics
300						.type_params()
301						.any(|p| primitive_type_path.is_ident(&p.ident))
302				{
303					from_structs.push((primitive_type_path, variant_ident.clone()));
304				}
305
306				let variant_display;
307
308				let no_source = flag!(&args, &"no_source");
309
310				if !no_source {
311					let src_ret = {
312						if let Some(convert_source_attr) = args.remove(&"convert_source") {
313							if convert_source_attr.values.len() != 1 {
314								return crate::error(
315									&convert_source_attr.path,
316									"'convert_source' takes exactly one argument!",
317								);
318							}
319
320							match &convert_source_attr.values[0] {
321								NestedMeta::Meta(Meta::Path(path)) => {
322									quote!( #path (x) )
323								},
324								_ => {
325									return crate::error(
326										&convert_source_attr.path,
327										"The argument of 'convert_source' must be a path!",
328									)
329								},
330							}
331						} else {
332							quote!(x)
333						}
334					};
335
336					variant_display = quote!(x);
337
338					if transparent {
339						err_sources.extend(quote! {
340							Self::#variant_ident(x) => std::error::Error::source(x),
341						});
342					} else {
343						err_sources.extend(quote! {
344							Self::#variant_ident(x) => Some( #src_ret ),
345						});
346					}
347				} else {
348					variant_display = quote!(#variant_ident_str);
349				}
350
351				if transparent {
352					display_cases.push(quote! {
353						Self::#variant_ident (x) => {
354							std::fmt::Display::fmt(x, f)?;
355						}
356					});
357				} else {
358					display_cases.push(quote! {
359						Self::#variant_ident (x) => {
360							writeln!(f, "{}{}:", #title, #description)?;
361							write!(f, "  └ {}", #variant_display)?;
362						}
363					});
364				}
365			}
366
367			display_cases.push(quote! {
368				_ => {}
369			});
370
371			display = quote! {
372				match self {
373					#(#display_cases),*
374				}
375				Ok(())
376			};
377
378			err_source = quote! {
379				match self {
380					#err_sources
381					_ => ::core::option::Option::None
382				}
383			};
384		},
385		Data::Struct(_) => {
386			display = quote! {
387				write!(f, "{}{}", #title, #description)
388			};
389
390			err_source = quote!(None);
391		},
392		_ => {
393			return error(&original_input, "Can only be used on enums!");
394		},
395	}
396
397	let mut generated = proc_macro2::TokenStream::new();
398
399	for (from_struct, variant_ident) in from_structs {
400		let stream = quote! {
401			#[automatically_derived]
402			impl #generics_impl ::core::convert::From< #from_struct > for #ident #generics_type #generics_where {
403				fn from(primitive: #from_struct) -> Self {
404					Self::#variant_ident( primitive )
405				}
406			}
407		};
408
409		generated.extend(stream);
410	}
411
412	for (from_enum, variant_idents) in from_enums {
413		let mut cases = proc_macro2::TokenStream::new();
414		let from_enum_path = from_enum.path();
415
416		for variant_ident in variant_idents {
417			cases.extend(quote! {
418				#from_enum_path::#variant_ident( p ) => Self::#variant_ident(p),
419			});
420		}
421
422		let stream = quote! {
423			#[automatically_derived]
424			impl #generics_impl ::core::convert::From< #from_enum > for #ident #generics_type #generics_where {
425				fn from(composite: #from_enum) -> Self {
426					match composite {
427						#cases
428					}
429				}
430			}
431		};
432
433		generated.extend(stream);
434	}
435
436	if !skip_display {
437		generated.extend(quote! {
438			#[automatically_derived]
439			impl #generics_impl ::core::fmt::Display for #ident #generics_type #generics_where {
440				fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
441					#display
442				}
443			}
444		});
445	}
446
447	// BTW: requires `std`
448	if !skip_error {
449		generated.extend(quote! {
450			#[automatically_derived]
451			impl #generics_impl ::std::error::Error for #ident #generics_type #generics_where {
452				fn source(&self) -> ::std::option::Option<&(dyn ::std::error::Error + 'static)> {
453					#err_source
454				}
455			}
456		});
457	}
458
459	generated.into()
460}