constdefault_derive/
lib.rs

1#![doc(html_root_url = "http://docs.rs/constdefault-derive/0.2.0")]
2
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use proc_macro2::{Ident, Literal, Span, TokenStream as TokenStream2};
7use quote::{quote, quote_spanned};
8use syn::{spanned::Spanned, Error};
9
10/// Derives an implementation for the [`ConstDefault`] trait.
11///
12/// # Note
13///
14/// Currently only works with `struct` inputs.
15///
16/// # Example
17///
18/// ## Struct
19///
20/// ```
21/// # use constdefault::ConstDefault;
22/// #[derive(ConstDefault)]
23/// # #[derive(Debug, PartialEq)]
24/// pub struct Color {
25///     r: u8,
26///     g: u8,
27///     b: u8,
28/// }
29///
30/// assert_eq!(
31///     <Color as ConstDefault>::DEFAULT,
32///     Color { r: 0, g: 0, b: 0 },
33/// )
34/// ```
35///
36/// ## Tuple Struct
37///
38/// ```
39/// # use constdefault::ConstDefault;
40/// #[derive(ConstDefault)]
41/// # #[derive(Debug, PartialEq)]
42/// pub struct Vec3(f32, f32, f32);
43///
44/// assert_eq!(
45///     <Vec3 as ConstDefault>::DEFAULT,
46///     Vec3(0.0, 0.0, 0.0),
47/// )
48/// ```
49#[proc_macro_derive(ConstDefault, attributes(constdefault))]
50pub fn derive(input: TokenStream) -> TokenStream {
51    match derive_default(input.into()) {
52        Ok(output) => output.into(),
53        Err(error) => error.to_compile_error().into(),
54    }
55}
56
57/// Implements the derive of `#[derive(ConstDefault)]` for struct types.
58fn derive_default(input: TokenStream2) -> Result<TokenStream2, syn::Error> {
59    let input = syn::parse2::<syn::DeriveInput>(input)?;
60    let ident = input.ident;
61    let data_struct = match input.data {
62        syn::Data::Struct(data_struct) => data_struct,
63        _ => {
64            return Err(Error::new(
65                Span::call_site(),
66                "ConstDefault derive only works on struct types",
67            ))
68        }
69    };
70    let default_impl = generate_default_impl_struct(&data_struct)?;
71    let mut generics = input.generics;
72    generate_default_impl_where_bounds(&data_struct, &mut generics)?;
73    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
74    Ok(quote! {
75        impl #impl_generics ConstDefault for #ident #ty_generics #where_clause {
76            const DEFAULT: Self = #default_impl;
77        }
78    })
79}
80
81/// Generates the `ConstDefault` implementation for `struct` input types.
82///
83/// # Note
84///
85/// The generated code abuses the fact that in Rust struct types can always
86/// be represented with braces and either using identifiers for fields or
87/// raw number literals in case of tuple-structs.
88///
89/// For example `struct Foo(u32)` can be represented as `Foo { 0: 42 }`.
90fn generate_default_impl_struct(
91    data_struct: &syn::DataStruct,
92) -> Result<TokenStream2, syn::Error> {
93    let fields_impl =
94        data_struct.fields.iter().enumerate().map(|(n, field)| {
95            let field_span = field.span();
96            let field_type = &field.ty;
97            let field_pos = Literal::usize_unsuffixed(n);
98            let field_ident = field
99                .ident
100                .as_ref()
101                .map(|ident| quote_spanned!(field_span=> #ident))
102                .unwrap_or_else(|| quote_spanned!(field_span=> #field_pos));
103            quote_spanned!(field_span=>
104                #field_ident: <#field_type as ConstDefault>::DEFAULT
105            )
106        });
107    Ok(quote! {
108        Self {
109            #( #fields_impl ),*
110        }
111    })
112}
113
114/// Generates `ConstDefault` where bounds for all fields of the input.
115fn generate_default_impl_where_bounds(
116    data_struct: &syn::DataStruct,
117    generics: &mut syn::Generics,
118) -> Result<(), syn::Error> {
119    let where_clause = generics.make_where_clause();
120    for field in &data_struct.fields {
121        let field_type = &field.ty;
122        where_clause.predicates.push(syn::parse_quote!(
123            #field_type: ConstDefault
124        ))
125    }
126    Ok(())
127}