cset_derive/
lib.rs

1use proc_macro2::{Ident, TokenStream};
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Attribute, DataStruct, DeriveInput, Error, Meta, NestedMeta, Type};
4
5#[proc_macro_derive(Track, attributes(track))]
6pub fn macro_entry(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8
9    let expanded = match &input.data {
10        syn::Data::Struct(data) => derive_tracked_struct(&input, data),
11        syn::Data::Enum(data) => {
12            syn::Error::new_spanned(data.enum_token, "Cannot derive Undo for enums")
13                .into_compile_error()
14        }
15        syn::Data::Union(data) => {
16            syn::Error::new_spanned(data.union_token, "Cannot derive Undo for unions")
17                .into_compile_error()
18        }
19    };
20
21    expanded.into()
22}
23
24struct TrackedField {
25    index: usize,
26    ident: Ident,
27    ty: Type,
28    flattened_ident: Option<Ident>,
29}
30
31fn derive_tracked_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
32    let struct_ident = &input.ident;
33
34    for field in &data.fields {
35        if field.ident.is_none() {
36            return syn::Error::new_spanned(&data.fields, "Cannot derive Undo for tuple structs")
37                .to_compile_error();
38        }
39    }
40
41    let fields = data
42        .fields
43        .iter()
44        .enumerate()
45        .map(|(index, field)| {
46            let ident = field.ident.clone().unwrap();
47            let is_flattened = field.attrs.iter().any(|attr| {
48                get_meta_items(attr).unwrap().iter().any(|meta| match meta {
49                    NestedMeta::Meta(Meta::Path(path)) => path.is_ident("flatten"),
50                    _ => false,
51                })
52            });
53            
54            let ty = field.ty.clone();
55            let flattened_ident = if is_flattened {
56                Some(flattened_struct_ident(&ty))
57            } else {
58                None
59            };
60
61            TrackedField {
62                index,
63                ident,
64                ty: field.ty.clone(),
65                flattened_ident,
66            }
67        })
68        .collect::<Vec<_>>();
69
70    let draft_struct = derive_draft_struct(struct_ident, &fields[..]);
71    let draft_ident = create_draft_ident(struct_ident);
72    let draft_setters = fields.iter().map(|field| {
73        let TrackedField {
74            ident,
75            flattened_ident,
76            ..
77        } = field;
78
79        if flattened_ident.is_some() {
80            quote!(#ident: self.#ident.edit())
81        } else {
82            quote!(#ident: ::cset::DraftField::new(&mut self.#ident))
83        }
84    });
85
86    let apply_value_fields = fields
87        .iter()
88        .filter(|field| field.flattened_ident.is_none())
89        .map(|field| {
90            let TrackedField {
91                index, ident, ty, ..
92            } = field;
93
94            quote! {
95                #index => {
96                    let new_value = *value.downcast::<#ty>().unwrap();
97                    let old_value = ::std::mem::replace(&mut self.#ident, new_value);
98                    reverse_changes.push(::cset::Change {
99                        field_id: change.field_id,
100                        value: ::cset::ChangeValue::Value(::std::boxed::Box::new(old_value)),
101                    });
102                }
103            }
104        });
105
106    let apply_changeset_fields = fields
107        .iter()
108        .filter(|field| field.flattened_ident.is_some())
109        .map(|field| {
110            let TrackedField {
111                index, ident, ..
112            } = field;
113
114            quote! {
115                #index => {
116                    let reverse_change = self.#ident.apply_impl(field_changes, depth + 1);
117                    reverse_changes.push(::cset::Change {
118                        field_id: change.field_id,
119                        value: ::cset::ChangeValue::ChangeSet(reverse_change),
120                    });
121                }
122            }
123        });
124
125    quote! {
126        impl #struct_ident {
127            pub fn edit(&mut self) -> #draft_ident {
128                #draft_ident {
129                    #(#draft_setters,)*
130                }
131            }
132
133            pub fn apply(&mut self, changeset: ::cset::ChangeSet) -> ::cset::ChangeSet {
134                self.apply_impl(changeset, 0)
135            }
136
137            fn apply_impl(&mut self, changeset: ::cset::ChangeSet, depth: usize) -> ::cset::ChangeSet {
138                assert!(changeset.for_type::<#struct_ident>());
139                let mut reverse_changes = Vec::new();
140
141                for change in changeset.changes {
142                    let field_index = change.field_id.field_index(depth);
143
144                    match change.value {
145                        ::cset::ChangeValue::Value(value) => match field_index {
146                            #(#apply_value_fields,)*
147                            _ => unreachable!(),
148                        },
149                        ::cset::ChangeValue::ChangeSet(field_changes) => match field_index {
150                            #(#apply_changeset_fields,)*
151                            _ => unreachable!(),
152                        },
153                    };
154                }
155
156                ::cset::ChangeSet::new::<#struct_ident>(reverse_changes)
157            }
158        }
159
160        #draft_struct
161    }
162}
163
164fn derive_draft_struct(struct_ident: &Ident, fields: &[TrackedField]) -> TokenStream {
165    let draft_ident = create_draft_ident(struct_ident);
166
167    let draft_fields = fields.iter().map(|field| {
168        let TrackedField { ident, ty, flattened_ident, .. } = field;
169
170        if let Some(flattened_ident) = flattened_ident {
171            let draft_ident = create_draft_ident(flattened_ident);
172            quote!(#ident: #draft_ident<'b>)
173        } else {
174            quote!(#ident: ::cset::DraftField::<'b, #ty>)
175        }
176    });
177
178    let field_api_fns = fields.iter().map(|field| {
179        let TrackedField { ident, ty, flattened_ident, .. } = field;
180        let dirty_checker = create_dirty_check_ident(ident);
181        let resetter = create_resetter_ident(ident);
182                
183        if let Some(flattened_ident) = flattened_ident {
184            let editor = format_ident!("edit_{ident}");
185            let flattened_draft_ident = create_draft_ident(flattened_ident); 
186            quote! {
187                pub fn #editor(&mut self) -> &mut #flattened_draft_ident<'b> {
188                    &mut self.#ident
189                }
190
191                pub fn #dirty_checker(&self) -> bool {
192                    self.#ident.is_dirty()
193                }
194
195                pub fn #resetter(&mut self) {
196                    self.#ident.reset();
197                }
198            }
199        } else {
200            let getter = format_ident!("get_{ident}");
201            let setter = format_ident!("set_{ident}");
202            quote! {
203                pub fn #getter(&self) -> &#ty {
204                    if let Some(#ident) = &self.#ident.draft {
205                        #ident
206                    } else {
207                        &self.#ident.original
208                    }
209                }
210
211                pub fn #setter(&mut self, #ident: #ty) {
212                    self.#ident.draft = Some(#ident);
213                }
214
215                pub fn #dirty_checker(&self) -> bool {
216                    self.#ident.draft.is_some()
217                }
218
219                pub fn #resetter(&mut self) -> Option<#ty> {
220                    self.#ident.draft.take()
221                }
222            }
223        }
224    });
225
226    let draft_change_checkers = fields.iter().map(|field| {
227        let TrackedField { ident, .. } = field;
228        let dirty_checker = create_dirty_check_ident(ident);
229        quote!(self.#dirty_checker())
230    });
231
232    let draft_resetters = fields.iter().map(|field| {
233        let TrackedField { ident, .. } = field;
234        let resetter = create_resetter_ident(ident);
235        quote!(self.#resetter())
236    });
237
238    let field_commits = fields.iter().map(|field| {
239        let TrackedField { index, ident, flattened_ident, .. } = field;
240
241        if flattened_ident.is_some() {
242            quote! {
243                {
244                    let new_field_idx = field_idx.push_field(#index);
245                    changes.push(::cset::Change {
246                        field_id: new_field_idx.clone(),
247                        value: ::cset::ChangeValue::ChangeSet(self.#ident.apply_impl(new_field_idx)),
248                    });
249                }
250            }
251        } else {   
252            quote! {
253                if let Some(change) = self.#ident.apply(field_idx.push_field(#index)) {
254                    changes.push(change);
255                }
256            }
257        }
258    });
259
260    quote! {
261        pub struct #draft_ident<'b> {
262            #(#draft_fields,)*
263        }
264
265        impl<'b> #draft_ident<'b> {
266            #(#field_api_fns)*
267
268            /// Returns true if the draft will modify the underlying struct if
269            /// committed.
270            pub fn is_dirty(&self) -> bool {
271                #(#draft_change_checkers)||*
272            }
273
274            /// Clear all updates to changed fields.
275            pub fn reset(&mut self) {
276                #(#draft_resetters;)*
277            }
278
279            pub fn apply(self) -> ::cset::ChangeSet {
280                self.apply_impl(::cset::FieldId::default())
281            }
282    
283            fn apply_impl(self, field_idx: ::cset::FieldId) -> ::cset::ChangeSet {
284                let mut changes = Vec::new();
285    
286                #(#field_commits)*
287    
288                ::cset::ChangeSet::new::<#struct_ident>(changes)
289            }
290        }
291    }
292}
293
294fn create_draft_ident(ident: &Ident) -> Ident {
295    format_ident!("{ident}Draft")
296}
297
298fn create_dirty_check_ident(ident: &Ident) -> Ident {
299    format_ident!("is_{ident}_dirty")
300}
301
302fn create_resetter_ident(ident: &Ident) -> Ident {
303    format_ident!("reset_{ident}")
304}
305
306fn get_meta_items(attr: &Attribute) -> syn::Result<Vec<NestedMeta>> {
307    if attr.path.is_ident("track") {
308        match attr.parse_meta()? {
309            Meta::List(meta) => Ok(Vec::from_iter(meta.nested)),
310            bad => Err(Error::new_spanned(bad, "unrecognized attribute")),
311        }
312    } else {
313        Ok(Vec::new())
314    }
315}
316
317fn flattened_struct_ident(ty: &Type) -> Ident {
318    match ty {
319        Type::Path(path) => {
320            path.path.get_ident().unwrap().clone()
321        },
322        _ => todo!(),
323    }
324}