exhaust_macros/
lib.rs

1use std::iter;
2
3use itertools::izip;
4use proc_macro::TokenStream;
5use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
6use quote::{quote, ToTokens as _};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{parse_macro_input, parse_quote, DeriveInput};
10
11mod common;
12use common::ExhaustContext;
13
14mod fields;
15use fields::{exhaust_iter_fields, ExhaustFields};
16
17use crate::common::ConstructorSyntax;
18
19// Note: documentation is on the reexport so that it can have working links.
20#[proc_macro_derive(Exhaust)]
21pub fn derive_exhaust(input: TokenStream) -> TokenStream {
22    let input = parse_macro_input!(input as DeriveInput);
23    derive_impl(input)
24        .unwrap_or_else(|err| err.to_compile_error())
25        .into()
26}
27
28/// Generate an impl of Exhaust for a built-in tuple type.
29/// This macro is only useful within the `exhaust` crate.
30#[proc_macro]
31#[doc(hidden)]
32pub fn impl_exhaust_for_tuples(input: TokenStream) -> TokenStream {
33    let input = parse_macro_input!(input as syn::LitInt);
34    tuple_impls_up_to(input.base10_parse().unwrap())
35        .unwrap_or_else(|err| err.to_compile_error())
36        .into()
37}
38
39fn derive_impl(input: DeriveInput) -> Result<TokenStream2, syn::Error> {
40    let DeriveInput {
41        ident: item_type_name,
42        attrs: _,
43        vis,
44        generics,
45        data,
46    } = input;
47
48    let item_type_name_str = &item_type_name.to_string();
49    let factory_type_name = common::generated_type_name(item_type_name_str, "Factory");
50    let iterator_type_name = common::generated_type_name(item_type_name_str, "Iter");
51
52    let ctx = ExhaustContext {
53        vis,
54        generics,
55        iterator_type_name,
56        item_type: ConstructorSyntax::Braced(item_type_name.to_token_stream()),
57        factory_type: ConstructorSyntax::Braced(factory_type_name.to_token_stream()),
58        exhaust_crate_path: syn::parse_quote! { ::exhaust },
59    };
60    let ExhaustContext {
61        iterator_type_name,
62        exhaust_crate_path,
63        ..
64    } = &ctx;
65
66    let (iterator_and_factory_decl, from_factory_body) = match data {
67        syn::Data::Struct(s) => exhaust_iter_struct(s, &ctx),
68        syn::Data::Enum(e) => exhaust_iter_enum(e, &ctx),
69        syn::Data::Union(syn::DataUnion { union_token, .. }) => Err(syn::Error::new(
70            union_token.span,
71            "derive(Exhaust) does not support unions",
72        )),
73    }?;
74
75    let (impl_generics, ty_generics, augmented_where_predicates) =
76        ctx.generics_with_bounds(syn::parse_quote! {});
77
78    Ok(quote! {
79        // rust-analyzer (but not rustc) sometimes produces lints on macro generated code it
80        // shouldn't. We don't expect to actually hit this case normally, but in general,
81        // we don't want to *ever* bother our users with unfixable warnings about weird names.
82        #[allow(nonstandard_style)]
83        // This anonymous constant allows us to make all our generated types be public-in-private,
84        // without altering the meaning of any paths they use as a nested module would.
85        const _: () = {
86            impl #impl_generics #exhaust_crate_path::Exhaust for #item_type_name #ty_generics
87            where #augmented_where_predicates {
88                type Iter = #iterator_type_name #ty_generics;
89                type Factory = #factory_type_name #ty_generics;
90                fn exhaust_factories() -> Self::Iter {
91                    ::core::default::Default::default()
92                }
93                fn from_factory(factory: Self::Factory) -> Self {
94                    #from_factory_body
95                }
96            }
97
98            #iterator_and_factory_decl
99        };
100    })
101}
102
103fn tuple_impls_up_to(size: u64) -> Result<TokenStream2, syn::Error> {
104    (2..=size).map(tuple_impl).collect()
105}
106
107/// Generate an impl of Exhaust for a built-in tuple type.
108///
109/// This is almost but not quite identical to [`exhaust_iter_struct`], due to the syntax
110/// of tuples and due to it being used from the same crate (so that access is via
111/// crate::Exhaust instead of ::exhaust::Exhaust).
112fn tuple_impl(size: u64) -> Result<TokenStream2, syn::Error> {
113    if size < 2 {
114        return Err(syn::Error::new(
115            Span::call_site(),
116            "tuple type of size less than 2 not supported",
117        ));
118    }
119
120    let value_type_vars: Vec<Ident> = (0..size)
121        .map(|i| Ident::new(&format!("T{i}"), Span::mixed_site()))
122        .collect();
123    let factory_value_vars: Vec<Ident> = (0..size)
124        .map(|i| Ident::new(&format!("factory{i}"), Span::mixed_site()))
125        .collect();
126    let synthetic_fields: syn::Fields = syn::Fields::Unnamed(syn::FieldsUnnamed {
127        paren_token: syn::token::Paren(Span::mixed_site()),
128        unnamed: value_type_vars
129            .iter()
130            .map(|type_var| syn::Field {
131                attrs: vec![],
132                vis: parse_quote! { pub },
133                mutability: syn::FieldMutability::None,
134                ident: None,
135                colon_token: None,
136                ty: syn::Type::Verbatim(type_var.to_token_stream()),
137            })
138            .collect(),
139    });
140
141    // Synthesize a good-enough context to use the derive tools.
142    let ctx: ExhaustContext = ExhaustContext {
143        vis: parse_quote! { pub },
144        generics: syn::Generics {
145            lt_token: None,
146            params: value_type_vars
147                .iter()
148                .map(|var| {
149                    syn::GenericParam::Type(syn::TypeParam {
150                        attrs: vec![],
151                        ident: var.clone(),
152                        colon_token: None,
153                        bounds: Punctuated::default(),
154                        eq_token: None,
155                        default: None,
156                    })
157                })
158                .collect(),
159            gt_token: None,
160            where_clause: None,
161        },
162        item_type: ConstructorSyntax::Tuple,
163        factory_type: ConstructorSyntax::Tuple,
164        iterator_type_name: common::generated_type_name("Tuple", "Iter"),
165        exhaust_crate_path: parse_quote! { crate },
166    };
167
168    let iterator_type_name = &ctx.iterator_type_name;
169
170    // Generate the field-exhausting iteration logic
171    let ExhaustFields {
172        state_field_decls,
173        factory_field_decls: _, // unused because we use tuples instead
174        initializers,
175        cloners,
176        field_pats,
177        advance,
178    } = exhaust_iter_fields(
179        &ctx,
180        &synthetic_fields,
181        &quote! {},
182        &ConstructorSyntax::Tuple,
183    );
184
185    let iterator_impls = ctx.impl_iterator_and_factory_traits(
186        quote! {
187            match self {
188                Self { #field_pats } => {
189                    #advance
190                }
191            }
192        },
193        quote! { Self { #initializers } },
194        quote! {
195            let Self { #field_pats } = self;
196            Self { #cloners }
197        },
198    );
199
200    let iterator_doc = ctx.iterator_doc();
201
202    Ok(quote! {
203        const _: () = {
204            impl<#( #value_type_vars , )*> crate::Exhaust for ( #( #value_type_vars , )* )
205            where #( #value_type_vars : crate::Exhaust, )*
206            {
207                type Iter = #iterator_type_name <#( #value_type_vars , )*>;
208                type Factory = (#(
209                    <#value_type_vars as crate::Exhaust>::Factory,
210                )*);
211                fn exhaust_factories() -> Self::Iter {
212                    ::core::default::Default::default()
213                }
214                fn from_factory(factory: Self::Factory) -> Self {
215                    let (#( #factory_value_vars , )*) = factory;
216                    (#(
217                        <#value_type_vars as crate::Exhaust>::from_factory(#factory_value_vars),
218                    )*)
219                }
220            }
221
222            #[doc = #iterator_doc]
223            pub struct #iterator_type_name <#( #value_type_vars , )*>
224            where #( #value_type_vars : crate::Exhaust, )*
225            {
226                #state_field_decls
227            }
228
229            #iterator_impls
230        };
231    })
232}
233
234fn exhaust_iter_struct(
235    s: syn::DataStruct,
236    ctx: &ExhaustContext,
237) -> Result<(TokenStream2, TokenStream2), syn::Error> {
238    let vis = &ctx.vis;
239    let exhaust_crate_path = &ctx.exhaust_crate_path;
240    let (impl_or_decl_generics, ty_generics, augmented_where_predicates) =
241        ctx.generics_with_bounds(syn::parse_quote! {});
242    let iterator_type_name = &ctx.iterator_type_name;
243    let factory_type_name = &ctx.factory_type.path()?;
244    let factory_type = &ctx.factory_type.parameterized(&ctx.generics);
245
246    let (
247        factory_state_struct_decl,
248        factory_state_struct_type,
249        factory_state_struct_clone_expr,
250        factory_to_self_transform,
251        ExhaustFields {
252            state_field_decls,
253            factory_field_decls: _,
254            initializers,
255            cloners,
256            field_pats,
257            advance,
258        },
259    ) = if s.fields.is_empty() {
260        // If there are no fields, then
261        // * We don't need to generate a `FactoryState` struct, and can just use ().
262        // * The iterator needs a special `done` field to tell whether it has produced its 1 item.
263
264        let fields = ExhaustFields {
265            state_field_decls: quote! { done: bool, },
266            factory_field_decls: syn::Fields::Unit,
267            initializers: quote! { done: false, },
268            cloners: quote! { done: *done, },
269            field_pats: quote! { done, },
270            advance: quote! {
271                if *done {
272                    ::core::option::Option::None
273                } else {
274                    *done = true;
275                    ::core::option::Option::Some(#factory_type_name(()))
276                }
277            },
278        };
279
280        let output_type = ctx.item_type.path()?;
281        (
282            quote! {},
283            quote! { () },
284            quote! { () },
285            quote! { () => #output_type },
286            fields,
287        )
288    } else {
289        let factory_state_struct_type = ctx.generated_type_name("FactoryState")?;
290        let factory_state_ctor =
291            ConstructorSyntax::Braced(factory_state_struct_type.to_token_stream());
292
293        let fields: ExhaustFields = exhaust_iter_fields(
294            ctx,
295            &s.fields,
296            ctx.factory_type.path()?,
297            &factory_state_ctor,
298        );
299        let factory_field_decls = &fields.factory_field_decls;
300
301        // Generate factory state struct with the same syntax type as the original
302        // (for elegance, not because it matters functionally).
303        // This struct is always wrapped in a newtype struct to hide implementation details reliably.
304        let factory_state_struct_decl = match &factory_field_decls {
305            syn::Fields::Unit | syn::Fields::Unnamed(_) => quote! {
306                #vis struct #factory_state_struct_type #impl_or_decl_generics #factory_field_decls
307                where #augmented_where_predicates;
308
309            },
310
311            syn::Fields::Named(_) => quote! {
312                #vis struct #factory_state_struct_type #impl_or_decl_generics
313                where #augmented_where_predicates
314                #factory_field_decls
315            },
316        };
317
318        let factory_state_struct_clone_arm = common::clone_like_struct_conversion(
319            &s.fields,
320            factory_state_ctor.path()?,
321            factory_state_ctor.path()?,
322            &quote! { ref },
323            |expr| quote! { ::core::clone::Clone::clone(#expr) },
324        );
325
326        let factory_to_self_transform = common::clone_like_struct_conversion(
327            &s.fields,
328            factory_state_ctor.path()?,
329            ctx.item_type.path()?,
330            &quote! {},
331            |expr| quote! { #exhaust_crate_path::Exhaust::from_factory(#expr) },
332        );
333
334        (
335            factory_state_struct_decl,
336            factory_state_struct_type.to_token_stream(),
337            // TODO: replace this 1-arm match with a let?
338            quote! { match self.0 { #factory_state_struct_clone_arm } },
339            factory_to_self_transform,
340            fields,
341        )
342    };
343
344    let impls = ctx.impl_iterator_and_factory_traits(
345        quote! {
346            match self {
347                Self { #field_pats } => {
348                    #advance
349                }
350            }
351        },
352        quote! { Self { #initializers } },
353        quote! {
354            let Self { #field_pats } = self;
355            Self { #cloners }
356        },
357    );
358
359    Ok((
360        quote! {
361            // Struct that is exposed as the `<Self as Exhaust>::Iter` type.
362            // A wrapper struct is not needed because it always has at least one private field.
363            //
364            // Note: The iterator struct must have trait bounds because its fields, being of type
365            // `<SomeOtherTy as Exhaust>::Iter`, require that `SomeOtherTy: Exhaust`.
366            #vis struct #iterator_type_name #impl_or_decl_generics
367            where #augmented_where_predicates {
368                #state_field_decls
369            }
370
371            // Struct that is exposed as the `<Self as Exhaust>::Factory` type,
372            // wrapping the private factory_state_struct_type.
373            #vis struct #factory_type_name #impl_or_decl_generics (
374                #factory_state_struct_type #ty_generics
375            )
376            where #augmented_where_predicates;
377
378            #impls
379
380            // Declare the factory_state_struct_type (`Exhaust*FactoryState`) struct,
381            // which is a private field of the factory type.
382            // This is empty for unit structs, which use () as the state type instead.
383            #factory_state_struct_decl
384
385            // A manual impl of Clone is required to *not* have a `Clone` bound on the generics.
386            impl #impl_or_decl_generics ::core::clone::Clone for #factory_type
387            where #augmented_where_predicates {
388                fn clone(&self) -> Self {
389                    Self(#factory_state_struct_clone_expr)
390                }
391            }
392
393        },
394        quote! {
395            match factory.0 {
396                #factory_to_self_transform
397            }
398        },
399    ))
400}
401
402fn exhaust_iter_enum(
403    e: syn::DataEnum,
404    ctx: &ExhaustContext,
405) -> Result<(TokenStream2, TokenStream2), syn::Error> {
406    let vis = &ctx.vis;
407    let exhaust_crate_path = &ctx.exhaust_crate_path;
408    let iterator_type_name = &ctx.iterator_type_name;
409    let factory_outer_type_path = &ctx.factory_type.path()?;
410    let factory_type = &ctx.factory_type.parameterized(&ctx.generics);
411
412    // These enum types are both wrapped in structs,
413    // so that the user of the macro cannot depend on its implementation details.
414    let iter_state_enum_type = ctx.generated_type_name("IterState")?;
415    let factory_state_enum_type = ctx.generated_type_name("FactoryState")?.to_token_stream();
416    let factory_state_ctor = ConstructorSyntax::Braced(factory_state_enum_type.clone());
417
418    // One ident per variant of the original enum.
419    let state_enum_progress_variants: Vec<Ident> = e
420        .variants
421        .iter()
422        .map(|v| {
423            // Renaming the variant serves two purposes: less confusing error/debug text,
424            // and disambiguating from the “Done” variant.
425            Ident::new(&format!("Exhaust{}", v.ident), v.span())
426        })
427        .collect();
428
429    // TODO: ensure no name conflict, perhaps by renaming the others
430    let done_variant = Ident::new("Done", Span::mixed_site());
431
432    // All variants of our generated enum, which are equal to the original enum
433    // plus a "done" variant.
434    #[allow(clippy::type_complexity)]
435    let (
436        state_enum_variant_decls,
437        state_enum_variant_initializers,
438        state_enum_variant_cloners,
439        state_enum_field_pats,
440        state_enum_variant_advancers,
441        mut factory_variant_decls,
442    ): (
443        Vec<TokenStream2>,
444        Vec<TokenStream2>,
445        Vec<TokenStream2>,
446        Vec<TokenStream2>,
447        Vec<TokenStream2>,
448        Vec<TokenStream2>,
449    ) = itertools::multiunzip(e
450        .variants
451        .iter()
452        .zip(state_enum_progress_variants.iter())
453        .map(|(target_variant, state_ident)| {
454            let target_variant_ident = &target_variant.ident;
455            let fields::ExhaustFields {
456                state_field_decls,
457                factory_field_decls,
458                initializers: state_fields_init,
459                cloners: state_fields_clone,
460                field_pats,
461                advance,
462            } = if target_variant.fields.is_empty() {
463                // TODO: don't even construct this dummy value (needs refactoring)
464                fields::ExhaustFields {
465                    state_field_decls: quote! {},
466                    factory_field_decls: syn::Fields::Unit,
467                    initializers: quote! {},
468                    cloners: quote! {},
469                    field_pats: quote! {},
470                    advance: quote! {
471                        compile_error!("can't happen: fieldless ExhaustFields not used")
472                    },
473                }
474            } else {
475                fields::exhaust_iter_fields(
476                    ctx,
477                    &target_variant.fields,
478                    factory_outer_type_path,
479                    &factory_state_ctor.with_variant(target_variant_ident),
480                )
481            };
482
483            (
484                quote! {
485                    #state_ident {
486                        #state_field_decls
487                    }
488                },
489                quote! {
490                    #iter_state_enum_type :: #state_ident { #state_fields_init }
491                },
492                quote! {
493                    #iter_state_enum_type :: #state_ident { #field_pats } =>
494                        #iter_state_enum_type :: #state_ident { #state_fields_clone }
495                },
496                field_pats,
497                advance,
498                quote! {
499                    #target_variant_ident #factory_field_decls
500                },
501            )
502        })
503        .chain(iter::once((
504            done_variant.to_token_stream(),
505            quote! {
506                // iterator construction
507                #iter_state_enum_type :: #done_variant {}
508            },
509            quote! {
510                // clone() match arm
511                #iter_state_enum_type :: #done_variant {} => #iter_state_enum_type :: #done_variant {}
512            },
513            quote! {},
514            quote! { compile_error!("done advancer not used") },
515            quote! { compile_error!("done factory variant not used") },
516        ))));
517
518    factory_variant_decls.pop(); // no Done arm in the factory enum
519
520    let first_state_variant_initializer = &state_enum_variant_initializers[0];
521
522    // Match arms to advance the iterator.
523    let variant_next_arms = izip!(
524        e.variants.iter(),
525        state_enum_progress_variants.iter(),
526        state_enum_field_pats.iter(),
527        state_enum_variant_initializers.iter().skip(1),
528        state_enum_variant_advancers.iter(),
529    )
530    .map(
531        |(target_enum_variant, state_ident, pats, next_state_initializer, field_advancer)| {
532            let target_variant_ident = &target_enum_variant.ident;
533            let advancer = if target_enum_variant.fields.is_empty() {
534                let factory_state_expr = factory_state_ctor
535                    .with_variant(target_variant_ident)
536                    .value_expr([].iter(), [].iter());
537                quote! {
538                    self.0 = #next_state_initializer;
539                    ::core::option::Option::Some(#factory_outer_type_path(#factory_state_expr))
540                }
541            } else {
542                quote! {
543                    // TODO: merge this logic into field_advancer itself so we’re not creating an
544                    // `Option` and then immediately matching it again.
545                    let maybe_this_variant = #field_advancer;
546                    match maybe_this_variant {
547                        ::core::option::Option::Some(_) => maybe_this_variant,
548                        ::core::option::Option::None => {
549                            self.0 = #next_state_initializer;
550                            continue 'variants
551                        }
552                    }
553                }
554            };
555            quote! {
556                #iter_state_enum_type::#state_ident { #pats } => {
557                    #advancer
558                }
559            }
560        },
561    );
562
563    let factory_enum_variant_clone_arms: Vec<TokenStream2> = common::clone_like_match_arms(
564        &e.variants,
565        &factory_state_enum_type,
566        &factory_state_enum_type,
567        &quote! { ref },
568        |expr| quote! { ::core::clone::Clone::clone(#expr) },
569    );
570    let factory_to_self_transform = common::clone_like_match_arms(
571        &e.variants,
572        &factory_state_enum_type,
573        ctx.item_type.path()?,
574        &quote! {},
575        |expr| quote! { #exhaust_crate_path::Exhaust::from_factory(#expr) },
576    );
577
578    let (impl_generics, ty_generics, augmented_where_predicates) =
579        ctx.generics_with_bounds(syn::parse_quote! {});
580
581    let impls = ctx.impl_iterator_and_factory_traits(
582        quote! {
583            'variants: loop {
584                break 'variants match &mut self.0 {
585                    #( #variant_next_arms , )*
586                    #iter_state_enum_type::#done_variant => ::core::option::Option::None,
587                }
588            }
589        },
590        quote! {
591            Self(#first_state_variant_initializer)
592        },
593        quote! {
594            Self(match &self.0 {
595                #( #state_enum_variant_cloners , )*
596            })
597        },
598    );
599
600    let iterator_decl = quote! {
601        // Struct that is exposed as the `<Self as Exhaust>::Iter` type.
602        #vis struct #iterator_type_name #ty_generics
603        (#iter_state_enum_type #ty_generics)
604        where #augmented_where_predicates;
605
606        // Struct that is exposed as the `<Self as Exhaust>::Factory` type.
607        #vis struct #factory_outer_type_path #ty_generics (#factory_state_enum_type #ty_generics)
608        where #augmented_where_predicates;
609
610        #impls
611
612        // Enum wrapped in #factory_type_name with the actual data.
613        enum #factory_state_enum_type #ty_generics
614        where #augmented_where_predicates { #( #factory_variant_decls ,)* }
615
616        // A manual impl of Clone is required to *not* have a `Clone` bound on the generics.
617        impl #impl_generics ::core::clone::Clone for #factory_type
618        where #augmented_where_predicates {
619            fn clone(&self) -> Self {
620                #![allow(unreachable_code)] // in case of empty enum
621                Self(match self.0 {
622                    #( #factory_enum_variant_clone_arms , )*
623                })
624            }
625        }
626
627        enum #iter_state_enum_type #ty_generics
628        where #augmented_where_predicates
629        {
630            #( #state_enum_variant_decls , )*
631        }
632    };
633
634    let from_factory_body = quote! {
635        match factory.0 {
636            #( #factory_to_self_transform , )*
637        }
638    };
639
640    Ok((iterator_decl, from_factory_body))
641}