1#![doc = include_str!("../README.md")]
9#![warn(
10 clippy::default_trait_access,
11 clippy::dbg_macro,
12 clippy::print_stdout,
13 clippy::unimplemented,
14 clippy::use_self,
15 missing_copy_implementations,
16 missing_docs,
17 non_snake_case,
18 non_upper_case_globals,
19 rust_2018_idioms,
20 unreachable_pub
21)]
22
23use heck::ToSnakeCase;
24use proc_macro2::{Ident, Span, TokenStream};
25use quote::quote;
26use syn::{parse_macro_input, DataEnum, DeriveInput, Visibility};
27
28fn unit_fields_return(
30 variant_name: &syn::Ident,
31 err_name: &syn::Ident,
32 ty_generics: &syn::TypeGenerics<'_>,
33 (function_name_is, doc_is): (&Ident, &str),
34 (function_name_ref, doc_ref): (&Ident, &str),
35 (function_name_val, doc_val): (&Ident, &str),
36) -> TokenStream {
37 quote!(
38 #[doc = #doc_is]
39 #[inline]
40 pub fn #function_name_is(&self) -> bool {
41 matches!(self, Self::#variant_name)
42 }
43
44 #[doc = #doc_ref ]
45 #[inline]
46 pub fn #function_name_ref(&self) -> ::core::result::Result<&(), #err_name #ty_generics> {
47 match self {
48 Self::#variant_name => {
49 ::core::result::Result::Ok(&())
50 }
51 _ => {
52 ::core::result::Result::Err(#err_name::new(
53 stringify!(#variant_name),
54 self.variant_name(),
55 ::core::option::Option::None,
56 ))
57 }
58 }
59 }
60
61 #[doc = #doc_val ]
62 #[inline]
63 pub fn #function_name_val(self) -> ::core::result::Result<(), #err_name #ty_generics> {
64 match self {
65 Self::#variant_name => {
66 ::core::result::Result::Ok(())
67 }
68 _ => {
69 ::core::result::Result::Err(#err_name::new(
70 stringify!(#variant_name),
71 self.variant_name(),
72 ::core::option::Option::Some(self),
73 ))
74 }
75 }
76 }
77 )
78}
79
80#[allow(clippy::too_many_arguments)]
82fn unnamed_fields_return(
83 variant_name: &syn::Ident,
84 err_name: &syn::Ident,
85 ty_generics: &syn::TypeGenerics<'_>,
86 (function_name_is, doc_is): (&Ident, &str),
87 (function_name_mut_ref, doc_mut_ref): (&Ident, &str),
88 (function_name_ref, doc_ref): (&Ident, &str),
89 (function_name_val, doc_val): (&Ident, &str),
90 fields: &syn::FieldsUnnamed,
91) -> TokenStream {
92 let (returns_mut_ref, returns_ref, returns_val, matches) = match fields.unnamed.len() {
93 1 => {
94 let field = fields.unnamed.first().expect("no fields on type");
95
96 let returns = &field.ty;
97 let returns_mut_ref = quote!(&mut #returns);
98 let returns_ref = quote!(&#returns);
99 let returns_val = quote!(#returns);
100 let matches = quote!(inner);
101
102 (returns_mut_ref, returns_ref, returns_val, matches)
103 }
104 0 => (quote!(()), quote!(()), quote!(()), quote!()),
105 _ => {
106 let mut returns_mut_ref = TokenStream::new();
107 let mut returns_ref = TokenStream::new();
108 let mut returns_val = TokenStream::new();
109 let mut matches = TokenStream::new();
110
111 for (i, field) in fields.unnamed.iter().enumerate() {
112 let rt = &field.ty;
113 let match_name = Ident::new(&format!("match_{}", i), Span::call_site());
114 returns_mut_ref.extend(quote!(&mut #rt,));
115 returns_ref.extend(quote!(&#rt,));
116 returns_val.extend(quote!(#rt,));
117 matches.extend(quote!(#match_name,));
118 }
119
120 (
121 quote!((#returns_mut_ref)),
122 quote!((#returns_ref)),
123 quote!((#returns_val)),
124 quote!(#matches),
125 )
126 }
127 };
128
129 quote!(
130 #[doc = #doc_is ]
131 #[inline]
132 #[allow(unused_variables)]
133 pub fn #function_name_is(&self) -> bool {
134 matches!(self, Self::#variant_name(#matches))
135 }
136
137 #[doc = #doc_mut_ref ]
138 #[inline]
139 pub fn #function_name_mut_ref(&mut self) -> ::core::result::Result<#returns_mut_ref, #err_name #ty_generics> {
140 match self {
141 Self::#variant_name(#matches) => {
142 ::core::result::Result::Ok((#matches))
143 }
144 _ => {
145 ::core::result::Result::Err(#err_name::new(
146 stringify!(#variant_name),
147 self.variant_name(),
148 ::core::option::Option::None,
149 ))
150 }
151 }
152 }
153
154 #[doc = #doc_ref ]
155 #[inline]
156 pub fn #function_name_ref(&self) -> ::core::result::Result<#returns_ref, #err_name #ty_generics> {
157 match self {
158 Self::#variant_name(#matches) => {
159 ::core::result::Result::Ok((#matches))
160 }
161 _ => {
162 ::core::result::Result::Err(#err_name::new(
163 stringify!(#variant_name),
164 self.variant_name(),
165 ::core::option::Option::None,
166 ))
167 }
168 }
169 }
170
171 #[doc = #doc_val ]
172 #[inline]
173 pub fn #function_name_val(self) -> ::core::result::Result<#returns_val, #err_name #ty_generics> {
174 match self {
175 Self::#variant_name(#matches) => {
176 ::core::result::Result::Ok((#matches))
177 }
178 _ => {
179 ::core::result::Result::Err(#err_name::new(
180 stringify!(#variant_name),
181 self.variant_name(),
182 ::core::option::Option::Some(self),
183 ))
184 }
185 }
186 }
187 )
188}
189
190#[allow(clippy::too_many_arguments)]
192fn named_fields_return(
193 variant_name: &syn::Ident,
194 err_name: &syn::Ident,
195 ty_generics: &syn::TypeGenerics<'_>,
196 (function_name_is, doc_is): (&Ident, &str),
197 (function_name_mut_ref, doc_mut_ref): (&Ident, &str),
198 (function_name_ref, doc_ref): (&Ident, &str),
199 (function_name_val, doc_val): (&Ident, &str),
200 fields: &syn::FieldsNamed,
201) -> TokenStream {
202 let (returns_mut_ref, returns_ref, returns_val, matches) = match fields.named.len() {
203 1 => {
204 let field = fields.named.first().expect("no fields on type");
205 let match_name = field.ident.as_ref().expect("expected a named field");
206
207 let returns = &field.ty;
208 let returns_mut_ref = quote!(&mut #returns);
209 let returns_ref = quote!(&#returns);
210 let returns_val = quote!(#returns);
211 let matches = quote!(#match_name);
212
213 (returns_mut_ref, returns_ref, returns_val, matches)
214 }
215 0 => (quote!(()), quote!(()), quote!(()), quote!(())),
216 _ => {
217 let mut returns_mut_ref = TokenStream::new();
218 let mut returns_ref = TokenStream::new();
219 let mut returns_val = TokenStream::new();
220 let mut matches = TokenStream::new();
221
222 for field in fields.named.iter() {
223 let rt = &field.ty;
224 let match_name = field.ident.as_ref().expect("expected a named field");
225
226 returns_mut_ref.extend(quote!(&mut #rt,));
227 returns_ref.extend(quote!(&#rt,));
228 returns_val.extend(quote!(#rt,));
229 matches.extend(quote!(#match_name,));
230 }
231
232 (
233 quote!((#returns_mut_ref)),
234 quote!((#returns_ref)),
235 quote!((#returns_val)),
236 quote!(#matches),
237 )
238 }
239 };
240
241 quote!(
242 #[doc = #doc_is ]
243 #[inline]
244 #[allow(unused_variables)]
245 pub fn #function_name_is(&self) -> bool {
246 matches!(self, Self::#variant_name{ #matches })
247 }
248
249 #[doc = #doc_mut_ref ]
250 #[inline]
251 pub fn #function_name_mut_ref(&mut self) -> ::core::result::Result<#returns_mut_ref, #err_name #ty_generics> {
252 match self {
253 Self::#variant_name{ #matches } => {
254 ::core::result::Result::Ok((#matches))
255 }
256 _ => {
257 ::core::result::Result::Err(#err_name::new(
258 stringify!(#variant_name),
259 self.variant_name(),
260 ::core::option::Option::None,
261 ))
262 }
263 }
264 }
265
266 #[doc = #doc_ref ]
267 #[inline]
268 pub fn #function_name_ref(&self) -> ::core::result::Result<#returns_ref, #err_name #ty_generics> {
269 match self {
270 Self::#variant_name{ #matches } => {
271 ::core::result::Result::Ok((#matches))
272 }
273 _ => {
274 ::core::result::Result::Err(#err_name::new(
275 stringify!(#variant_name),
276 self.variant_name(),
277 ::core::option::Option::None,
278 ))
279 }
280 }
281 }
282
283 #[doc = #doc_val ]
284 #[inline]
285 pub fn #function_name_val(self) -> ::core::result::Result<#returns_val, #err_name #ty_generics> {
286 match self {
287 Self::#variant_name{ #matches } => {
288 ::core::result::Result::Ok((#matches))
289 }
290 _ => {
291 ::core::result::Result::Err(#err_name::new(
292 stringify!(#variant_name),
293 self.variant_name(),
294 ::core::option::Option::Some(self),
295 ))
296 }
297 }
298 }
299 )
300}
301
302fn impl_all_as_fns(
303 name: &Ident,
304 err_name: &Ident,
305 generics: &syn::Generics,
306 data: &DataEnum,
307) -> TokenStream {
308 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
309
310 let mut stream = TokenStream::new();
311 let mut variant_names = TokenStream::new();
312 for variant_data in &data.variants {
313 let variant_name = &variant_data.ident;
314 let function_name_ref = Ident::new(
315 &format!("try_as_{}", variant_name).to_snake_case(),
316 Span::call_site(),
317 );
318 let doc_ref = format!(
319 "Returns references to the inner fields if this is a `{}::{}`, otherwise an `{}`",
320 name, variant_name, &err_name,
321 );
322 let function_name_mut_ref = Ident::new(
323 &format!("try_as_{}_mut", variant_name).to_snake_case(),
324 Span::call_site(),
325 );
326 let doc_mut_ref = format!(
327 "Returns mutable references to the inner fields if this is a `{}::{}`, otherwise an `{}`",
328 name,
329 variant_name,
330 &err_name,
331 );
332
333 let function_name_val = Ident::new(
334 &format!("try_into_{}", variant_name).to_snake_case(),
335 Span::call_site(),
336 );
337 let doc_val = format!(
338 "Returns the inner fields if this is a `{}::{}`, otherwise returns back the enum in the `Err` case of the result",
339 name,
340 variant_name,
341 );
342
343 let function_name_is = Ident::new(
344 &format!("is_{}", variant_name).to_snake_case(),
345 Span::call_site(),
346 );
347 let doc_is = format!(
348 "Returns true if this is a `{}::{}`, otherwise false",
349 name, variant_name,
350 );
351
352 let tokens = match &variant_data.fields {
353 syn::Fields::Unit => unit_fields_return(
354 variant_name,
355 err_name,
356 &ty_generics,
357 (&function_name_is, &doc_is),
358 (&function_name_ref, &doc_ref),
359 (&function_name_val, &doc_val),
360 ),
361 syn::Fields::Unnamed(unnamed) => unnamed_fields_return(
362 variant_name,
363 err_name,
364 &ty_generics,
365 (&function_name_is, &doc_is),
366 (&function_name_mut_ref, &doc_mut_ref),
367 (&function_name_ref, &doc_ref),
368 (&function_name_val, &doc_val),
369 unnamed,
370 ),
371 syn::Fields::Named(named) => named_fields_return(
372 variant_name,
373 err_name,
374 &ty_generics,
375 (&function_name_is, &doc_is),
376 (&function_name_mut_ref, &doc_mut_ref),
377 (&function_name_ref, &doc_ref),
378 (&function_name_val, &doc_val),
379 named,
380 ),
381 };
382
383 stream.extend(tokens);
384
385 let variant_name = match &variant_data.fields {
386 syn::Fields::Unit => quote!(Self::#variant_name => stringify!(#variant_name),),
387 syn::Fields::Unnamed(_) => {
388 quote!(Self::#variant_name(..) => stringify!(#variant_name),)
389 }
390 syn::Fields::Named(_) => quote!(Self::#variant_name{..} => stringify!(#variant_name),),
391 };
392
393 variant_names.extend(variant_name);
394 }
395
396 quote!(
397 impl #impl_generics #name #ty_generics #where_clause {
398 #stream
399
400 fn variant_name(&self) -> &'static str {
402 match self {
403 #variant_names
404 _ => unreachable!(),
405 }
406 }
407 }
408 )
409}
410
411fn impl_err(
412 name: &Ident,
413 err_name: &Ident,
414 vis: &Visibility,
415 generics: &syn::Generics,
416 attrs: &[syn::Attribute],
417) -> TokenStream {
418 let doc_err = format!("An error type for the `{}::try_as_*` functions", name);
419
420 let mut derives = Vec::new();
422 let mut derive_debug = false;
423 for attr in attrs {
424 if attr.path().is_ident("derive_err") {
425 attr.parse_nested_meta(|meta| {
426 if meta.path.is_ident("Debug") {
427 derive_debug = true;
428 } else {
429 derives.push(meta.path);
430 }
431
432 Ok(())
433 })
434 .expect("failed to parse derive nested meta");
435 }
436 }
437
438 let derive_err = if derives.is_empty() {
439 quote!()
440 } else {
441 quote!(#[derive(#(#derives),*)])
442 };
443
444 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
445
446 let mut err_impl = quote!(
447 #[doc = #doc_err ]
448 #derive_err
449 #vis struct #err_name #generics {
450 expected: &'static str,
451 actual: &'static str,
452 value: ::core::option::Option<#name #ty_generics>,
453 }
454
455 impl #impl_generics #err_name #ty_generics #where_clause {
456 fn new(
458 expected: &'static str,
459 actual: &'static str,
460 value: ::core::option::Option<#name #ty_generics>
461 ) -> Self {
462 Self {
463 expected,
464 actual,
465 value,
466 }
467 }
468
469 pub fn expected(&self) -> &'static str {
471 self.expected
472 }
473
474 pub fn actual(&self) -> &'static str {
476 self.actual
477 }
478
479 pub fn value(&self) -> ::core::option::Option<&#name #ty_generics> {
481 self.value.as_ref()
482 }
483
484 pub fn into_value(self) -> ::core::option::Option<#name #ty_generics> {
486 self.value
487 }
488 }
489 );
490
491 if derive_debug {
492 let impl_debug_body = {
493 let where_clause = if let Some(where_clause) = where_clause {
494 quote!(#where_clause, #name #ty_generics: ::core::fmt::Debug)
495 } else {
496 quote!(where #name #ty_generics: ::core::fmt::Debug)
497 };
498
499 quote!(
500 impl #impl_generics ::core::fmt::Debug for #err_name #ty_generics #where_clause {
501 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
502 f.debug_struct(stringify!(#err_name))
503 .field("expected", &self.expected)
504 .field("actual", &self.actual)
505 .field("value", &self.value)
506 .finish()
507 }
508 }
509 )
510 };
511
512 let impl_display_body = {
513 let display_fmt = format!("expected {name}::{{}}, but got {name}::{{}}");
514 quote!(
515 impl #impl_generics ::core::fmt::Display for #err_name #ty_generics #where_clause {
516 fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
517 write!(
518 formatter,
519 #display_fmt,
520 self.expected(),
521 self.actual(),
522 )
523 }
524 }
525 )
526 };
527
528 let impl_err_body = {
529 let where_clause = if let Some(where_clause) = where_clause {
530 quote!(#where_clause, #name #ty_generics: ::core::fmt::Debug)
531 } else {
532 quote!(where #name #ty_generics: ::core::fmt::Debug)
533 };
534
535 quote!(
536 impl #impl_generics ::std::error::Error for #err_name #ty_generics #where_clause {}
537 )
538 };
539
540 err_impl.extend(quote!(
541 #impl_debug_body
542
543 #impl_display_body
544
545 #impl_err_body
546 ))
547 }
548
549 err_impl
550}
551
552#[proc_macro_derive(EnumTryAsInner, attributes(derive_err))]
554pub fn enum_try_as_inner(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
555 let ast: DeriveInput = parse_macro_input!(input as DeriveInput);
557
558 let name = &ast.ident;
559 let err_name = Ident::new(&format!("{}Error", name), Span::call_site());
560 let generics = &ast.generics;
561 let vis = &ast.vis;
562
563 let enum_data = if let syn::Data::Enum(data) = &ast.data {
564 data
565 } else {
566 panic!("{} is not an enum", name);
567 };
568
569 let mut expanded = TokenStream::new();
570
571 let fns = impl_all_as_fns(name, &err_name, generics, enum_data);
573
574 let err = impl_err(name, &err_name, vis, generics, &ast.attrs);
576
577 expanded.extend(fns);
578 expanded.extend(err);
579
580 proc_macro::TokenStream::from(expanded)
581}