delta_struct_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenTree};
5use proc_macro_error::abort_call_site;
6use quote::{format_ident, quote};
7use std::{iter::FromIterator, str::FromStr};
8use syn::{
9    parse_macro_input, punctuated::Punctuated, Attribute, Data, DeriveInput, Fields, Ident, Lit,
10    Meta, MetaList, MetaNameValue, NestedMeta, Path, PredicateType, Token, TraitBound,
11    TraitBoundModifier, Type, TypeParamBound, WherePredicate,
12};
13
14#[derive(Copy, Clone, Debug, Eq, PartialEq)]
15enum FieldType {
16    Ordered,
17    Unordered,
18    Scalar,
19    Delta,
20}
21
22const VALID_FIELD_TYPES: &str = "\"ordered\", \"unordered\", or \"scalar\"";
23
24#[proc_macro_derive(Delta, attributes(delta_struct))]
25pub fn derive_delta(input: TokenStream) -> TokenStream {
26    let DeriveInput {
27        attrs,
28        vis,
29        ident,
30        mut generics,
31        data,
32    } = parse_macro_input!(input as DeriveInput);
33    let (default_field_type, delta_leader) = match get_fieldtype_from_attrs(attrs.into_iter(), "default") {
34        Ok((v, delta_leader)) => (v.unwrap_or(FieldType::Scalar), delta_leader),
35        Err(_) => {
36            abort_call_site!(
37                "delta_struct(default = ...) for {} is not an accepted value, expected {}.",
38                ident,
39                VALID_FIELD_TYPES
40            );
41        }
42    };
43
44    let (named, fields) = match data {
45        Data::Struct(strukt) => match strukt.fields {
46            Fields::Named(named) => (
47                true,
48                collect_results(
49                    named.named.into_iter().map(|field| {
50                        (
51                            field.ident.unwrap().to_string(),
52                            field.ty,
53                            get_fieldtype_from_attrs(field.attrs.into_iter(), "field_type"),
54                        )
55                    }),
56                    default_field_type,
57                ),
58            ),
59            Fields::Unnamed(unnamed) => (
60                false,
61                collect_results(
62                    unnamed.unnamed.into_iter().enumerate().map(|(i, field)| {
63                        (
64                            i.to_string(),
65                            field.ty,
66                            get_fieldtype_from_attrs(field.attrs.into_iter(), "field_type"),
67                        )
68                    }),
69                    default_field_type,
70                ),
71            ),
72            Fields::Unit => {
73                (false, Ok(vec![]))
74            }
75        },
76        _ => {
77            abort_call_site!(
78                "delta_struct::Delta may only be derived for struct types currently. {} is not a struct type."
79            , ident)
80        }
81    };
82    let fields = match fields {
83        Ok(fields) => fields,
84        Err(bad_fields) => {
85            let bad_fields = format!("{:?}", bad_fields);
86            abort_call_site!(
87                "delta_struct(field_type = ...) for fields in {}: {} are not valid values. Expected {}.",
88                ident,
89                bad_fields,
90                VALID_FIELD_TYPES
91            )
92        }
93    };
94    let delta_leader = proc_macro2::TokenStream::from_str(&delta_leader).unwrap();
95    let delta_ident = format_ident!("{}Delta", ident);
96    let delta_fields = delta_fields(named, fields.iter().cloned());
97    let delta_struct = quote! {
98      #delta_leader
99      #vis struct #delta_ident #generics {
100          #delta_fields
101      }
102    };
103    let (delta_compute_let, delta_compute_fields) =
104        delta_compute_fields(named, fields.iter().cloned());
105    let (delta_apply_let, delta_apply_actions) = delta_apply_fields(named, fields.into_iter());
106    let partial_eq_types = generics
107        .type_params()
108        .map(|t| t.ident.clone())
109        .collect::<Vec<_>>();
110    let where_clause = generics.make_where_clause();
111    for ty in partial_eq_types {
112        let mut bounds = Punctuated::new();
113        let mut segments = Punctuated::new();
114        segments.push(Ident::new("std", Span::call_site()).into());
115        segments.push(Ident::new("cmp", Span::call_site()).into());
116        segments.push(Ident::new("PartialEq", Span::call_site()).into());
117        bounds.push(TypeParamBound::Trait(TraitBound {
118            paren_token: None,
119            modifier: TraitBoundModifier::None,
120            lifetimes: None,
121            path: Path {
122                leading_colon: Some(Token!(::)(Span::call_site())),
123                segments,
124            },
125        }));
126        where_clause
127            .predicates
128            .push(WherePredicate::Type(PredicateType {
129                lifetimes: None,
130                bounded_ty: Type::Verbatim(<Ident as Into<TokenTree>>::into(ty).into()),
131                colon_token: Token!(:)(Span::call_site()),
132                bounds,
133            }));
134    }
135    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
136    let delta_impl = quote! {
137      impl #impl_generics Delta for #ident #ty_generics #where_clause  {
138          type Output = #delta_ident #generics;
139
140          fn delta(old: Self, new: Self) -> Option<Self::Output> {
141           let mut delta_is_some = false;
142           #delta_compute_let
143           if delta_is_some {
144               Some(Self::Output {
145                #delta_compute_fields
146               })
147           } else {
148               None
149           }
150          }
151
152          fn apply_delta(&mut self, delta: Self::Output) {
153            let Self::Output {
154                #delta_apply_let
155            } = delta;
156            #delta_apply_actions
157          }
158      }
159    };
160    let output = quote! {
161        #delta_struct
162
163        #delta_impl
164    };
165    TokenStream::from(output)
166}
167
168fn delta_fields(
169    named: bool,
170    iter: impl Iterator<Item = (String, Type, FieldType, String)>,
171) -> proc_macro2::TokenStream {
172    FromIterator::from_iter(iter.map(|(ident, ty, field_ty, field_leader)| {
173        let field_leader = proc_macro2::TokenStream::from_str(&field_leader).unwrap();
174        let ident = if named {
175            format_ident!("{}", ident)
176        } else {
177            format_ident!("field_{}", ident)
178        };
179        match field_ty {
180            FieldType::Ordered => unimplemented!(),
181            FieldType::Unordered => {
182                let add = format_ident!("{}_add", ident);
183                let remove = format_ident!("{}_remove", ident);
184                quote! {
185                 #field_leader
186                 pub #add: Vec<<#ty as ::std::iter::IntoIterator>::Item>,
187                 #field_leader
188                 pub #remove: Vec<<#ty as ::std::iter::IntoIterator>::Item>,
189                }
190            }
191            FieldType::Scalar => {
192                quote! {
193                  #field_leader
194                  pub #ident: ::std::option::Option<#ty>,
195                }
196            }
197            FieldType::Delta => {
198                quote! {
199                    #field_leader
200                    pub #ident: ::std::option::Option<<#ty as Delta>::Output>,
201                }
202            }
203        }
204    }))
205}
206
207fn delta_compute_fields(
208    named: bool,
209    iter: impl Iterator<Item = (String, Type, FieldType, String)>,
210) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
211    iter.map(|(og_ident, _ty, field_ty, _field_leader)| {
212        let ident = if named {
213            format_ident!("{}", og_ident)
214        } else {
215            format_ident!("field_{}", og_ident)
216        };
217        let og_ident: proc_macro2::TokenStream = FromStr::from_str(&og_ident).unwrap();
218        match field_ty {
219            FieldType::Ordered => unimplemented!(),
220            FieldType::Unordered => {
221                let add = format_ident!("{}_add", ident);
222                let remove = format_ident!("{}_remove", ident);
223
224                (
225                    quote! {
226                        let mut #add = new.#og_ident.into_iter().collect::<::std::vec::Vec<_>>();
227                        let #remove = old.#og_ident.into_iter().filter_map(|i| {
228                            if let Some(index) = #add.iter().position(|a| a == &i) {
229                                #add.remove(index);
230                                None
231                            } else {
232                                Some(i)
233                            }
234                        }).collect::<::std::vec::Vec<_>>();
235                        delta_is_some = delta_is_some || !#add.is_empty() || !#remove.is_empty();
236                    },
237                    quote! {
238                        #add,
239                        #remove,
240                    },
241                )
242            }
243            FieldType::Scalar => (
244                quote! {
245                   let #ident = if old.#og_ident != new.#og_ident {
246                       delta_is_some = true;
247                       Some(new.#og_ident)
248                   } else {
249                       None
250                   };
251                },
252                quote! {
253                    #ident,
254                },
255            ),
256            FieldType::Delta => (
257                quote! {
258                    let #ident = Delta::delta(old.#og_ident, new.#og_ident);
259                    delta_is_some = delta_is_some || #ident.is_some();
260
261                },
262                quote! {
263                    #ident,
264                },
265            ),
266        }
267    })
268    .unzip()
269}
270
271fn delta_apply_fields(
272    named: bool,
273    iter: impl Iterator<Item = (String, Type, FieldType, String)>,
274) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
275    iter.map(|(og_ident, ty, field_ty, _field_leader)| {
276        let ident = if named {
277            format_ident!("{}", og_ident)
278        } else {
279            format_ident!("field_{}", og_ident)
280        };
281        let og_ident: proc_macro2::TokenStream = FromStr::from_str(&og_ident).unwrap();
282        match field_ty {
283            FieldType::Ordered => unimplemented!(),
284            FieldType::Unordered => {
285                let add = format_ident!("{}_add", ident);
286                let remove = format_ident!("{}_remove", ident);
287                (
288                    quote! {
289                        #add,
290                        mut #remove,
291                    },
292                    quote! {
293                        {
294                            let og = ::std::mem::replace(&mut self.#og_ident, ::std::iter::FromIterator::from_iter(vec![]));
295                            let mut #ident: #ty = ::std::iter::FromIterator::from_iter(og.into_iter().filter_map(|i| {
296                               if let Some(index) = #remove.iter().position(|a| a == &i) {
297                                 #remove.remove(index);
298                                 None
299                               } else {
300                                 Some(i)
301                               }
302                            }));
303                            #ident.extend(#add.into_iter());
304                            self.#og_ident = #ident;
305                        }
306                    }
307                )
308            }
309            FieldType::Scalar => 
310            (
311                quote! {
312                    #ident,
313                },
314                quote! {
315                   if let Some(v) = #ident {
316                       self.#og_ident = v; 
317                   }
318                }
319            ),
320            FieldType::Delta => 
321            (
322                quote! {
323                    #ident,
324                },
325                quote!{
326                   if let Some(v) = #ident {
327                       self.#og_ident.apply_delta(v); 
328                   }
329                }
330            ),
331        }
332    }).unzip()
333}
334
335fn collect_results(
336    iter: impl Iterator<Item = (String, Type, Result<(Option<FieldType>, String), FieldTypeError>)>,
337    default_field_type: FieldType,
338) -> Result<Vec<(String, Type, FieldType, String)>, Vec<String>> {
339    iter.fold(Ok(vec![]), |v, i| match (v, i) {
340        (Ok(mut v), (ident, b, Ok((c, d)))) => {
341            v.push((ident, b, c.unwrap_or(default_field_type), d));
342            Ok(v)
343        }
344        (Ok(_), (ident, _, Err(_))) => Err(vec![ident]),
345        (Err(mut v), (ident, _, Err(_))) => {
346            v.push(ident);
347            Err(v)
348        }
349        (v @ Err(_), _) => v,
350    })
351}
352
353enum FieldTypeError {
354    UnrecognizedJunkFound(Vec<NestedMeta>),
355}
356
357fn get_fieldtype_from_attrs(
358    iter: impl Iterator<Item = Attribute>,
359    attr_name: &str,
360) -> Result<(Option<FieldType>, String), FieldTypeError> {
361    for attr in iter {
362        if let Ok(Meta::List(MetaList { path, nested, .. })) = attr.parse_meta() {
363            let Path { segments, .. } = path;
364            if segments
365                .iter()
366                .map(|p| &p.ident)
367                .eq(["delta_struct"].iter().cloned())
368            {
369                let values: Result<Vec<_>, Vec<NestedMeta>> = nested
370                    .iter()
371                    .map(|nested_meta| match nested_meta {
372                        NestedMeta::Meta(Meta::NameValue(MetaNameValue {
373                            path,
374                            lit: Lit::Str(s),
375                            ..
376                        })) => Ok((path.get_ident().map(|i| i.to_string()), s.value())),
377                        e @ _ => Err(e),
378                    })
379                    .fold(Ok(vec![]), |v, i| match (v, i) {
380                        (Ok(mut v), Ok(i)) => {
381                            v.push(i);
382                            Ok(v)
383                        }
384                        (Ok(_), Err(e)) => Err(vec![e.clone()]),
385                        (Err(mut v), Err(e)) => {
386                            v.push(e.clone());
387                            Err(v)
388                        }
389                        (v @ Err(_), _) => v,
390                    });
391                return match values {
392                    Ok(v) => {
393                        let mut field_type = None;
394                        let mut delta_leader = String::new();
395                        for i in v {
396                            match i.0.as_deref() {
397                                Some("delta_leader") => {
398                                    delta_leader = i.1;
399                                },
400                                a @ _ if Some(attr_name) == a => {
401                                   field_type = string_to_fieldtype(&i.1); 
402                                },
403                                a @ _ => {
404                                    abort_call_site!("Unrecognized value {:?}", a);
405                                }
406                            }
407                        }
408                        Ok((field_type, delta_leader))
409                    }
410                    Err(v) => Err(FieldTypeError::UnrecognizedJunkFound(v)),
411                };
412            }
413        }
414    }
415    Ok((None, String::new()))
416}
417
418fn string_to_fieldtype(s: &str) -> Option<FieldType> {
419    match s {
420        "ordered" => Some(FieldType::Ordered),
421        "unordered" => Some(FieldType::Unordered),
422        "scalar" => Some(FieldType::Scalar),
423        "delta" => Some(FieldType::Delta),
424        _ => None,
425    }
426}