extendable_data/
lib.rs

1
2use std::collections::{HashMap, HashSet};
3use std::{mem};
4use std::ops::{IndexMut};
5pub use extendable_data_helpers::extendable_data;
6use syn::meta::ParseNestedMeta;
7use syn::spanned::Spanned;
8use syn::token::{Comma, Lt, Gt, Where, Brace, Paren};
9use syn::{self, DeriveInput, Generics, GenericParam, WhereClause, WherePredicate, Variant, Attribute, Path, Fields, FieldsNamed, FieldsUnnamed};
10use syn::parse::{Result};
11use syn::punctuated::{Punctuated};
12use syn::Fields::*;
13use proc_macro2::{TokenStream, Ident, Span};
14use quote::{quote, ToTokens};
15
16
17fn overwrite_optionals<T>(option_a: Option<T>, option_b: Option<T>) -> Option<T> {
18    match (&option_a, &option_b) {
19        (Some(_), None) => option_a,
20        _ => option_b
21    }
22}
23
24fn path_into_string(path: Path) -> String {
25    path.into_token_stream().to_string()
26}
27
28fn path_to_string(path: &Path) -> String {
29    path.to_token_stream().to_string()
30}
31
32fn meta_to_string(meta: &syn::Meta) -> String {
33    match meta {
34        syn::Meta::Path(p) => path_to_string(p),
35        syn::Meta::List(m) => path_to_string(&m.path),
36        syn::Meta::NameValue(m) => path_to_string(&m.path)
37    }
38}
39
40trait Len {
41    fn len(&self) -> usize;
42}
43
44impl<T> Len for Vec<T> {
45    fn len(&self) -> usize {
46        self.len()
47    }
48}
49
50impl<T, P> Len for Punctuated<T, P> {
51    fn len(&self) -> usize {
52        self.len()
53    }
54}
55
56fn combine_iters<T, U, F1, F2>(iter_a: T, iter_b: T, to_str: F1, handle_conflict: F2, args: &Args) -> Result<T>
57where
58    T: IntoIterator<Item = U> + Default + Extend<U> + IndexMut<usize, Output = U> + Len,
59    F1: Fn(&U) -> String,
60    F2: Fn(&mut T, usize, U, &Args) -> Result<()>
61{
62    let mut seen: HashMap<String, usize> = HashMap::with_capacity(iter_a.len()); // Will always need exactly this much
63    let mut combined: T = <T as std::default::Default>::default();
64    let mut i = 0;
65    for a in iter_a.into_iter() {
66        let repr = to_str(&a);
67        if !args.filter.contains(&repr) {
68            seen.insert(repr, i);
69            combined.extend([a]);
70            i += 1;
71        }
72    }
73    for b in iter_b.into_iter() {
74        if let Some(i) = seen.remove(&to_str(&b)) {
75            handle_conflict(&mut combined, i, b, args)?;
76        } else {
77            combined.extend([b]);
78        }
79    }
80    Ok(combined)
81}
82
83fn handle_conflicts_basic<T: IntoIterator<Item = U> + Default + Extend<U> + IndexMut<usize, Output = U> + Len, U>(list: &mut T, index: usize, data: U, _args: &Args) -> Result<()> {
84    list[index] = data;
85    Ok(())
86}
87
88fn combine_attrs(attr_a: Vec<Attribute>, attr_b: Vec<Attribute>, args: &Args) -> Result<Vec<Attribute>> {
89    combine_iters(attr_a, attr_b, |x| meta_to_string(&x.meta), handle_conflicts_basic, args)
90}
91
92fn combine_wheres(where_a: WhereClause, where_b: WhereClause) -> WhereClause {
93    let pred_a = where_a.predicates.into_iter();
94    let pred_b = where_b.predicates.into_iter();
95    let combined: Punctuated<WherePredicate, Comma> = Punctuated::from_iter(pred_a.chain(pred_b));
96    WhereClause { 
97        where_token: Where::default(), 
98        predicates: combined,
99    }
100}
101
102fn combine_generics(input_a: Generics, input_b: Generics) -> Generics {
103    let params_a = input_a.params.into_iter();
104    let params_b = input_b.params.into_iter();
105    let combined = params_a.chain(params_b);
106    let params_c: Punctuated<GenericParam, Comma> = Punctuated::from_iter(combined);
107    let where_c: Option<WhereClause> = match (input_a.where_clause, input_b.where_clause) {
108        (None, None) => None,
109        (None, Some(where_b)) => Some(where_b),
110        (Some(where_a), None) => Some(where_a),
111        (Some(where_a), Some(where_b)) => Some(combine_wheres(where_a, where_b)),
112    };
113    Generics { 
114        lt_token: Some(Lt::default()), 
115        params: params_c, 
116        gt_token: Some(Gt::default()), 
117        where_clause: where_c
118    }
119}
120
121fn combine_fields_named(fields_a: FieldsNamed, fields_b: FieldsNamed, args: &Args) -> Result<FieldsNamed> {
122    let named = combine_iters(fields_a.named, fields_b.named, |x| x.ident.as_ref().unwrap().to_string(), handle_conflicts_basic, args)?;
123    Ok(FieldsNamed {
124        brace_token: Brace::default(),
125        named
126    })
127}
128
129fn combine_fields(fields_a: Fields, fields_b: Fields, args: &Args, merging: bool) -> Result<Fields> {
130    let b_span = fields_b.span();
131    match (fields_a, fields_b) {
132        (_, f) if !merging => Ok(f),
133        (Named(fields_a), Named(fields_b)) => {
134            let resp = combine_fields_named(fields_a, fields_b, args)?;
135            Ok(Named(resp))
136        },
137        (Unnamed(fields_a), Unnamed(fields_b)) => {
138            let unnamed = combine_iters(fields_a.unnamed, fields_b.unnamed, |x| x.ty.to_token_stream().to_string(), handle_conflicts_basic, args)?;
139            Ok(Unnamed(FieldsUnnamed {
140                paren_token: Paren::default(),
141                unnamed
142            }))
143        },
144        (Unit, f) | (f, Unit) => Ok(f),
145        _ => Err(syn::Error::new(b_span, "Can not combine provided structs. Either make sure they are the same type, or filter out the offending struct."))
146    }
147}
148
149fn combine_enum_variants(variants_a: Punctuated<Variant, Comma>, variants_b: Punctuated<Variant, Comma>, args: Args) -> Result<Punctuated<Variant, Comma>> {
150    fn handle_merge_conflict(combined: &mut Punctuated<Variant, Comma>, i: usize, b: Variant, args: &Args) -> Result<()> {
151        let a = mem::replace(&mut combined[i], b);
152        combined[i].attrs = combine_attrs(a.attrs, mem::take(&mut combined[i].attrs), args)?;
153        combined[i].fields = combine_fields(a.fields, mem::replace(&mut combined[i].fields, Unit), args, args.merge)?;
154        combined[i].discriminant = overwrite_optionals(a.discriminant, mem::take(&mut combined[i].discriminant));
155        Ok(())
156    }
157    let handle_conflicts = if args.merge {
158        handle_merge_conflict
159    } else {
160        handle_conflicts_basic
161    };
162    combine_iters(variants_a, variants_b, |x| x.ident.to_string(), handle_conflicts, &args)
163}
164
165fn combine_enums(enum_a: syn::DataEnum, enum_b: syn::DataEnum, args: Args) -> Result<(TokenStream, &'static str)> {
166    let variants = combine_enum_variants(enum_a.variants, enum_b.variants, args)?;
167    let tokens = quote!({
168        #variants
169    });
170    Ok((tokens, "enum"))
171}
172
173fn combine_structs(struct_a: syn::DataStruct, struct_b: syn::DataStruct, args: Args) -> Result<(TokenStream, &'static str)> {
174    let fields = combine_fields(struct_a.fields, struct_b.fields, &args, true)?;
175    let tokens = match fields {
176        Named(fields) => quote!(#fields),
177        Unnamed(fields) => quote!(#fields;),
178        Unit => quote!(;)
179    };
180    Ok((tokens, "struct"))
181}
182
183fn combine_unions(union_a: syn::DataUnion, union_b: syn::DataUnion, args: Args) -> Result<(TokenStream, &'static str)> {
184    let fields = combine_fields_named(union_a.fields, union_b.fields, &args)?;
185    let tokens = quote!({#fields});
186    Ok((tokens, "union"))
187}
188
189fn construct_stream (
190        data: TokenStream, 
191        data_token: Ident, 
192        visibility: syn::Visibility, 
193        gens: Generics, 
194        name: syn::Ident, 
195        attrs: Vec<syn::Attribute>
196    ) -> TokenStream {
197    quote! {
198        #(#attrs)*
199        #visibility #data_token #name #gens #data
200    }
201}
202
203#[derive(Default)]
204struct Args {
205    filter: HashSet<String>,
206    merge: bool
207}
208
209impl Args {
210    fn parse(&mut self, meta: ParseNestedMeta) -> Result<()> {
211        if meta.path.is_ident("filter") {
212            meta.parse_nested_meta(|meta| {
213                let ident: String = path_into_string(meta.path);
214                self.filter.insert(ident);
215                Ok(())
216            })
217        } else if meta.path.is_ident("merge_on_conflict") {
218            self.merge = true;
219            Ok(())
220        } else {
221            Err(meta.error("Unsupported Argument"))
222        }
223    }
224}
225
226/// Combines two datas into a single data using TokenStreams.
227///
228/// Additionally, optionally a stream of arguments can provided. At time of writing, the supported arguments are
229/// `merge_on_conflict` and `filter(list)`. 
230///
231/// The resulting TokenStream can then be used by a procedural macro to generate code that represents the new combined data.
232///
233/// See extendable_data_helpers::extendable_data for the macro that generates the code.
234pub fn combine_data(input_a: TokenStream, input_b: TokenStream, args_input: Option<TokenStream>) -> TokenStream {
235    let ast_a = match syn::parse2::<DeriveInput>(input_a) {
236        Ok(a) => a,
237        Err(e) => return e.to_compile_error()
238    };
239    let ast_b = match syn::parse2::<DeriveInput>(input_b) {
240        Ok(b) => b,
241        Err(e) => return e.to_compile_error()
242    };
243    let mut args = Args::default();
244    if let Some(a) = args_input {
245        let arg_parser = syn::meta::parser(|meta| args.parse(meta));
246        if let Err(e) = syn::parse::Parser::parse2(arg_parser, a) {
247            return e.to_compile_error();
248        }
249    }
250    let b_span = ast_b.span();
251    let generics = combine_generics(ast_a.generics, ast_b.generics);
252    let attrs = match combine_attrs(ast_a.attrs, ast_b.attrs, &args) {
253        Ok(attrs) => attrs,
254        Err(e) => return e.to_compile_error()
255    };
256    let resp = match (ast_a.data, ast_b.data) {
257        (syn::Data::Enum(enum_a), syn::Data::Enum(enum_b)) => combine_enums(enum_a, enum_b, args),
258        (syn::Data::Struct(struct_a), syn::Data::Struct(struct_b)) => combine_structs(struct_a, struct_b, args),
259        (syn::Data::Union(union_a), syn::Data::Union(union_b)) => combine_unions(union_a, union_b, args),
260        _ => Err(syn::Error::new(b_span, "Can only combine 2 of the same type of data structure!",))
261    };
262    match resp {
263        Ok((data, data_token)) => {
264            let vis_b = ast_b.vis;
265            construct_stream(data, Ident::new(data_token, Span::call_site()), vis_b, generics, ast_b.ident, attrs)
266        },
267        Err(e) => e.to_compile_error()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273
274    use super::combine_data;
275    use quote::quote;
276    use syn::{DeriveInput};
277    use proc_macro2::TokenStream;
278    use assert_tokenstreams_eq::assert_tokenstreams_eq;
279
280    fn assert_streams(result: TokenStream, expected: TokenStream) {
281        assert_eq!(syn::parse2::<DeriveInput>(result).unwrap(), syn::parse2::<DeriveInput>(expected).unwrap());
282    }
283
284    fn assert_compiler_error(result: TokenStream, msg: &str) {
285        let expected = quote!(::core::compile_error! { #msg });
286        assert_tokenstreams_eq!(&result, &expected);
287    }
288
289    #[test]
290    fn test_combine_enums() {
291        let enum_a = quote! {
292            enum A {
293                One(Thing),
294                Two {
295                    SubOne: i32,
296                    SubTwo: Another
297                },
298                Three 
299            }
300        };
301        let enum_b = quote! {
302            enum B {
303                Four(Thing, Thing, Thing),
304                Five,
305                Six
306            }
307        };
308        let expected = quote! {
309            enum B {
310                One(Thing),
311                Two {
312                    SubOne: i32,
313                    SubTwo: Another
314                },
315                Three,
316                Four(Thing, Thing, Thing),
317                Five,
318                Six
319            }
320        };
321        let result = combine_data(enum_a, enum_b, None);
322        assert_streams(result, expected);
323    }
324
325    #[test]
326    fn test_combine_named_structs() {
327        let struct_a = quote! {
328            struct A {
329                one: i32
330            }
331        };
332        let struct_b = quote! {
333            struct B {
334                two: SomeType
335            }
336        };
337        let expected = quote! {
338            struct B {
339                one: i32,
340                two: SomeType
341            }
342        };
343        let result = combine_data(struct_a, struct_b, None);
344        assert_streams(result, expected);
345    }
346
347    #[test]
348    fn test_combine_unnamed_structs() {
349        let struct_a = quote! {
350            struct A(i32, SomeType);
351        };
352        let struct_b = quote! {
353            struct B(i64, Another);
354        };
355        let expected = quote! {
356            struct B(i32, SomeType, i64, Another);
357        };
358        let result = combine_data(struct_a, struct_b, None);
359        assert_streams(result, expected);
360    }
361
362    #[test]
363    fn test_combine_unit_structs() {
364        let struct_a = quote! {
365            struct A;
366        };
367        let struct_b = quote! {
368            struct B;
369        };
370        let struct_c = quote! {
371            struct C {
372                one: i32
373            }
374        };
375        let struct_d = quote! {
376            struct D(i32, i32);
377        };
378        let expected_unit = quote! {
379            struct B;
380        };
381        let expected_named_one = quote! {
382            struct C {
383                one: i32
384            }
385        };
386        let expected_named_two = quote! {
387            struct A {
388                one: i32
389            }
390        };
391        let expected_unnamed_one = quote! {
392            struct D(i32, i32);
393        };
394        let expected_unnamed_two = quote! {
395            struct A(i32, i32);
396        };
397
398        // Don't care for testing, just clone
399        let result_unit = combine_data(struct_a.clone(), struct_b, None);
400        let result_named_one = combine_data(struct_a.clone(), struct_c.clone(), None);
401        let result_named_two = combine_data(struct_c, struct_a.clone(), None);
402        let result_unnamed_one = combine_data(struct_a.clone(), struct_d.clone(), None);
403        let result_unnamed_two = combine_data(struct_d, struct_a, None);
404
405        assert_streams(result_unit, expected_unit);
406        assert_streams(result_named_one, expected_named_one);
407        assert_streams(result_named_two, expected_named_two);
408        assert_streams(result_unnamed_one, expected_unnamed_one);
409        assert_streams(result_unnamed_two, expected_unnamed_two);
410    }
411
412    #[test]
413    fn test_invalid_combine() {
414        let input_a = quote! {
415            struct A;
416        };
417        let input_b = quote! {
418            enum B {
419                Thing
420            }
421        };
422        let result = combine_data(input_a, input_b, None);
423        assert_compiler_error(result, "Can only combine 2 of the same type of data structure!");
424    }
425
426    #[test]
427    fn test_invalid_combine_structs() {
428        let input_a = quote! {
429            struct A(i32, i32);
430        };
431        let input_b = quote! {
432            struct B {
433                one: i32
434            }
435        };
436        let result = combine_data(input_a, input_b, None);
437        assert_compiler_error(result, "Can not combine provided structs. Either make sure they are the same type, or filter out the offending struct.");
438    }
439
440    #[test]
441    fn test_invalid_args() {
442        let input_a = quote! {
443            struct A;
444        };
445        let input_b = quote! {
446            struct B;
447        };
448        let result = combine_data(input_a, input_b, Some(quote!(fake arg)));
449        assert_compiler_error(result, "Unsupported Argument");
450    }
451
452    #[test]
453    fn test_combine_visibility() {
454        let input_a = quote! {
455            enum A {
456                One
457            }
458        };
459        let input_b = quote! {
460            pub enum B {
461                Two
462            }
463        };
464        let expected = quote! {
465            pub enum B {
466                One,
467                Two
468            }
469        };
470        let result = combine_data(input_a, input_b, None);
471        assert_streams(result, expected);
472    }
473
474    #[test]
475    fn test_combine_attributes() {
476        let input_a = quote! {
477            #[some_attr]
478            enum A {
479                One
480            }
481        };
482        let input_b = quote! {
483            #[another(attr)]
484            enum B {
485                #[on_attr]
486                Two
487            }
488        };
489        let expected = quote! {
490            #[some_attr]
491            #[another(attr)]
492            enum B {
493                One,
494
495                #[on_attr]
496                Two
497            }
498        };
499        let result = combine_data(input_a, input_b, None);
500        assert_streams(result, expected);
501        
502        let input_a = quote! {
503            #[some_attr]
504            struct A;
505        };
506        let input_b = quote! {
507            #[another_attr]
508            struct B {
509                one: i32
510            }
511        };
512        let expected = quote! {
513            #[some_attr]
514            #[another_attr]
515            struct B {
516                one: i32
517            }
518        };
519        let result = combine_data(input_a, input_b, None);
520        assert_streams(result, expected);
521
522    }
523
524    #[test]
525    fn test_combine_generics() {
526        let input_a = quote! {
527            enum A<'life, T> {
528                One(Thing<'life, T>)
529            }
530        };
531        let input_b = quote! {
532            enum B<'efil, U> {
533                Two(Thing<'efil, U>)
534            }
535        };
536        let expected = quote! {
537            enum B<'life, 'efil, T, U> {
538                One(Thing<'life, T>),
539                Two(Thing<'efil, U>)
540            }
541        };
542        let result = combine_data(input_a, input_b, None);
543        assert_streams(result, expected);
544    }
545
546    #[test]
547    fn test_namespace_confict_overwrite() {
548        let input_a = quote! {
549            enum A {
550                One,
551                #[one_attr]
552                Two,
553                Three {
554                    x: i32
555                }
556            }
557        };
558        let input_b = quote! {
559            enum B {
560                #[two_attr]
561                Two(Thing),
562                Three {
563                    y: i32
564                },
565                Four
566            }
567        };
568        let expected = quote! {
569            enum B {
570                One,
571                #[two_attr]
572                Two(Thing),
573                Three {
574                    y: i32
575                },
576                Four
577            }
578        };
579        let result = combine_data(input_a, input_b, None);
580        assert_streams(result, expected);
581    }
582
583    #[test]
584    fn test_namespace_conflict_merge() {
585        let input_a = quote! {
586            enum A {
587                One,
588                #[one_attr]
589                Two,
590                Three {
591                    x: i32
592                }
593            }
594        };
595        let input_b = quote! {
596            enum B {
597                #[two_attr]
598                Two(Thing),
599                Three {
600                    y: i32
601                },
602                Four
603            }
604        };
605        let expected = quote! {
606            enum B {
607                #[one_attr]
608                #[two_attr]
609                Two(Thing),
610                Three {
611                    x: i32,
612                    y: i32
613                },
614                Four
615            }
616        };
617        let args = quote!(merge_on_conflict, filter(One));
618        let result = combine_data(input_a, input_b, Some(args));
619        assert_streams(result, expected);
620    }
621
622    #[test]
623    fn test_namespace_conflict_struct() {
624        let input_a = quote! {
625            struct A {
626                x: i32,
627                y: i32
628            }
629        };
630        let input_b = quote! {
631            struct B {
632                x: i64
633            }
634        };
635        let expected = quote! {
636            struct B {
637                x: i64,
638                y: i32
639            }
640        };
641        let result = combine_data(input_a, input_b, None);
642        assert_streams(result, expected);
643    }
644
645    #[test]
646    fn test_filter() {
647        let input_a = quote! {
648            #[attr]
649            enum A {
650                #[another_attr]
651                One,
652                Two,
653                #[another_attr]
654                Three
655            }
656        };
657        let input_b = quote! {
658            enum B {
659                One,
660                Four,
661            }
662        };
663        let expected = quote! {
664            enum B {
665                One,
666                #[another_attr]
667                Three,
668                Four
669            }
670        };
671        let filter = quote!(filter(Two, attr, another_attr));
672        let result = combine_data(input_a, input_b, Some(filter));
673        assert_streams(result, expected);
674    }
675}