1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::generics::InferredBounds;
4use crate::span::MemberSpan;
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote, quote_spanned, ToTokens};
7use std::collections::BTreeSet as Set;
8use syn::{DeriveInput, GenericArgument, Member, PathArguments, Result, Token, Type};
9
10pub fn derive(input: &DeriveInput) -> TokenStream {
11 match try_expand(input) {
12 Ok(expanded) => expanded,
13 Err(error) => fallback(input, error),
17 }
18}
19
20fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
21 let input = Input::from_syn(input)?;
22 input.validate()?;
23 Ok(match input {
24 Input::Struct(input) => impl_struct(input),
25 Input::Enum(input) => impl_enum(input),
26 })
27}
28
29fn fallback(input: &DeriveInput, error: syn::Error) -> TokenStream {
30 let ty = &input.ident;
31 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
32
33 let error = error.to_compile_error();
34
35 quote! {
36 #error
37
38 #[allow(unused_qualifications)]
39 impl #impl_generics thiserror::StdError for #ty #ty_generics #where_clause
40 where
41 for<'workaround> #ty #ty_generics: ::core::fmt::Debug,
44 {}
45
46 #[allow(unused_qualifications)]
47 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
48 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
49 ::core::unreachable!()
50 }
51 }
52 }
53}
54
55fn impl_struct(input: Struct) -> TokenStream {
56 let ty = &input.ident;
57 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
58 let mut error_inferred_bounds = InferredBounds::new();
59
60 let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
61 let only_field = &input.fields[0];
62 if only_field.contains_generic {
63 error_inferred_bounds.insert(only_field.ty, quote!(thiserror::StdError));
64 }
65 let member = &only_field.member;
66 Some(quote_spanned! {transparent_attr.span=>
67 thiserror::StdError::source(self.#member.as_dyn_error())
68 })
69 } else if let Some(source_field) = input.source_field() {
70 let source = &source_field.member;
71 if source_field.contains_generic {
72 let ty = unoptional_type(source_field.ty);
73 error_inferred_bounds.insert(ty, quote!(thiserror::StdError + 'static));
74 }
75 let asref = if type_is_option(source_field.ty) {
76 Some(quote_spanned!(source.member_span()=> .as_ref()?))
77 } else {
78 None
79 };
80 let dyn_error = quote_spanned! {source_field.source_span()=>
81 self.#source #asref.as_dyn_error()
82 };
83 Some(quote! {
84 ::core::option::Option::Some(#dyn_error)
85 })
86 } else {
87 None
88 };
89 let source_method = source_body.map(|body| {
90 quote! {
91 fn source(&self) -> ::core::option::Option<&(dyn thiserror::StdError + 'static)> {
92 use thiserror::__private::AsDynError as _;
93 #body
94 }
95 }
96 });
97
98 #[cfg(feature = "std")]
99 let provide_method = input.backtrace_field().map(|backtrace_field| {
100 let request = quote!(request);
101 let backtrace = &backtrace_field.member;
102 let body = if let Some(source_field) = input.source_field() {
103 let source = &source_field.member;
104 let source_provide = if type_is_option(source_field.ty) {
105 quote_spanned! {source.member_span()=>
106 if let ::core::option::Option::Some(source) = &self.#source {
107 source.thiserror_provide(#request);
108 }
109 }
110 } else {
111 quote_spanned! {source.member_span()=>
112 self.#source.thiserror_provide(#request);
113 }
114 };
115 let self_provide = if source == backtrace {
116 None
117 } else if type_is_option(backtrace_field.ty) {
118 Some(quote! {
119 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
120 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
121 }
122 })
123 } else {
124 Some(quote! {
125 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
126 })
127 };
128 quote! {
129 use thiserror::__private::ThiserrorProvide as _;
130 #source_provide
131 #self_provide
132 }
133 } else if type_is_option(backtrace_field.ty) {
134 quote! {
135 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
136 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
137 }
138 }
139 } else {
140 quote! {
141 #request.provide_ref::<std::backtrace::Backtrace>(&self.#backtrace);
142 }
143 };
144 quote! {
145 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
146 #body
147 }
148 }
149 });
150
151 #[cfg(not(feature = "std"))]
152 let provide_method: Option<TokenStream> = None;
153
154 let mut display_implied_bounds = Set::new();
155 let display_body = if input.attrs.transparent.is_some() {
156 let only_field = &input.fields[0].member;
157 display_implied_bounds.insert((0, Trait::Display));
158 Some(quote! {
159 ::core::fmt::Display::fmt(&self.#only_field, __formatter)
160 })
161 } else if let Some(display) = &input.attrs.display {
162 display_implied_bounds.clone_from(&display.implied_bounds);
163 let use_as_display = use_as_display(display.has_bonus_display);
164 let pat = fields_pat(&input.fields);
165 Some(quote! {
166 #use_as_display
167 #[allow(unused_variables, deprecated)]
168 let Self #pat = self;
169 #display
170 })
171 } else {
172 None
173 };
174 let display_impl = display_body.map(|body| {
175 let mut display_inferred_bounds = InferredBounds::new();
176 for (field, bound) in display_implied_bounds {
177 let field = &input.fields[field];
178 if field.contains_generic {
179 display_inferred_bounds.insert(field.ty, bound);
180 }
181 }
182 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
183 quote! {
184 #[allow(unused_qualifications)]
185 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
186 #[allow(clippy::used_underscore_binding)]
187 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
188 #body
189 }
190 }
191 }
192 });
193
194 let from_impl = input.from_field().map(|from_field| {
195 let backtrace_field = input.distinct_backtrace_field();
196 let from = unoptional_type(from_field.ty);
197 let body = from_initializer(from_field, backtrace_field);
198 quote! {
199 #[allow(unused_qualifications)]
200 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
201 #[allow(deprecated)]
202 fn from(source: #from) -> Self {
203 #ty #body
204 }
205 }
206 }
207 });
208
209 if input.generics.type_params().next().is_some() {
210 let self_token = <Token![Self]>::default();
211 error_inferred_bounds.insert(self_token, Trait::Debug);
212 error_inferred_bounds.insert(self_token, Trait::Display);
213 }
214 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
215
216 quote! {
217 #[allow(unused_qualifications)]
218 impl #impl_generics thiserror::StdError for #ty #ty_generics #error_where_clause {
219 #source_method
220 #provide_method
221 }
222 #display_impl
223 #from_impl
224 }
225}
226
227fn impl_enum(input: Enum) -> TokenStream {
228 let ty = &input.ident;
229 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
230 let mut error_inferred_bounds = InferredBounds::new();
231
232 let source_method = if input.has_source() {
233 let arms = input.variants.iter().map(|variant| {
234 let ident = &variant.ident;
235 if let Some(transparent_attr) = &variant.attrs.transparent {
236 let only_field = &variant.fields[0];
237 if only_field.contains_generic {
238 error_inferred_bounds.insert(only_field.ty, quote!(thiserror::StdError));
239 }
240 let member = &only_field.member;
241 let source = quote_spanned! {transparent_attr.span=>
242 thiserror::StdError::source(transparent.as_dyn_error())
243 };
244 quote! {
245 #ty::#ident {#member: transparent} => #source,
246 }
247 } else if let Some(source_field) = variant.source_field() {
248 let source = &source_field.member;
249 if source_field.contains_generic {
250 let ty = unoptional_type(source_field.ty);
251 error_inferred_bounds.insert(ty, quote!(thiserror::StdError + 'static));
252 }
253 let asref = if type_is_option(source_field.ty) {
254 Some(quote_spanned!(source.member_span()=> .as_ref()?))
255 } else {
256 None
257 };
258 let varsource = quote!(source);
259 let dyn_error = quote_spanned! {source_field.source_span()=>
260 #varsource #asref.as_dyn_error()
261 };
262 quote! {
263 #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
264 }
265 } else {
266 quote! {
267 #ty::#ident {..} => ::core::option::Option::None,
268 }
269 }
270 });
271 Some(quote! {
272 fn source(&self) -> ::core::option::Option<&(dyn thiserror::StdError + 'static)> {
273 use thiserror::__private::AsDynError as _;
274 #[allow(deprecated)]
275 match self {
276 #(#arms)*
277 }
278 }
279 })
280 } else {
281 None
282 };
283
284 #[cfg(feature = "std")]
285 let provide_method = if input.has_backtrace() {
286 let request = quote!(request);
287 let arms = input.variants.iter().map(|variant| {
288 let ident = &variant.ident;
289 match (variant.backtrace_field(), variant.source_field()) {
290 (Some(backtrace_field), Some(source_field))
291 if backtrace_field.attrs.backtrace.is_none() =>
292 {
293 let backtrace = &backtrace_field.member;
294 let source = &source_field.member;
295 let varsource = quote!(source);
296 let source_provide = if type_is_option(source_field.ty) {
297 quote_spanned! {source.member_span()=>
298 if let ::core::option::Option::Some(source) = #varsource {
299 source.thiserror_provide(#request);
300 }
301 }
302 } else {
303 quote_spanned! {source.member_span()=>
304 #varsource.thiserror_provide(#request);
305 }
306 };
307 let self_provide = if type_is_option(backtrace_field.ty) {
308 quote! {
309 if let ::core::option::Option::Some(backtrace) = backtrace {
310 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
311 }
312 }
313 } else {
314 quote! {
315 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
316 }
317 };
318 quote! {
319 #ty::#ident {
320 #backtrace: backtrace,
321 #source: #varsource,
322 ..
323 } => {
324 use thiserror::__private::ThiserrorProvide as _;
325 #source_provide
326 #self_provide
327 }
328 }
329 }
330 (Some(backtrace_field), Some(source_field))
331 if backtrace_field.member == source_field.member =>
332 {
333 let backtrace = &backtrace_field.member;
334 let varsource = quote!(source);
335 let source_provide = if type_is_option(source_field.ty) {
336 quote_spanned! {backtrace.member_span()=>
337 if let ::core::option::Option::Some(source) = #varsource {
338 source.thiserror_provide(#request);
339 }
340 }
341 } else {
342 quote_spanned! {backtrace.member_span()=>
343 #varsource.thiserror_provide(#request);
344 }
345 };
346 quote! {
347 #ty::#ident {#backtrace: #varsource, ..} => {
348 use thiserror::__private::ThiserrorProvide as _;
349 #source_provide
350 }
351 }
352 }
353 (Some(backtrace_field), _) => {
354 let backtrace = &backtrace_field.member;
355 let body = if type_is_option(backtrace_field.ty) {
356 quote! {
357 if let ::core::option::Option::Some(backtrace) = backtrace {
358 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
359 }
360 }
361 } else {
362 quote! {
363 #request.provide_ref::<std::backtrace::Backtrace>(backtrace);
364 }
365 };
366 quote! {
367 #ty::#ident {#backtrace: backtrace, ..} => {
368 #body
369 }
370 }
371 }
372 (None, _) => quote! {
373 #ty::#ident {..} => {}
374 },
375 }
376 });
377 Some(quote! {
378 fn provide<'_request>(&'_request self, #request: &mut std::error::Request<'_request>) {
379 #[allow(deprecated)]
380 match self {
381 #(#arms)*
382 }
383 }
384 })
385 } else {
386 None
387 };
388
389 #[cfg(not(feature = "std"))]
390 let provide_method: Option<TokenStream> = None;
391
392 let display_impl = if input.has_display() {
393 let mut display_inferred_bounds = InferredBounds::new();
394 let has_bonus_display = input.variants.iter().any(|v| {
395 v.attrs
396 .display
397 .as_ref()
398 .map_or(false, |display| display.has_bonus_display)
399 });
400 let use_as_display = use_as_display(has_bonus_display);
401 let void_deref = if input.variants.is_empty() {
402 Some(quote!(*))
403 } else {
404 None
405 };
406 let arms = input.variants.iter().map(|variant| {
407 let mut display_implied_bounds = Set::new();
408 let display = match &variant.attrs.display {
409 Some(display) => {
410 display_implied_bounds.clone_from(&display.implied_bounds);
411 display.to_token_stream()
412 }
413 None => {
414 let only_field = match &variant.fields[0].member {
415 Member::Named(ident) => ident.clone(),
416 Member::Unnamed(index) => format_ident!("_{}", index),
417 };
418 display_implied_bounds.insert((0, Trait::Display));
419 quote!(::core::fmt::Display::fmt(#only_field, __formatter))
420 }
421 };
422 for (field, bound) in display_implied_bounds {
423 let field = &variant.fields[field];
424 if field.contains_generic {
425 display_inferred_bounds.insert(field.ty, bound);
426 }
427 }
428 let ident = &variant.ident;
429 let pat = fields_pat(&variant.fields);
430 quote! {
431 #ty::#ident #pat => #display
432 }
433 });
434 let arms = arms.collect::<Vec<_>>();
435 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
436 Some(quote! {
437 #[allow(unused_qualifications)]
438 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
439 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
440 #use_as_display
441 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
442 match #void_deref self {
443 #(#arms,)*
444 }
445 }
446 }
447 })
448 } else {
449 None
450 };
451
452 let from_impls = input.variants.iter().filter_map(|variant| {
453 let from_field = variant.from_field()?;
454 let backtrace_field = variant.distinct_backtrace_field();
455 let variant = &variant.ident;
456 let from = unoptional_type(from_field.ty);
457 let body = from_initializer(from_field, backtrace_field);
458 Some(quote! {
459 #[allow(unused_qualifications)]
460 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
461 #[allow(deprecated)]
462 fn from(source: #from) -> Self {
463 #ty::#variant #body
464 }
465 }
466 })
467 });
468
469 if input.generics.type_params().next().is_some() {
470 let self_token = <Token![Self]>::default();
471 error_inferred_bounds.insert(self_token, Trait::Debug);
472 error_inferred_bounds.insert(self_token, Trait::Display);
473 }
474 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
475
476 quote! {
477 #[allow(unused_qualifications)]
478 impl #impl_generics thiserror::StdError for #ty #ty_generics #error_where_clause {
479 #source_method
480 #provide_method
481 }
482 #display_impl
483 #(#from_impls)*
484 }
485}
486
487fn fields_pat(fields: &[Field]) -> TokenStream {
488 let mut members = fields.iter().map(|field| &field.member).peekable();
489 match members.peek() {
490 Some(Member::Named(_)) => quote!({ #(#members),* }),
491 Some(Member::Unnamed(_)) => {
492 let vars = members.map(|member| match member {
493 Member::Unnamed(member) => format_ident!("_{}", member),
494 Member::Named(_) => unreachable!(),
495 });
496 quote!((#(#vars),*))
497 }
498 None => quote!({}),
499 }
500}
501
502fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
503 if needs_as_display {
504 Some(quote! {
505 use thiserror::__private::AsDisplay as _;
506 })
507 } else {
508 None
509 }
510}
511
512fn from_initializer(from_field: &Field, backtrace_field: Option<&Field>) -> TokenStream {
513 let from_member = &from_field.member;
514 let some_source = if type_is_option(from_field.ty) {
515 quote!(::core::option::Option::Some(source))
516 } else {
517 quote!(source)
518 };
519 let backtrace = backtrace_field.map(|backtrace_field| {
520 let backtrace_member = &backtrace_field.member;
521 if type_is_option(backtrace_field.ty) {
522 quote! {
523 #backtrace_member: ::core::option::Option::Some(std::backtrace::Backtrace::capture()),
524 }
525 } else {
526 quote! {
527 #backtrace_member: ::core::convert::From::from(std::backtrace::Backtrace::capture()),
528 }
529 }
530 });
531 quote!({
532 #from_member: #some_source,
533 #backtrace
534 })
535}
536
537fn type_is_option(ty: &Type) -> bool {
538 type_parameter_of_option(ty).is_some()
539}
540
541fn unoptional_type(ty: &Type) -> TokenStream {
542 let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
543 quote!(#unoptional)
544}
545
546fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
547 let path = match ty {
548 Type::Path(ty) => &ty.path,
549 _ => return None,
550 };
551
552 let last = path.segments.last().unwrap();
553 if last.ident != "Option" {
554 return None;
555 }
556
557 let bracketed = match &last.arguments {
558 PathArguments::AngleBracketed(bracketed) => bracketed,
559 _ => return None,
560 };
561
562 if bracketed.args.len() != 1 {
563 return None;
564 }
565
566 match &bracketed.args[0] {
567 GenericArgument::Type(arg) => Some(arg),
568 _ => None,
569 }
570}