auto_default/lib.rs
1//! [](https://crates.io/crates/auto-default)
2//! [](https://docs.rs/auto-default)
3//! 
4//! 
5//! [](https://github.com/nik-rev/auto-default)
6//!
7//! This crate provides an attribute macro `#[auto_default]`, which adds a default field value of
8//! `Default::default()` to fields that do not have one.
9//!
10//! ```toml
11//! [dependencies]
12//! auto-default = "0.1"
13//! ```
14//!
15//! Note: `auto-default` has *zero* dependencies. Not even `syn`! The compile times are very fast.
16//!
17//! ## Showcase
18//!
19//! Rust's [default field values](https://github.com/rust-lang/rust/issues/132162) allow
20//! the shorthand `Struct { field, .. }` instead of the lengthy `Struct { field, ..Default::default() }`
21//!
22//! For `..` instead of `..Default::default()` to work,
23//! your `Struct` needs **all** fields to have a default value.
24//!
25//! This often means `= Default::default()` boilerplate on every field, because it is
26//! very common to want field defaults to be the value of their `Default` implementation
27//!
28//! ### Before
29//!
30//! ```rust
31//! # #![feature(default_field_values)]
32//! # #![feature(const_trait_impl)]
33//! # #![feature(const_default)]
34//! # #![feature(derive_const)]
35//! # use auto_default::auto_default;
36//! # #[derive_const(Default)]
37//! # struct Rect { value: f32 }
38//! # #[derive_const(Default)]
39//! # struct Size { value: f32 }
40//! # #[derive_const(Default)]
41//! # struct Point { value: f32 }
42//! #[derive(Default)]
43//! pub struct Layout {
44//! order: u32 = Default::default(),
45//! location: Point = Default::default(),
46//! size: Size = Default::default(),
47//! content_size: Size = Default::default(),
48//! scrollbar_size: Size = Default::default(),
49//! border: Rect = Default::default(),
50//! padding: Rect = Default::default(),
51//! margin: Rect = Default::default(),
52//! }
53//! ```
54//!
55//! ### With `#[auto_default]`
56//!
57//! ```rust
58//! # #![feature(default_field_values)]
59//! # #![feature(const_trait_impl)]
60//! # #![feature(const_default)]
61//! # #![feature(derive_const)]
62//! # use auto_default::auto_default;
63//! # #[derive_const(Default)]
64//! # struct Rect { value: f32 }
65//! # #[derive_const(Default)]
66//! # struct Size { value: f32 }
67//! # #[derive_const(Default)]
68//! # struct Point { value: f32 }
69//! #[auto_default]
70//! #[derive(Default)]
71//! pub struct Layout {
72//! order: u32,
73//! location: Point,
74//! size: Size,
75//! content_size: Size,
76//! scrollbar_size: Size,
77//! border: Rect,
78//! padding: Rect,
79//! margin: Rect,
80//! }
81//! ```
82//!
83//! You can apply the [`#[auto_default]`](macro@auto_default) macro to `struct`s with named fields, or enums
84//!
85//! If any field or variant has the `#[auto_default(skip)]` attribute, a default field value of `Default::default()`
86//! will not be added
87use std::iter::Peekable;
88
89use proc_macro::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
90
91/// Adds a default field value of `Default::default()` to fields that don't have one
92///
93/// # Example
94///
95/// Turns this:
96///
97/// ```rust
98/// # #![feature(default_field_values)]
99/// # #![feature(const_trait_impl)]
100/// # #![feature(const_default)]
101/// #[auto_default]
102/// struct User {
103/// age: u8,
104/// is_admin: bool = false
105/// }
106/// # use auto_default::auto_default;
107/// ```
108///
109/// Into this:
110///
111/// ```rust
112/// # #![feature(default_field_values)]
113/// # #![feature(const_trait_impl)]
114/// # #![feature(const_default)]
115/// struct User {
116/// age: u8 = Default::default(),
117/// is_admin: bool = false
118/// }
119/// ```
120///
121/// This macro applies to `struct`s with named fields, and enums.
122///
123/// # Do not add `= Default::default()` field value to select fields
124///
125/// If you do not want a specific field to have a default, you can opt-out
126/// with `#[auto_default(skip)]`:
127///
128/// ```rust
129/// # #![feature(default_field_values)]
130/// # #![feature(const_trait_impl)]
131/// # #![feature(const_default)]
132/// #[auto_default]
133/// struct User {
134/// #[auto_default(skip)]
135/// age: u8,
136/// is_admin: bool
137/// }
138/// # use auto_default::auto_default;
139/// ```
140///
141/// The above is transformed into this:
142///
143/// ```rust
144/// # #![feature(default_field_values)]
145/// # #![feature(const_trait_impl)]
146/// # #![feature(const_default)]
147/// struct User {
148/// age: u8,
149/// is_admin: bool = false
150/// }
151/// ```
152#[proc_macro_attribute]
153pub fn auto_default(args: TokenStream, input: TokenStream) -> TokenStream {
154 let mut compile_errors = TokenStream::new();
155
156 if !args.is_empty() {
157 compile_errors.extend(create_compile_error!(
158 args.into_iter().next(),
159 "no arguments expected",
160 ));
161 }
162
163 // Input supplied by the user. All tokens from here will
164 // get sent back to `output`
165 let mut source = input.into_iter().peekable();
166
167 // We collect all tokens into here and then return this
168 let mut sink = TokenStream::new();
169
170 stream_attrs(
171 &mut source,
172 &mut sink,
173 &mut compile_errors,
174 // no skip allowed on the container, would make no sense
175 // (just don't use the `#[auto_default]` at all at that point!)
176 IsSkipAllowed(false),
177 );
178 stream_vis(&mut source, &mut sink);
179
180 // pub(in crate) struct Foo
181 // ^^^^^^
182 let item_kind = match source.next() {
183 Some(TokenTree::Ident(kw)) if kw.to_string() == "struct" => {
184 sink.extend([kw]);
185 ItemKind::Struct
186 }
187 Some(TokenTree::Ident(kw)) if kw.to_string() == "enum" => {
188 sink.extend([kw]);
189 ItemKind::Enum
190 }
191 tt => {
192 compile_errors.extend(create_compile_error!(
193 tt,
194 "expected a `struct` or an `enum`"
195 ));
196 return compile_errors;
197 }
198 };
199
200 // struct Foo
201 // ^^^
202 let item_ident_span = stream_ident(&mut source, &mut sink)
203 .expect("`struct` or `enum` keyword is always followed by an identifier");
204
205 // Generics
206 //
207 // struct Foo<Bar, Baz: Trait> where Baz: Quux { ... }
208 // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
209 let source_item_fields = loop {
210 match source.next() {
211 // Fields of the struct
212 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => break group,
213 // This token is part of the generics of the struct
214 Some(tt) => sink.extend([tt]),
215 // reached end of input
216 None => {
217 // note: if enum, this is unreachable because `enum Foo` is invalid (requires `{}`),
218 // whilst `struct Foo;` is completely valid
219 compile_errors.extend(CompileError::new(
220 item_ident_span,
221 "expected struct with named fields",
222 ));
223 return compile_errors;
224 }
225 }
226 };
227
228 match item_kind {
229 ItemKind::Struct => {
230 sink.extend([add_default_field_values(
231 source_item_fields,
232 &mut compile_errors,
233 // none of the fields are considered to be skipped initially
234 IsSkip(false),
235 )]);
236 }
237 ItemKind::Enum => {
238 let mut source_variants = source_item_fields.stream().into_iter().peekable();
239 let mut sink_variants = TokenStream::new();
240
241 loop {
242 // if this variant is marked #[auto_default(skip)]
243 let is_skip = stream_attrs(
244 &mut source_variants,
245 &mut sink_variants,
246 &mut compile_errors,
247 // can skip the variant, which removes auto-default for all
248 // fields
249 IsSkipAllowed(true),
250 );
251
252 // variants technically can have visibility, at least on a syntactic level
253 //
254 // pub Variant { }
255 // ^^^
256 stream_vis(&mut source_variants, &mut sink_variants);
257
258 // Variant { }
259 // ^^^^^^^
260 let Some(variant_ident_span) =
261 stream_ident(&mut source_variants, &mut sink_variants)
262 else {
263 // that means we have an enum with no variants, e.g.:
264 //
265 // enum Never {}
266 //
267 // When we parse the variants, there won't be an identifier
268 break;
269 };
270
271 // only variants with named fields can be marked `#[auto_default(skip)]`
272 let mut disallow_skip = || {
273 if is_skip.0 {
274 compile_errors.extend(CompileError::new(
275 variant_ident_span,
276 concat!(
277 "`#[auto_default(skip)]` is",
278 " only allowed on variants with named fields"
279 ),
280 ));
281 }
282 };
283
284 match source_variants.peek() {
285 // Enum variant with named fields. Add default field values.
286 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => {
287 let Some(TokenTree::Group(named_variant_fields)) = source_variants.next()
288 else {
289 unreachable!()
290 };
291 sink_variants.extend([add_default_field_values(
292 named_variant_fields,
293 &mut compile_errors,
294 is_skip,
295 )]);
296
297 stream_enum_variant_discriminant_and_comma(
298 &mut source_variants,
299 &mut sink_variants,
300 );
301 }
302 // Enum variant with unnamed fields.
303 Some(TokenTree::Group(group))
304 if group.delimiter() == Delimiter::Parenthesis =>
305 {
306 disallow_skip();
307 let Some(TokenTree::Group(unnamed_variant_fields)) = source_variants.next()
308 else {
309 unreachable!()
310 };
311 sink_variants.extend([unnamed_variant_fields]);
312
313 stream_enum_variant_discriminant_and_comma(
314 &mut source_variants,
315 &mut sink_variants,
316 );
317 }
318 // This was a unit variant. Next variant may exist,
319 // if it does it is parsed on next iteration
320 Some(TokenTree::Punct(punct))
321 if punct.as_char() == ',' || punct.as_char() == '=' =>
322 {
323 disallow_skip();
324 stream_enum_variant_discriminant_and_comma(
325 &mut source_variants,
326 &mut sink_variants,
327 );
328 }
329 // Unit variant, with no comma at the end. This is the last variant
330 None => {
331 disallow_skip();
332 break;
333 }
334 Some(_) => unreachable!(),
335 }
336 }
337
338 let mut sink_variants = Group::new(source_item_fields.delimiter(), sink_variants);
339 sink_variants.set_span(source_item_fields.span());
340 sink.extend([sink_variants]);
341 }
342 }
343
344 sink.extend(compile_errors);
345
346 sink
347}
348
349struct IsSkip(bool);
350struct IsSkipAllowed(bool);
351
352/// Streams enum variant discriminant + comma at the end from `source` into `sink`
353///
354/// enum Example {
355/// Three,
356/// ^
357/// Two(u32) = 2,
358/// ^^^^^
359/// Four { hello: u32 } = 4,
360/// ^^^^^
361/// }
362fn stream_enum_variant_discriminant_and_comma(source: &mut Source, sink: &mut Sink) {
363 match source.next() {
364 // No discriminant, there may be another variant after this
365 Some(TokenTree::Punct(punct)) if punct.as_char() == ',' => {
366 sink.extend([punct]);
367 }
368 // No discriminant, this is the final enum variant
369 None => {}
370 // Enum variant has a discriminant
371 Some(TokenTree::Punct(punct)) if punct.as_char() == '=' => {
372 sink.extend([punct]);
373
374 // Stream discriminant expression from `source` into `sink`
375 loop {
376 match source.next() {
377 // End of discriminant, there may be a variant after this
378 Some(TokenTree::Punct(punct)) if punct.as_char() == ',' => {
379 sink.extend([punct]);
380 break;
381 }
382 // This token is part of the variant's expression
383 Some(tt) => {
384 sink.extend([tt]);
385 }
386 // End of discriminant, this is the last variant
387 None => break,
388 }
389 }
390 }
391 Some(_) => unreachable!(),
392 }
393}
394
395type Source = Peekable<proc_macro::token_stream::IntoIter>;
396type Sink = TokenStream;
397
398/// Streams the identifier from `input` into `output`, returning its span, if the identifier exists
399fn stream_ident(source: &mut Source, sink: &mut Sink) -> Option<Span> {
400 let ident = source.next()?;
401 let span = ident.span();
402 sink.extend([ident]);
403 Some(span)
404}
405
406// Parses attributes
407//
408// #[attr] #[attr] pub field: Type
409// #[attr] #[attr] struct Foo
410// #[attr] #[attr] enum Foo
411//
412// Returns `true` if `#[auto_default(skip)]` was encountered
413fn stream_attrs(
414 source: &mut Source,
415 sink: &mut Sink,
416 errors: &mut TokenStream,
417 is_skip_allowed: IsSkipAllowed,
418) -> IsSkip {
419 let mut is_skip = None;
420
421 let is_skip = loop {
422 if !matches!(source.peek(), Some(TokenTree::Punct(hash)) if *hash == '#') {
423 break is_skip;
424 };
425
426 // #[some_attr]
427 // ^
428 let pound = source.next();
429
430 // #[some_attr]
431 // ^^^^^^^^^^^
432 let Some(TokenTree::Group(attr)) = source.next() else {
433 unreachable!()
434 };
435
436 // #[some_attr = hello]
437 // ^^^^^^^^^^^^^^^^^
438 let mut attr_tokens = attr.stream().into_iter().peekable();
439
440 // Check if this attribute is `#[auto_default(skip)]`
441 if let Some(skip_span) = is_skip_attribute(&mut attr_tokens, errors) {
442 if is_skip.is_some() {
443 // Disallow 2 attributes on a single field:
444 //
445 // #[auto_default(skip)]
446 // #[auto_default(skip)]
447 errors.extend(CompileError::new(
448 skip_span,
449 "duplicate `#[auto_default(skip)]`",
450 ));
451 } else {
452 is_skip = Some(skip_span);
453 }
454 continue;
455 }
456
457 // #[attr]
458 // ^
459 sink.extend(pound);
460
461 // Re-construct the `[..]` for the attribute
462 //
463 // #[attr]
464 // ^^^^^^
465 let mut group = Group::new(attr.delimiter(), attr_tokens.collect());
466 group.set_span(attr.span());
467
468 // #[attr]
469 // ^^^^^^
470 sink.extend([group]);
471 };
472
473 if let Some(skip_span) = is_skip
474 && !is_skip_allowed.0
475 {
476 errors.extend(CompileError::new(
477 skip_span,
478 "`#[auto_default(skip)]` is not allowed on container",
479 ));
480 }
481
482 IsSkip(is_skip.is_some())
483}
484
485/// if `source` is exactly `auto_default(skip)`, returns `Some(span)`
486/// with `span` being the `Span` of the `skip` identifier
487fn is_skip_attribute(source: &mut Source, errors: &mut TokenStream) -> Option<Span> {
488 let Some(TokenTree::Ident(ident)) = source.peek() else {
489 return None;
490 };
491
492 if ident.to_string() != "auto_default" {
493 return None;
494 };
495
496 // #[auto_default(skip)]
497 // ^^^^^^^^^^^^
498 let ident = source.next().unwrap();
499
500 // We know it is `#[auto_default ???]`, we need to validate that `???`
501 // is exactly `(skip)` now
502
503 // #[auto_default(skip)]
504 // ^^^^^^^^^^^^
505 let auto_default_span = ident.span();
506
507 // #[auto_default(skip)]
508 // ^^^^^^
509 let group = match source.next() {
510 Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => group,
511 Some(tt) => {
512 errors.extend(CompileError::new(tt.span(), "expected `(skip)`"));
513 return None;
514 }
515 None => {
516 errors.extend(CompileError::new(
517 auto_default_span,
518 "expected `(skip)` after this",
519 ));
520 return None;
521 }
522 };
523
524 // #[auto_default(skip)]
525 // ^^^^
526 let mut inside = group.stream().into_iter();
527
528 // #[auto_default(skip)]
529 // ^^^^
530 let ident_skip = match inside.next() {
531 Some(TokenTree::Ident(ident)) => ident,
532 Some(tt) => {
533 errors.extend(CompileError::new(tt.span(), "expected `skip`"));
534 return None;
535 }
536 None => {
537 errors.extend(CompileError::new(
538 group.span(),
539 "expected `(skip)`, found `()`",
540 ));
541 return None;
542 }
543 };
544
545 if ident_skip.to_string() != "skip" {
546 errors.extend(CompileError::new(ident_skip.span(), "expected `skip`"));
547 return None;
548 }
549
550 // Validate that there's nothing after `skip`
551 //
552 // #[auto_default(skip )]
553 // ^^^^
554 if let Some(tt) = inside.next() {
555 errors.extend(CompileError::new(tt.span(), "unexpected token"));
556 return None;
557 }
558
559 Some(ident_skip.span())
560}
561
562fn stream_vis(source: &mut Source, sink: &mut Sink) {
563 // Remove visibility if it is present
564 //
565 // pub(in crate) struct
566 // ^^^^^^^^^^^^^
567 if let Some(TokenTree::Ident(vis)) = source.peek()
568 && vis.to_string() == "pub"
569 {
570 // pub(in crate) struct
571 // ^^^
572 sink.extend(source.next());
573
574 if let Some(TokenTree::Group(group)) = source.peek()
575 && let Delimiter::Parenthesis = group.delimiter()
576 {
577 // pub(in crate) struct
578 // ^^^^^^^^^^
579 sink.extend(source.next());
580 }
581 };
582}
583
584#[derive(PartialEq)]
585enum ItemKind {
586 Struct,
587 Enum,
588}
589
590/// `fields` is [`StructFields`] in the grammar.
591///
592/// It is the curly braces, and everything within, for a struct with named fields,
593/// or an enum variant with named fields.
594///
595/// These fields are transformed by adding `= Default::default()` to every
596/// field that doesn't already have a default value.
597///
598/// If a field is marked with `#[auto_default(skip)]`, no default value will be
599/// added
600///
601/// [`StructFields`]: https://doc.rust-lang.org/reference/items/structs.html#grammar-StructFields
602fn add_default_field_values(
603 fields: Group,
604 compile_errors: &mut TokenStream,
605 is_skip_variant: IsSkip,
606) -> Group {
607 // All the tokens corresponding to the struct's field, passed by the user
608 // These tokens will eventually all be sent to `output_fields`,
609 // plus a few extra for any `Default::default()` that we output
610 let mut input_fields = fields.stream().into_iter().peekable();
611
612 // The tokens corresponding to the fields of the output struct
613 let mut output_fields = TokenStream::new();
614
615 // Parses all fields.
616 // Each iteration parses a single field
617 'parse_field: loop {
618 let is_skip_field = stream_attrs(
619 &mut input_fields,
620 &mut output_fields,
621 compile_errors,
622 IsSkipAllowed(true),
623 );
624 let is_skip = is_skip_field.0 || is_skip_variant.0;
625 stream_vis(&mut input_fields, &mut output_fields);
626 let Some(field_ident_span) = stream_ident(&mut input_fields, &mut output_fields) else {
627 // No fields. e.g.: `struct Struct {}`
628 break;
629 };
630
631 // field: Type
632 // ^
633 output_fields.extend(input_fields.next());
634
635 // Everything after the `:` in the field
636 //
637 // Involves:
638 //
639 // - Adding default value of `= Default::default()` if one is not present
640 // - Continue to next iteration of the loop
641 loop {
642 match input_fields.peek() {
643 // This field has a custom default field value
644 //
645 // field: Type = default
646 // ^
647 Some(TokenTree::Punct(p)) if p.as_char() == '=' => loop {
648 match input_fields.next() {
649 Some(TokenTree::Punct(p)) if p == ',' => {
650 output_fields.extend([p]);
651 // Comma after field. Field is finished.
652 continue 'parse_field;
653 }
654 Some(tt) => output_fields.extend([tt]),
655 // End of input. Field is finished. This is the last field
656 None => break 'parse_field,
657 }
658 },
659 // Reached end of field, has comma at the end, no custom default value
660 //
661 // field: Type,
662 // ^
663 Some(TokenTree::Punct(p)) if p.as_char() == ',' => {
664 // Insert default value before the comma
665 //
666 // field: Type = Default::default(),
667 // ^^^^^^^^^^^^^^^^^^^^
668 if !is_skip {
669 output_fields.extend(default(field_ident_span));
670 }
671 // field: Type = Default::default(),
672 // ^
673 output_fields.extend(input_fields.next());
674 // Next iteration handles the next field
675 continue 'parse_field;
676 }
677 // This token is part of the field's type
678 //
679 // field: some::Type
680 // ^^^^
681 Some(_) => output_fields.extend(input_fields.next()),
682 // Reached end of input, and it has no comma.
683 // This is the last field.
684 //
685 // struct Foo {
686 // field: Type
687 // ^
688 // }
689 None => {
690 if !is_skip {
691 output_fields.extend(default(field_ident_span));
692 }
693 // No more fields
694 break 'parse_field;
695 }
696 }
697 }
698 }
699 let mut g = Group::new(Delimiter::Brace, output_fields);
700 g.set_span(fields.span());
701 g
702}
703
704// = ::core::default::Default::default()
705fn default(span: Span) -> [TokenTree; 14] {
706 [
707 TokenTree::Punct(Punct::new('=', Spacing::Alone)),
708 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
709 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
710 TokenTree::Ident(Ident::new("core", span)),
711 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
712 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
713 TokenTree::Ident(Ident::new("default", span)),
714 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
715 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
716 TokenTree::Ident(Ident::new("Default", span)),
717 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
718 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(span),
719 TokenTree::Ident(Ident::new("default", span)),
720 TokenTree::Group(Group::new(Delimiter::Parenthesis, TokenStream::new())).with_span(span),
721 ]
722}
723
724macro_rules! create_compile_error {
725 ($spanned:expr, $($tt:tt)*) => {{
726 let span = if let Some(spanned) = $spanned {
727 spanned.span()
728 } else {
729 Span::call_site()
730 };
731 CompileError::new(span, format!($($tt)*))
732 }};
733}
734use create_compile_error;
735
736/// `.into_iter()` generates `compile_error!($message)` at `$span`
737struct CompileError {
738 /// Where the compile error is generates
739 pub span: Span,
740 /// Message of the compile error
741 pub message: String,
742}
743
744impl CompileError {
745 /// Create a new compile error
746 pub fn new(span: Span, message: impl AsRef<str>) -> Self {
747 Self {
748 span,
749 message: message.as_ref().to_string(),
750 }
751 }
752}
753
754impl IntoIterator for CompileError {
755 type Item = TokenTree;
756 type IntoIter = std::array::IntoIter<Self::Item, 8>;
757
758 fn into_iter(self) -> Self::IntoIter {
759 [
760 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
761 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
762 TokenTree::Ident(Ident::new("core", self.span)),
763 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
764 TokenTree::Punct(Punct::new(':', Spacing::Joint)).with_span(self.span),
765 TokenTree::Ident(Ident::new("compile_error", self.span)),
766 TokenTree::Punct(Punct::new('!', Spacing::Alone)).with_span(self.span),
767 TokenTree::Group(Group::new(Delimiter::Brace, {
768 TokenStream::from(
769 TokenTree::Literal(Literal::string(&self.message)).with_span(self.span),
770 )
771 }))
772 .with_span(self.span),
773 ]
774 .into_iter()
775 }
776}
777
778trait TokenTreeExt {
779 /// Set span of `TokenTree` without needing to create a new binding
780 fn with_span(self, span: Span) -> TokenTree;
781}
782
783impl TokenTreeExt for TokenTree {
784 fn with_span(mut self, span: Span) -> TokenTree {
785 self.set_span(span);
786 self
787 }
788}