derive_new/
lib.rs

1//! # A custom derive implementation for `#[derive(new)]`
2//!
3//! A `derive(new)` attribute creates a `new` constructor function for the annotated
4//! type. That function takes an argument for each field in the type giving a
5//! trivial constructor. This is useful since as your type evolves you can make the
6//! constructor non-trivial (and add or remove fields) without changing client code
7//! (i.e., without breaking backwards compatibility). It is also the most succinct
8//! way to initialise a struct or an enum.
9//!
10//! Implementation uses macros 1.1 custom derive (which works in stable Rust from
11//! 1.15 onwards).
12//!
13//! ## Examples
14//!
15//! Cargo.toml:
16//!
17//! ```toml
18//! [dependencies]
19//! derive-new = "0.5"
20//! ```
21//!
22//! Include the macro:
23//!
24//! ```rust
25//! use derive_new::new;
26//!
27//! fn main() {}
28//! ```
29//!
30//! Generating constructor for a simple struct:
31//!
32//! ```rust
33//! use derive_new::new;
34//!
35//! #[derive(new)]
36//! struct Bar {
37//!     a: i32,
38//!     b: String,
39//! }
40//!
41//! fn main() {
42//!   let _ = Bar::new(42, "Hello".to_owned());
43//! }
44//! ```
45//!
46//! Default values can be specified either via `#[new(default)]` attribute which removes
47//! the argument from the constructor and populates the field with `Default::default()`,
48//! or via `#[new(value = "..")]` which initializes the field with a given expression:
49//!
50//! ```rust
51//! use derive_new::new;
52//!
53//! #[derive(new)]
54//! struct Foo {
55//!     x: bool,
56//!     #[new(value = "42")]
57//!     y: i32,
58//!     #[new(default)]
59//!     z: Vec<String>,
60//! }
61//!
62//! fn main() {
63//!   let _ = Foo::new(true);
64//! }
65//! ```
66//!
67//! To make type conversion easier, `#[new(into)]` attribute changes the parameter type
68//! to `impl Into<T>`, and populates the field with `value.into()`:
69//!
70//! ```rust
71//! # use derive_new::new;
72//! #[derive(new)]
73//! struct Foo {
74//!     #[new(into)]
75//!     x: String,
76//! }
77//!
78//! let _ = Foo::new("Hello");
79//! ```
80//!
81//! For iterators/collections, `#[new(into_iter = "T")]` attribute changes the parameter type
82//! to `impl IntoIterator<Item = T>`, and populates the field with `value.into_iter().collect()`:
83//!
84//! ```rust
85//! # use derive_new::new;
86//! #[derive(new)]
87//! struct Foo {
88//!     #[new(into_iter = "bool")]
89//!     x: Vec<bool>,
90//! }
91//!
92//! let _ = Foo::new([true, false]);
93//! let _ = Foo::new(Some(true));
94//! ```
95//!
96//! Generic types are supported; in particular, `PhantomData<T>` fields will be not
97//! included in the argument list and will be initialized automatically:
98//!
99//! ```rust
100//! use derive_new::new;
101//!
102//! use std::marker::PhantomData;
103//!
104//! #[derive(new)]
105//! struct Generic<'a, T: Default, P> {
106//!     x: &'a str,
107//!     y: PhantomData<P>,
108//!     #[new(default)]
109//!     z: T,
110//! }
111//!
112//! fn main() {
113//!   let _ = Generic::<i32, u8>::new("Hello");
114//! }
115//! ```
116//!
117//! For enums, one constructor method is generated for each variant, with the type
118//! name being converted to snake case; otherwise, all features supported for
119//! structs work for enum variants as well:
120//!
121//! ```rust
122//! use derive_new::new;
123//!
124//! #[derive(new)]
125//! enum Enum {
126//!     FirstVariant,
127//!     SecondVariant(bool, #[new(default)] u8),
128//!     ThirdVariant { x: i32, #[new(value = "vec![1]")] y: Vec<u8> }
129//! }
130//!
131//! fn main() {
132//!   let _ = Enum::new_first_variant();
133//!   let _ = Enum::new_second_variant(true);
134//!   let _ = Enum::new_third_variant(42);
135//! }
136//! ```
137//! ### Setting Visibility for the Constructor
138//!
139//! By default, the generated constructor will be `pub`. However, you can control the visibility of the constructor using the `#[new(visibility = "...")]` attribute.
140//!
141//! #### Public Constructor (default)
142//!
143//! ```rust
144//! use derive_new::new;
145//!
146//! #[derive(new)]
147//! pub struct Bar {
148//!     a: i32,
149//!     b: String,
150//! }
151//!
152//! fn main() {
153//!   let _ = Bar::new(42, "Hello".to_owned());
154//! }
155//! ```
156//!
157//! #### Crate-Visible Constructor
158//!
159//! ```rust
160//! use derive_new::new;
161//!
162//! #[derive(new)]
163//! #[new(visibility = "pub(crate)")]
164//! pub struct Bar {
165//!     a: i32,
166//!     b: String,
167//! }
168//!
169//! fn main() {
170//!   let _ = Bar::new(42, "Hello".to_owned());
171//! }
172//! ```
173//!
174//! #### Private Constructor
175//!
176//! ```rust
177//! use derive_new::new;
178//!
179//! #[derive(new)]
180//! #[new(visibility = "")]
181//! pub struct Bar {
182//!     a: i32,
183//!     b: String,
184//! }
185//!
186//! fn main() {
187//!   // Bar::new is not accessible here as it is private
188//!   let _ = Bar::new(42, "Hello".to_owned()); // This will cause a compile error
189//! }
190//! ```
191#![crate_type = "proc-macro"]
192#![recursion_limit = "192"]
193
194extern crate proc_macro;
195extern crate proc_macro2;
196#[macro_use]
197extern crate quote;
198extern crate syn;
199
200macro_rules! my_quote {
201    ($($t:tt)*) => (quote_spanned!(proc_macro2::Span::call_site() => $($t)*))
202}
203
204fn path_to_string(path: &syn::Path) -> String {
205    path.segments
206        .iter()
207        .map(|s| s.ident.to_string())
208        .collect::<Vec<String>>()
209        .join("::")
210}
211
212use proc_macro::TokenStream;
213use proc_macro2::TokenStream as TokenStream2;
214use syn::{punctuated::Punctuated, Attribute, Lit, Token, Visibility};
215
216#[proc_macro_derive(new, attributes(new))]
217pub fn derive(input: TokenStream) -> TokenStream {
218    let ast: syn::DeriveInput = syn::parse(input).expect("Couldn't parse item");
219    let options = NewOptions::from_attributes(&ast.attrs);
220    let result = match ast.data {
221        syn::Data::Enum(ref e) => new_for_enum(&ast, e, &options),
222        syn::Data::Struct(ref s) => new_for_struct(&ast, &s.fields, None, &options),
223        syn::Data::Union(_) => panic!("doesn't work with unions yet"),
224    };
225    result.into()
226}
227
228fn new_for_struct(
229    ast: &syn::DeriveInput,
230    fields: &syn::Fields,
231    variant: Option<&syn::Ident>,
232    options: &NewOptions,
233) -> proc_macro2::TokenStream {
234    match *fields {
235        syn::Fields::Named(ref fields) => {
236            new_impl(ast, Some(&fields.named), true, variant, options)
237        }
238        syn::Fields::Unit => new_impl(ast, None, false, variant, options),
239        syn::Fields::Unnamed(ref fields) => {
240            new_impl(ast, Some(&fields.unnamed), false, variant, options)
241        }
242    }
243}
244
245fn new_for_enum(
246    ast: &syn::DeriveInput,
247    data: &syn::DataEnum,
248    options: &NewOptions,
249) -> proc_macro2::TokenStream {
250    if data.variants.is_empty() {
251        panic!("#[derive(new)] cannot be implemented for enums with zero variants");
252    }
253    let impls = data.variants.iter().map(|v| {
254        if v.discriminant.is_some() {
255            panic!("#[derive(new)] cannot be implemented for enums with discriminants");
256        }
257        new_for_struct(ast, &v.fields, Some(&v.ident), options)
258    });
259    my_quote!(#(#impls)*)
260}
261
262fn new_impl(
263    ast: &syn::DeriveInput,
264    fields: Option<&Punctuated<syn::Field, Token![,]>>,
265    named: bool,
266    variant: Option<&syn::Ident>,
267    options: &NewOptions,
268) -> proc_macro2::TokenStream {
269    let name = &ast.ident;
270    let unit = fields.is_none();
271    let empty = Default::default();
272    let fields: Vec<_> = fields
273        .unwrap_or(&empty)
274        .iter()
275        .enumerate()
276        .map(|(i, f)| FieldExt::new(f, i, named))
277        .collect();
278    let args = fields.iter().filter_map(|f| f.as_arg());
279    let inits = fields.iter().map(|f| f.as_init());
280    let inits = if unit {
281        my_quote!()
282    } else if named {
283        my_quote![{ #(#inits),* }]
284    } else {
285        my_quote![( #(#inits),* )]
286    };
287    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
288    let (mut new, qual, doc) = match variant {
289        None => (
290            syn::Ident::new("new", proc_macro2::Span::call_site()),
291            my_quote!(),
292            format!("Constructs a new `{}`.", name),
293        ),
294        Some(ref variant) => (
295            syn::Ident::new(
296                &format!("new_{}", to_snake_case(&variant.to_string())),
297                proc_macro2::Span::call_site(),
298            ),
299            my_quote!(::#variant),
300            format!("Constructs a new `{}::{}`.", name, variant),
301        ),
302    };
303    new.set_span(proc_macro2::Span::call_site());
304    let lint_attrs = collect_parent_lint_attrs(&ast.attrs);
305    let lint_attrs = my_quote![#(#lint_attrs),*];
306    let visibility = &options.visibility;
307    my_quote! {
308        impl #impl_generics #name #ty_generics #where_clause {
309            #[doc = #doc]
310            #lint_attrs
311            #visibility fn #new(#(#args),*) -> Self {
312                #name #qual #inits
313            }
314        }
315    }
316}
317
318fn collect_parent_lint_attrs(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
319    fn is_lint(item: &syn::Meta) -> bool {
320        if let syn::Meta::List(ref l) = *item {
321            let path = &l.path;
322            return path.is_ident("allow")
323                || path.is_ident("deny")
324                || path.is_ident("forbid")
325                || path.is_ident("warn");
326        }
327        false
328    }
329
330    fn is_cfg_attr_lint(item: &syn::Meta) -> bool {
331        if let syn::Meta::List(ref l) = *item {
332            if l.path.is_ident("cfg_attr") {
333                if let Ok(nested) =
334                    l.parse_args_with(Punctuated::<syn::Meta, Token![,]>::parse_terminated)
335                {
336                    return nested.len() == 2 && is_lint(&nested[1]);
337                }
338            }
339        }
340        false
341    }
342
343    attrs
344        .iter()
345        .filter(|a| is_lint(&a.meta) || is_cfg_attr_lint(&a.meta))
346        .cloned()
347        .collect()
348}
349
350struct NewOptions {
351    visibility: Option<syn::Visibility>,
352}
353
354impl NewOptions {
355    fn from_attributes(attrs: &[Attribute]) -> Self {
356        // Default visibility is public
357        let mut visibility = Some(Visibility::Public(syn::token::Pub {
358            span: proc_macro2::Span::call_site(),
359        }));
360
361        for attr in attrs {
362            if attr.path().is_ident("new") {
363                attr.parse_nested_meta(|meta| {
364                    if meta.path.is_ident("visibility") {
365                        let value: Lit = meta.value()?.parse()?;
366                        if let Lit::Str(lit_str) = value {
367                            // Parse the visibility string into a syn::Visibility type
368                            let parsed_visibility: Visibility =
369                                lit_str.parse().expect("Invalid visibility");
370                            visibility = Some(parsed_visibility);
371                        }
372                        Ok(())
373                    } else {
374                        Err(meta.error("unsupported attribute"))
375                    }
376                })
377                .unwrap_or(());
378            }
379        }
380
381        NewOptions { visibility }
382    }
383}
384
385enum FieldAttr {
386    Default,
387    Into,
388    IntoIter(proc_macro2::TokenStream),
389    Value(proc_macro2::TokenStream),
390}
391
392impl FieldAttr {
393    pub fn as_tokens(&self, name: &syn::Ident) -> proc_macro2::TokenStream {
394        match *self {
395            FieldAttr::Default => my_quote!(::core::default::Default::default()),
396            FieldAttr::Into => my_quote!(::core::convert::Into::into(#name)),
397            FieldAttr::IntoIter(_) => {
398                my_quote!(::core::iter::Iterator::collect(::core::iter::IntoIterator::into_iter(#name)))
399            }
400            FieldAttr::Value(ref s) => my_quote!(#s),
401        }
402    }
403
404    pub fn parse(attrs: &[syn::Attribute]) -> Option<FieldAttr> {
405        let mut result = None;
406        for attr in attrs.iter() {
407            match attr.style {
408                syn::AttrStyle::Outer => {}
409                _ => continue,
410            }
411            let last_attr_path = attr
412                .path()
413                .segments
414                .last()
415                .expect("Expected at least one segment where #[segment[::segment*](..)]");
416            if last_attr_path.ident != "new" {
417                continue;
418            }
419            let list = match attr.meta {
420                syn::Meta::List(ref l) => l,
421                _ if attr.path().is_ident("new") => {
422                    panic!("Invalid #[new] attribute, expected #[new(..)]")
423                }
424                _ => continue,
425            };
426            if result.is_some() {
427                panic!("Expected at most one #[new] attribute");
428            }
429            for item in list
430                .parse_args_with(Punctuated::<syn::Meta, Token![,]>::parse_terminated)
431                .unwrap_or_else(|err| panic!("Invalid #[new] attribute: {}", err))
432            {
433                match item {
434                    syn::Meta::Path(path) => match path.get_ident() {
435                        Some(ident) if ident == "default" => {
436                            result = Some(FieldAttr::Default);
437                        }
438                        Some(ident) if ident == "into" => {
439                            result = Some(FieldAttr::Into);
440                        }
441                        _ => panic!(
442                            "Invalid #[new] attribute: #[new({})]",
443                            path_to_string(&path)
444                        ),
445                    },
446                    syn::Meta::NameValue(kv) => {
447                        if let syn::Expr::Lit(syn::ExprLit {
448                            lit: syn::Lit::Str(ref s),
449                            ..
450                        }) = kv.value
451                        {
452                            let tokens = lit_str_to_token_stream(s)
453                                .ok()
454                                .expect(&format!("Invalid expression in #[new]: `{}`", s.value()));
455
456                            match kv.path.get_ident() {
457                                Some(ident) if ident == "into_iter" => {
458                                    result = Some(FieldAttr::IntoIter(tokens));
459                                }
460                                Some(ident) if ident == "value" => {
461                                    result = Some(FieldAttr::Value(tokens));
462                                }
463                                _ => panic!(
464                                    "Invalid #[new] attribute: #[new({} = ..)]",
465                                    path_to_string(&kv.path)
466                                ),
467                            }
468                        } else {
469                            panic!("Non-string literal value in #[new] attribute");
470                        }
471                    }
472                    syn::Meta::List(l) => {
473                        panic!(
474                            "Invalid #[new] attribute: #[new({}(..))]",
475                            path_to_string(&l.path)
476                        );
477                    }
478                }
479            }
480        }
481        result
482    }
483}
484
485struct FieldExt<'a> {
486    ty: &'a syn::Type,
487    attr: Option<FieldAttr>,
488    ident: syn::Ident,
489    named: bool,
490}
491
492impl<'a> FieldExt<'a> {
493    pub fn new(field: &'a syn::Field, idx: usize, named: bool) -> FieldExt<'a> {
494        FieldExt {
495            ty: &field.ty,
496            attr: FieldAttr::parse(&field.attrs),
497            ident: if named {
498                field.ident.clone().unwrap()
499            } else {
500                syn::Ident::new(&format!("f{}", idx), proc_macro2::Span::call_site())
501            },
502            named,
503        }
504    }
505
506    pub fn is_phantom_data(&self) -> bool {
507        match *self.ty {
508            syn::Type::Path(syn::TypePath {
509                qself: None,
510                ref path,
511            }) => path
512                .segments
513                .last()
514                .map(|x| x.ident == "PhantomData")
515                .unwrap_or(false),
516            _ => false,
517        }
518    }
519
520    pub fn as_arg(&self) -> Option<proc_macro2::TokenStream> {
521        if self.is_phantom_data() {
522            return None;
523        }
524
525        let ident = &self.ident;
526        let ty = &self.ty;
527
528        match self.attr {
529            Some(FieldAttr::Default) => None,
530            Some(FieldAttr::Into) => Some(my_quote!(#ident: impl ::core::convert::Into<#ty>)),
531            Some(FieldAttr::IntoIter(ref s)) => {
532                Some(my_quote!(#ident: impl ::core::iter::IntoIterator<Item = #s>))
533            }
534            Some(FieldAttr::Value(_)) => None,
535            None => Some(my_quote!(#ident: #ty)),
536        }
537    }
538
539    pub fn as_init(&self) -> proc_macro2::TokenStream {
540        let f_name = &self.ident;
541        let init = if self.is_phantom_data() {
542            my_quote!(::core::marker::PhantomData)
543        } else {
544            match self.attr {
545                None => my_quote!(#f_name),
546                Some(ref attr) => attr.as_tokens(f_name),
547            }
548        };
549        if self.named {
550            my_quote!(#f_name: #init)
551        } else {
552            my_quote!(#init)
553        }
554    }
555}
556
557fn lit_str_to_token_stream(s: &syn::LitStr) -> Result<TokenStream2, proc_macro2::LexError> {
558    let code = s.value();
559    let ts: TokenStream2 = code.parse()?;
560    Ok(set_ts_span_recursive(ts, &s.span()))
561}
562
563fn set_ts_span_recursive(ts: TokenStream2, span: &proc_macro2::Span) -> TokenStream2 {
564    ts.into_iter()
565        .map(|mut tt| {
566            tt.set_span(*span);
567            if let proc_macro2::TokenTree::Group(group) = &mut tt {
568                let stream = set_ts_span_recursive(group.stream(), span);
569                *group = proc_macro2::Group::new(group.delimiter(), stream);
570            }
571            tt
572        })
573        .collect()
574}
575
576fn to_snake_case(s: &str) -> String {
577    let (ch, next, mut acc): (Option<char>, Option<char>, String) =
578        s.chars()
579            .fold((None, None, String::new()), |(prev, ch, mut acc), next| {
580                if let Some(ch) = ch {
581                    if let Some(prev) = prev {
582                        if ch.is_uppercase()
583                            && (prev.is_lowercase()
584                                || prev.is_numeric()
585                                || (prev.is_uppercase() && next.is_lowercase()))
586                        {
587                            acc.push('_');
588                        }
589                    }
590                    acc.extend(ch.to_lowercase());
591                }
592                (ch, Some(next), acc)
593            });
594    if let Some(next) = next {
595        if let Some(ch) = ch {
596            if (ch.is_lowercase() || ch.is_numeric()) && next.is_uppercase() {
597                acc.push('_');
598            }
599        }
600        acc.extend(next.to_lowercase());
601    }
602    acc
603}
604
605#[test]
606fn test_to_snake_case() {
607    assert_eq!(to_snake_case(""), "");
608    assert_eq!(to_snake_case("a"), "a");
609    assert_eq!(to_snake_case("B"), "b");
610    assert_eq!(to_snake_case("BC"), "bc");
611    assert_eq!(to_snake_case("Bc"), "bc");
612    assert_eq!(to_snake_case("bC"), "b_c");
613    assert_eq!(to_snake_case("Fred"), "fred");
614    assert_eq!(to_snake_case("CARGO"), "cargo");
615    assert_eq!(to_snake_case("_Hello"), "_hello");
616    assert_eq!(to_snake_case("QuxBaz"), "qux_baz");
617    assert_eq!(to_snake_case("FreeBSD"), "free_bsd");
618    assert_eq!(to_snake_case("specialK"), "special_k");
619    assert_eq!(to_snake_case("hello1World"), "hello1_world");
620    assert_eq!(to_snake_case("Keep_underscore"), "keep_underscore");
621    assert_eq!(to_snake_case("ThisISNotADrill"), "this_is_not_a_drill");
622}