gusket_codegen/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote, quote_spanned};
3use syn::parse::{Parse, ParseStream};
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::{Error, Result};
7
8mod tests;
9
10#[proc_macro_derive(Gusket, attributes(gusket))]
11pub fn gusket(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
12    match gusket_impl(ts.into()) {
13        Ok(output) => output,
14        Err(err) => err.into_compile_error(),
15    }
16    .into()
17}
18
19fn gusket_impl(ts: TokenStream) -> Result<TokenStream> {
20    let input = syn::parse2::<syn::DeriveInput>(ts)?;
21    let input_ident = &input.ident;
22
23    let mut input_attrs = InputAttrs::new(&input.vis);
24
25    for attr in &input.attrs {
26        if attr.path.is_ident("gusket") {
27            input_attrs.apply(attr)?;
28        }
29    }
30
31    let (generics_decl, generics_usage) = if input.generics.params.is_empty() {
32        (quote!(), quote!())
33    } else {
34        let decl: Vec<_> = input.generics.params.iter().collect();
35        let usage: Vec<_> = input
36            .generics
37            .params
38            .iter()
39            .map(|param| match param {
40                syn::GenericParam::Type(syn::TypeParam { ident, .. }) => quote!(#ident),
41                syn::GenericParam::Lifetime(syn::LifetimeDef { lifetime, .. }) => {
42                    quote!(#lifetime)
43                }
44                syn::GenericParam::Const(syn::ConstParam { ident, .. }) => quote!(#ident),
45            })
46            .collect();
47        (quote!(<#(#decl),*>), quote!(<#(#usage),*>))
48    };
49    let generics_where = &input.generics.where_clause;
50
51    let data = match &input.data {
52        syn::Data::Struct(data) => data,
53        syn::Data::Enum(data) => {
54            return Err(Error::new_spanned(&data.enum_token, "Enums are not supported"));
55        }
56        syn::Data::Union(data) => {
57            return Err(Error::new_spanned(&data.union_token, "Unions are not supported"));
58        }
59    };
60
61    let named = match &data.fields {
62        syn::Fields::Named(fields) => fields,
63        syn::Fields::Unnamed(fields) => {
64            return Err(Error::new(fields.paren_token.span, "Tuple structs are not supported"));
65        }
66        syn::Fields::Unit => {
67            return Err(Error::new_spanned(&data.semi_token, "Tuple structs are not supported"));
68        }
69    };
70
71    let mut methods = TokenStream::new();
72
73    for field in &named.named {
74        process_field(field, &input_attrs, &mut methods)?;
75    }
76
77    let output = quote! {
78        impl #generics_decl #input_ident #generics_usage #generics_where {
79            #methods
80        }
81    };
82
83    Ok(output)
84}
85
86fn process_field(
87    field: &syn::Field,
88    input_attrs: &InputAttrs,
89    methods: &mut TokenStream,
90) -> Result<()> {
91    let field_ident = field.ident.as_ref().expect("Struct is named");
92    let field_ty = &field.ty;
93
94    let mut field_vis = input_attrs.vis.clone();
95    let mut is_copy = None;
96    let mut derive = input_attrs.derive;
97    let mut mutable = input_attrs.mutable;
98
99    let mut docs = Vec::new();
100
101    for attr in &field.attrs {
102        if attr.path.is_ident("gusket") {
103            derive = true;
104
105            if !attr.tokens.is_empty() {
106                let attr_list: Punctuated<FieldAttr, syn::Token![,]> =
107                    attr.parse_args_with(Punctuated::parse_terminated)?;
108                for attr in attr_list {
109                    match attr {
110                        FieldAttr::Vis(_, vis) => field_vis = vis,
111                        FieldAttr::Immut(_) => mutable = false,
112                        FieldAttr::Mut(_) => mutable = true,
113                        FieldAttr::Copy(ident) => is_copy = Some(ident),
114                        FieldAttr::Skip(_) => derive = false,
115                    }
116                }
117            }
118        } else if attr.path.is_ident("doc") {
119            docs.push(attr);
120        }
121    }
122
123    if !derive {
124        return Ok(());
125    }
126
127    let ref_op = match is_copy {
128        Some(_) => quote!(),
129        None => quote_spanned!(field.span() => &),
130    };
131
132    methods.extend(quote_spanned! { field.span() =>
133        #(#docs)*
134        #[must_use = "Getters have no side effect"]
135        #[inline(always)]
136        #field_vis fn #field_ident(&self) -> #ref_op #field_ty {
137            #ref_op self.#field_ident
138        }
139    });
140
141    if mutable {
142        let setter = format_ident!("set_{}", &field_ident);
143        let mut_getter = format_ident!("{}_mut", &field_ident);
144
145        methods.extend(quote_spanned! { field.span() =>
146            #(#docs)*
147            #[must_use = "Mutable getters have no side effect"]
148            #[inline(always)]
149            #field_vis fn #mut_getter(&mut self) -> &mut #field_ty {
150                &mut self.#field_ident
151            }
152
153            #(#docs)*
154            #[inline(always)]
155            #field_vis fn #setter(&mut self, #field_ident: #field_ty) {
156                self.#field_ident = #field_ident;
157            }
158        })
159    }
160
161    Ok(())
162}
163
164struct InputAttrs {
165    vis:     syn::Visibility,
166    mutable: bool,
167    derive:  bool,
168}
169
170impl InputAttrs {
171    fn new(vis: &syn::Visibility) -> Self {
172        InputAttrs { vis: vis.clone(), mutable: true, derive: false }
173    }
174
175    fn apply(&mut self, attr: &syn::Attribute) -> Result<()> {
176        let attr_list: Punctuated<InputAttr, syn::Token![,]> =
177            attr.parse_args_with(Punctuated::parse_terminated)?;
178
179        for attr in attr_list {
180            match attr {
181                InputAttr::Vis(_, vis) => self.vis = vis,
182                InputAttr::Immut(_) => self.mutable = false,
183                InputAttr::All(_) => self.derive = true,
184            }
185        }
186
187        Ok(())
188    }
189}
190
191enum InputAttr {
192    Vis(syn::Ident, syn::Visibility),
193    Immut(syn::Ident),
194    All(syn::Ident),
195}
196
197impl Parse for InputAttr {
198    fn parse(input: ParseStream) -> Result<Self> {
199        let ident: syn::Ident = input.parse()?;
200        if ident == "vis" {
201            input.parse::<syn::Token![=]>()?;
202            let vis: syn::Visibility = input.parse()?;
203            Ok(Self::Vis(ident, vis))
204        } else if ident == "immut" {
205            Ok(Self::Immut(ident))
206        } else if ident == "all" {
207            Ok(Self::All(ident))
208        } else {
209            Err(Error::new_spanned(ident, "Unsupported attribute"))
210        }
211    }
212}
213
214enum FieldAttr {
215    Vis(syn::Ident, syn::Visibility),
216    Immut(syn::Ident),
217    Mut(syn::Token![mut]),
218    Copy(syn::Ident),
219    Skip(syn::Ident),
220}
221
222impl Parse for FieldAttr {
223    fn parse(input: ParseStream) -> Result<Self> {
224        if input.peek(syn::Token![mut]) {
225            let mut_token: syn::Token![mut] = input.parse()?;
226            return Ok(Self::Mut(mut_token));
227        }
228
229        let ident: syn::Ident = input.parse()?;
230        if ident == "vis" {
231            input.parse::<syn::Token![=]>()?;
232            let vis: syn::Visibility = input.parse()?;
233            Ok(Self::Vis(ident, vis))
234        } else if ident == "immut" {
235            Ok(Self::Immut(ident))
236        } else if ident == "copy" {
237            Ok(Self::Copy(ident))
238        } else if ident == "skip" {
239            Ok(Self::Skip(ident))
240        } else {
241            Err(Error::new_spanned(ident, "Unsupported attribute"))
242        }
243    }
244}