1use proc_macro2::TokenStream;
2use quote::{format_ident, quote, quote_spanned};
3use syn::parse::{Parse, ParseStream};
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::{Error, Result};
7
8mod tests;
9
10#[proc_macro_derive(Gusket, attributes(gusket))]
11pub fn gusket(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
12 match gusket_impl(ts.into()) {
13 Ok(output) => output,
14 Err(err) => err.into_compile_error(),
15 }
16 .into()
17}
18
19fn gusket_impl(ts: TokenStream) -> Result<TokenStream> {
20 let input = syn::parse2::<syn::DeriveInput>(ts)?;
21 let input_ident = &input.ident;
22
23 let mut input_attrs = InputAttrs::new(&input.vis);
24
25 for attr in &input.attrs {
26 if attr.path.is_ident("gusket") {
27 input_attrs.apply(attr)?;
28 }
29 }
30
31 let (generics_decl, generics_usage) = if input.generics.params.is_empty() {
32 (quote!(), quote!())
33 } else {
34 let decl: Vec<_> = input.generics.params.iter().collect();
35 let usage: Vec<_> = input
36 .generics
37 .params
38 .iter()
39 .map(|param| match param {
40 syn::GenericParam::Type(syn::TypeParam { ident, .. }) => quote!(#ident),
41 syn::GenericParam::Lifetime(syn::LifetimeDef { lifetime, .. }) => {
42 quote!(#lifetime)
43 }
44 syn::GenericParam::Const(syn::ConstParam { ident, .. }) => quote!(#ident),
45 })
46 .collect();
47 (quote!(<#(#decl),*>), quote!(<#(#usage),*>))
48 };
49 let generics_where = &input.generics.where_clause;
50
51 let data = match &input.data {
52 syn::Data::Struct(data) => data,
53 syn::Data::Enum(data) => {
54 return Err(Error::new_spanned(&data.enum_token, "Enums are not supported"));
55 }
56 syn::Data::Union(data) => {
57 return Err(Error::new_spanned(&data.union_token, "Unions are not supported"));
58 }
59 };
60
61 let named = match &data.fields {
62 syn::Fields::Named(fields) => fields,
63 syn::Fields::Unnamed(fields) => {
64 return Err(Error::new(fields.paren_token.span, "Tuple structs are not supported"));
65 }
66 syn::Fields::Unit => {
67 return Err(Error::new_spanned(&data.semi_token, "Tuple structs are not supported"));
68 }
69 };
70
71 let mut methods = TokenStream::new();
72
73 for field in &named.named {
74 process_field(field, &input_attrs, &mut methods)?;
75 }
76
77 let output = quote! {
78 impl #generics_decl #input_ident #generics_usage #generics_where {
79 #methods
80 }
81 };
82
83 Ok(output)
84}
85
86fn process_field(
87 field: &syn::Field,
88 input_attrs: &InputAttrs,
89 methods: &mut TokenStream,
90) -> Result<()> {
91 let field_ident = field.ident.as_ref().expect("Struct is named");
92 let field_ty = &field.ty;
93
94 let mut field_vis = input_attrs.vis.clone();
95 let mut is_copy = None;
96 let mut derive = input_attrs.derive;
97 let mut mutable = input_attrs.mutable;
98
99 let mut docs = Vec::new();
100
101 for attr in &field.attrs {
102 if attr.path.is_ident("gusket") {
103 derive = true;
104
105 if !attr.tokens.is_empty() {
106 let attr_list: Punctuated<FieldAttr, syn::Token![,]> =
107 attr.parse_args_with(Punctuated::parse_terminated)?;
108 for attr in attr_list {
109 match attr {
110 FieldAttr::Vis(_, vis) => field_vis = vis,
111 FieldAttr::Immut(_) => mutable = false,
112 FieldAttr::Mut(_) => mutable = true,
113 FieldAttr::Copy(ident) => is_copy = Some(ident),
114 FieldAttr::Skip(_) => derive = false,
115 }
116 }
117 }
118 } else if attr.path.is_ident("doc") {
119 docs.push(attr);
120 }
121 }
122
123 if !derive {
124 return Ok(());
125 }
126
127 let ref_op = match is_copy {
128 Some(_) => quote!(),
129 None => quote_spanned!(field.span() => &),
130 };
131
132 methods.extend(quote_spanned! { field.span() =>
133 #(#docs)*
134 #[must_use = "Getters have no side effect"]
135 #[inline(always)]
136 #field_vis fn #field_ident(&self) -> #ref_op #field_ty {
137 #ref_op self.#field_ident
138 }
139 });
140
141 if mutable {
142 let setter = format_ident!("set_{}", &field_ident);
143 let mut_getter = format_ident!("{}_mut", &field_ident);
144
145 methods.extend(quote_spanned! { field.span() =>
146 #(#docs)*
147 #[must_use = "Mutable getters have no side effect"]
148 #[inline(always)]
149 #field_vis fn #mut_getter(&mut self) -> &mut #field_ty {
150 &mut self.#field_ident
151 }
152
153 #(#docs)*
154 #[inline(always)]
155 #field_vis fn #setter(&mut self, #field_ident: #field_ty) {
156 self.#field_ident = #field_ident;
157 }
158 })
159 }
160
161 Ok(())
162}
163
164struct InputAttrs {
165 vis: syn::Visibility,
166 mutable: bool,
167 derive: bool,
168}
169
170impl InputAttrs {
171 fn new(vis: &syn::Visibility) -> Self {
172 InputAttrs { vis: vis.clone(), mutable: true, derive: false }
173 }
174
175 fn apply(&mut self, attr: &syn::Attribute) -> Result<()> {
176 let attr_list: Punctuated<InputAttr, syn::Token![,]> =
177 attr.parse_args_with(Punctuated::parse_terminated)?;
178
179 for attr in attr_list {
180 match attr {
181 InputAttr::Vis(_, vis) => self.vis = vis,
182 InputAttr::Immut(_) => self.mutable = false,
183 InputAttr::All(_) => self.derive = true,
184 }
185 }
186
187 Ok(())
188 }
189}
190
191enum InputAttr {
192 Vis(syn::Ident, syn::Visibility),
193 Immut(syn::Ident),
194 All(syn::Ident),
195}
196
197impl Parse for InputAttr {
198 fn parse(input: ParseStream) -> Result<Self> {
199 let ident: syn::Ident = input.parse()?;
200 if ident == "vis" {
201 input.parse::<syn::Token![=]>()?;
202 let vis: syn::Visibility = input.parse()?;
203 Ok(Self::Vis(ident, vis))
204 } else if ident == "immut" {
205 Ok(Self::Immut(ident))
206 } else if ident == "all" {
207 Ok(Self::All(ident))
208 } else {
209 Err(Error::new_spanned(ident, "Unsupported attribute"))
210 }
211 }
212}
213
214enum FieldAttr {
215 Vis(syn::Ident, syn::Visibility),
216 Immut(syn::Ident),
217 Mut(syn::Token![mut]),
218 Copy(syn::Ident),
219 Skip(syn::Ident),
220}
221
222impl Parse for FieldAttr {
223 fn parse(input: ParseStream) -> Result<Self> {
224 if input.peek(syn::Token![mut]) {
225 let mut_token: syn::Token![mut] = input.parse()?;
226 return Ok(Self::Mut(mut_token));
227 }
228
229 let ident: syn::Ident = input.parse()?;
230 if ident == "vis" {
231 input.parse::<syn::Token![=]>()?;
232 let vis: syn::Visibility = input.parse()?;
233 Ok(Self::Vis(ident, vis))
234 } else if ident == "immut" {
235 Ok(Self::Immut(ident))
236 } else if ident == "copy" {
237 Ok(Self::Copy(ident))
238 } else if ident == "skip" {
239 Ok(Self::Skip(ident))
240 } else {
241 Err(Error::new_spanned(ident, "Unsupported attribute"))
242 }
243 }
244}