clap_serde_proc/
lib.rs

1// Copyright (C) 2022 Davide Peressoni
2//
3// This program is free software: you can redistribute it and/or modify
4// it under the terms of the GNU Affero General Public License as published by
5// the Free Software Foundation, either version 3 of the License, or
6// (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11// GNU Affero General Public License for more details.
12//
13// You should have received a copy of the GNU Affero General Public License
14// along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16use proc_macro::TokenStream;
17use quote::{format_ident, quote};
18use syn::parse_quote;
19
20/// Enables clap and serde derive on the following struct.
21///
22/// It will automatically implement also the `ClapSerde` and [`Default`] traits.  
23/// Use `#[default(value)]` attribute to change the default value of a field.  
24/// Use `#[clap_serde]` attribute on the fields which type is a struct generated by this macro
25/// (recursive behaviour).
26///
27/// This derive macro generates a struct (`ClapSerde::Opt`) with the same fields of the original,
28/// but which types are wrapped in [`Option`].  
29/// Such structure can be parsed from command line with clap and from any serde Deserializer.
30/// It can also be merged to the original struct: only the not-`None` fields will be used to update
31/// it. See `ClapSerde::update` and `ClapSerde::merge`.
32#[proc_macro_derive(
33    ClapSerde,
34    attributes(clap, structopt, command, arg, group, serde, default, clap_serde)
35)]
36pub fn clap_serde(item: TokenStream) -> TokenStream {
37    // AST of the struct on which this derive proc macro is applied.
38    // It will be modified to generate the corresponding Opt.
39    let mut ast: syn::DeriveInput = syn::parse(item).unwrap();
40
41    let mut no_fields = syn::punctuated::Punctuated::new();
42
43    // Get a mutable reference to the fields.
44    let fields = match &mut ast.data {
45        syn::Data::Struct(data) => match &mut data.fields {
46            syn::Fields::Named(fields) => &mut fields.named,
47            syn::Fields::Unit => &mut no_fields,
48            _ => panic!("clap supports only non-tuple structs"),
49        },
50        syn::Data::Enum(_) => panic!("ClapSerde currently supports only structs"),
51        _ => panic!("clap supports only structs and enums"),
52    };
53
54    // Get the name of the struct and generate name for the corresponding Opt
55    let name = ast.ident;
56    ast.ident = format_ident!("ClapSerdeOptional{}", name);
57    let opt_name = &ast.ident;
58    let default_doc = format!("Create default {}", name);
59
60    // Get generics for trait implementation
61    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
62
63    let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone().unwrap()).collect();
64
65    // Fields which implement ClapSerde
66    let mut recursive_fields = Vec::new();
67    // Fields which do not implement ClapSerde
68    let mut not_recursive_fields = Vec::new();
69
70    let default_values: Vec<_> = fields
71        .iter_mut()
72        .map(|f| {
73            let ty = f.ty.clone();
74
75            // Default value for this field
76            // If no default value was provided use the default for the type
77            let mut def_val = parse_quote!(<#ty as core::default::Default>::default());
78            // Is this field recursive? (Does it implement ClapSerde?)
79            let mut not_recursive = true;
80
81            f.attrs.retain(|attr| {
82                if attr.path.is_ident(&format_ident!("default")) {
83                    // Found default value
84                    def_val = attr.tokens.clone();
85                    false
86                } else if attr.path.is_ident(&format_ident!("clap_serde")) {
87                    // Found declaration of recursive field
88                    not_recursive = false;
89                    recursive_fields.push(f.ident.clone());
90                    f.ty = parse_quote!(<#ty as clap_serde_derive::ClapSerde>::Opt);
91                    false
92                } else {
93                    // Other attributes must be retained
94                    true
95                }
96            });
97
98            if not_recursive {
99                not_recursive_fields.push(f.ident.clone());
100            }
101
102            // Wrap field type in option
103            let ty = &f.ty; // Use possible updated type
104            f.ty = parse_quote!(Option<#ty>);
105
106            def_val
107        })
108        .collect();
109
110    quote! {
111        // the struct with options
112        #[doc(hidden)]
113        #[derive(clap::Parser, serde::Deserialize, core::default::Default)]
114        #ast
115
116        // implement ClapSerde
117        impl #impl_generics clap_serde_derive::ClapSerde for #name #ty_generics
118            #where_clause
119        {
120            type Opt = #opt_name;
121
122            fn update(&mut self, mut other: impl core::borrow::BorrowMut<Self::Opt>) {
123                let other = other.borrow_mut();
124                #(
125                    if let core::option::Option::Some(v) = other.#not_recursive_fields.take() {
126                        self.#not_recursive_fields = v;
127                    }
128                )*
129                #(
130                    if let core::option::Option::Some(mut v) = other.#recursive_fields.take() {
131                        self.#recursive_fields.update(&mut v);
132                    }
133                )*
134            }
135        }
136
137        // implement Default
138        impl #impl_generics core::default::Default for #name #ty_generics #where_clause {
139            #[doc = #default_doc]
140            fn default() -> Self {
141                Self {
142                    #(
143                        #field_names: #default_values,
144                    )*
145                }
146            }
147        }
148
149        // implement From
150        impl #impl_generics core::convert::From<<Self as clap_serde_derive::ClapSerde>::Opt>
151            for #name #ty_generics #where_clause
152        {
153            /// Create new object from Opt.
154            fn from(data: <Self as clap_serde_derive::ClapSerde>::Opt) -> Self {
155                Self::default().merge(data)
156            }
157        }
158        impl #impl_generics core::convert::From<&mut <Self as clap_serde_derive::ClapSerde>::Opt>
159            for #name #ty_generics #where_clause
160        {
161            /// Create new object from &mut Opt.
162            fn from(data: &mut <Self as clap_serde_derive::ClapSerde>::Opt) -> Self {
163                Self::default().merge(data)
164            }
165        }
166    }
167    .into()
168}