clap_serde_proc/lib.rs
1// Copyright (C) 2022 Davide Peressoni
2//
3// This program is free software: you can redistribute it and/or modify
4// it under the terms of the GNU Affero General Public License as published by
5// the Free Software Foundation, either version 3 of the License, or
6// (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU Affero General Public License for more details.
12//
13// You should have received a copy of the GNU Affero General Public License
14// along with this program. If not, see <http://www.gnu.org/licenses/>.
15
16use proc_macro::TokenStream;
17use quote::{format_ident, quote};
18use syn::parse_quote;
19
20/// Enables clap and serde derive on the following struct.
21///
22/// It will automatically implement also the `ClapSerde` and [`Default`] traits.
23/// Use `#[default(value)]` attribute to change the default value of a field.
24/// Use `#[clap_serde]` attribute on the fields which type is a struct generated by this macro
25/// (recursive behaviour).
26///
27/// This derive macro generates a struct (`ClapSerde::Opt`) with the same fields of the original,
28/// but which types are wrapped in [`Option`].
29/// Such structure can be parsed from command line with clap and from any serde Deserializer.
30/// It can also be merged to the original struct: only the not-`None` fields will be used to update
31/// it. See `ClapSerde::update` and `ClapSerde::merge`.
32#[proc_macro_derive(
33 ClapSerde,
34 attributes(clap, structopt, command, arg, group, serde, default, clap_serde)
35)]
36pub fn clap_serde(item: TokenStream) -> TokenStream {
37 // AST of the struct on which this derive proc macro is applied.
38 // It will be modified to generate the corresponding Opt.
39 let mut ast: syn::DeriveInput = syn::parse(item).unwrap();
40
41 let mut no_fields = syn::punctuated::Punctuated::new();
42
43 // Get a mutable reference to the fields.
44 let fields = match &mut ast.data {
45 syn::Data::Struct(data) => match &mut data.fields {
46 syn::Fields::Named(fields) => &mut fields.named,
47 syn::Fields::Unit => &mut no_fields,
48 _ => panic!("clap supports only non-tuple structs"),
49 },
50 syn::Data::Enum(_) => panic!("ClapSerde currently supports only structs"),
51 _ => panic!("clap supports only structs and enums"),
52 };
53
54 // Get the name of the struct and generate name for the corresponding Opt
55 let name = ast.ident;
56 ast.ident = format_ident!("ClapSerdeOptional{}", name);
57 let opt_name = &ast.ident;
58 let default_doc = format!("Create default {}", name);
59
60 // Get generics for trait implementation
61 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
62
63 let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone().unwrap()).collect();
64
65 // Fields which implement ClapSerde
66 let mut recursive_fields = Vec::new();
67 // Fields which do not implement ClapSerde
68 let mut not_recursive_fields = Vec::new();
69
70 let default_values: Vec<_> = fields
71 .iter_mut()
72 .map(|f| {
73 let ty = f.ty.clone();
74
75 // Default value for this field
76 // If no default value was provided use the default for the type
77 let mut def_val = parse_quote!(<#ty as core::default::Default>::default());
78 // Is this field recursive? (Does it implement ClapSerde?)
79 let mut not_recursive = true;
80
81 f.attrs.retain(|attr| {
82 if attr.path.is_ident(&format_ident!("default")) {
83 // Found default value
84 def_val = attr.tokens.clone();
85 false
86 } else if attr.path.is_ident(&format_ident!("clap_serde")) {
87 // Found declaration of recursive field
88 not_recursive = false;
89 recursive_fields.push(f.ident.clone());
90 f.ty = parse_quote!(<#ty as clap_serde_derive::ClapSerde>::Opt);
91 false
92 } else {
93 // Other attributes must be retained
94 true
95 }
96 });
97
98 if not_recursive {
99 not_recursive_fields.push(f.ident.clone());
100 }
101
102 // Wrap field type in option
103 let ty = &f.ty; // Use possible updated type
104 f.ty = parse_quote!(Option<#ty>);
105
106 def_val
107 })
108 .collect();
109
110 quote! {
111 // the struct with options
112 #[doc(hidden)]
113 #[derive(clap::Parser, serde::Deserialize, core::default::Default)]
114 #ast
115
116 // implement ClapSerde
117 impl #impl_generics clap_serde_derive::ClapSerde for #name #ty_generics
118 #where_clause
119 {
120 type Opt = #opt_name;
121
122 fn update(&mut self, mut other: impl core::borrow::BorrowMut<Self::Opt>) {
123 let other = other.borrow_mut();
124 #(
125 if let core::option::Option::Some(v) = other.#not_recursive_fields.take() {
126 self.#not_recursive_fields = v;
127 }
128 )*
129 #(
130 if let core::option::Option::Some(mut v) = other.#recursive_fields.take() {
131 self.#recursive_fields.update(&mut v);
132 }
133 )*
134 }
135 }
136
137 // implement Default
138 impl #impl_generics core::default::Default for #name #ty_generics #where_clause {
139 #[doc = #default_doc]
140 fn default() -> Self {
141 Self {
142 #(
143 #field_names: #default_values,
144 )*
145 }
146 }
147 }
148
149 // implement From
150 impl #impl_generics core::convert::From<<Self as clap_serde_derive::ClapSerde>::Opt>
151 for #name #ty_generics #where_clause
152 {
153 /// Create new object from Opt.
154 fn from(data: <Self as clap_serde_derive::ClapSerde>::Opt) -> Self {
155 Self::default().merge(data)
156 }
157 }
158 impl #impl_generics core::convert::From<&mut <Self as clap_serde_derive::ClapSerde>::Opt>
159 for #name #ty_generics #where_clause
160 {
161 /// Create new object from &mut Opt.
162 fn from(data: &mut <Self as clap_serde_derive::ClapSerde>::Opt) -> Self {
163 Self::default().merge(data)
164 }
165 }
166 }
167 .into()
168}