perstruct_proc_macros/
lib.rs

1//! This crate contains the proc macros for the [`perstruct`](https://crates.io/crates/perstruct) crate.
2//! It is not intended to be used directly.
3
4use quote::quote;
5use quote::ToTokens;
6use syn::{parse_macro_input, ItemStruct};
7
8#[proc_macro_attribute]
9pub fn perstruct(
10    _args: proc_macro::TokenStream,
11    input: proc_macro::TokenStream,
12) -> proc_macro::TokenStream {
13    let input: ItemStruct = parse_macro_input!(input as ItemStruct);
14    process_struct(input)
15        .unwrap_or_else(syn::Error::into_compile_error)
16        .into()
17}
18
19fn process_struct(mut input: ItemStruct) -> syn::Result<proc_macro2::TokenStream> {
20    if input.generics.params.len() != 0 {
21        panic!("Struct cannot be generic");
22    }
23
24    let mut fields = vec![];
25    let mut skipped_fields = vec![];
26    for field in input.fields.iter_mut() {
27        let ident = field.ident.clone().unwrap();
28        let mut to_remove: Vec<syn::Path> = vec![];
29        let mut key: Option<String> = None;
30        let mut default_fn = None;
31        let mut default_lit = None;
32        let mut skip = false;
33
34        for attr in &field.attrs {
35            let attr_path = attr.path().clone();
36            if attr_path.is_ident("perstruct") {
37                to_remove.push(attr_path);
38                let meta = attr.parse_args()?;
39                match meta {
40                    syn::Meta::NameValue(syn::MetaNameValue {
41                        path,
42                        value: syn::Expr::Lit(lit),
43                        ..
44                    }) => match path {
45                        p if p.is_ident("key") => {
46                            if let syn::Lit::Str(s) = lit.lit {
47                                key = Some(s.value());
48                            } else {
49                                return Err(syn::Error::new_spanned(
50                                    lit,
51                                    "Expected string literal",
52                                ));
53                            }
54                        }
55                        p if p.is_ident("default_fn") => {
56                            if let syn::Lit::Str(s) = lit.lit {
57                                default_fn = Some(s.value());
58                            } else {
59                                return Err(syn::Error::new_spanned(
60                                    lit,
61                                    "Expected string literal",
62                                ));
63                            }
64                        }
65                        p if p.is_ident("default") => {
66                            default_lit = Some(lit.lit);
67                        }
68                        thing => return Err(syn::Error::new_spanned(
69                            thing.into_token_stream(),
70                            "Unknown perstruct attribute (available: key, default_fn, default, skip)",
71                        )),
72                    },
73                    syn::Meta::Path(path) => {
74                        if path.is_ident("skip") {
75                            skip = true;
76                        } else {
77                            return Err(syn::Error::new_spanned(
78                                    path.into_token_stream(),
79                                    "Unknown perstruct attribute (available: key, default_fn, default, skip)",
80                                ));
81                        }
82                    }
83                    thing => {
84                        return Err(syn::Error::new_spanned(
85                            attr.into_token_stream(),
86                            format!("Parse args failed: {thing:?}"),
87                        ))
88                    }
89                }
90            }
91        }
92        for attr in to_remove {
93            field.attrs.retain(|a| a.path() != &attr);
94        }
95        if skip {
96            skipped_fields.push(ident);
97            continue;
98        }
99        field.vis = syn::Visibility::Inherited;
100        let ty = field.ty.clone();
101        fields.push(PerstructField {
102            key: key.unwrap_or(ident.to_string()),
103            ident,
104            default_fn,
105            default_lit,
106            ty,
107        });
108    }
109
110    // Add _perstruct_changed_keys field
111    let syn::Fields::Named(syn::FieldsNamed { named, .. }) = &mut input.fields else {
112        return Err(syn::Error::new_spanned(
113            input.ident,
114            "Perstruct: struct must have named fields",
115        ));
116    };
117    named.push(syn::Field {
118        attrs: vec![],
119        vis: syn::Visibility::Inherited,
120        mutability: syn::FieldMutability::None,
121        ident: Some(syn::Ident::new(
122            "_perstruct_changed_keys",
123            proc_macro2::Span::mixed_site(),
124        )),
125        colon_token: None,
126        ty: syn::Type::Verbatim(quote! { std::collections::HashSet<&'static str> }),
127    });
128
129    let ident = input.ident.clone();
130    let default_impl = generate_default_impl(&ident, &fields, &skipped_fields);
131    let methods_impl = generate_methods_impl(&ident, &fields);
132    let trait_impl = generate_trait_impl(&ident, &fields);
133
134    let tokens = quote::quote! {
135        #input
136
137        #default_impl
138
139        #methods_impl
140
141        #trait_impl
142    };
143    Ok(tokens)
144}
145
146fn generate_methods_impl(
147    ident: &syn::Ident,
148    fields: &[PerstructField],
149) -> proc_macro2::TokenStream {
150    let methods = fields.iter().map(|field| {
151        let ident = &field.ident;
152        let ty = &field.ty;
153        let (reference_return, reference_ty) = match ty {
154            // copy types should be returned by value - all integer, float, bool, char
155            syn::Type::Path(syn::TypePath { qself: None, path }) if path.segments.len() == 1 => {
156                let segment = &path.segments[0];
157                match segment.ident.to_string().as_str() {
158                    "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
159                    | "u64" | "u128" | "usize" | "f32" | "f64" | "bool" | "char" => {
160                        (quote! { self.#ident }, quote! { #ty })
161                    }
162                    _ => (quote! { &self.#ident }, quote! { &#ty }),
163                }
164            }
165            _ => (quote! { &self.#ident }, quote! { &#ty }),
166        };
167        let set_ident = syn::Ident::new(&format!("set_{}", ident), ident.span());
168        let key = field.key.clone();
169        let key_lit = syn::ExprLit {
170            attrs: vec![],
171            lit: syn::Lit::Str(syn::LitStr::new(&key.to_string(), ident.span())),
172        };
173        let update_ident = syn::Ident::new(&format!("update_{}", ident), ident.span());
174        quote! {
175            pub fn #ident(&self) -> #reference_ty {
176                #reference_return
177            }
178            pub fn #set_ident(&mut self, value: #ty) {
179                self.#ident = value;
180                self._perstruct_changed_keys.insert(#key_lit);
181            }
182            pub fn #update_ident(&mut self, f: impl FnOnce(&mut #ty)) {
183                f(&mut self.#ident);
184                self._perstruct_changed_keys.insert(#key_lit);
185            }
186        }
187    });
188    quote::quote! {
189        impl #ident {
190            #(#methods)*
191        }
192    }
193}
194
195fn generate_default_impl(
196    ident: &syn::Ident,
197    fields: &[PerstructField],
198    skipped_fields: &[syn::Ident],
199) -> proc_macro2::TokenStream {
200    let default_fields = fields.iter().map(|field| {
201        let ident = &field.ident;
202        if let Some(default_fn) = &field.default_fn {
203            let default_fn = syn::Ident::new(default_fn, ident.span());
204            quote::quote! { #ident: #default_fn() }
205        } else if let Some(default_lit) = &field.default_lit {
206            quote::quote! { #ident: #default_lit }
207        } else {
208            quote::quote! { #ident: Default::default() }
209        }
210    });
211    let default_skipped_fields = skipped_fields.iter().map(|ident| {
212        quote::quote! { #ident: Default::default() }
213    });
214    quote::quote! {
215        #[automatically_derived]
216        impl Default for #ident {
217            fn default() -> Self {
218                Self {
219                    _perstruct_changed_keys: Default::default(),
220                    #(#default_fields),*,
221                    #(#default_skipped_fields),*
222                }
223            }
224        }
225    }
226}
227
228fn generate_trait_impl(ident: &syn::Ident, fields: &[PerstructField]) -> proc_macro2::TokenStream {
229    let key_lits = fields.iter().map(|field| {
230        let key = field.key.clone();
231        syn::LitStr::new(&key, proc_macro2::Span::mixed_site())
232    });
233
234    let from_map_impl = generate_from_map_impl(fields);
235    let serialize_changes_impl = generate_serialize_changes_impl(fields);
236
237    quote::quote! {
238        impl ::perstruct::Perstruct for #ident {
239            #from_map_impl
240
241            fn keys() -> std::vec::Vec<&'static str> {
242                vec![#( #key_lits.clone() ),*]
243            }
244
245            fn changed_keys(&self) -> &std::collections::HashSet<&'static str> {
246                &self._perstruct_changed_keys
247            }
248
249            fn mark_keys_changed(&mut self, keys: &[&'static str]) {
250                for key in keys {
251                    self._perstruct_changed_keys.insert(*key);
252                }
253            }
254
255            #serialize_changes_impl
256
257            fn clear_changes(&mut self) {
258                self._perstruct_changed_keys.clear();
259            }
260        }
261    }
262}
263
264fn generate_from_map_impl(fields: &[PerstructField]) -> proc_macro2::TokenStream {
265    let key_lits = fields.iter().map(|field| {
266        let key = field.key.clone();
267        syn::LitStr::new(&key, proc_macro2::Span::mixed_site())
268    });
269    let field_matches = fields
270        .iter()
271        .map(|field| {
272            let key = field.key.clone();
273            let key_lit = syn::LitStr::new(&key, proc_macro2::Span::mixed_site());
274            let ty = &field.ty;
275            let ident = &field.ident;
276            quote! {
277                #key_lit => {
278                    match serde_json::from_str::<#ty>(value.as_ref()) {
279                        Ok(json_value) => {
280                            struct_value.#ident = json_value;
281                            changed_keys.remove(#key_lit);
282                        }
283                        Err(e) => {
284                            deserialization_errors.push((#key_lit, e.to_string()));
285                        }
286                    }
287                }
288            }
289        })
290        .collect::<Vec<_>>();
291    quote::quote! {
292        fn from_map(map: &std::collections::HashMap<&str, &str>) -> ::perstruct::LoadResult<Self> {
293            let mut changed_keys = vec![
294                #( #key_lits ),*
295            ].into_iter().collect::<std::collections::HashSet<&'static str>>();
296            let mut unknown_fields = vec![];
297
298            let mut struct_value = Self::default();
299            let mut deserialization_errors = vec![];
300            for (key, value) in map.iter() {
301                match *key {
302                    #(#field_matches)*
303                    unknown_key => {
304                        unknown_fields.push(unknown_key.to_string());
305                    }
306                }
307            }
308            struct_value._perstruct_changed_keys = changed_keys;
309            ::perstruct::LoadResult {
310                value: struct_value,
311                deserialization_errors,
312                unknown_fields,
313            }
314        }
315    }
316}
317
318fn generate_serialize_changes_impl(fields: &[PerstructField]) -> proc_macro2::TokenStream {
319    let change_matches = fields
320        .iter()
321        .map(|field| {
322            let ident = &field.ident;
323            let key = field.key.clone();
324            let key_lit = syn::LitStr::new(&key, proc_macro2::Span::mixed_site());
325            quote! {
326                #key_lit => {
327                    let value = serde_json::to_string(&self.#ident).map_err(|e| e.to_string())?;
328                    changes.push((#key_lit, value));
329                }
330            }
331        })
332        .collect::<Vec<_>>();
333    quote::quote! {
334        fn serialize_changes(&self) -> Result<Vec<(&'static str, String)>, String> {
335            let mut changes = vec![];
336            for key in self._perstruct_changed_keys.iter() {
337                match *key {
338                    #(#change_matches)*
339                    _ => {}
340                }
341            }
342            Ok(changes)
343        }
344    }
345}
346
347#[derive(Debug)]
348struct PerstructField {
349    ident: syn::Ident,
350    key: String,
351    default_fn: Option<String>,
352    default_lit: Option<syn::Lit>,
353    ty: syn::Type,
354}