1use proc_macro2::{Ident, TokenStream};
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Attribute, DataStruct, DeriveInput, Error, Meta, NestedMeta, Type};
4
5#[proc_macro_derive(Track, attributes(track))]
6pub fn macro_entry(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8
9 let expanded = match &input.data {
10 syn::Data::Struct(data) => derive_tracked_struct(&input, data),
11 syn::Data::Enum(data) => {
12 syn::Error::new_spanned(data.enum_token, "Cannot derive Undo for enums")
13 .into_compile_error()
14 }
15 syn::Data::Union(data) => {
16 syn::Error::new_spanned(data.union_token, "Cannot derive Undo for unions")
17 .into_compile_error()
18 }
19 };
20
21 expanded.into()
22}
23
24struct TrackedField {
25 index: usize,
26 ident: Ident,
27 ty: Type,
28 flattened_ident: Option<Ident>,
29}
30
31fn derive_tracked_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
32 let struct_ident = &input.ident;
33
34 for field in &data.fields {
35 if field.ident.is_none() {
36 return syn::Error::new_spanned(&data.fields, "Cannot derive Undo for tuple structs")
37 .to_compile_error();
38 }
39 }
40
41 let fields = data
42 .fields
43 .iter()
44 .enumerate()
45 .map(|(index, field)| {
46 let ident = field.ident.clone().unwrap();
47 let is_flattened = field.attrs.iter().any(|attr| {
48 get_meta_items(attr).unwrap().iter().any(|meta| match meta {
49 NestedMeta::Meta(Meta::Path(path)) => path.is_ident("flatten"),
50 _ => false,
51 })
52 });
53
54 let ty = field.ty.clone();
55 let flattened_ident = if is_flattened {
56 Some(flattened_struct_ident(&ty))
57 } else {
58 None
59 };
60
61 TrackedField {
62 index,
63 ident,
64 ty: field.ty.clone(),
65 flattened_ident,
66 }
67 })
68 .collect::<Vec<_>>();
69
70 let draft_struct = derive_draft_struct(struct_ident, &fields[..]);
71 let draft_ident = create_draft_ident(struct_ident);
72 let draft_setters = fields.iter().map(|field| {
73 let TrackedField {
74 ident,
75 flattened_ident,
76 ..
77 } = field;
78
79 if flattened_ident.is_some() {
80 quote!(#ident: self.#ident.edit())
81 } else {
82 quote!(#ident: ::cset::DraftField::new(&mut self.#ident))
83 }
84 });
85
86 let apply_value_fields = fields
87 .iter()
88 .filter(|field| field.flattened_ident.is_none())
89 .map(|field| {
90 let TrackedField {
91 index, ident, ty, ..
92 } = field;
93
94 quote! {
95 #index => {
96 let new_value = *value.downcast::<#ty>().unwrap();
97 let old_value = ::std::mem::replace(&mut self.#ident, new_value);
98 reverse_changes.push(::cset::Change {
99 field_id: change.field_id,
100 value: ::cset::ChangeValue::Value(::std::boxed::Box::new(old_value)),
101 });
102 }
103 }
104 });
105
106 let apply_changeset_fields = fields
107 .iter()
108 .filter(|field| field.flattened_ident.is_some())
109 .map(|field| {
110 let TrackedField {
111 index, ident, ..
112 } = field;
113
114 quote! {
115 #index => {
116 let reverse_change = self.#ident.apply_impl(field_changes, depth + 1);
117 reverse_changes.push(::cset::Change {
118 field_id: change.field_id,
119 value: ::cset::ChangeValue::ChangeSet(reverse_change),
120 });
121 }
122 }
123 });
124
125 quote! {
126 impl #struct_ident {
127 pub fn edit(&mut self) -> #draft_ident {
128 #draft_ident {
129 #(#draft_setters,)*
130 }
131 }
132
133 pub fn apply(&mut self, changeset: ::cset::ChangeSet) -> ::cset::ChangeSet {
134 self.apply_impl(changeset, 0)
135 }
136
137 fn apply_impl(&mut self, changeset: ::cset::ChangeSet, depth: usize) -> ::cset::ChangeSet {
138 assert!(changeset.for_type::<#struct_ident>());
139 let mut reverse_changes = Vec::new();
140
141 for change in changeset.changes {
142 let field_index = change.field_id.field_index(depth);
143
144 match change.value {
145 ::cset::ChangeValue::Value(value) => match field_index {
146 #(#apply_value_fields,)*
147 _ => unreachable!(),
148 },
149 ::cset::ChangeValue::ChangeSet(field_changes) => match field_index {
150 #(#apply_changeset_fields,)*
151 _ => unreachable!(),
152 },
153 };
154 }
155
156 ::cset::ChangeSet::new::<#struct_ident>(reverse_changes)
157 }
158 }
159
160 #draft_struct
161 }
162}
163
164fn derive_draft_struct(struct_ident: &Ident, fields: &[TrackedField]) -> TokenStream {
165 let draft_ident = create_draft_ident(struct_ident);
166
167 let draft_fields = fields.iter().map(|field| {
168 let TrackedField { ident, ty, flattened_ident, .. } = field;
169
170 if let Some(flattened_ident) = flattened_ident {
171 let draft_ident = create_draft_ident(flattened_ident);
172 quote!(#ident: #draft_ident<'b>)
173 } else {
174 quote!(#ident: ::cset::DraftField::<'b, #ty>)
175 }
176 });
177
178 let field_api_fns = fields.iter().map(|field| {
179 let TrackedField { ident, ty, flattened_ident, .. } = field;
180 let dirty_checker = create_dirty_check_ident(ident);
181 let resetter = create_resetter_ident(ident);
182
183 if let Some(flattened_ident) = flattened_ident {
184 let editor = format_ident!("edit_{ident}");
185 let flattened_draft_ident = create_draft_ident(flattened_ident);
186 quote! {
187 pub fn #editor(&mut self) -> &mut #flattened_draft_ident<'b> {
188 &mut self.#ident
189 }
190
191 pub fn #dirty_checker(&self) -> bool {
192 self.#ident.is_dirty()
193 }
194
195 pub fn #resetter(&mut self) {
196 self.#ident.reset();
197 }
198 }
199 } else {
200 let getter = format_ident!("get_{ident}");
201 let setter = format_ident!("set_{ident}");
202 quote! {
203 pub fn #getter(&self) -> &#ty {
204 if let Some(#ident) = &self.#ident.draft {
205 #ident
206 } else {
207 &self.#ident.original
208 }
209 }
210
211 pub fn #setter(&mut self, #ident: #ty) {
212 self.#ident.draft = Some(#ident);
213 }
214
215 pub fn #dirty_checker(&self) -> bool {
216 self.#ident.draft.is_some()
217 }
218
219 pub fn #resetter(&mut self) -> Option<#ty> {
220 self.#ident.draft.take()
221 }
222 }
223 }
224 });
225
226 let draft_change_checkers = fields.iter().map(|field| {
227 let TrackedField { ident, .. } = field;
228 let dirty_checker = create_dirty_check_ident(ident);
229 quote!(self.#dirty_checker())
230 });
231
232 let draft_resetters = fields.iter().map(|field| {
233 let TrackedField { ident, .. } = field;
234 let resetter = create_resetter_ident(ident);
235 quote!(self.#resetter())
236 });
237
238 let field_commits = fields.iter().map(|field| {
239 let TrackedField { index, ident, flattened_ident, .. } = field;
240
241 if flattened_ident.is_some() {
242 quote! {
243 {
244 let new_field_idx = field_idx.push_field(#index);
245 changes.push(::cset::Change {
246 field_id: new_field_idx.clone(),
247 value: ::cset::ChangeValue::ChangeSet(self.#ident.apply_impl(new_field_idx)),
248 });
249 }
250 }
251 } else {
252 quote! {
253 if let Some(change) = self.#ident.apply(field_idx.push_field(#index)) {
254 changes.push(change);
255 }
256 }
257 }
258 });
259
260 quote! {
261 pub struct #draft_ident<'b> {
262 #(#draft_fields,)*
263 }
264
265 impl<'b> #draft_ident<'b> {
266 #(#field_api_fns)*
267
268 pub fn is_dirty(&self) -> bool {
271 #(#draft_change_checkers)||*
272 }
273
274 pub fn reset(&mut self) {
276 #(#draft_resetters;)*
277 }
278
279 pub fn apply(self) -> ::cset::ChangeSet {
280 self.apply_impl(::cset::FieldId::default())
281 }
282
283 fn apply_impl(self, field_idx: ::cset::FieldId) -> ::cset::ChangeSet {
284 let mut changes = Vec::new();
285
286 #(#field_commits)*
287
288 ::cset::ChangeSet::new::<#struct_ident>(changes)
289 }
290 }
291 }
292}
293
294fn create_draft_ident(ident: &Ident) -> Ident {
295 format_ident!("{ident}Draft")
296}
297
298fn create_dirty_check_ident(ident: &Ident) -> Ident {
299 format_ident!("is_{ident}_dirty")
300}
301
302fn create_resetter_ident(ident: &Ident) -> Ident {
303 format_ident!("reset_{ident}")
304}
305
306fn get_meta_items(attr: &Attribute) -> syn::Result<Vec<NestedMeta>> {
307 if attr.path.is_ident("track") {
308 match attr.parse_meta()? {
309 Meta::List(meta) => Ok(Vec::from_iter(meta.nested)),
310 bad => Err(Error::new_spanned(bad, "unrecognized attribute")),
311 }
312 } else {
313 Ok(Vec::new())
314 }
315}
316
317fn flattened_struct_ident(ty: &Type) -> Ident {
318 match ty {
319 Type::Path(path) => {
320 path.path.get_ident().unwrap().clone()
321 },
322 _ => todo!(),
323 }
324}