1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenTree};
5use proc_macro_error::abort_call_site;
6use quote::{format_ident, quote};
7use std::{iter::FromIterator, str::FromStr};
8use syn::{
9 parse_macro_input, punctuated::Punctuated, Attribute, Data, DeriveInput, Fields, Ident, Lit,
10 Meta, MetaList, MetaNameValue, NestedMeta, Path, PredicateType, Token, TraitBound,
11 TraitBoundModifier, Type, TypeParamBound, WherePredicate,
12};
13
14#[derive(Copy, Clone, Debug, Eq, PartialEq)]
15enum FieldType {
16 Ordered,
17 Unordered,
18 Scalar,
19 Delta,
20}
21
22const VALID_FIELD_TYPES: &str = "\"ordered\", \"unordered\", or \"scalar\"";
23
24#[proc_macro_derive(Delta, attributes(delta_struct))]
25pub fn derive_delta(input: TokenStream) -> TokenStream {
26 let DeriveInput {
27 attrs,
28 vis,
29 ident,
30 mut generics,
31 data,
32 } = parse_macro_input!(input as DeriveInput);
33 let (default_field_type, delta_leader) = match get_fieldtype_from_attrs(attrs.into_iter(), "default") {
34 Ok((v, delta_leader)) => (v.unwrap_or(FieldType::Scalar), delta_leader),
35 Err(_) => {
36 abort_call_site!(
37 "delta_struct(default = ...) for {} is not an accepted value, expected {}.",
38 ident,
39 VALID_FIELD_TYPES
40 );
41 }
42 };
43
44 let (named, fields) = match data {
45 Data::Struct(strukt) => match strukt.fields {
46 Fields::Named(named) => (
47 true,
48 collect_results(
49 named.named.into_iter().map(|field| {
50 (
51 field.ident.unwrap().to_string(),
52 field.ty,
53 get_fieldtype_from_attrs(field.attrs.into_iter(), "field_type"),
54 )
55 }),
56 default_field_type,
57 ),
58 ),
59 Fields::Unnamed(unnamed) => (
60 false,
61 collect_results(
62 unnamed.unnamed.into_iter().enumerate().map(|(i, field)| {
63 (
64 i.to_string(),
65 field.ty,
66 get_fieldtype_from_attrs(field.attrs.into_iter(), "field_type"),
67 )
68 }),
69 default_field_type,
70 ),
71 ),
72 Fields::Unit => {
73 (false, Ok(vec![]))
74 }
75 },
76 _ => {
77 abort_call_site!(
78 "delta_struct::Delta may only be derived for struct types currently. {} is not a struct type."
79 , ident)
80 }
81 };
82 let fields = match fields {
83 Ok(fields) => fields,
84 Err(bad_fields) => {
85 let bad_fields = format!("{:?}", bad_fields);
86 abort_call_site!(
87 "delta_struct(field_type = ...) for fields in {}: {} are not valid values. Expected {}.",
88 ident,
89 bad_fields,
90 VALID_FIELD_TYPES
91 )
92 }
93 };
94 let delta_leader = proc_macro2::TokenStream::from_str(&delta_leader).unwrap();
95 let delta_ident = format_ident!("{}Delta", ident);
96 let delta_fields = delta_fields(named, fields.iter().cloned());
97 let delta_struct = quote! {
98 #delta_leader
99 #vis struct #delta_ident #generics {
100 #delta_fields
101 }
102 };
103 let (delta_compute_let, delta_compute_fields) =
104 delta_compute_fields(named, fields.iter().cloned());
105 let (delta_apply_let, delta_apply_actions) = delta_apply_fields(named, fields.into_iter());
106 let partial_eq_types = generics
107 .type_params()
108 .map(|t| t.ident.clone())
109 .collect::<Vec<_>>();
110 let where_clause = generics.make_where_clause();
111 for ty in partial_eq_types {
112 let mut bounds = Punctuated::new();
113 let mut segments = Punctuated::new();
114 segments.push(Ident::new("std", Span::call_site()).into());
115 segments.push(Ident::new("cmp", Span::call_site()).into());
116 segments.push(Ident::new("PartialEq", Span::call_site()).into());
117 bounds.push(TypeParamBound::Trait(TraitBound {
118 paren_token: None,
119 modifier: TraitBoundModifier::None,
120 lifetimes: None,
121 path: Path {
122 leading_colon: Some(Token!(::)(Span::call_site())),
123 segments,
124 },
125 }));
126 where_clause
127 .predicates
128 .push(WherePredicate::Type(PredicateType {
129 lifetimes: None,
130 bounded_ty: Type::Verbatim(<Ident as Into<TokenTree>>::into(ty).into()),
131 colon_token: Token!(:)(Span::call_site()),
132 bounds,
133 }));
134 }
135 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
136 let delta_impl = quote! {
137 impl #impl_generics Delta for #ident #ty_generics #where_clause {
138 type Output = #delta_ident #generics;
139
140 fn delta(old: Self, new: Self) -> Option<Self::Output> {
141 let mut delta_is_some = false;
142 #delta_compute_let
143 if delta_is_some {
144 Some(Self::Output {
145 #delta_compute_fields
146 })
147 } else {
148 None
149 }
150 }
151
152 fn apply_delta(&mut self, delta: Self::Output) {
153 let Self::Output {
154 #delta_apply_let
155 } = delta;
156 #delta_apply_actions
157 }
158 }
159 };
160 let output = quote! {
161 #delta_struct
162
163 #delta_impl
164 };
165 TokenStream::from(output)
166}
167
168fn delta_fields(
169 named: bool,
170 iter: impl Iterator<Item = (String, Type, FieldType, String)>,
171) -> proc_macro2::TokenStream {
172 FromIterator::from_iter(iter.map(|(ident, ty, field_ty, field_leader)| {
173 let field_leader = proc_macro2::TokenStream::from_str(&field_leader).unwrap();
174 let ident = if named {
175 format_ident!("{}", ident)
176 } else {
177 format_ident!("field_{}", ident)
178 };
179 match field_ty {
180 FieldType::Ordered => unimplemented!(),
181 FieldType::Unordered => {
182 let add = format_ident!("{}_add", ident);
183 let remove = format_ident!("{}_remove", ident);
184 quote! {
185 #field_leader
186 pub #add: Vec<<#ty as ::std::iter::IntoIterator>::Item>,
187 #field_leader
188 pub #remove: Vec<<#ty as ::std::iter::IntoIterator>::Item>,
189 }
190 }
191 FieldType::Scalar => {
192 quote! {
193 #field_leader
194 pub #ident: ::std::option::Option<#ty>,
195 }
196 }
197 FieldType::Delta => {
198 quote! {
199 #field_leader
200 pub #ident: ::std::option::Option<<#ty as Delta>::Output>,
201 }
202 }
203 }
204 }))
205}
206
207fn delta_compute_fields(
208 named: bool,
209 iter: impl Iterator<Item = (String, Type, FieldType, String)>,
210) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
211 iter.map(|(og_ident, _ty, field_ty, _field_leader)| {
212 let ident = if named {
213 format_ident!("{}", og_ident)
214 } else {
215 format_ident!("field_{}", og_ident)
216 };
217 let og_ident: proc_macro2::TokenStream = FromStr::from_str(&og_ident).unwrap();
218 match field_ty {
219 FieldType::Ordered => unimplemented!(),
220 FieldType::Unordered => {
221 let add = format_ident!("{}_add", ident);
222 let remove = format_ident!("{}_remove", ident);
223
224 (
225 quote! {
226 let mut #add = new.#og_ident.into_iter().collect::<::std::vec::Vec<_>>();
227 let #remove = old.#og_ident.into_iter().filter_map(|i| {
228 if let Some(index) = #add.iter().position(|a| a == &i) {
229 #add.remove(index);
230 None
231 } else {
232 Some(i)
233 }
234 }).collect::<::std::vec::Vec<_>>();
235 delta_is_some = delta_is_some || !#add.is_empty() || !#remove.is_empty();
236 },
237 quote! {
238 #add,
239 #remove,
240 },
241 )
242 }
243 FieldType::Scalar => (
244 quote! {
245 let #ident = if old.#og_ident != new.#og_ident {
246 delta_is_some = true;
247 Some(new.#og_ident)
248 } else {
249 None
250 };
251 },
252 quote! {
253 #ident,
254 },
255 ),
256 FieldType::Delta => (
257 quote! {
258 let #ident = Delta::delta(old.#og_ident, new.#og_ident);
259 delta_is_some = delta_is_some || #ident.is_some();
260
261 },
262 quote! {
263 #ident,
264 },
265 ),
266 }
267 })
268 .unzip()
269}
270
271fn delta_apply_fields(
272 named: bool,
273 iter: impl Iterator<Item = (String, Type, FieldType, String)>,
274) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
275 iter.map(|(og_ident, ty, field_ty, _field_leader)| {
276 let ident = if named {
277 format_ident!("{}", og_ident)
278 } else {
279 format_ident!("field_{}", og_ident)
280 };
281 let og_ident: proc_macro2::TokenStream = FromStr::from_str(&og_ident).unwrap();
282 match field_ty {
283 FieldType::Ordered => unimplemented!(),
284 FieldType::Unordered => {
285 let add = format_ident!("{}_add", ident);
286 let remove = format_ident!("{}_remove", ident);
287 (
288 quote! {
289 #add,
290 mut #remove,
291 },
292 quote! {
293 {
294 let og = ::std::mem::replace(&mut self.#og_ident, ::std::iter::FromIterator::from_iter(vec![]));
295 let mut #ident: #ty = ::std::iter::FromIterator::from_iter(og.into_iter().filter_map(|i| {
296 if let Some(index) = #remove.iter().position(|a| a == &i) {
297 #remove.remove(index);
298 None
299 } else {
300 Some(i)
301 }
302 }));
303 #ident.extend(#add.into_iter());
304 self.#og_ident = #ident;
305 }
306 }
307 )
308 }
309 FieldType::Scalar =>
310 (
311 quote! {
312 #ident,
313 },
314 quote! {
315 if let Some(v) = #ident {
316 self.#og_ident = v;
317 }
318 }
319 ),
320 FieldType::Delta =>
321 (
322 quote! {
323 #ident,
324 },
325 quote!{
326 if let Some(v) = #ident {
327 self.#og_ident.apply_delta(v);
328 }
329 }
330 ),
331 }
332 }).unzip()
333}
334
335fn collect_results(
336 iter: impl Iterator<Item = (String, Type, Result<(Option<FieldType>, String), FieldTypeError>)>,
337 default_field_type: FieldType,
338) -> Result<Vec<(String, Type, FieldType, String)>, Vec<String>> {
339 iter.fold(Ok(vec![]), |v, i| match (v, i) {
340 (Ok(mut v), (ident, b, Ok((c, d)))) => {
341 v.push((ident, b, c.unwrap_or(default_field_type), d));
342 Ok(v)
343 }
344 (Ok(_), (ident, _, Err(_))) => Err(vec![ident]),
345 (Err(mut v), (ident, _, Err(_))) => {
346 v.push(ident);
347 Err(v)
348 }
349 (v @ Err(_), _) => v,
350 })
351}
352
353enum FieldTypeError {
354 UnrecognizedJunkFound(Vec<NestedMeta>),
355}
356
357fn get_fieldtype_from_attrs(
358 iter: impl Iterator<Item = Attribute>,
359 attr_name: &str,
360) -> Result<(Option<FieldType>, String), FieldTypeError> {
361 for attr in iter {
362 if let Ok(Meta::List(MetaList { path, nested, .. })) = attr.parse_meta() {
363 let Path { segments, .. } = path;
364 if segments
365 .iter()
366 .map(|p| &p.ident)
367 .eq(["delta_struct"].iter().cloned())
368 {
369 let values: Result<Vec<_>, Vec<NestedMeta>> = nested
370 .iter()
371 .map(|nested_meta| match nested_meta {
372 NestedMeta::Meta(Meta::NameValue(MetaNameValue {
373 path,
374 lit: Lit::Str(s),
375 ..
376 })) => Ok((path.get_ident().map(|i| i.to_string()), s.value())),
377 e @ _ => Err(e),
378 })
379 .fold(Ok(vec![]), |v, i| match (v, i) {
380 (Ok(mut v), Ok(i)) => {
381 v.push(i);
382 Ok(v)
383 }
384 (Ok(_), Err(e)) => Err(vec![e.clone()]),
385 (Err(mut v), Err(e)) => {
386 v.push(e.clone());
387 Err(v)
388 }
389 (v @ Err(_), _) => v,
390 });
391 return match values {
392 Ok(v) => {
393 let mut field_type = None;
394 let mut delta_leader = String::new();
395 for i in v {
396 match i.0.as_deref() {
397 Some("delta_leader") => {
398 delta_leader = i.1;
399 },
400 a @ _ if Some(attr_name) == a => {
401 field_type = string_to_fieldtype(&i.1);
402 },
403 a @ _ => {
404 abort_call_site!("Unrecognized value {:?}", a);
405 }
406 }
407 }
408 Ok((field_type, delta_leader))
409 }
410 Err(v) => Err(FieldTypeError::UnrecognizedJunkFound(v)),
411 };
412 }
413 }
414 }
415 Ok((None, String::new()))
416}
417
418fn string_to_fieldtype(s: &str) -> Option<FieldType> {
419 match s {
420 "ordered" => Some(FieldType::Ordered),
421 "unordered" => Some(FieldType::Unordered),
422 "scalar" => Some(FieldType::Scalar),
423 "delta" => Some(FieldType::Delta),
424 _ => None,
425 }
426}