thirtyfour_querier_derive/
lib.rs1#![doc = include_str!("../README.md")]
2
3use std::{collections::HashSet, hash::Hash};
4
5use darling::{ast, FromDeriveInput, FromField};
6use quote::{format_ident, quote};
7use regex::Regex;
8use syn::{
9 parse_macro_input, DeriveInput, GenericArgument, Ident, Path, PathArguments, Type, TypePath,
10 TypeTuple,
11};
12
13#[derive(FromField)]
14#[darling(attributes(querier))]
15struct QuerierField {
16 ident: Option<Ident>,
17 #[allow(dead_code)]
18 ty: Type,
19 css: String,
20 #[darling(default)]
21 wait: Option<u64>,
22
23 #[darling(default)]
24 maybe: bool,
25 #[darling(default)]
26 all: bool,
27
28 #[darling(default)]
29 nested: bool,
30}
31
32#[derive(Clone, Copy)]
33enum RequiredNum {
34 Maybe,
35 One,
36 All,
37}
38
39impl RequiredNum {
40 fn new(maybe: bool, all: bool) -> Self {
41 match (maybe, all) {
42 (true, false) => Self::Maybe,
43 (false, false) => Self::One,
44 (false, true) => Self::All,
45 _ => panic!("#[quirer(maybe)] and #[quirer(all)] can't co-exist"),
46 }
47 }
48}
49
50#[derive(FromDeriveInput)]
51#[darling(supports(struct_named))]
52struct Querier {
53 ident: Ident,
54 data: ast::Data<darling::util::Ignored, QuerierField>,
55}
56
57fn unwrap_generic(ty: Type) -> Type {
58 let segment = match ty {
59 Type::Path(TypePath {
60 path: Path { segments, .. },
61 ..
62 }) => segments.last().unwrap().clone(),
63 _ => {
64 panic!("Expected Type<...> type in Querier");
65 }
66 };
67
68 let args = match segment.arguments {
69 PathArguments::AngleBracketed(args) => args.args,
70 _ => {
71 panic!("Expected Type<...> type in Querier");
72 }
73 };
74
75 assert_eq!(args.len(), 1, "Expected Type<...> type in Querier");
76
77 match &args[0] {
78 GenericArgument::Type(ty) => ty.clone(),
79 _ => panic!("Expected Type<...> type in Querier"),
80 }
81}
82
83fn unwrap_two_tuple(ty: Type) -> (Type, Type) {
84 let elems = match ty {
85 Type::Tuple(TypeTuple { elems, .. }) => elems,
86 _ => panic!("Expected (..., ...) tuple type"),
87 };
88 let elems = elems.into_iter().collect::<Vec<_>>();
89 assert_eq!(elems.len(), 2, "Expected (..., ...) tuple type");
90 let [t0, t1]: [Type; 2] = elems.try_into().unwrap();
91 (t0, t1)
92}
93
94#[proc_macro_derive(Querier, attributes(querier))]
95pub fn derive_querier_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
96 let input: DeriveInput = parse_macro_input!(input);
97 let querier: Querier = match Querier::from_derive_input(&input) {
98 Ok(q) => q,
99 Err(e) => {
100 return proc_macro::TokenStream::from(e.write_errors());
101 }
102 };
103
104 let fields = querier.data.take_struct().unwrap();
105
106 let mut field_names = vec![];
107 let mut query_async_blocks = vec![];
108 let re = Regex::new(r"\{(.*?)\}").unwrap();
109
110 let mut extra_args_names = vec![];
111
112 for qf in fields {
113 field_names.push(qf.ident.unwrap());
114
115 let wait_clause = if let Some(wait) = qf.wait {
117 quote! { .wait(::std::time::Duration::from_secs(#wait), ::std::time::Duration::from_millis(150)) }
118 } else {
119 quote! { .nowait() }
120 };
121
122 let required_num = RequiredNum::new(qf.maybe, qf.all);
123 let fetch_clause = match required_num {
124 RequiredNum::Maybe => quote! { .first_opt() },
125 RequiredNum::One => quote! { .single() },
126 RequiredNum::All => quote! { .all() },
127 };
128
129 let css = qf.css;
130
131 let arg_names: Vec<_> = re
134 .captures_iter(&css)
135 .map(|cap| cap.get(1).unwrap().as_str().to_string())
136 .collect();
137 extra_args_names.extend(arg_names.iter().cloned());
138 let arg_names: Vec<_> = arg_names
139 .into_iter()
140 .map(|name| format_ident!("{}", name))
141 .map(|name| quote! { #name = #name })
142 .collect();
143
144 let query_clause = quote! {
145 .query(::thirtyfour::By::Css(&format!(#css, #( #arg_names ,)*)))
146 };
147
148 let query_stmts = if qf.nested {
149 match required_num {
150 RequiredNum::Maybe => {
151 let (_, q_ty) = unwrap_two_tuple(unwrap_generic(qf.ty));
152 quote! {
153 let elem = driver
154 #query_clause
155 #wait_clause
156 #fetch_clause
157 .await?;
158 if let Some(elem) = elem {
159 let q = #q_ty::query(&elem).await?;
160 ::std::result::Result::<
161 ::std::option::Option<(::thirtyfour::WebElement, #q_ty)>,
162 ::thirtyfour::error::WebDriverError
163 >::Ok(Some((elem, q)))
164 } else {
165 ::std::result::Result::<
166 ::std::option::Option<(::thirtyfour::WebElement, #q_ty)>,
167 ::thirtyfour::error::WebDriverError
168 >::Ok(None)
169 }
170 }
171 }
172 RequiredNum::One => {
173 let (_, q_ty) = unwrap_two_tuple(qf.ty);
174 quote! {
175 let elem = driver
176 #query_clause
177 #wait_clause
178 #fetch_clause
179 .await?;
180 let sub_querier = #q_ty::query(&elem).await?;
181 ::std::result::Result::<(WebElement, #q_ty), ::thirtyfour::error::WebDriverError>::Ok((elem, sub_querier))
182 }
183 }
184 RequiredNum::All => {
185 let (_, q_ty) = unwrap_two_tuple(unwrap_generic(qf.ty));
186 quote! {
187 use ::thirtyfour::WebElement;
188 let elems = driver
189 #query_clause
190 #wait_clause
191 #fetch_clause
192 .await?;
193 let outputs =
194 ::futures::future::try_join_all(elems.into_iter().map(|elem| async move {
195 let sub_querier = #q_ty::query(&elem).await?;
196 ::std::result::Result::<
197 (WebElement, #q_ty), ::thirtyfour::error::WebDriverError
198 >::Ok((
199 elem,
200 sub_querier,
201 ))
202 }))
203 .await?;
204 ::std::result::Result::<
205 Vec<(WebElement, #q_ty)>, ::thirtyfour::error::WebDriverError
206 >::Ok(outputs)
207 }
208 }
209 }
210 } else {
211 quote! {
212 driver
213 #query_clause
214 #wait_clause
215 #fetch_clause
216 .await
217 }
218 };
219
220 query_async_blocks.push(quote! {
221 async {
222 use ::thirtyfour::prelude::ElementQueryable;
223 #query_stmts
224 }
225 });
226 }
227
228 let query_body = quote! {
229 let ( #(#field_names ,)* ) = ::futures::join!(
230 #(#query_async_blocks ,)*
231 );
232 let ( #(#field_names ,)* ) = ( #(#field_names ? ,)* );
233 Ok(Self { #(#field_names),* })
234 };
235
236 dedup(&mut extra_args_names);
237
238 let extra_args_typenames: Vec<_> = extra_args_names
240 .iter()
241 .map(|name| name.to_uppercase().to_string())
242 .map(|name| format_ident!("{}", name))
243 .collect();
244 let extra_args_names: Vec<_> = extra_args_names
245 .iter()
246 .map(|name| format_ident!("{}", name))
247 .collect();
248
249 let extra_args_typeargs: Vec<_> = extra_args_typenames
251 .iter()
252 .map(|typename| {
253 quote! { #typename: ::std::fmt::Display }
254 })
255 .collect();
256 let extra_args_args: Vec<_> = extra_args_typenames
258 .iter()
259 .zip(extra_args_names.iter())
260 .map(|(typename, name)| {
261 quote! { #name: #typename }
262 })
263 .collect();
264
265 let ident = querier.ident;
266 let output = quote! {
267 impl #ident {
268 pub async fn query<
269 T: ::thirtyfour::prelude::ElementQueryable,
270 #( #extra_args_typeargs ,)*
271 >(
272 driver: &T,
273 #( #extra_args_args ,)*
274 )
275 -> ::std::result::Result<Self, ::thirtyfour::error::WebDriverError> {
276 #query_body
277 }
278 }
279 };
280
281 output.into()
282}
283
284fn dedup<T: Eq + Hash + Clone>(v: &mut Vec<T>) {
285 let mut set = HashSet::new();
286
287 v.retain(|x| set.insert(x.clone()));
288}