derive_extras/
lib.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5use proc_macro::TokenStream;
6use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
7use quote::{quote, ToTokens};
8use syn::{
9	parse_macro_input,
10	parse_quote,
11	spanned::Spanned,
12	Attribute,
13	Data,
14	DataEnum,
15	DataStruct,
16	DeriveInput,
17	Error,
18	Fields,
19	GenericParam,
20	Index,
21	Meta,
22	Type,
23	Variant,
24};
25
26use crate::numbers_to_words::encode_ordinal;
27
28mod numbers_to_words;
29
30/// Derives [`Default`] with more flexibility than Rust's built-in `#[derive(Default)]`.
31///
32/// In addition to the functionality allowed by the built-in `#[derive(Default)]`, this version
33/// allows:
34/// - overriding the default values for particular fields using `#[default = ...]` or
35///   `#[default(...)]`; and
36/// - using `#[default]` on a non-unit enum variant
37///
38/// # Examples
39///
40/// ## Structs
41/// ```
42/// # #[allow(unused_qualifications)]
43/// #[derive(derive_extras::Default)]
44/// struct ExampleStruct {
45///     x: i32,
46///     #[default = 15]
47///     y: i32,
48///     z: i32,
49/// }
50/// ```
51/// For this struct, the following [`Default`] implementation is derived:
52/// ```
53/// # struct ExampleStruct {
54/// #     x: i32,
55/// #     y: i32,
56/// #     z: i32,
57/// # }
58/// #
59/// impl Default for ExampleStruct {
60///     fn default() -> Self {
61///         Self {
62///             x: Default::default(),
63///             y: 15,
64///             z: Default::default(),
65///         }
66///     }
67/// }
68/// ```
69///
70/// ## Enums
71/// ```
72/// # #[allow(unused_qualifications)]
73/// #[derive(derive_extras::Default)]
74/// enum ExampleEnum {
75///     Unit,
76///     Tuple(i32, i32, i32),
77///
78///     #[default]
79///     Struct {
80///         x: i32,
81///         #[default = 15]
82///         y: i32,
83///         z: i32,
84///     },
85/// }
86/// ```
87/// For this enum, the following [`Default`] implementation is derived:
88/// ```
89/// # enum ExampleEnum {
90/// #     Unit,
91/// #     Tuple(i32, i32, i32),
92/// #     Struct {
93/// #         x: i32,
94/// #         y: i32,
95/// #         z: i32,
96/// #     },
97/// # }
98/// #
99/// impl Default for ExampleEnum {
100///     fn default() -> Self {
101///         Self::Struct {
102///             x: Default::default(),
103///             y: 15,
104///             z: Default::default(),
105///         }
106///     }
107/// }
108/// ```
109///
110/// [`Default`]: core::default::Default
111#[proc_macro_derive(Default, attributes(default))]
112pub fn default(input: TokenStream) -> TokenStream {
113	let input = parse_macro_input!(input as DeriveInput);
114
115	let DeriveInput {
116		mut generics,
117		ident,
118		data,
119		..
120	} = input;
121
122	for param in &mut generics.params {
123		if let GenericParam::Type(param) = param {
124			param.bounds.push(parse_quote!(::core::default::Default))
125		}
126	}
127
128	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
129
130	let body = match data {
131		Data::Struct(DataStruct { fields, .. }) => {
132			let init = expand_default_for_fields(fields);
133
134			quote!(Self #init)
135		},
136
137		Data::Enum(r#enum) => expand_default_for_enum(r#enum).unwrap_or_else(|error| error.into_compile_error()),
138
139		Data::Union(_) => unimplemented!("unions are not supported"),
140	};
141
142	let tokens = quote! {
143		impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
144			fn default() -> Self {
145				#body
146			}
147		}
148	};
149
150	tokens.into()
151}
152
153fn expand_default_for_enum(r#enum: DataEnum) -> syn::Result<TokenStream2> {
154	let variant = {
155		let mut variant = None;
156
157		for var in r#enum.variants {
158			for attr in &var.attrs {
159				if attr.meta.require_path_only().map(|path| path.is_ident("default"))? {
160					match &variant {
161						None => variant = Some(var),
162						Some(_) => return Err(Error::new(attr.span(), "conflicting #[default] attribute")),
163					}
164
165					break;
166				}
167			}
168		}
169
170		variant
171	};
172
173	match variant {
174		Some(Variant { ident, fields, .. }) => {
175			let init = expand_default_for_fields(fields);
176
177			Ok(quote!(Self::#ident #init))
178		},
179
180		None => Err(Error::new(
181			Span::call_site(),
182			"one enum variant must have a #[default] attribute",
183		)),
184	}
185}
186
187fn expand_default_for_fields(fields: Fields) -> Option<TokenStream2> {
188	match fields {
189		Fields::Unit => None,
190
191		Fields::Named(fields) => {
192			let field_init = fields.named.into_pairs().map(|pair| {
193				let (field, comma) = pair.into_tuple();
194
195				let (ident, ty, attrs) = (field.ident, field.ty, field.attrs);
196
197				let value = field_value(ty, attrs);
198
199				quote!(#ident: #value #comma)
200			});
201
202			Some(quote!({ #(#field_init)* }))
203		},
204
205		Fields::Unnamed(fields) => {
206			let field_init = fields.unnamed.into_pairs().map(|pair| {
207				let (field, comma) = pair.into_tuple();
208
209				let value = field_value(field.ty, field.attrs);
210
211				quote!(#value #comma)
212			});
213
214			Some(quote!((#(#field_init)*)))
215		},
216	}
217}
218
219fn field_value(ty: Type, attrs: Vec<Attribute>) -> TokenStream2 {
220	let default_attr = attrs
221		.into_iter()
222		.find(|attribute| attribute.meta.path().is_ident("default"));
223
224	match default_attr {
225		// If there is a default attribute, use its value.
226		Some(default_attr) => match default_attr.meta {
227			Meta::Path(path) => Error::new(
228				path.span(),
229				format!(
230					"expected a value for this attribute: `{}(...)` or `{} = ...`",
231					"default", "default",
232				),
233			)
234			.into_compile_error(),
235
236			Meta::List(meta) => {
237				let tokens = meta.tokens;
238
239				quote!({ #tokens })
240			},
241
242			Meta::NameValue(meta) => meta.value.into_token_stream(),
243		},
244
245		// If there is no default attribute, use `Default::default()`.
246		None => {
247			quote!(<#ty as ::core::default::Default>::default())
248		},
249	}
250}
251
252/// Generates appropriate builder methods for a struct.
253///
254/// This assumes that the struct itself acts like a builder.
255///
256/// This derive macro also adds a `#[new]` helper attribute. If this is added to the struct, a `new`
257/// function is also generated with a `where Self: Default` bound.
258///
259/// The builder methods generated for each field will have that field's visibility: private fields
260/// will have private methods, etc.
261///
262/// # Examples
263/// ```
264/// # use derive_extras::builder;
265/// #
266/// #[derive(Default, builder)]
267/// #[new]
268/// struct Example {
269///     pub x: i32,
270///     pub y: i32,
271/// }
272/// ```
273/// This will derive the following implementations:
274/// ```
275/// # #[derive(Default)]
276/// # struct Example {
277/// #     pub x: i32,
278/// #     pub y: i32,
279/// # }
280/// #
281/// impl Example {
282///     /// Creates a new `Example`.
283///     ///
284///     /// This is equivalent to <code>Example::[default()]</code>.
285///     ///
286///     /// [default()]: Default::default()
287///     pub fn new() -> Self
288///     where
289///         Self: Default,
290///     {
291///         Self::default()
292///     }
293///
294///     /// Sets `x` to the given value.
295///     pub fn x(mut self, x: i32) -> Self {
296///         self.x = x;
297///
298///         self
299///     }
300///
301///     /// Sets `y` to the given value.
302///     pub fn y(mut self, y: i32) -> Self {
303///         self.y = y;
304///
305///         self
306///     }
307/// }
308/// ```
309///
310///
311/// `#[derive(builder)]` also works on tuple structs (with any number of fields):
312/// ```
313/// # use derive_extras::builder;
314/// #
315/// #[derive(Default, builder)]
316/// struct Example(pub i32, pub i32);
317/// ```
318/// This will derive the following implementations:
319/// ```
320/// # #[derive(Default)]
321/// # struct Example(pub i32, pub i32);
322/// #
323/// impl Example {
324///     /// Sets the first field to the given value.
325///     pub fn first(mut self, first: i32) -> Self {
326///         self.0 = first;
327///
328///         self
329///     }
330///
331///     /// Sets the second field to the given value.
332///     pub fn second(mut self, second: i32) -> Self {
333///         self.1 = second;
334///
335///         self
336///     }
337/// }
338/// ```
339#[proc_macro_derive(builder, attributes(new))]
340pub fn builder(input: TokenStream) -> TokenStream {
341	let input = parse_macro_input!(input as DeriveInput);
342
343	let tokens = match input.data {
344		Data::Struct(r#struct) => {
345			let name = input.ident;
346
347			let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
348
349			let new = input.attrs.iter().any(|attr| attr.path().is_ident("new")).then(|| {
350				quote! {
351					#[doc = concat!("Creates a new `", stringify!(#name), "`.")]
352					#[doc = ""]
353					#[doc = concat!("This is equivalent to <code>", stringify!(#name), "::[default()]</code>.")]
354					#[doc = ""]
355					#[doc = "[default()]: ::core::default::Default::default()"]
356					pub fn new() -> Self
357					where
358						Self: ::core::default::Default,
359					{
360						<Self as ::core::default::Default>::default()
361					}
362				}
363			});
364
365			let config_methods = match &r#struct.fields {
366				Fields::Unit => None,
367
368				// Named fields.
369				Fields::Named(fields) => {
370					let methods: TokenStream2 = fields
371						.named
372						.iter()
373						.map(|field| {
374							let vis = &field.vis;
375							let ty = &field.ty;
376
377							let ident = &field.ident;
378
379							let docs = ident
380								.as_ref()
381								.map(|ident| format!("Sets `{ident}` to the given value."));
382
383							quote! {
384								#[doc = #docs]
385								#vis fn #ident(mut self, #ident: #ty) -> Self {
386									self.#ident = #ident;
387
388									self
389								}
390							}
391						})
392						.collect();
393
394					Some(methods)
395				},
396
397				// Unnamed fields.
398				Fields::Unnamed(fields) => {
399					let methods = fields
400						.unnamed
401						.iter()
402						.enumerate()
403						.map(|(i, field)| {
404							let vis = &field.vis;
405							let ty = &field.ty;
406
407							let index = Index::from(i);
408							let ident = Ident::new(&encode_ordinal(i + 1, '_'), Span::call_site());
409
410							let ordinal = encode_ordinal(i + 1, ' ');
411							let docs = format!("Sets the {ordinal} field to the given value.");
412
413							quote! {
414								#[doc = #docs]
415								#vis fn #ident(mut self, #ident: #ty) -> Self {
416									self.#index = #ident;
417
418									self
419								}
420							}
421						})
422						.collect();
423
424					Some(methods)
425				},
426			};
427
428			// Final generated implementation.
429			quote! {
430				impl #impl_generics #name #ty_generics #where_clause {
431					#new
432
433					#config_methods
434				}
435			}
436		},
437
438		Data::Enum(_enum) => unimplemented!("enums are not supported"),
439		Data::Union(_union) => unimplemented!("unions are not supported"),
440	};
441
442	tokens.into()
443}