1
2use std::collections::{HashMap, HashSet};
3use std::{mem};
4use std::ops::{IndexMut};
5pub use extendable_data_helpers::extendable_data;
6use syn::meta::ParseNestedMeta;
7use syn::spanned::Spanned;
8use syn::token::{Comma, Lt, Gt, Where, Brace, Paren};
9use syn::{self, DeriveInput, Generics, GenericParam, WhereClause, WherePredicate, Variant, Attribute, Path, Fields, FieldsNamed, FieldsUnnamed};
10use syn::parse::{Result};
11use syn::punctuated::{Punctuated};
12use syn::Fields::*;
13use proc_macro2::{TokenStream, Ident, Span};
14use quote::{quote, ToTokens};
15
16
17fn overwrite_optionals<T>(option_a: Option<T>, option_b: Option<T>) -> Option<T> {
18 match (&option_a, &option_b) {
19 (Some(_), None) => option_a,
20 _ => option_b
21 }
22}
23
24fn path_into_string(path: Path) -> String {
25 path.into_token_stream().to_string()
26}
27
28fn path_to_string(path: &Path) -> String {
29 path.to_token_stream().to_string()
30}
31
32fn meta_to_string(meta: &syn::Meta) -> String {
33 match meta {
34 syn::Meta::Path(p) => path_to_string(p),
35 syn::Meta::List(m) => path_to_string(&m.path),
36 syn::Meta::NameValue(m) => path_to_string(&m.path)
37 }
38}
39
40trait Len {
41 fn len(&self) -> usize;
42}
43
44impl<T> Len for Vec<T> {
45 fn len(&self) -> usize {
46 self.len()
47 }
48}
49
50impl<T, P> Len for Punctuated<T, P> {
51 fn len(&self) -> usize {
52 self.len()
53 }
54}
55
56fn combine_iters<T, U, F1, F2>(iter_a: T, iter_b: T, to_str: F1, handle_conflict: F2, args: &Args) -> Result<T>
57where
58 T: IntoIterator<Item = U> + Default + Extend<U> + IndexMut<usize, Output = U> + Len,
59 F1: Fn(&U) -> String,
60 F2: Fn(&mut T, usize, U, &Args) -> Result<()>
61{
62 let mut seen: HashMap<String, usize> = HashMap::with_capacity(iter_a.len()); let mut combined: T = <T as std::default::Default>::default();
64 let mut i = 0;
65 for a in iter_a.into_iter() {
66 let repr = to_str(&a);
67 if !args.filter.contains(&repr) {
68 seen.insert(repr, i);
69 combined.extend([a]);
70 i += 1;
71 }
72 }
73 for b in iter_b.into_iter() {
74 if let Some(i) = seen.remove(&to_str(&b)) {
75 handle_conflict(&mut combined, i, b, args)?;
76 } else {
77 combined.extend([b]);
78 }
79 }
80 Ok(combined)
81}
82
83fn handle_conflicts_basic<T: IntoIterator<Item = U> + Default + Extend<U> + IndexMut<usize, Output = U> + Len, U>(list: &mut T, index: usize, data: U, _args: &Args) -> Result<()> {
84 list[index] = data;
85 Ok(())
86}
87
88fn combine_attrs(attr_a: Vec<Attribute>, attr_b: Vec<Attribute>, args: &Args) -> Result<Vec<Attribute>> {
89 combine_iters(attr_a, attr_b, |x| meta_to_string(&x.meta), handle_conflicts_basic, args)
90}
91
92fn combine_wheres(where_a: WhereClause, where_b: WhereClause) -> WhereClause {
93 let pred_a = where_a.predicates.into_iter();
94 let pred_b = where_b.predicates.into_iter();
95 let combined: Punctuated<WherePredicate, Comma> = Punctuated::from_iter(pred_a.chain(pred_b));
96 WhereClause {
97 where_token: Where::default(),
98 predicates: combined,
99 }
100}
101
102fn combine_generics(input_a: Generics, input_b: Generics) -> Generics {
103 let params_a = input_a.params.into_iter();
104 let params_b = input_b.params.into_iter();
105 let combined = params_a.chain(params_b);
106 let params_c: Punctuated<GenericParam, Comma> = Punctuated::from_iter(combined);
107 let where_c: Option<WhereClause> = match (input_a.where_clause, input_b.where_clause) {
108 (None, None) => None,
109 (None, Some(where_b)) => Some(where_b),
110 (Some(where_a), None) => Some(where_a),
111 (Some(where_a), Some(where_b)) => Some(combine_wheres(where_a, where_b)),
112 };
113 Generics {
114 lt_token: Some(Lt::default()),
115 params: params_c,
116 gt_token: Some(Gt::default()),
117 where_clause: where_c
118 }
119}
120
121fn combine_fields_named(fields_a: FieldsNamed, fields_b: FieldsNamed, args: &Args) -> Result<FieldsNamed> {
122 let named = combine_iters(fields_a.named, fields_b.named, |x| x.ident.as_ref().unwrap().to_string(), handle_conflicts_basic, args)?;
123 Ok(FieldsNamed {
124 brace_token: Brace::default(),
125 named
126 })
127}
128
129fn combine_fields(fields_a: Fields, fields_b: Fields, args: &Args, merging: bool) -> Result<Fields> {
130 let b_span = fields_b.span();
131 match (fields_a, fields_b) {
132 (_, f) if !merging => Ok(f),
133 (Named(fields_a), Named(fields_b)) => {
134 let resp = combine_fields_named(fields_a, fields_b, args)?;
135 Ok(Named(resp))
136 },
137 (Unnamed(fields_a), Unnamed(fields_b)) => {
138 let unnamed = combine_iters(fields_a.unnamed, fields_b.unnamed, |x| x.ty.to_token_stream().to_string(), handle_conflicts_basic, args)?;
139 Ok(Unnamed(FieldsUnnamed {
140 paren_token: Paren::default(),
141 unnamed
142 }))
143 },
144 (Unit, f) | (f, Unit) => Ok(f),
145 _ => Err(syn::Error::new(b_span, "Can not combine provided structs. Either make sure they are the same type, or filter out the offending struct."))
146 }
147}
148
149fn combine_enum_variants(variants_a: Punctuated<Variant, Comma>, variants_b: Punctuated<Variant, Comma>, args: Args) -> Result<Punctuated<Variant, Comma>> {
150 fn handle_merge_conflict(combined: &mut Punctuated<Variant, Comma>, i: usize, b: Variant, args: &Args) -> Result<()> {
151 let a = mem::replace(&mut combined[i], b);
152 combined[i].attrs = combine_attrs(a.attrs, mem::take(&mut combined[i].attrs), args)?;
153 combined[i].fields = combine_fields(a.fields, mem::replace(&mut combined[i].fields, Unit), args, args.merge)?;
154 combined[i].discriminant = overwrite_optionals(a.discriminant, mem::take(&mut combined[i].discriminant));
155 Ok(())
156 }
157 let handle_conflicts = if args.merge {
158 handle_merge_conflict
159 } else {
160 handle_conflicts_basic
161 };
162 combine_iters(variants_a, variants_b, |x| x.ident.to_string(), handle_conflicts, &args)
163}
164
165fn combine_enums(enum_a: syn::DataEnum, enum_b: syn::DataEnum, args: Args) -> Result<(TokenStream, &'static str)> {
166 let variants = combine_enum_variants(enum_a.variants, enum_b.variants, args)?;
167 let tokens = quote!({
168 #variants
169 });
170 Ok((tokens, "enum"))
171}
172
173fn combine_structs(struct_a: syn::DataStruct, struct_b: syn::DataStruct, args: Args) -> Result<(TokenStream, &'static str)> {
174 let fields = combine_fields(struct_a.fields, struct_b.fields, &args, true)?;
175 let tokens = match fields {
176 Named(fields) => quote!(#fields),
177 Unnamed(fields) => quote!(#fields;),
178 Unit => quote!(;)
179 };
180 Ok((tokens, "struct"))
181}
182
183fn combine_unions(union_a: syn::DataUnion, union_b: syn::DataUnion, args: Args) -> Result<(TokenStream, &'static str)> {
184 let fields = combine_fields_named(union_a.fields, union_b.fields, &args)?;
185 let tokens = quote!({#fields});
186 Ok((tokens, "union"))
187}
188
189fn construct_stream (
190 data: TokenStream,
191 data_token: Ident,
192 visibility: syn::Visibility,
193 gens: Generics,
194 name: syn::Ident,
195 attrs: Vec<syn::Attribute>
196 ) -> TokenStream {
197 quote! {
198 #(#attrs)*
199 #visibility #data_token #name #gens #data
200 }
201}
202
203#[derive(Default)]
204struct Args {
205 filter: HashSet<String>,
206 merge: bool
207}
208
209impl Args {
210 fn parse(&mut self, meta: ParseNestedMeta) -> Result<()> {
211 if meta.path.is_ident("filter") {
212 meta.parse_nested_meta(|meta| {
213 let ident: String = path_into_string(meta.path);
214 self.filter.insert(ident);
215 Ok(())
216 })
217 } else if meta.path.is_ident("merge_on_conflict") {
218 self.merge = true;
219 Ok(())
220 } else {
221 Err(meta.error("Unsupported Argument"))
222 }
223 }
224}
225
226pub fn combine_data(input_a: TokenStream, input_b: TokenStream, args_input: Option<TokenStream>) -> TokenStream {
235 let ast_a = match syn::parse2::<DeriveInput>(input_a) {
236 Ok(a) => a,
237 Err(e) => return e.to_compile_error()
238 };
239 let ast_b = match syn::parse2::<DeriveInput>(input_b) {
240 Ok(b) => b,
241 Err(e) => return e.to_compile_error()
242 };
243 let mut args = Args::default();
244 if let Some(a) = args_input {
245 let arg_parser = syn::meta::parser(|meta| args.parse(meta));
246 if let Err(e) = syn::parse::Parser::parse2(arg_parser, a) {
247 return e.to_compile_error();
248 }
249 }
250 let b_span = ast_b.span();
251 let generics = combine_generics(ast_a.generics, ast_b.generics);
252 let attrs = match combine_attrs(ast_a.attrs, ast_b.attrs, &args) {
253 Ok(attrs) => attrs,
254 Err(e) => return e.to_compile_error()
255 };
256 let resp = match (ast_a.data, ast_b.data) {
257 (syn::Data::Enum(enum_a), syn::Data::Enum(enum_b)) => combine_enums(enum_a, enum_b, args),
258 (syn::Data::Struct(struct_a), syn::Data::Struct(struct_b)) => combine_structs(struct_a, struct_b, args),
259 (syn::Data::Union(union_a), syn::Data::Union(union_b)) => combine_unions(union_a, union_b, args),
260 _ => Err(syn::Error::new(b_span, "Can only combine 2 of the same type of data structure!",))
261 };
262 match resp {
263 Ok((data, data_token)) => {
264 let vis_b = ast_b.vis;
265 construct_stream(data, Ident::new(data_token, Span::call_site()), vis_b, generics, ast_b.ident, attrs)
266 },
267 Err(e) => e.to_compile_error()
268 }
269}
270
271#[cfg(test)]
272mod tests {
273
274 use super::combine_data;
275 use quote::quote;
276 use syn::{DeriveInput};
277 use proc_macro2::TokenStream;
278 use assert_tokenstreams_eq::assert_tokenstreams_eq;
279
280 fn assert_streams(result: TokenStream, expected: TokenStream) {
281 assert_eq!(syn::parse2::<DeriveInput>(result).unwrap(), syn::parse2::<DeriveInput>(expected).unwrap());
282 }
283
284 fn assert_compiler_error(result: TokenStream, msg: &str) {
285 let expected = quote!(::core::compile_error! { #msg });
286 assert_tokenstreams_eq!(&result, &expected);
287 }
288
289 #[test]
290 fn test_combine_enums() {
291 let enum_a = quote! {
292 enum A {
293 One(Thing),
294 Two {
295 SubOne: i32,
296 SubTwo: Another
297 },
298 Three
299 }
300 };
301 let enum_b = quote! {
302 enum B {
303 Four(Thing, Thing, Thing),
304 Five,
305 Six
306 }
307 };
308 let expected = quote! {
309 enum B {
310 One(Thing),
311 Two {
312 SubOne: i32,
313 SubTwo: Another
314 },
315 Three,
316 Four(Thing, Thing, Thing),
317 Five,
318 Six
319 }
320 };
321 let result = combine_data(enum_a, enum_b, None);
322 assert_streams(result, expected);
323 }
324
325 #[test]
326 fn test_combine_named_structs() {
327 let struct_a = quote! {
328 struct A {
329 one: i32
330 }
331 };
332 let struct_b = quote! {
333 struct B {
334 two: SomeType
335 }
336 };
337 let expected = quote! {
338 struct B {
339 one: i32,
340 two: SomeType
341 }
342 };
343 let result = combine_data(struct_a, struct_b, None);
344 assert_streams(result, expected);
345 }
346
347 #[test]
348 fn test_combine_unnamed_structs() {
349 let struct_a = quote! {
350 struct A(i32, SomeType);
351 };
352 let struct_b = quote! {
353 struct B(i64, Another);
354 };
355 let expected = quote! {
356 struct B(i32, SomeType, i64, Another);
357 };
358 let result = combine_data(struct_a, struct_b, None);
359 assert_streams(result, expected);
360 }
361
362 #[test]
363 fn test_combine_unit_structs() {
364 let struct_a = quote! {
365 struct A;
366 };
367 let struct_b = quote! {
368 struct B;
369 };
370 let struct_c = quote! {
371 struct C {
372 one: i32
373 }
374 };
375 let struct_d = quote! {
376 struct D(i32, i32);
377 };
378 let expected_unit = quote! {
379 struct B;
380 };
381 let expected_named_one = quote! {
382 struct C {
383 one: i32
384 }
385 };
386 let expected_named_two = quote! {
387 struct A {
388 one: i32
389 }
390 };
391 let expected_unnamed_one = quote! {
392 struct D(i32, i32);
393 };
394 let expected_unnamed_two = quote! {
395 struct A(i32, i32);
396 };
397
398 let result_unit = combine_data(struct_a.clone(), struct_b, None);
400 let result_named_one = combine_data(struct_a.clone(), struct_c.clone(), None);
401 let result_named_two = combine_data(struct_c, struct_a.clone(), None);
402 let result_unnamed_one = combine_data(struct_a.clone(), struct_d.clone(), None);
403 let result_unnamed_two = combine_data(struct_d, struct_a, None);
404
405 assert_streams(result_unit, expected_unit);
406 assert_streams(result_named_one, expected_named_one);
407 assert_streams(result_named_two, expected_named_two);
408 assert_streams(result_unnamed_one, expected_unnamed_one);
409 assert_streams(result_unnamed_two, expected_unnamed_two);
410 }
411
412 #[test]
413 fn test_invalid_combine() {
414 let input_a = quote! {
415 struct A;
416 };
417 let input_b = quote! {
418 enum B {
419 Thing
420 }
421 };
422 let result = combine_data(input_a, input_b, None);
423 assert_compiler_error(result, "Can only combine 2 of the same type of data structure!");
424 }
425
426 #[test]
427 fn test_invalid_combine_structs() {
428 let input_a = quote! {
429 struct A(i32, i32);
430 };
431 let input_b = quote! {
432 struct B {
433 one: i32
434 }
435 };
436 let result = combine_data(input_a, input_b, None);
437 assert_compiler_error(result, "Can not combine provided structs. Either make sure they are the same type, or filter out the offending struct.");
438 }
439
440 #[test]
441 fn test_invalid_args() {
442 let input_a = quote! {
443 struct A;
444 };
445 let input_b = quote! {
446 struct B;
447 };
448 let result = combine_data(input_a, input_b, Some(quote!(fake arg)));
449 assert_compiler_error(result, "Unsupported Argument");
450 }
451
452 #[test]
453 fn test_combine_visibility() {
454 let input_a = quote! {
455 enum A {
456 One
457 }
458 };
459 let input_b = quote! {
460 pub enum B {
461 Two
462 }
463 };
464 let expected = quote! {
465 pub enum B {
466 One,
467 Two
468 }
469 };
470 let result = combine_data(input_a, input_b, None);
471 assert_streams(result, expected);
472 }
473
474 #[test]
475 fn test_combine_attributes() {
476 let input_a = quote! {
477 #[some_attr]
478 enum A {
479 One
480 }
481 };
482 let input_b = quote! {
483 #[another(attr)]
484 enum B {
485 #[on_attr]
486 Two
487 }
488 };
489 let expected = quote! {
490 #[some_attr]
491 #[another(attr)]
492 enum B {
493 One,
494
495 #[on_attr]
496 Two
497 }
498 };
499 let result = combine_data(input_a, input_b, None);
500 assert_streams(result, expected);
501
502 let input_a = quote! {
503 #[some_attr]
504 struct A;
505 };
506 let input_b = quote! {
507 #[another_attr]
508 struct B {
509 one: i32
510 }
511 };
512 let expected = quote! {
513 #[some_attr]
514 #[another_attr]
515 struct B {
516 one: i32
517 }
518 };
519 let result = combine_data(input_a, input_b, None);
520 assert_streams(result, expected);
521
522 }
523
524 #[test]
525 fn test_combine_generics() {
526 let input_a = quote! {
527 enum A<'life, T> {
528 One(Thing<'life, T>)
529 }
530 };
531 let input_b = quote! {
532 enum B<'efil, U> {
533 Two(Thing<'efil, U>)
534 }
535 };
536 let expected = quote! {
537 enum B<'life, 'efil, T, U> {
538 One(Thing<'life, T>),
539 Two(Thing<'efil, U>)
540 }
541 };
542 let result = combine_data(input_a, input_b, None);
543 assert_streams(result, expected);
544 }
545
546 #[test]
547 fn test_namespace_confict_overwrite() {
548 let input_a = quote! {
549 enum A {
550 One,
551 #[one_attr]
552 Two,
553 Three {
554 x: i32
555 }
556 }
557 };
558 let input_b = quote! {
559 enum B {
560 #[two_attr]
561 Two(Thing),
562 Three {
563 y: i32
564 },
565 Four
566 }
567 };
568 let expected = quote! {
569 enum B {
570 One,
571 #[two_attr]
572 Two(Thing),
573 Three {
574 y: i32
575 },
576 Four
577 }
578 };
579 let result = combine_data(input_a, input_b, None);
580 assert_streams(result, expected);
581 }
582
583 #[test]
584 fn test_namespace_conflict_merge() {
585 let input_a = quote! {
586 enum A {
587 One,
588 #[one_attr]
589 Two,
590 Three {
591 x: i32
592 }
593 }
594 };
595 let input_b = quote! {
596 enum B {
597 #[two_attr]
598 Two(Thing),
599 Three {
600 y: i32
601 },
602 Four
603 }
604 };
605 let expected = quote! {
606 enum B {
607 #[one_attr]
608 #[two_attr]
609 Two(Thing),
610 Three {
611 x: i32,
612 y: i32
613 },
614 Four
615 }
616 };
617 let args = quote!(merge_on_conflict, filter(One));
618 let result = combine_data(input_a, input_b, Some(args));
619 assert_streams(result, expected);
620 }
621
622 #[test]
623 fn test_namespace_conflict_struct() {
624 let input_a = quote! {
625 struct A {
626 x: i32,
627 y: i32
628 }
629 };
630 let input_b = quote! {
631 struct B {
632 x: i64
633 }
634 };
635 let expected = quote! {
636 struct B {
637 x: i64,
638 y: i32
639 }
640 };
641 let result = combine_data(input_a, input_b, None);
642 assert_streams(result, expected);
643 }
644
645 #[test]
646 fn test_filter() {
647 let input_a = quote! {
648 #[attr]
649 enum A {
650 #[another_attr]
651 One,
652 Two,
653 #[another_attr]
654 Three
655 }
656 };
657 let input_b = quote! {
658 enum B {
659 One,
660 Four,
661 }
662 };
663 let expected = quote! {
664 enum B {
665 One,
666 #[another_attr]
667 Three,
668 Four
669 }
670 };
671 let filter = quote!(filter(Two, attr, another_attr));
672 let result = combine_data(input_a, input_b, Some(filter));
673 assert_streams(result, expected);
674 }
675}