fieldfilter_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Attribute, Data, DataStruct, Fields, FieldsNamed, GenericArgument, Meta, NestedMeta, Path, Type,
5};
6
7#[proc_macro_derive(FieldFilterable, attributes(field_filterable_on))]
8pub fn derive_field_filterable_impl(ts: TokenStream) -> TokenStream {
9    let input = syn::parse_macro_input!(ts as syn::DeriveInput);
10
11    let name = input.ident;
12
13    let object_path = get_object_path(input.attrs);
14
15    let fields = if let Data::Struct(DataStruct {
16        fields: Fields::Named(FieldsNamed { ref named, .. }),
17        ..
18    }) = input.data
19    {
20        named
21    } else {
22        unimplemented!();
23    };
24
25    let filter_logic_q = fields.iter().map(|f| {
26        let name = f.ident.as_ref().expect("cannot be derived on tuples");
27        let ty = &f.ty;
28        if ty_inner_type("Option", ty).is_some() {
29            let name_str = format!("{}", name);
30            quote! {
31                let #name = if fields.contains(#name_str) { Some(o.#name) } else { None };
32            }
33        } else {
34            quote! {
35                let #name = o.#name;
36            }
37        }
38    });
39
40    let return_o = fields.iter().map(|f| {
41        let name = f.ident.as_ref().expect("cannot be derived on tuples");
42        quote! { #name }
43    });
44
45    let expanded = quote! {
46        impl FieldFilterable<#object_path> for #name {
47            fn field_filter(o: #object_path, fields: std::collections::HashSet<String>) -> Self {
48                #(#filter_logic_q)*
49                Self { #(#return_o),* }
50            }
51        }
52    };
53    expanded.into()
54}
55
56fn get_object_path(attrs: Vec<Attribute>) -> Path {
57    for attr in attrs {
58        let meta = attr.parse_meta().unwrap();
59        if let Meta::List(list) = meta {
60            match get_single_segment(&list.path) {
61                // this is the attr name we care about
62                Some(ref seg) if seg == "field_filterable_on" => {
63                    for nested in list.nested {
64                        if let NestedMeta::Meta(Meta::Path(path)) = nested {
65                            return path;
66                        }
67                    }
68                }
69                _ => (),
70            }
71        }
72    }
73    unimplemented!(r#"#[field_filterable_on(<TYPE>)] must be set"#)
74}
75
76fn get_single_segment(path: &Path) -> Option<String> {
77    if path.segments.len() == 1 {
78        Some(path.segments[0].ident.to_string())
79    } else {
80        None
81    }
82}
83
84fn ty_inner_type<'a>(wrapper: &str, ty: &'a Type) -> Option<&'a Type> {
85    if let Type::Path(ref p) = ty {
86        if p.path.segments.len() != 1 || p.path.segments[0].ident != wrapper {
87            return None;
88        }
89
90        if let syn::PathArguments::AngleBracketed(ref inner_ty) = p.path.segments[0].arguments {
91            if inner_ty.args.len() != 1 {
92                return None;
93            }
94
95            let inner_ty = inner_ty.args.first().unwrap();
96            if let GenericArgument::Type(ref t) = inner_ty {
97                return Some(t);
98            }
99        }
100    }
101    None
102}