avantis_utils_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use syn::parse_macro_input;
5use syn::DeriveInput;
6
7use self::models::PaginatedStruct;
8
9#[proc_macro_derive(PaginatedQuery, attributes(limit, offset))]
10pub fn paginated_query_macro_derive(input: TokenStream) -> TokenStream {
11    let syntax_tree = parse_macro_input!(input as DeriveInput);
12
13    let model = PaginatedStruct::try_from(&syntax_tree).unwrap();
14
15    model.gen().into()
16}
17
18mod models {
19    use proc_macro2::Span;
20    use proc_macro2::TokenStream;
21    use quote::quote;
22    use syn::*;
23
24    #[derive(Clone, Debug)]
25    pub(super) struct PaginatedStruct {
26        name: Ident,
27        limit: PaginatedStructField,
28        offset: PaginatedStructField,
29    }
30
31    impl PaginatedStruct {
32        pub(super) fn gen(&self) -> TokenStream {
33            let name = &self.name;
34            let limit_fn = &self.limit.gen("limit");
35            let offset_fn = &self.offset.gen("offset");
36
37            quote! {
38                impl PaginatedQuery for #name {
39                    #limit_fn
40
41                    #offset_fn
42                }
43            }
44            .into()
45        }
46    }
47
48    #[derive(Clone, Debug)]
49    struct PaginatedStructField {
50        ident_opt: Option<Ident>,
51        default_value: LitInt,
52    }
53
54    impl PaginatedStructField {
55        fn gen(&self, fn_name: &'static str) -> TokenStream {
56            let default_value_lit = &self.default_value;
57
58            let impl_quote = match self.ident_opt.as_ref() {
59                Some(ident) => quote! { self.#ident.unwrap_or(#default_value_lit) },
60                None => quote! { #default_value_lit },
61            };
62
63            let fn_name = Ident::new(fn_name, Span::call_site());
64
65            quote! {
66                fn #fn_name(&self) -> i32 {
67                    #impl_quote
68                }
69            }
70        }
71
72        fn limit_field<T>(
73            fields: &punctuated::Punctuated<syn::Field, T>,
74        ) -> core::result::Result<Self, &'static str> {
75            let matched_fields = fields
76                .iter()
77                .filter(|f| matches!(Attr::try_from(*f), Ok(Attr::Limit(_))))
78                .filter_map(|f| PaginatedStructField::try_from(f).ok())
79                .collect::<Vec<_>>();
80
81            if matched_fields.len() > 1 {
82                return Err("too many attributes");
83            }
84
85            Ok(matched_fields
86                .first()
87                .ok_or_else(|| "field not found")?
88                .clone())
89        }
90
91        fn offset_field<T>(
92            fields: &punctuated::Punctuated<syn::Field, T>,
93        ) -> core::result::Result<Self, &'static str> {
94            let matched_fields = fields
95                .iter()
96                .filter(|f| matches!(Attr::try_from(*f), Ok(Attr::Offset(_))))
97                .filter_map(|f| PaginatedStructField::try_from(f).ok())
98                .collect::<Vec<_>>();
99
100            if matched_fields.len() > 1 {
101                return Err("too many attributes");
102            }
103
104            Ok(matched_fields
105                .first()
106                .ok_or_else(|| "field not found")?
107                .clone())
108        }
109    }
110
111    #[derive(Clone, Debug)]
112    enum Attr {
113        Limit(LitInt),
114        Offset(LitInt),
115    }
116
117    impl Attr {
118        fn default_value(&self) -> &LitInt {
119            match self {
120                Attr::Limit(default) => default,
121                Attr::Offset(default) => default,
122            }
123        }
124    }
125
126    pub(super) mod extractors {
127        use super::*;
128
129        impl TryFrom<&DeriveInput> for PaginatedStruct {
130            type Error = &'static str;
131
132            fn try_from(input: &DeriveInput) -> core::result::Result<Self, Self::Error> {
133                match input.data {
134                    syn::Data::Struct(syn::DataStruct {
135                        fields: syn::Fields::Named(FieldsNamed { ref named, .. }),
136                        ..
137                    }) => Ok(PaginatedStruct {
138                        name: input.ident.clone(),
139                        limit: PaginatedStructField::limit_field(&named)?,
140                        offset: PaginatedStructField::offset_field(&named)?,
141                    }),
142                    _ => Err("help!"),
143                }
144            }
145        }
146
147        impl TryFrom<&Field> for PaginatedStructField {
148            type Error = &'static str;
149
150            fn try_from(field: &Field) -> core::result::Result<Self, Self::Error> {
151                let ident_opt = field.ident.clone();
152                let default_value = Attr::try_from(field.attrs.as_slice())?
153                    .default_value()
154                    .clone();
155
156                match is_option_i32(&field.ty) {
157                    true => Ok(PaginatedStructField {
158                        ident_opt,
159                        default_value,
160                    }),
161                    false => Err("not option i32"),
162                }
163            }
164        }
165
166        impl TryFrom<&Field> for Attr {
167            type Error = &'static str;
168
169            fn try_from(field: &Field) -> core::result::Result<Self, Self::Error> {
170                field.attrs.as_slice().try_into()
171            }
172        }
173
174        impl TryFrom<&[Attribute]> for Attr {
175            type Error = &'static str;
176
177            fn try_from(attrs: &[Attribute]) -> std::result::Result<Self, Self::Error> {
178                if attrs.len() != 1 {
179                    return Err("unexpected attributes");
180                }
181
182                (&attrs[0]).try_into()
183            }
184        }
185
186        impl TryFrom<&Attribute> for Attr {
187            type Error = &'static str;
188
189            fn try_from(attr: &Attribute) -> core::result::Result<Self, Self::Error> {
190                let lit = match attr.parse_meta() {
191                    Ok(Meta::List(MetaList { nested, .. })) if nested.len() == 1 => {
192                        match &nested[0] {
193                            NestedMeta::Meta(Meta::NameValue(MetaNameValue {
194                                lit: Lit::Int(lit),
195                                ..
196                            })) => lit.clone(),
197                            _ => return Err("unexpected attributes"),
198                        }
199                    }
200                    _ => return Err("unexpected attributes"),
201                };
202
203                match attr.path.get_ident() {
204                    Some(ident) if ident == "limit" => Ok(Attr::Limit(lit)),
205                    Some(ident) if ident == "offset" => Ok(Attr::Offset(lit)),
206                    _ => Err("unexpected attributes"),
207                }
208            }
209        }
210
211        fn is_option_i32(ty: &Type) -> bool {
212            match ty {
213                Type::Path(TypePath {
214                    path: Path { segments, .. },
215                    ..
216                }) if segments.len() == 1 => match &segments[0] {
217                    PathSegment {
218                        ident,
219                        arguments:
220                            PathArguments::AngleBracketed(AngleBracketedGenericArguments {
221                                args: generic_args,
222                                ..
223                            }),
224                    } if &ident.to_string() == "Option" && generic_args.len() == 1 => {
225                        match &generic_args[0] {
226                            GenericArgument::Type(Type::Path(TypePath { path, .. }))
227                                if path.is_ident("i32") =>
228                            {
229                                true
230                            }
231                            _ => false,
232                        }
233                    }
234                    _ => false,
235                },
236                _ => false,
237            }
238        }
239    }
240}