Skip to main content

auto_default/
lib.rs

1#![doc = concat!("[![crates.io](https://img.shields.io/crates/v/", env!("CARGO_PKG_NAME"), "?style=flat-square&logo=rust)](https://crates.io/crates/", env!("CARGO_PKG_NAME"), ")")]
2#![doc = concat!("[![docs.rs](https://img.shields.io/docsrs/", env!("CARGO_PKG_NAME"), "?style=flat-square&logo=docs.rs)](https://docs.rs/", env!("CARGO_PKG_NAME"), ")")]
3#![doc = "![license](https://img.shields.io/badge/license-Apache--2.0_OR_MIT-blue?style=flat-square)"]
4//! ![msrv](https://img.shields.io/badge/msrv-nightly-blue?style=flat-square&logo=rust)
5//! [![github](https://img.shields.io/github/stars/nik-rev/auto-default)](https://github.com/nik-rev/auto-default)
6//!
7//! This crate provides an attribute macro `#[auto_default]`, which adds a default field value of
8//! `Default::default()` to fields that do not have one.
9//!
10//! ```toml
11#![doc = concat!(env!("CARGO_PKG_NAME"), " = ", "\"", env!("CARGO_PKG_VERSION_MAJOR"), ".", env!("CARGO_PKG_VERSION_MINOR"), "\"")]
12//! ```
13//!
14//! Note: `auto-default` has *zero* dependencies. Not even `syn`! The compile times are very fast.
15//!
16//! # Showcase
17//!
18//! Rust's [default field values](https://github.com/rust-lang/rust/issues/132162) allow
19//! the shorthand `Struct { field, .. }` instead of the lengthy `Struct { field, ..Default::default() }`
20//!
21//! For `..` instead of `..Default::default()` to work,
22//! your `Struct` needs **all** fields to have a default value.
23//!
24//! This often means `= Default::default()` boilerplate on every field, because it is
25//! very common to want field defaults to be the value of their `Default` implementation
26//!
27//! ## Before
28//!
29//! ```rust
30//! # #![feature(default_field_values)]
31//! # #![feature(const_trait_impl)]
32//! # #![feature(const_default)]
33//! # #![feature(derive_const)]
34//! # use auto_default::auto_default;
35//! # #[derive_const(Default)]
36//! # struct Rect { value: f32 }
37//! # #[derive_const(Default)]
38//! # struct Size { value: f32 }
39//! # #[derive_const(Default)]
40//! # struct Point { value: f32 }
41//! #[derive(Default)]
42//! pub struct Layout {
43//!     order: u32 = Default::default(),
44//!     location: Point = Default::default(),
45//!     size: Size = Default::default(),
46//!     content_size: Size = Default::default(),
47//!     scrollbar_size: Size = Default::default(),
48//!     border: Rect = Default::default(),
49//!     padding: Rect = Default::default(),
50//!     margin: Rect = Default::default(),
51//! }
52//! ```
53//!
54//! ## With `#[auto_default]`
55//!
56//! ```rust
57//! # #![feature(default_field_values)]
58//! # #![feature(const_trait_impl)]
59//! # #![feature(const_default)]
60//! # #![feature(derive_const)]
61//! # use auto_default::auto_default;
62//! # #[derive_const(Default)]
63//! # struct Rect { value: f32 }
64//! # #[derive_const(Default)]
65//! # struct Size { value: f32 }
66//! # #[derive_const(Default)]
67//! # struct Point { value: f32 }
68//! #[auto_default]
69//! #[derive(Default)]
70//! pub struct Layout {
71//!     order: u32,
72//!     location: Point,
73//!     size: Size,
74//!     content_size: Size,
75//!     scrollbar_size: Size,
76//!     border: Rect,
77//!     padding: Rect,
78//!     margin: Rect,
79//! }
80//! ```
81//!
82//! You can apply the [`#[auto_default]`](macro@auto_default) macro to `struct`s with named fields, and `enum`s.
83//!
84//! If any field or variant has the `#[auto_default(skip)]` attribute, a default field value of `Default::default()`
85//! will **not** be added
86//!
87//! # `#[auto_default(Option)]`
88//!
89//! By default, a default field value will be added to every field without one.
90//!
91//! If `#[auto_default(Option)]` is used, a default field value will be added only to fields
92//! that are `Option<T>`.
93//!
94//! # Global Import
95//!
96//! This will make `#[auto_default]` globally accessible in your entire crate, without needing to import it:
97//!
98//! ```
99//! #[macro_use(auto_default)]
100//! extern crate auto_default;
101//! ```
102use std::iter::Peekable;
103
104use proc_macro::Delimiter;
105use proc_macro::Group;
106use proc_macro::Ident;
107use proc_macro::Literal;
108use proc_macro::Punct;
109use proc_macro::Spacing;
110use proc_macro::Span;
111use proc_macro::TokenStream;
112use proc_macro::TokenTree;
113
114/// Adds a default field value of `Default::default()` to fields that don't have one
115///
116/// # Example
117///
118/// Turns this:
119///
120/// ```rust
121/// # #![feature(default_field_values)]
122/// # #![feature(const_trait_impl)]
123/// # #![feature(const_default)]
124/// #[auto_default]
125/// struct User {
126///     age: u8,
127///     is_admin: bool = false
128/// }
129/// # use auto_default::auto_default;
130/// ```
131///
132/// Into this:
133///
134/// ```rust
135/// # #![feature(default_field_values)]
136/// # #![feature(const_trait_impl)]
137/// # #![feature(const_default)]
138/// struct User {
139///     age: u8 = Default::default(),
140///     is_admin: bool = false
141/// }
142/// ```
143///
144/// This macro applies to `struct`s with named fields, and enums.
145///
146/// # Do not add `= Default::default()` field value to select fields
147///
148/// If you do not want a specific field to have a default, you can opt-out
149/// with `#[auto_default(skip)]`:
150///
151/// ```rust
152/// # #![feature(default_field_values)]
153/// # #![feature(const_trait_impl)]
154/// # #![feature(const_default)]
155/// #[auto_default]
156/// struct User {
157///     #[auto_default(skip)]
158///     age: u8,
159///     is_admin: bool
160/// }
161/// # use auto_default::auto_default;
162/// ```
163///
164/// The above is transformed into this:
165///
166/// ```rust
167/// # #![feature(default_field_values)]
168/// # #![feature(const_trait_impl)]
169/// # #![feature(const_default)]
170/// struct User {
171///     age: u8,
172///     is_admin: bool = Default::default()
173/// }
174/// ```
175#[proc_macro_attribute]
176pub fn auto_default(args: TokenStream, input: TokenStream) -> TokenStream {
177    let mut compile_errors = TokenStream::new();
178
179    let mut args = args.into_iter();
180
181    // If `#[auto_default(Option)]` is specified, only add a default field value to
182    // fields that are of the type `Option<T>`
183    let is_only_option = match args.next() {
184        Some(TokenTree::Ident(ident)) if ident.to_string() == "Option" => {
185            if let Some(tt) = args.next() {
186                compile_errors.extend(CompileError::new(tt.span(), "unexpected token"));
187            }
188            true
189        }
190        Some(tt) => {
191            compile_errors.extend(CompileError::new(tt.span(), "unexpected token"));
192            return compile_errors;
193        }
194        None => false,
195    };
196    let is_only_option = IsOnlyOption(is_only_option);
197
198    // Input supplied by the user. All tokens from here will
199    // get sent back to `output`
200    let mut source = input.into_iter().peekable();
201
202    // We collect all tokens into here and then return this
203    let mut sink = TokenStream::new();
204
205    // #[derive(Foo)]
206    // pub(in crate) struct Foo
207    stream_attrs(
208        &mut source,
209        &mut sink,
210        &mut compile_errors,
211        // no skip allowed on the container, would make no sense
212        // (just don't use the `#[auto_default]` at all at that point!)
213        IsSkipAllowed(false),
214    );
215
216    // pub(in crate) struct Foo
217    // ^^^^^^^^^^^^^
218    stream_vis(&mut source, &mut sink);
219
220    // pub(in crate) struct Foo
221    //               ^^^^^^
222    let item_kind = match source.next() {
223        Some(TokenTree::Ident(kw)) if kw.to_string() == "struct" => {
224            sink.extend([kw]);
225            ItemKind::Struct
226        }
227        Some(TokenTree::Ident(kw)) if kw.to_string() == "enum" => {
228            sink.extend([kw]);
229            ItemKind::Enum
230        }
231        tt => {
232            compile_errors.extend(create_compile_error!(
233                tt,
234                "expected a `struct` or an `enum`"
235            ));
236            return compile_errors;
237        }
238    };
239
240    // struct Foo
241    //        ^^^
242    let item_ident_span = stream_ident(&mut source, &mut sink)
243        .expect("`struct` or `enum` keyword is always followed by an identifier");
244
245    // Generics
246    //
247    // struct Foo<Bar, Baz: Trait> where Baz: Quux { ... }
248    //           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
249    let source_item_fields = loop {
250        match source.next() {
251            // Fields of the struct
252            Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => break group,
253            // This token is part of the generics of the struct
254            Some(tt) => sink.extend([tt]),
255            // reached end of input
256            None => {
257                // note: if enum, this is unreachable because `enum Foo` is invalid (requires `{}`),
258                // whilst `struct Foo;` is completely valid
259                compile_errors.extend(CompileError::new(
260                    item_ident_span,
261                    "expected struct with named fields",
262                ));
263                return compile_errors;
264            }
265        }
266    };
267
268    match item_kind {
269        ItemKind::Struct => {
270            sink.extend([add_default_field_values(
271                source_item_fields,
272                &mut compile_errors,
273                // none of the fields are considered to be skipped initially
274                IsSkip(false),
275                is_only_option,
276            )]);
277        }
278        ItemKind::Enum => {
279            let mut source_variants = source_item_fields.stream().into_iter().peekable();
280            let mut sink_variants = TokenStream::new();
281
282            loop {
283                // if this variant is marked #[auto_default(skip)]
284                let is_skip = stream_attrs(
285                    &mut source_variants,
286                    &mut sink_variants,
287                    &mut compile_errors,
288                    // can skip the variant, which removes auto-default for all
289                    // fields
290                    IsSkipAllowed(true),
291                );
292
293                // variants technically can have visibility, at least on a syntactic level
294                //
295                // pub Variant {  }
296                // ^^^
297                stream_vis(&mut source_variants, &mut sink_variants);
298
299                // Variant {  }
300                // ^^^^^^^
301                let Some(variant_ident_span) =
302                    stream_ident(&mut source_variants, &mut sink_variants)
303                else {
304                    // that means we have an enum with no variants, e.g.:
305                    //
306                    // enum Never {}
307                    //
308                    // When we parse the variants, there won't be an identifier
309                    break;
310                };
311
312                // only variants with named fields can be marked `#[auto_default(skip)]`
313                let mut disallow_skip = || {
314                    if is_skip.0 {
315                        compile_errors.extend(CompileError::new(
316                            variant_ident_span,
317                            concat!(
318                                "`#[auto_default(skip)]` is",
319                                " only allowed on variants with named fields"
320                            ),
321                        ));
322                    }
323                };
324
325                match source_variants.peek() {
326                    // Enum variant with named fields. Add default field values.
327                    Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => {
328                        let Some(TokenTree::Group(named_variant_fields)) = source_variants.next()
329                        else {
330                            unreachable!()
331                        };
332                        sink_variants.extend([add_default_field_values(
333                            named_variant_fields,
334                            &mut compile_errors,
335                            is_skip,
336                            is_only_option,
337                        )]);
338
339                        stream_enum_variant_discriminant_and_comma(
340                            &mut source_variants,
341                            &mut sink_variants,
342                        );
343                    }
344                    // Enum variant with unnamed fields.
345                    Some(TokenTree::Group(group))
346                        if group.delimiter() == Delimiter::Parenthesis =>
347                    {
348                        disallow_skip();
349                        let Some(TokenTree::Group(unnamed_variant_fields)) = source_variants.next()
350                        else {
351                            unreachable!()
352                        };
353                        sink_variants.extend([unnamed_variant_fields]);
354
355                        stream_enum_variant_discriminant_and_comma(
356                            &mut source_variants,
357                            &mut sink_variants,
358                        );
359                    }
360                    // This was a unit variant. Next variant may exist,
361                    // if it does it is parsed on next iteration
362                    Some(TokenTree::Punct(punct))
363                        if punct.as_char() == ',' || punct.as_char() == '=' =>
364                    {
365                        disallow_skip();
366                        stream_enum_variant_discriminant_and_comma(
367                            &mut source_variants,
368                            &mut sink_variants,
369                        );
370                    }
371                    // Unit variant, with no comma at the end. This is the last variant
372                    None => {
373                        disallow_skip();
374                        break;
375                    }
376                    Some(_) => unreachable!(),
377                }
378            }
379
380            let mut sink_variants = Group::new(source_item_fields.delimiter(), sink_variants);
381            sink_variants.set_span(source_item_fields.span());
382            sink.extend([sink_variants]);
383        }
384    }
385
386    sink.extend(compile_errors);
387
388    sink
389}
390
391struct IsSkip(bool);
392/// If `#[auto_default(Option)]` is specified
393#[derive(Copy, Clone)]
394struct IsOnlyOption(bool);
395struct IsSkipAllowed(bool);
396
397/// Streams enum variant discriminant + comma at the end from `source` into `sink`
398///
399/// enum Example {
400///     Three,
401///          ^
402///     Two(u32) = 2,
403///             ^^^^^
404///     Four { hello: u32 } = 4,
405///                        ^^^^^
406/// }
407fn stream_enum_variant_discriminant_and_comma(source: &mut Source, sink: &mut Sink) {
408    match source.next() {
409        // No discriminant, there may be another variant after this
410        Some(TokenTree::Punct(punct)) if punct.as_char() == ',' => {
411            sink.extend([punct]);
412        }
413        // No discriminant, this is the final enum variant
414        None => {}
415        // Enum variant has a discriminant
416        Some(TokenTree::Punct(punct)) if punct.as_char() == '=' => {
417            sink.extend([punct]);
418
419            // Stream discriminant expression from `source` into `sink`
420            loop {
421                match source.next() {
422                    // End of discriminant, there may be a variant after this
423                    Some(TokenTree::Punct(punct)) if punct.as_char() == ',' => {
424                        sink.extend([punct]);
425                        break;
426                    }
427                    // This token is part of the variant's expression
428                    Some(tt) => {
429                        sink.extend([tt]);
430                    }
431                    // End of discriminant, this is the last variant
432                    None => break,
433                }
434            }
435        }
436        Some(_) => unreachable!(),
437    }
438}
439
440type Source = Peekable<proc_macro::token_stream::IntoIter>;
441type Sink = TokenStream;
442
443/// Streams the identifier from `input` into `output`, returning its span, if the identifier exists
444fn stream_ident(source: &mut Source, sink: &mut Sink) -> Option<Span> {
445    let ident = source.next()?;
446    let span = ident.span();
447    sink.extend([ident]);
448    Some(span)
449}
450
451// Parses attributes
452//
453// #[attr] #[attr] pub field: Type
454// #[attr] #[attr] struct Foo
455// #[attr] #[attr] enum Foo
456//
457// Returns `true` if `#[auto_default(skip)]` was encountered
458fn stream_attrs(
459    source: &mut Source,
460    sink: &mut Sink,
461    errors: &mut TokenStream,
462    is_skip_allowed: IsSkipAllowed,
463) -> IsSkip {
464    let mut is_skip = None;
465
466    let is_skip = loop {
467        if !matches!(source.peek(), Some(TokenTree::Punct(hash)) if *hash == '#') {
468            break is_skip;
469        };
470
471        // #[some_attr]
472        // ^
473        let pound = source.next();
474
475        // #[some_attr]
476        //  ^^^^^^^^^^^
477        let Some(TokenTree::Group(attr)) = source.next() else {
478            unreachable!()
479        };
480
481        // #[some_attr = hello]
482        //   ^^^^^^^^^^^^^^^^^
483        let mut attr_tokens = attr.stream().into_iter().peekable();
484
485        // Check if this attribute is `#[auto_default(skip)]`
486        if let Some(skip_span) = is_skip_attribute(&mut attr_tokens, errors) {
487            if is_skip.is_some() {
488                // Disallow 2 attributes on a single field:
489                //
490                // #[auto_default(skip)]
491                // #[auto_default(skip)]
492                errors.extend(CompileError::new(
493                    skip_span,
494                    "duplicate `#[auto_default(skip)]`",
495                ));
496            } else {
497                is_skip = Some(skip_span);
498            }
499            continue;
500        }
501
502        // #[attr]
503        // ^
504        sink.extend(pound);
505
506        // Re-construct the `[..]` for the attribute
507        //
508        // #[attr]
509        //  ^^^^^^
510        let mut group = Group::new(attr.delimiter(), attr_tokens.collect());
511        group.set_span(attr.span());
512
513        // #[attr]
514        //  ^^^^^^
515        sink.extend([group]);
516    };
517
518    if let Some(skip_span) = is_skip
519        && !is_skip_allowed.0
520    {
521        errors.extend(CompileError::new(
522            skip_span,
523            "`#[auto_default(skip)]` is not allowed on container",
524        ));
525    }
526
527    IsSkip(is_skip.is_some())
528}
529
530/// if `source` is exactly `auto_default(skip)`, returns `Some(span)`
531/// with `span` being the `Span` of the `skip` identifier
532fn is_skip_attribute(source: &mut Source, errors: &mut TokenStream) -> Option<Span> {
533    let Some(TokenTree::Ident(ident)) = source.peek() else {
534        return None;
535    };
536
537    if ident.to_string() != "auto_default" {
538        return None;
539    };
540
541    // #[auto_default(skip)]
542    //   ^^^^^^^^^^^^
543    let ident = source.next().unwrap();
544
545    // We know it is `#[auto_default ???]`, we need to validate that `???`
546    // is exactly `(skip)` now
547
548    // #[auto_default(skip)]
549    //   ^^^^^^^^^^^^
550    let auto_default_span = ident.span();
551
552    // #[auto_default(skip)]
553    //               ^^^^^^
554    let group = match source.next() {
555        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => group,
556        Some(tt) => {
557            errors.extend(CompileError::new(tt.span(), "expected `(skip)`"));
558            return None;
559        }
560        None => {
561            errors.extend(CompileError::new(
562                auto_default_span,
563                "expected `(skip)` after this",
564            ));
565            return None;
566        }
567    };
568
569    // #[auto_default(skip)]
570    //                ^^^^
571    let mut inside = group.stream().into_iter();
572
573    // #[auto_default(skip)]
574    //                ^^^^
575    let ident_skip = match inside.next() {
576        Some(TokenTree::Ident(ident)) => ident,
577        Some(tt) => {
578            errors.extend(CompileError::new(tt.span(), "expected `skip`"));
579            return None;
580        }
581        None => {
582            errors.extend(CompileError::new(
583                group.span(),
584                "expected `(skip)`, found `()`",
585            ));
586            return None;
587        }
588    };
589
590    if ident_skip.to_string() != "skip" {
591        errors.extend(CompileError::new(ident_skip.span(), "expected `skip`"));
592        return None;
593    }
594
595    // Validate that there's nothing after `skip`
596    //
597    // #[auto_default(skip    )]
598    //                    ^^^^
599    if let Some(tt) = inside.next() {
600        errors.extend(CompileError::new(tt.span(), "unexpected token"));
601        return None;
602    }
603
604    Some(ident_skip.span())
605}
606
607fn stream_vis(source: &mut Source, sink: &mut Sink) {
608    // Remove visibility if it is present
609    //
610    // pub(in crate) struct
611    // ^^^^^^^^^^^^^
612    if let Some(TokenTree::Ident(vis)) = source.peek()
613        && vis.to_string() == "pub"
614    {
615        // pub(in crate) struct
616        // ^^^
617        sink.extend(source.next());
618
619        if let Some(TokenTree::Group(group)) = source.peek()
620            && let Delimiter::Parenthesis = group.delimiter()
621        {
622            // pub(in crate) struct
623            //    ^^^^^^^^^^
624            sink.extend(source.next());
625        }
626    };
627}
628
629#[derive(PartialEq)]
630enum ItemKind {
631    Struct,
632    Enum,
633}
634
635/// `fields` is [`StructFields`] in Rust's grammar.
636///
637/// [`StructFields`]: https://doc.rust-lang.org/reference/items/structs.html#grammar-StructFields
638///
639/// It is the curly braces, and everything within, for a struct with named fields,
640/// or an enum variant with named fields.
641///
642/// These fields are transformed by adding `= Default::default()` to every
643/// field that doesn't already have a default value.
644///
645/// If a field is marked with `#[auto_default(skip)]`, no default value will be
646/// added
647///
648/// # `#[auto_default(Option)]`
649///
650/// If `is_only_option` is true, then no default value will be added,
651/// unless:
652///
653/// - Last segment of the field's type's path is `Option`
654/// - The field's type is generic (has `<` and `>`)
655/// - The field's type has only 1 generic type parameter
656fn add_default_field_values(
657    fields: Group,
658    compile_errors: &mut TokenStream,
659    is_skip_variant: IsSkip,
660    is_only_option: IsOnlyOption,
661) -> Group {
662    // All the tokens corresponding to the struct's field, passed by the user
663    // These tokens will eventually all be sent to `output_fields`,
664    // plus a few extra for any `Default::default()` that we output
665    let mut input_fields = fields.stream().into_iter().peekable();
666
667    // The tokens corresponding to the fields of the output struct
668    let mut output_fields = TokenStream::new();
669
670    // Parses all fields.
671    // Each iteration parses a single field
672    'parse_field: loop {
673        // #[serde(rename = "x")] pub field: Type
674        // ^^^^^^^^^^^^^^^^^^^^^^
675        let is_skip_field = stream_attrs(
676            &mut input_fields,
677            &mut output_fields,
678            compile_errors,
679            IsSkipAllowed(true),
680        );
681
682        // If #[auto_default(skip)] is present
683        let is_skip = is_skip_field.0 || is_skip_variant.0;
684
685        // If this is `true`, no default field value will be added
686        //
687        // mut: This is set to `true` if `is_only_option` is true and
688        // we discover that this field is of type `Option<T>`
689        //
690        // We start out by considering every field as skipped if it isn't in an
691        // Option. Only if we confirm it to be in an Option, do we add a default field value
692        let mut add_default_field_value = if is_only_option.0 { false } else { !is_skip };
693
694        // pub field: Type
695        // ^^^
696        stream_vis(&mut input_fields, &mut output_fields);
697
698        // If the type path ends in an `Option`.
699        //
700        // ```txt
701        // ::core::option::Option
702        //                 ^^^^^^
703        // ```
704        let mut is_potentially_option = false;
705
706        // pub field: Type
707        //     ^^^^^
708        let Some(field_ident_span) = stream_ident(&mut input_fields, &mut output_fields) else {
709            // No fields. e.g.: `struct Struct {}`
710            break;
711        };
712
713        // field: Type
714        //      ^
715        output_fields.extend(input_fields.next());
716
717        // Everything after the `:` in the field
718        //
719        // Involves:
720        //
721        // - Adding default value of `= Default::default()` if one is not present
722        // - Continue to next iteration of the loop
723        loop {
724            let mut validate_attr = || {
725                // Only error if the skip was REDUNDANT.
726                //
727                // Option-only mode is on, but this isn't an option.
728                if is_only_option.0 && !is_potentially_option && is_skip_field.0 {
729                    compile_errors.extend(CompileError::new(
730                            field_ident_span,
731                            "this field is marked `#[auto_default(skip)]`, which does nothing since this item is marked `#[auto_default(Option)]` and this field doesn't appear to be an `Option<T>`"
732                        ));
733                }
734                // The variant/struct itself was already skipped.
735                else if is_skip_variant.0 && is_skip_field.0 {
736                    compile_errors.extend(CompileError::new(
737                            field_ident_span,
738                            "this field is marked `#[auto_default(skip)]`, which is redundant because the entire variant/struct is already skipped"
739                        ));
740                }
741            };
742            match input_fields.peek() {
743                // This field has a custom default field value
744                //
745                // field: Type = default
746                //             ^
747                Some(TokenTree::Punct(p)) if p.as_char() == '=' => {
748                    if is_skip_field.0 {
749                        compile_errors.extend(CompileError::new(
750                                field_ident_span,
751                                "this field is marked `#[auto_default(skip)]`, which does nothing since this field has a default value: `= ...`"
752                            ));
753                    }
754
755                    // Take all tokens representing the default field value from
756                    // `input_fields` and move them into `output_fields`
757                    loop {
758                        match input_fields.next() {
759                            // Comma after field. Field is finished.
760                            Some(TokenTree::Punct(p)) if p == ',' => {
761                                output_fields.extend([p]);
762                                continue 'parse_field;
763                            }
764                            // This token is part of the field's default value
765                            //
766                            // foo = Some(42)
767                            //       ^^^^
768                            Some(tt) => output_fields.extend([tt]),
769                            // End of input. Field is finished. This is the last field
770                            None => {
771                                break 'parse_field;
772                            }
773                        }
774                    }
775                }
776                // Reached end of field, has comma at the end, no custom default value
777                //
778                // field: Type,
779                //            ^
780                //
781                // OR, this comma is part of the field's type:
782                //
783                // field: HashMap<u8, u8>,
784                //                  ^
785                Some(TokenTree::Punct(p)) if p.as_char() == ',' => {
786                    let comma = input_fields.next().expect("match on `Some`");
787
788                    /// What does this comma represent?
789                    #[derive(Debug)]
790                    enum CommaStatus {
791                        /// End of field
792                        EndOfField0,
793                        /// End of field, with 2 tokens to insert after end of the field
794                        EndOfField1(TokenTree),
795                        /// Part of type
796                        PartOfType0,
797                        /// Part of type, with 1 token to insert after the comma
798                        PartOfType1(TokenTree),
799                    }
800
801                    // When we encounter a COMMA token while parsing
802                    let comma_status = match input_fields.peek() {
803                        // pub field: Type
804                        // ^^^
805                        Some(TokenTree::Ident(ident)) if ident.to_string() == "pub" => {
806                            // this `comma` is end of the field
807                            //
808                            //     field: HashMap<u8, u8>,
809                            //                           ^ `comma`
810                            // pub next_field: Type
811                            // ^^^ `ident`
812                            CommaStatus::EndOfField0
813                        }
814
815                        // #[foo(bar)] pub field: Type
816                        // ^
817                        Some(TokenTree::Punct(punct)) if *punct == '#' => {
818                            // this `comma` is end of the field
819                            //
820                            //             field: HashMap<u8, u8>,
821                            //                                   ^ this
822                            // #[foo(bar)] next_field: HashMap<u8, u8>,
823                            CommaStatus::EndOfField0
824                        }
825
826                        // field: Type
827                        // ^^^^^
828                        Some(TokenTree::Ident(_)) => {
829                            let field_or_ty_ident = input_fields.next().expect("match on `Some`");
830
831                            match input_fields.peek() {
832                                // field: Type
833                                //      ^
834                                Some(TokenTree::Punct(punct)) if *punct == ':' => {
835                                    let field_ident = field_or_ty_ident;
836
837                                    CommaStatus::EndOfField1(field_ident)
838                                }
839                                // This identifier is part of the type
840                                //
841                                // field: HashMap<u8, u8>,
842                                //                    ^^
843                                _ => {
844                                    let ty_ident = field_or_ty_ident;
845                                    CommaStatus::PartOfType1(ty_ident)
846                                }
847                            }
848                        }
849
850                        // This comma is part of a type, NOT end of field!
851                        //
852                        // pub field: HashMap<String, String>
853                        //                          ^
854                        Some(_) => CommaStatus::PartOfType0,
855
856                        // Reached end of input. This comma is end of the field
857                        //
858                        // field: Type,
859                        //            ^
860                        None => CommaStatus::EndOfField0,
861                    };
862
863                    let insert_extra_token = match comma_status {
864                        CommaStatus::EndOfField0 => None,
865                        CommaStatus::EndOfField1(token_tree) => Some(token_tree),
866                        CommaStatus::PartOfType0 => {
867                            output_fields.extend([comma]);
868                            continue;
869                        }
870                        CommaStatus::PartOfType1(token_tree) => {
871                            output_fields.extend([comma]);
872                            output_fields.extend([token_tree]);
873                            continue;
874                        }
875                    };
876
877                    // Now that we're here, we can be 100% sure that the `comma` we have
878                    // is at the END of the field
879
880                    if is_skip && !add_default_field_value {
881                        validate_attr();
882                    }
883
884                    // Insert default value before the comma
885                    //
886                    // field: Type = Default::default(),
887                    //             ^^^^^^^^^^^^^^^^^^^^
888                    if add_default_field_value {
889                        output_fields.extend(default(field_ident_span));
890                    }
891
892                    // field: Type = Default::default(),
893                    //                                 ^
894                    output_fields.extend([comma]);
895
896                    if let Some(token_tree) = insert_extra_token {
897                        // Insert the extra token which we needed to take when figuring out if this
898                        // comma is part of the type, or end of the field
899
900                        output_fields.extend([token_tree]);
901                    }
902
903                    // Next iteration handles the next field
904                    continue 'parse_field;
905                }
906                // This token is part of the field's type
907                //
908                // field: some::Option
909                //              ^^^^^^
910                Some(TokenTree::Ident(ident))
911                    if is_only_option.0
912                        && ident.to_string() == "Option"
913                        && !is_potentially_option =>
914                {
915                    is_potentially_option = true;
916                    output_fields.extend(input_fields.next())
917                }
918                // This isn't actually an Option.
919                //
920                // field: some::Option::Type
921                //                      ^^^^
922                Some(TokenTree::Ident(ident)) if is_only_option.0 && is_potentially_option => {
923                    is_potentially_option = false;
924                    output_fields.extend(input_fields.next())
925                }
926                // This token is part of the field's type
927                //
928                // field: some::Option<
929                //                    ^
930                Some(TokenTree::Punct(punct))
931                    if is_only_option.0 && punct.as_char() == '<' && is_potentially_option =>
932                {
933                    output_fields.extend(input_fields.next());
934
935                    // The nesting level of angle brackets that we are currently in
936                    //
937                    //     Option<HashMap<u32, u32>, String>
938                    //        ^ 0
939                    //              ^ 1
940                    //                      ^ 2
941                    //                            ^ 1
942                    //                                ^ 1
943                    //                                      ^ 1
944                    //
945                    //
946                    // INVARIANT: We assume that the input TokenStream contains a balanced amount of
947                    // '<' and '>', as is the case for rust's types (when not descending into nested TokenStreams
948                    // e.g. `foo::<{ 1 < 2 }>`). only has the top level tokens `foo`, `:`, `:`, `<`, `{...}`, `>`, which
949                    // has a balanced amount of angle brackets.
950                    let mut nesting_level = 1;
951
952                    // If we've seen a ',' token at the top-level
953                    //
954                    // Option<T,>
955                    //         ^ true
956                    //
957                    // First comma in the flat list of tokens is not at the top-level:
958                    //
959                    // Option<HashMap<u32, u32>, String>
960                    //                   ^ false
961                    //                         ^ true
962                    let mut seen_comma_at_top_level = false;
963
964                    // If the top-level generics has a 2nd type parameter
965                    //
966                    // false:
967                    //
968                    //     Option<T,>
969                    //
970                    // false:
971                    //
972                    //     Option<HashMap<u32, u32>>
973                    //
974                    // true:
975                    //
976                    //     Option<u32, u8>
977                    //
978                    // Can only be `true` if `seen_comma_at_the_top_level` is true.
979                    let mut has_second_type_parameter_at_top_level = false;
980
981                    loop {
982                        let is_top_level = nesting_level == 1;
983
984                        match input_fields.peek() {
985                            Some(TokenTree::Punct(p)) if p.as_char() == '<' => {
986                                nesting_level += 1;
987                                output_fields.extend(input_fields.next());
988                            }
989                            Some(TokenTree::Punct(p)) if p.as_char() == '>' => {
990                                if is_top_level {
991                                    output_fields.extend(input_fields.next());
992                                    // Closing the generic type of the Option.
993                                    break;
994                                } else {
995                                    nesting_level -= 1;
996                                    output_fields.extend(input_fields.next());
997                                }
998                            }
999                            Some(TokenTree::Punct(p)) if p.as_char() == ',' && is_top_level => {
1000                                seen_comma_at_top_level = true;
1001                                output_fields.extend(input_fields.next());
1002                            }
1003                            // type parameter or identifier following a comma at top level
1004                            Some(_) if seen_comma_at_top_level && is_top_level => {
1005                                has_second_type_parameter_at_top_level = true;
1006                                output_fields.extend(input_fields.next());
1007                            }
1008                            // any other token (ident, literal, punct that isn't <, >, or top-level ,)
1009                            Some(_) => {
1010                                output_fields.extend(input_fields.next());
1011                            }
1012                            None => break,
1013                        }
1014                    }
1015
1016                    // if the type looks something like `Option<T>`
1017                    let is_option = !has_second_type_parameter_at_top_level;
1018
1019                    if is_option {
1020                        // it is likely a standard Option<T>.
1021                        //
1022                        // Only skip if #[auto_default(skip)] was applied.
1023                        add_default_field_value = !is_skip;
1024                    } else {
1025                        // We skip this, because it is NOT an option
1026                        add_default_field_value = false;
1027                    }
1028                }
1029                // This token is part of the field's type
1030                //
1031                // field: some::Type
1032                //              ^^^^
1033                Some(_) => output_fields.extend(input_fields.next()),
1034                // Reached end of input, and it has no comma.
1035                // This is the last field.
1036                //
1037                // struct Foo {
1038                //     field: Type
1039                //                ^
1040                // }
1041                None => {
1042                    if add_default_field_value {
1043                        validate_attr();
1044                        output_fields.extend(default(field_ident_span));
1045                    }
1046                    // No more fields
1047                    break 'parse_field;
1048                }
1049            }
1050        }
1051    }
1052    let mut g = Group::new(Delimiter::Brace, output_fields);
1053    g.set_span(fields.span());
1054    g
1055}
1056
1057// = ::core::default::Default::default()
1058fn default(span: Span) -> [TokenTree; 14] {
1059    [
1060        TokenTree::Punct(Punct::new('=', Spacing::Alone)),
1061        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1062        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1063        TokenTree::Ident(Ident::new("core", span)),
1064        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1065        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1066        TokenTree::Ident(Ident::new("default", span)),
1067        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1068        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1069        TokenTree::Ident(Ident::new("Default", span)),
1070        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1071        TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
1072        TokenTree::Ident(Ident::new("default", span)),
1073        TokenTree::Group(Group::new(Delimiter::Parenthesis, TokenStream::new())).with_span(span),
1074    ]
1075}
1076
1077macro_rules! create_compile_error {
1078    ($spanned:expr, $($tt:tt)*) => {{
1079        let span = if let Some(spanned) = $spanned {
1080            spanned.span()
1081        } else {
1082            Span::call_site()
1083        };
1084        CompileError::new(span, format!($($tt)*))
1085    }};
1086}
1087use create_compile_error;
1088
1089/// `.into_iter()` generates `compile_error!($message)` at `$span`
1090struct CompileError {
1091    /// Where the compile error is generates
1092    pub span: Span,
1093    /// Message of the compile error
1094    pub message: String,
1095}
1096
1097impl CompileError {
1098    /// Create a new compile error
1099    pub fn new(span: Span, message: impl AsRef<str>) -> Self {
1100        Self {
1101            span,
1102            message: message.as_ref().to_string(),
1103        }
1104    }
1105}
1106
1107impl IntoIterator for CompileError {
1108    type Item = TokenTree;
1109    type IntoIter = std::array::IntoIter<Self::Item, 8>;
1110
1111    fn into_iter(self) -> Self::IntoIter {
1112        [
1113            TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
1114            TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
1115            TokenTree::Ident(Ident::new("core", self.span)),
1116            TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
1117            TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
1118            TokenTree::Ident(Ident::new("compile_error", self.span)),
1119            TokenTree::Punct(Punct::new('!', Spacing::Alone)).with_span(self.span),
1120            TokenTree::Group(Group::new(Delimiter::Brace, {
1121                TokenStream::from(
1122                    TokenTree::Literal(Literal::string(&self.message)).with_span(self.span),
1123                )
1124            }))
1125            .with_span(self.span),
1126        ]
1127        .into_iter()
1128    }
1129}
1130
1131trait TokenTreeExt {
1132    /// Set span of `TokenTree` without needing to create a new binding
1133    fn with_span(self, span: Span) -> TokenTree;
1134}
1135
1136impl TokenTreeExt for TokenTree {
1137    fn with_span(mut self, span: Span) -> TokenTree {
1138        self.set_span(span);
1139        self
1140    }
1141}