avantis_utils_derive/
lib.rs1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use syn::parse_macro_input;
5use syn::DeriveInput;
6
7use self::models::PaginatedStruct;
8
9#[proc_macro_derive(PaginatedQuery, attributes(limit, offset))]
10pub fn paginated_query_macro_derive(input: TokenStream) -> TokenStream {
11 let syntax_tree = parse_macro_input!(input as DeriveInput);
12
13 let model = PaginatedStruct::try_from(&syntax_tree).unwrap();
14
15 model.gen().into()
16}
17
18mod models {
19 use proc_macro2::Span;
20 use proc_macro2::TokenStream;
21 use quote::quote;
22 use syn::*;
23
24 #[derive(Clone, Debug)]
25 pub(super) struct PaginatedStruct {
26 name: Ident,
27 limit: PaginatedStructField,
28 offset: PaginatedStructField,
29 }
30
31 impl PaginatedStruct {
32 pub(super) fn gen(&self) -> TokenStream {
33 let name = &self.name;
34 let limit_fn = &self.limit.gen("limit");
35 let offset_fn = &self.offset.gen("offset");
36
37 quote! {
38 impl PaginatedQuery for #name {
39 #limit_fn
40
41 #offset_fn
42 }
43 }
44 .into()
45 }
46 }
47
48 #[derive(Clone, Debug)]
49 struct PaginatedStructField {
50 ident_opt: Option<Ident>,
51 default_value: LitInt,
52 }
53
54 impl PaginatedStructField {
55 fn gen(&self, fn_name: &'static str) -> TokenStream {
56 let default_value_lit = &self.default_value;
57
58 let impl_quote = match self.ident_opt.as_ref() {
59 Some(ident) => quote! { self.#ident.unwrap_or(#default_value_lit) },
60 None => quote! { #default_value_lit },
61 };
62
63 let fn_name = Ident::new(fn_name, Span::call_site());
64
65 quote! {
66 fn #fn_name(&self) -> i32 {
67 #impl_quote
68 }
69 }
70 }
71
72 fn limit_field<T>(
73 fields: &punctuated::Punctuated<syn::Field, T>,
74 ) -> core::result::Result<Self, &'static str> {
75 let matched_fields = fields
76 .iter()
77 .filter(|f| matches!(Attr::try_from(*f), Ok(Attr::Limit(_))))
78 .filter_map(|f| PaginatedStructField::try_from(f).ok())
79 .collect::<Vec<_>>();
80
81 if matched_fields.len() > 1 {
82 return Err("too many attributes");
83 }
84
85 Ok(matched_fields
86 .first()
87 .ok_or_else(|| "field not found")?
88 .clone())
89 }
90
91 fn offset_field<T>(
92 fields: &punctuated::Punctuated<syn::Field, T>,
93 ) -> core::result::Result<Self, &'static str> {
94 let matched_fields = fields
95 .iter()
96 .filter(|f| matches!(Attr::try_from(*f), Ok(Attr::Offset(_))))
97 .filter_map(|f| PaginatedStructField::try_from(f).ok())
98 .collect::<Vec<_>>();
99
100 if matched_fields.len() > 1 {
101 return Err("too many attributes");
102 }
103
104 Ok(matched_fields
105 .first()
106 .ok_or_else(|| "field not found")?
107 .clone())
108 }
109 }
110
111 #[derive(Clone, Debug)]
112 enum Attr {
113 Limit(LitInt),
114 Offset(LitInt),
115 }
116
117 impl Attr {
118 fn default_value(&self) -> &LitInt {
119 match self {
120 Attr::Limit(default) => default,
121 Attr::Offset(default) => default,
122 }
123 }
124 }
125
126 pub(super) mod extractors {
127 use super::*;
128
129 impl TryFrom<&DeriveInput> for PaginatedStruct {
130 type Error = &'static str;
131
132 fn try_from(input: &DeriveInput) -> core::result::Result<Self, Self::Error> {
133 match input.data {
134 syn::Data::Struct(syn::DataStruct {
135 fields: syn::Fields::Named(FieldsNamed { ref named, .. }),
136 ..
137 }) => Ok(PaginatedStruct {
138 name: input.ident.clone(),
139 limit: PaginatedStructField::limit_field(&named)?,
140 offset: PaginatedStructField::offset_field(&named)?,
141 }),
142 _ => Err("help!"),
143 }
144 }
145 }
146
147 impl TryFrom<&Field> for PaginatedStructField {
148 type Error = &'static str;
149
150 fn try_from(field: &Field) -> core::result::Result<Self, Self::Error> {
151 let ident_opt = field.ident.clone();
152 let default_value = Attr::try_from(field.attrs.as_slice())?
153 .default_value()
154 .clone();
155
156 match is_option_i32(&field.ty) {
157 true => Ok(PaginatedStructField {
158 ident_opt,
159 default_value,
160 }),
161 false => Err("not option i32"),
162 }
163 }
164 }
165
166 impl TryFrom<&Field> for Attr {
167 type Error = &'static str;
168
169 fn try_from(field: &Field) -> core::result::Result<Self, Self::Error> {
170 field.attrs.as_slice().try_into()
171 }
172 }
173
174 impl TryFrom<&[Attribute]> for Attr {
175 type Error = &'static str;
176
177 fn try_from(attrs: &[Attribute]) -> std::result::Result<Self, Self::Error> {
178 if attrs.len() != 1 {
179 return Err("unexpected attributes");
180 }
181
182 (&attrs[0]).try_into()
183 }
184 }
185
186 impl TryFrom<&Attribute> for Attr {
187 type Error = &'static str;
188
189 fn try_from(attr: &Attribute) -> core::result::Result<Self, Self::Error> {
190 let lit = match attr.parse_meta() {
191 Ok(Meta::List(MetaList { nested, .. })) if nested.len() == 1 => {
192 match &nested[0] {
193 NestedMeta::Meta(Meta::NameValue(MetaNameValue {
194 lit: Lit::Int(lit),
195 ..
196 })) => lit.clone(),
197 _ => return Err("unexpected attributes"),
198 }
199 }
200 _ => return Err("unexpected attributes"),
201 };
202
203 match attr.path.get_ident() {
204 Some(ident) if ident == "limit" => Ok(Attr::Limit(lit)),
205 Some(ident) if ident == "offset" => Ok(Attr::Offset(lit)),
206 _ => Err("unexpected attributes"),
207 }
208 }
209 }
210
211 fn is_option_i32(ty: &Type) -> bool {
212 match ty {
213 Type::Path(TypePath {
214 path: Path { segments, .. },
215 ..
216 }) if segments.len() == 1 => match &segments[0] {
217 PathSegment {
218 ident,
219 arguments:
220 PathArguments::AngleBracketed(AngleBracketedGenericArguments {
221 args: generic_args,
222 ..
223 }),
224 } if &ident.to_string() == "Option" && generic_args.len() == 1 => {
225 match &generic_args[0] {
226 GenericArgument::Type(Type::Path(TypePath { path, .. }))
227 if path.is_ident("i32") =>
228 {
229 true
230 }
231 _ => false,
232 }
233 }
234 _ => false,
235 },
236 _ => false,
237 }
238 }
239 }
240}