thirtyfour_querier_derive/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{collections::HashSet, hash::Hash};
4
5use darling::{ast, FromDeriveInput, FromField};
6use quote::{format_ident, quote};
7use regex::Regex;
8use syn::{
9    parse_macro_input, DeriveInput, GenericArgument, Ident, Path, PathArguments, Type, TypePath,
10    TypeTuple,
11};
12
13#[derive(FromField)]
14#[darling(attributes(querier))]
15struct QuerierField {
16    ident: Option<Ident>,
17    #[allow(dead_code)]
18    ty: Type,
19    css: String,
20    #[darling(default)]
21    wait: Option<u64>,
22
23    #[darling(default)]
24    maybe: bool,
25    #[darling(default)]
26    all: bool,
27
28    #[darling(default)]
29    nested: bool,
30}
31
32#[derive(Clone, Copy)]
33enum RequiredNum {
34    Maybe,
35    One,
36    All,
37}
38
39impl RequiredNum {
40    fn new(maybe: bool, all: bool) -> Self {
41        match (maybe, all) {
42            (true, false) => Self::Maybe,
43            (false, false) => Self::One,
44            (false, true) => Self::All,
45            _ => panic!("#[quirer(maybe)] and #[quirer(all)] can't co-exist"),
46        }
47    }
48}
49
50#[derive(FromDeriveInput)]
51#[darling(supports(struct_named))]
52struct Querier {
53    ident: Ident,
54    data: ast::Data<darling::util::Ignored, QuerierField>,
55}
56
57fn unwrap_generic(ty: Type) -> Type {
58    let segment = match ty {
59        Type::Path(TypePath {
60            path: Path { segments, .. },
61            ..
62        }) => segments.last().unwrap().clone(),
63        _ => {
64            panic!("Expected Type<...> type in Querier");
65        }
66    };
67
68    let args = match segment.arguments {
69        PathArguments::AngleBracketed(args) => args.args,
70        _ => {
71            panic!("Expected Type<...> type in Querier");
72        }
73    };
74
75    assert_eq!(args.len(), 1, "Expected Type<...> type in Querier");
76
77    match &args[0] {
78        GenericArgument::Type(ty) => ty.clone(),
79        _ => panic!("Expected Type<...> type in Querier"),
80    }
81}
82
83fn unwrap_two_tuple(ty: Type) -> (Type, Type) {
84    let elems = match ty {
85        Type::Tuple(TypeTuple { elems, .. }) => elems,
86        _ => panic!("Expected (..., ...) tuple type"),
87    };
88    let elems = elems.into_iter().collect::<Vec<_>>();
89    assert_eq!(elems.len(), 2, "Expected (..., ...) tuple type");
90    let [t0, t1]: [Type; 2] = elems.try_into().unwrap();
91    (t0, t1)
92}
93
94#[proc_macro_derive(Querier, attributes(querier))]
95pub fn derive_querier_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
96    let input: DeriveInput = parse_macro_input!(input);
97    let querier: Querier = match Querier::from_derive_input(&input) {
98        Ok(q) => q,
99        Err(e) => {
100            return proc_macro::TokenStream::from(e.write_errors());
101        }
102    };
103
104    let fields = querier.data.take_struct().unwrap();
105
106    let mut field_names = vec![];
107    let mut query_async_blocks = vec![];
108    let re = Regex::new(r"\{(.*?)\}").unwrap();
109
110    let mut extra_args_names = vec![];
111
112    for qf in fields {
113        field_names.push(qf.ident.unwrap());
114
115        // Construct the query async block
116        let wait_clause = if let Some(wait) = qf.wait {
117            quote! { .wait(::std::time::Duration::from_secs(#wait), ::std::time::Duration::from_millis(150)) }
118        } else {
119            quote! { .nowait() }
120        };
121
122        let required_num = RequiredNum::new(qf.maybe, qf.all);
123        let fetch_clause = match required_num {
124            RequiredNum::Maybe => quote! { .first_opt() },
125            RequiredNum::One => quote! { .single() },
126            RequiredNum::All => quote! { .all() },
127        };
128
129        let css = qf.css;
130
131        // Parse all extra arguments and add to the set
132
133        let arg_names: Vec<_> = re
134            .captures_iter(&css)
135            .map(|cap| cap.get(1).unwrap().as_str().to_string())
136            .collect();
137        extra_args_names.extend(arg_names.iter().cloned());
138        let arg_names: Vec<_> = arg_names
139            .into_iter()
140            .map(|name| format_ident!("{}", name))
141            .map(|name| quote! { #name = #name })
142            .collect();
143
144        let query_clause = quote! {
145            .query(::thirtyfour::By::Css(&format!(#css, #( #arg_names ,)*)))
146        };
147
148        let query_stmts = if qf.nested {
149            match required_num {
150                RequiredNum::Maybe => {
151                    let (_, q_ty) = unwrap_two_tuple(unwrap_generic(qf.ty));
152                    quote! {
153                        let elem = driver
154                            #query_clause
155                            #wait_clause
156                            #fetch_clause
157                            .await?;
158                        if let Some(elem) = elem {
159                            let q = #q_ty::query(&elem).await?;
160                            ::std::result::Result::<
161                                ::std::option::Option<(::thirtyfour::WebElement, #q_ty)>,
162                                ::thirtyfour::error::WebDriverError
163                            >::Ok(Some((elem, q)))
164                        } else {
165                            ::std::result::Result::<
166                                ::std::option::Option<(::thirtyfour::WebElement, #q_ty)>,
167                                ::thirtyfour::error::WebDriverError
168                            >::Ok(None)
169                        }
170                    }
171                }
172                RequiredNum::One => {
173                    let (_, q_ty) = unwrap_two_tuple(qf.ty);
174                    quote! {
175                        let elem = driver
176                            #query_clause
177                            #wait_clause
178                            #fetch_clause
179                            .await?;
180                        let sub_querier = #q_ty::query(&elem).await?;
181                        ::std::result::Result::<(WebElement, #q_ty), ::thirtyfour::error::WebDriverError>::Ok((elem, sub_querier))
182                    }
183                }
184                RequiredNum::All => {
185                    let (_, q_ty) = unwrap_two_tuple(unwrap_generic(qf.ty));
186                    quote! {
187                        use ::thirtyfour::WebElement;
188                        let elems = driver
189                            #query_clause
190                            #wait_clause
191                            #fetch_clause
192                            .await?;
193                        let outputs =
194                            ::futures::future::try_join_all(elems.into_iter().map(|elem| async move {
195                                let sub_querier = #q_ty::query(&elem).await?;
196                                ::std::result::Result::<
197                                    (WebElement, #q_ty), ::thirtyfour::error::WebDriverError
198                                >::Ok((
199                                    elem,
200                                    sub_querier,
201                                ))
202                            }))
203                            .await?;
204                        ::std::result::Result::<
205                            Vec<(WebElement, #q_ty)>, ::thirtyfour::error::WebDriverError
206                        >::Ok(outputs)
207                    }
208                }
209            }
210        } else {
211            quote! {
212                driver
213                    #query_clause
214                    #wait_clause
215                    #fetch_clause
216                    .await
217            }
218        };
219
220        query_async_blocks.push(quote! {
221            async {
222                use ::thirtyfour::prelude::ElementQueryable;
223                #query_stmts
224            }
225        });
226    }
227
228    let query_body = quote! {
229        let ( #(#field_names ,)* ) = ::futures::join!(
230            #(#query_async_blocks ,)*
231        );
232        let ( #(#field_names ,)* ) = ( #(#field_names ? ,)* );
233        Ok(Self { #(#field_names),* })
234    };
235
236    dedup(&mut extra_args_names);
237
238    // Transform String into Ident (and uppercase the typenames)
239    let extra_args_typenames: Vec<_> = extra_args_names
240        .iter()
241        .map(|name| name.to_uppercase().to_string())
242        .map(|name| format_ident!("{}", name))
243        .collect();
244    let extra_args_names: Vec<_> = extra_args_names
245        .iter()
246        .map(|name| format_ident!("{}", name))
247        .collect();
248
249    // Like "ID: Display, NAME: Display"
250    let extra_args_typeargs: Vec<_> = extra_args_typenames
251        .iter()
252        .map(|typename| {
253            quote! { #typename: ::std::fmt::Display }
254        })
255        .collect();
256    // Like "id: ID, name: NAME"
257    let extra_args_args: Vec<_> = extra_args_typenames
258        .iter()
259        .zip(extra_args_names.iter())
260        .map(|(typename, name)| {
261            quote! { #name: #typename }
262        })
263        .collect();
264
265    let ident = querier.ident;
266    let output = quote! {
267        impl #ident {
268            pub async fn query<
269                T: ::thirtyfour::prelude::ElementQueryable,
270                #( #extra_args_typeargs ,)*
271            >(
272                driver: &T,
273                #( #extra_args_args ,)*
274            )
275                -> ::std::result::Result<Self, ::thirtyfour::error::WebDriverError> {
276                #query_body
277            }
278        }
279    };
280
281    output.into()
282}
283
284fn dedup<T: Eq + Hash + Clone>(v: &mut Vec<T>) {
285    let mut set = HashSet::new();
286
287    v.retain(|x| set.insert(x.clone()));
288}