Skip to main content

state_validation_derive/
lib.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToSnakeCase;
4use itertools::Itertools;
5use proc_macro::TokenStream;
6use quote::TokenStreamExt;
7use syn::{
8    Expr, GenericArgument, GenericParam, Generics, Ident, Lifetime, Type, TypePath, ext,
9    parse_macro_input, parse_quote,
10};
11
12#[derive(Debug, Clone, Hash, PartialEq, Eq)]
13struct ConversionSort {
14    sort_number: usize,
15    ty: ConversionType,
16}
17impl Ord for ConversionSort {
18    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
19        self.sort_number.cmp(&other.sort_number)
20    }
21}
22impl PartialOrd for ConversionSort {
23    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
24        Some(self.cmp(other))
25    }
26}
27impl quote::ToTokens for ConversionSort {
28    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
29        self.ty.to_tokens(tokens)
30    }
31}
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33enum ConversionType {
34    Type(syn::Type),
35    Generic {
36        generic_ident: Vec<syn::Ident>,
37        ty: syn::Type,
38    },
39}
40impl syn::parse::Parse for ConversionType {
41    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
42        if input.peek(syn::Ident) && input.peek2(syn::Token![=]) {
43            let generic_ident = input.parse()?;
44            let generic_ident = vec![generic_ident];
45            let _: syn::Token![=] = input.parse()?;
46            let ty = input.parse()?;
47            Ok(ConversionType::Generic { generic_ident, ty })
48        } else if input.peek(syn::Ident) && input.peek2(syn::Token![,]) {
49            let mut generic_ident = Vec::with_capacity(2);
50            loop {
51                let generic = input.parse::<syn::Ident>()?;
52                generic_ident.push(generic);
53                if input.parse::<syn::Token![,]>().is_err() {
54                    break;
55                }
56            }
57            let _: syn::Token![=] = input.parse()?;
58            let ty = input.parse()?;
59            Ok(ConversionType::Generic { generic_ident, ty })
60        } else {
61            input.parse().map(ConversionType::Type)
62        }
63    }
64}
65impl quote::ToTokens for ConversionType {
66    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
67        match self {
68            ConversionType::Type(ty) => {
69                tokens.append_all(ty.to_token_stream());
70            }
71            ConversionType::Generic { ty, .. } => {
72                tokens.append_all(quote::quote!(#ty));
73            }
74        }
75    }
76}
77
78/// Implements `StateFilterInputConversion` and `StateFilterInputCombination`,
79/// so that this struct can be split into every individual field and later combined together.
80///
81/// The types each field can be converted to must be specified.
82///
83/// Use `conversion` on the struct fields to convert them to a different type:
84/// ```ignore
85/// #[derive(StateFilterConversion)]
86/// struct ExampleStruct {
87///     #[conversion(AdminUser)]
88///     some_value: UnknownUser,
89/// }
90/// # struct UnknownUser;
91/// # struct AdminUser;
92/// ```
93/// The above code will let a `StateFilter` take `ExampleStruct` as an input,
94/// and output a new struct whose `some_value` is of type `AdminUser` or the original `UnknownUser`.
95///
96/// In some cases, your filters may result in more data output than what was given.
97/// In those cases, you can use the `conversion` attribute on the struct itself for extra fields:
98/// ```ignore
99/// #[derive(StateFilterConversion)]
100/// #[conversion(Age)]
101/// struct ExampleStruct {
102///     some_value: UnknownUser,
103/// }
104/// # struct UnknownUser;
105/// # struct Age;
106/// ```
107/// The above code will allow `ExampleStruct` to be deconstructed and
108/// then reconstructed into a new struct which also contains the field `age: Age`.
109///
110/// If you use generics, use this syntax:
111/// ```ignore
112/// #[derive(StateFilterConversion)]
113/// struct ExampleStruct {
114///     #[conversion(T = UserWithData<T>)]
115///     some_value: UnknownUser,
116/// }
117/// # struct UnknownUser;
118/// # struct UserWithData<T>(T);
119/// ```
120/// And if you want more than one generic, be sure to use a different generic name:
121/// ```ignore
122/// #[derive(StateFilterConversion)]
123/// #[conversion(T0 = SomeMoreData<T0>)]
124/// struct ExampleStruct {
125///     #[conversion(T1, T2 = UserWithData<T1, T2>)]
126///     some_value: UnknownUser,
127/// }
128/// # struct UnknownUser;
129/// # struct UserWithData<T0, T1>(T0, T1);
130/// # struct SomeMoreData<T>(T);
131/// ```
132///
133/// The `conversion` attribute can be used multiple times on a single field for different conversion types:
134/// ```
135/// #[derive(StateFilterConversion)]
136/// struct ExampleStruct {
137///     #[conversion(AdminUser)]
138///     #[conversion(UserWithData)]
139///     some_value: UnknownUser,
140/// }
141/// # struct UnknownUser;
142/// # struct AdminUser;
143/// # struct UserWithData;
144/// ```
145#[proc_macro_derive(StateFilterConversion, attributes(conversion))]
146pub fn state_filter_conversion(input: TokenStream) -> TokenStream {
147    let ast = parse_macro_input!(input as syn::DeriveInput);
148    let name = &ast.ident;
149    let state_conversions = match &ast.data {
150        syn::Data::Struct(s) => {
151            let fields_count = s.fields.len();
152            let mut state_conversions = Vec::with_capacity(fields_count);
153            let (iter, extra_fields_count) = {
154                let mut iter: Vec<_> = s
155                    .fields
156                    .iter()
157                    .enumerate()
158                    .map(|(i, field)| {
159                        let field_name = field.ident.as_ref().expect("expected a named field");
160                        let mut all_conversion_fields = Vec::new();
161                        all_conversion_fields.push((
162                            field_name.clone(),
163                            ConversionSort {
164                                sort_number: i,
165                                ty: ConversionType::Type(field.ty.clone()),
166                            },
167                            extract_generics_from_type(&field.ty, &ast.generics),
168                        ));
169                        for attr in field
170                            .attrs
171                            .iter()
172                            .filter(|attr| attr.path().is_ident("conversion"))
173                        {
174                            let f = attr
175                                .parse_args::<ConversionType>()
176                                .expect("expected a conversion type");
177                            let generics = match &f {
178                                ConversionType::Type(ty) => {
179                                    extract_generics_from_type(ty, &ast.generics)
180                                }
181                                ConversionType::Generic { generic_ident, .. } => {
182                                    parse_quote!(<#(#generic_ident),*>)
183                                }
184                            };
185                            all_conversion_fields.push((
186                                field_name.clone(),
187                                ConversionSort {
188                                    sort_number: i,
189                                    ty: f,
190                                },
191                                generics,
192                            ));
193                        }
194                        all_conversion_fields
195                    })
196                    .collect();
197                let extra_struct_fields: Vec<_> = ast
198                    .attrs
199                    .into_iter()
200                    .filter(|attr| attr.path().is_ident("conversion"))
201                    .enumerate()
202                    .map(|(i, attr)| {
203                        let f = attr
204                            .parse_args::<ConversionType>()
205                            .expect("expected a conversion type");
206                        let (field_name, generics) = match &f {
207                            ConversionType::Type(ty) => {
208                                let ident = type_to_ident(ty);
209                                (
210                                    quote::format_ident!("{}", ident.to_string().to_snake_case()),
211                                    extract_generics_from_type(ty, &ast.generics),
212                                )
213                            }
214                            ConversionType::Generic { generic_ident, ty } => {
215                                let ident = type_to_ident(ty);
216                                (
217                                    quote::format_ident!("{}", ident.to_string().to_snake_case()),
218                                    parse_quote!(<#(#generic_ident),*>),
219                                )
220                            }
221                        };
222                        // TODO: for now, the extra fields can be of only 1 type
223                        vec![(
224                            field_name,
225                            ConversionSort {
226                                sort_number: i + iter.len(),
227                                ty: f,
228                            },
229                            generics,
230                        )]
231                    })
232                    .collect();
233                let extra_fields_count = extra_struct_fields.len();
234                iter.extend(extra_struct_fields);
235                (iter, extra_fields_count)
236            };
237            let mut combination_names = HashMap::new();
238            let mut remainder_names = HashMap::new();
239            let mut i = 0;
240            for powerset in iter.iter().powerset() {
241                for (field_names, mut field_types, field_generics) in
242                    powerset.into_iter().multi_cartesian_product().map(|f| {
243                        let mut field_names = Vec::with_capacity(f.len());
244                        let mut field_types = Vec::with_capacity(f.len());
245                        let mut generics = Vec::with_capacity(f.len());
246                        for (field_name, field_type, field_generics) in f {
247                            field_names.push(field_name);
248                            field_types.push(field_type.clone());
249                            generics.push(field_generics);
250                        }
251                        (field_names, field_types, generics)
252                    })
253                {
254                    let combination_struct_name =
255                        quote::format_ident!("__StateValidationGeneration_{name}Combined_{i}");
256                    let mut generics = Generics::default();
257                    for g in field_generics {
258                        generics = merge_generics(generics, g);
259                    }
260                    let q = quote::quote! {
261                        pub struct #combination_struct_name #generics {
262                            #(pub #field_names: #field_types),*
263                        }
264                    };
265                    state_conversions.push(q);
266                    field_types.sort();
267                    combination_names.insert(field_types, combination_struct_name);
268                    i += 1;
269                }
270            }
271            let mut i = 0;
272            for powerset in iter.iter().powerset() {
273                for (field_names, mut field_types, field_generics) in
274                    powerset.into_iter().multi_cartesian_product().map(|f| {
275                        let mut field_names = Vec::with_capacity(f.len());
276                        let mut field_types = Vec::with_capacity(f.len());
277                        let mut generics = Vec::with_capacity(f.len());
278                        for (field_name, field_type, field_generics) in f {
279                            field_names.push(field_name);
280                            field_types.push(field_type.clone());
281                            generics.push(field_generics);
282                        }
283                        (field_names, field_types, generics)
284                    })
285                {
286                    let remainder_struct_name =
287                        quote::format_ident!("__StateValidationGeneration_{name}Remainder_{i}");
288                    let mut generics = Generics::default();
289                    for g in field_generics {
290                        generics = merge_generics(generics, g);
291                    }
292                    let q = quote::quote! {
293                        pub struct #remainder_struct_name #generics {
294                            #(#field_names: #field_types),*
295                        }
296                    };
297                    state_conversions.push(q);
298                    field_types.sort();
299                    remainder_names.insert(field_types, remainder_struct_name);
300                    i += 1;
301                }
302            }
303            create_original_conversion_combinations(
304                &mut state_conversions,
305                &ast.generics,
306                &combination_names,
307                &remainder_names,
308                name,
309                &s.fields,
310                ast.generics.clone(),
311            );
312            let cartesian_product = iter.iter().multi_cartesian_product().map(|f| {
313                let mut field_names = Vec::with_capacity(f.len());
314                let mut field_types = Vec::with_capacity(f.len());
315                let mut generics = Vec::with_capacity(f.len());
316                for (field_name, field_type, field_generics) in f {
317                    field_names.push(field_name);
318                    field_types.push(field_type);
319                    generics.push(field_generics);
320                }
321                (field_names, field_types, generics)
322            });
323            for (k, (field_names, field_types, field_generics)) in cartesian_product.enumerate() {
324                let mut all_field_generics = Generics::default();
325                for field_generics in field_generics.iter() {
326                    all_field_generics = merge_generics(all_field_generics, field_generics);
327                }
328                let fields_name_type_generics: Vec<_> = field_names
329                    .clone()
330                    .into_iter()
331                    .zip(field_types.clone().into_iter())
332                    .zip(field_generics.clone().into_iter())
333                    .collect();
334                for count in 0..=(fields_count + extra_fields_count) {
335                    for f in fields_name_type_generics.iter().combinations(count) {
336                        for (
337                            current_field_names,
338                            current_field_types,
339                            current_field_generics,
340                            other_field_names,
341                            other_field_types,
342                            other_field_generics,
343                        ) in f.into_iter().permutations(count).map(|subset| {
344                            let remainder: Vec<_> = fields_name_type_generics
345                                .iter()
346                                .filter(|((field_name_a, ..), ..)| {
347                                    !subset.iter().any(|((field_name_b, ..), ..)| {
348                                        field_name_a == field_name_b
349                                    })
350                                })
351                                .collect();
352                            let mut current_field_names = Vec::with_capacity(subset.len());
353                            let mut current_field_types = Vec::with_capacity(subset.len());
354                            let mut current_field_generics = Vec::with_capacity(subset.len());
355                            for ((field_name, field_type), generics) in subset {
356                                current_field_names.push((*field_name).clone());
357                                current_field_types.push((*field_type).clone());
358                                current_field_generics.push((*generics).clone());
359                            }
360                            let mut other_field_names = Vec::with_capacity(remainder.len());
361                            let mut other_field_types = Vec::with_capacity(remainder.len());
362                            let mut other_field_generics = Vec::with_capacity(remainder.len());
363                            for ((field_name, field_type), generics) in remainder {
364                                other_field_names.push((*field_name).clone());
365                                other_field_types.push((*field_type).clone());
366                                other_field_generics.push((*generics).clone());
367                            }
368                            (
369                                current_field_names,
370                                current_field_types,
371                                current_field_generics,
372                                other_field_names,
373                                other_field_types,
374                                other_field_generics,
375                            )
376                        }) {
377                            let r = current_field_types
378                                .iter()
379                                .chain(other_field_types.iter())
380                                .cloned()
381                                .sorted()
382                                .collect::<Vec<_>>();
383                            let combined_struct_name = combination_names
384                                .get(&r)
385                                .expect(&format!("0: expected a combined struct: {:?}", r));
386                            let remainder_struct_name = {
387                                let mut other_field_types = other_field_types.clone();
388                                other_field_types.sort();
389                                remainder_names.get(&other_field_types).unwrap()
390                            };
391                            let mut o = Generics::default();
392                            for other_field_generics in other_field_generics {
393                                o = merge_generics(o, &other_field_generics);
394                            }
395                            let other_field_generics = o;
396                            let q = quote::quote! {
397                                impl #all_field_generics state_validation::StateFilterInputCombination<(#(#current_field_types),*)> for #remainder_struct_name #other_field_generics {
398                                    type Combined = #combined_struct_name #all_field_generics;
399                                    fn combine(self, (#(#current_field_names),*): (#(#current_field_types),*)) -> Self::Combined {
400                                        #combined_struct_name {
401                                            #(#current_field_names,)*
402                                            #(#other_field_names: self.#other_field_names),*
403                                        }
404                                    }
405                                }
406                                impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #combined_struct_name #all_field_generics {
407                                    type Remainder = #remainder_struct_name #other_field_generics;
408                                    fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
409                                        (
410                                            (#(self.#current_field_names),*),
411                                            #remainder_struct_name {
412                                                #(#other_field_names: self.#other_field_names),*
413                                            },
414                                        )
415                                    }
416                                }
417                            };
418                            state_conversions.push(q);
419                        }
420                    }
421                }
422            }
423            state_conversions
424        }
425        _ => todo!(),
426    };
427    quote::quote! {
428        #(#state_conversions)*
429    }
430    .into()
431}
432
433fn create_original_conversion_combinations(
434    state_conversions: &mut Vec<proc_macro2::TokenStream>,
435    original_generics: &Generics,
436    combination_names: &HashMap<Vec<ConversionSort>, Ident>,
437    remainder_names: &HashMap<Vec<ConversionSort>, Ident>,
438    name: &Ident,
439    fields: &syn::Fields,
440    mut all_field_generics: Generics,
441) {
442    let fields: Vec<_> = fields
443        .iter()
444        .enumerate()
445        .map(|(i, field)| {
446            let field_name = field.ident.as_ref().expect("expected a named field");
447            (
448                field_name,
449                ConversionSort {
450                    sort_number: i,
451                    ty: ConversionType::Type(field.ty.clone()),
452                },
453                extract_generics_from_type(&field.ty, original_generics),
454            )
455        })
456        .collect();
457    for (_, _, generics_b) in fields.iter() {
458        all_field_generics = merge_generics(all_field_generics, generics_b);
459    }
460    for k in 0..=fields.len() {
461        for combination in fields.iter().combinations(k) {
462            for (
463                current_field_names,
464                current_field_types,
465                current_field_generics,
466                other_field_names,
467                other_field_types,
468                other_field_generics,
469            ) in combination.into_iter().permutations(k).map(|subset| {
470                let remainder: Vec<_> = fields
471                    .iter()
472                    .filter(|(field_name_a, ..)| {
473                        !subset
474                            .iter()
475                            .any(|(field_name_b, ..)| field_name_a == field_name_b)
476                    })
477                    .collect();
478                let mut current_field_names = Vec::with_capacity(subset.len());
479                let mut current_field_types = Vec::with_capacity(subset.len());
480                let mut current_field_generics = Vec::new();
481                for (field_name, field_type, generics) in subset {
482                    current_field_names.push((*field_name).clone());
483                    current_field_types.push(field_type.clone());
484                    current_field_generics.push(generics.clone());
485                }
486                let mut other_field_names = Vec::with_capacity(remainder.len());
487                let mut other_field_types = Vec::with_capacity(remainder.len());
488                let mut other_field_generics = Vec::new();
489                for (field_name, field_type, generics) in remainder {
490                    other_field_names.push((*field_name).clone());
491                    other_field_types.push(field_type.clone());
492                    other_field_generics.push(generics.clone());
493                }
494                (
495                    current_field_names,
496                    current_field_types,
497                    current_field_generics,
498                    other_field_names,
499                    other_field_types,
500                    other_field_generics,
501                )
502            }) {
503                let r = current_field_types
504                    .iter()
505                    .chain(other_field_types.iter())
506                    .cloned()
507                    .sorted()
508                    .collect::<Vec<_>>();
509                let combined_struct_name = combination_names.get(&r).expect(&format!(
510                    "1: expected a combined struct: {:#?}\nCOMBINATION NAMES: {:#?}",
511                    r, combination_names,
512                ));
513                let remainder_struct_name = {
514                    let mut other_field_types = other_field_types.clone();
515                    other_field_types.sort();
516                    remainder_names
517                        .get(&other_field_types)
518                        .expect("expected a remainder struct")
519                };
520                let mut current_field_generic = Generics::default();
521                for current_generics in current_field_generics {
522                    current_field_generic =
523                        merge_generics(current_field_generic, &current_generics);
524                }
525                let current_field_generics = current_field_generic;
526                let mut other_field_generic = Generics::default();
527                for other_generics in other_field_generics {
528                    other_field_generic = merge_generics(other_field_generic, &other_generics);
529                }
530                let other_field_generics = other_field_generic;
531                let q = quote::quote! {
532                    impl #all_field_generics state_validation::StateFilterInputConversion<(#(#current_field_types),*)> for #name #all_field_generics {
533                        type Remainder = #remainder_struct_name #other_field_generics;
534                        fn split_take(self) -> ((#(#current_field_types),*), Self::Remainder) {
535                            (
536                                (#(self.#current_field_names),*),
537                                #remainder_struct_name {
538                                    #(#other_field_names: self.#other_field_names),*
539                                },
540                            )
541                        }
542                    }
543                };
544                state_conversions.push(q);
545            }
546        }
547    }
548}
549
550// UTILITY //
551
552fn extract_generics_from_type(ty: &Type, original_generics: &Generics) -> Generics {
553    let mut type_params = BTreeSet::new();
554    let mut lifetime_params = BTreeSet::new();
555    let mut const_params = BTreeSet::new();
556
557    collect_generics(
558        ty,
559        original_generics,
560        &mut type_params,
561        &mut lifetime_params,
562        &mut const_params,
563    );
564
565    let mut generics = Generics::default();
566
567    for lt in lifetime_params {
568        generics
569            .params
570            .push(GenericParam::Lifetime(parse_quote!(#lt)));
571    }
572    for tp in type_params {
573        generics.params.push(GenericParam::Type(parse_quote!(#tp)));
574    }
575    for cp in const_params {
576        generics
577            .params
578            .push(GenericParam::Const(parse_quote!(const #cp: usize)));
579    }
580
581    generics
582}
583
584fn collect_generics(
585    ty: &Type,
586    original_generics: &Generics,
587    type_params: &mut BTreeSet<syn::Ident>,
588    lifetime_params: &mut BTreeSet<Lifetime>,
589    const_params: &mut BTreeSet<syn::Ident>,
590) {
591    match ty {
592        Type::Path(TypePath { path, .. }) => {
593            for segment in &path.segments {
594                // Extract angle bracketed generics
595                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
596                    for arg in &args.args {
597                        match arg {
598                            GenericArgument::Type(inner_ty) => {
599                                if let Type::Path(p) = inner_ty
600                                    && let Some(ident) = p.path.get_ident()
601                                    && original_generics.type_params().any(|ty| ty.ident == *ident)
602                                {
603                                    type_params.insert(ident.clone());
604                                }
605                                collect_generics(
606                                    inner_ty,
607                                    original_generics,
608                                    type_params,
609                                    lifetime_params,
610                                    const_params,
611                                );
612                            }
613                            GenericArgument::Lifetime(lt) => {
614                                lifetime_params.insert(lt.clone());
615                            }
616                            GenericArgument::Const(expr) => {
617                                if let syn::Expr::Path(expr_path) = expr
618                                    && let Some(ident) = expr_path.path.get_ident()
619                                {
620                                    const_params.insert(ident.clone());
621                                }
622                            }
623                            _ => {}
624                        }
625                    }
626                }
627            }
628        }
629        Type::Reference(r) => {
630            if let Some(lt) = &r.lifetime {
631                lifetime_params.insert(lt.clone());
632            }
633            collect_generics(
634                &r.elem,
635                original_generics,
636                type_params,
637                lifetime_params,
638                const_params,
639            );
640        }
641        _ => {}
642    }
643}
644
645fn merge_generics(mut generics_a: Generics, generics_b: &Generics) -> Generics {
646    let mut existing = BTreeSet::new();
647    for param in &generics_a.params {
648        match param {
649            GenericParam::Type(tp) => {
650                existing.insert(tp.ident.to_string());
651            }
652            GenericParam::Lifetime(lt) => {
653                existing.insert(lt.lifetime.ident.to_string());
654            }
655            GenericParam::Const(cp) => {
656                existing.insert(cp.ident.to_string());
657            }
658        }
659    }
660
661    for param in &generics_b.params {
662        let name = match param {
663            GenericParam::Type(tp) => tp.ident.to_string(),
664            GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
665            GenericParam::Const(cp) => cp.ident.to_string(),
666        };
667        if !existing.contains(&name) {
668            generics_a.params.push(param.clone());
669            existing.insert(name);
670        }
671    }
672
673    match (&mut generics_a.where_clause, &generics_b.where_clause) {
674        (Some(a_wc), Some(b_wc)) => {
675            a_wc.predicates.extend(b_wc.predicates.clone());
676        }
677        (None, Some(b_wc)) => {
678            generics_a.where_clause = Some(b_wc.clone());
679        }
680        _ => {}
681    }
682
683    generics_a
684}
685
686fn type_to_ident(ty: &Type) -> &Ident {
687    match ty {
688        Type::Path(type_path) => type_path
689            .path
690            .segments
691            .last()
692            .map(|seg| &seg.ident)
693            .unwrap(),
694        _ => unimplemented!(),
695    }
696}