query_params_macro/
lib.rs

1//! QueryParams is a procedural macro for deriving a [`Hyper`]-centric representation
2//! of that struct as query parameters that can be easily appended to query parameters in the Hyper
3//! framework. *This crate is only meant to be tested and re-exported by the `QueryParams` crate,
4//! and is not meant for direct consumption.*
5//!
6//! [`Hyper`]: https://crates.io/crates/hyper
7use proc_macro::{self, TokenStream};
8use quote::quote;
9use std::collections::HashSet;
10use std::vec::Vec;
11use syn::__private::TokenStream2;
12use syn::{parse_macro_input, Attribute, DeriveInput, Field, Fields, Ident, LitStr, Path, Type};
13
14#[derive(Debug, Eq, PartialEq, Hash)]
15enum FieldAttributes {
16    Required,
17    Excluded,
18    Rename(String),
19}
20
21struct FieldDescription<'f> {
22    pub field: &'f Field,
23    pub field_name: String,
24    pub ident: Ident,
25    pub attributes: HashSet<FieldAttributes>,
26}
27
28/// [`QueryParams`] derives `fn to_query_params(&self) -> Vec<(String, String)>` for
29/// any struct with field values supporting `.to_string()`.
30///
31/// Optional values are only included if present,
32/// and fields marked `#[query(required)]` must be non-optional. Renaming and excluding of fields is
33/// also available, using `#[query(rename = "new_name")]` or `#[query(exclude)]` on the field.
34///
35/// # Example: Query Params
36/// QueryParams supports both required and optional fields, which won't be included in the output
37/// if their value is None.
38///
39/// ```
40/// # use query_params_macro::QueryParams;
41/// # // trait defined here again since it can't be provided by macro crate
42/// # pub trait ToQueryParams {
43/// #    fn to_query_params(&self) -> Vec<(String, String)>;
44/// # }
45/// // Eq and PartialEq are just for assertions
46/// #[derive(QueryParams, Debug, PartialEq, Eq)]
47/// struct ProductRequest {
48///     #[query(required)]
49///     id: i32,
50///     min_price: Option<i32>,
51///     max_price: Option<i32>,
52/// }
53///
54/// pub fn main() {
55///     let request = ProductRequest {
56///         id: 999, // will be included in output
57///         min_price: None, // will *not* be included in output
58///         max_price: Some(100), // will be included in output
59///     };
60///
61///     let expected = vec![
62///         ("id".into(), "999".into()),
63///         ("max_price".into(), "100".into())
64///     ];
65///     
66///     let query_params = request.to_query_params();
67///
68///     assert_eq!(expected, query_params);
69/// }
70/// ```
71///
72/// ## Attributes
73/// QueryParams supports attributes under `#[query(...)]` on individual fields to carry metadata.
74/// At this time, the available attributes are:
75/// - required -- marks a field as required, meaning it can be `T` instead of `Option<T>` on the struct
76/// and will always appear in the resulting `Vec`
77/// - rename -- marks a field to be renamed when it is output in the resulting Vec.
78/// E.g. `#[query(rename = "newName")]`
79/// - exclude -- marks a field to never be included in the output query params
80///
81/// # Example: Renaming and Excluding
82/// In some cases, names of query parameters are not valid identifiers, or don't adhere to Rust's
83/// default style of "snake_case". [`QueryParams`] can rename individual fields when creating the
84/// query parameters Vec if the attribute with the rename attribute: `#[query(rename = "new_name")]`.
85///
86/// In the below example, an API expects a type of product and a max price, given as
87/// `type=something&maxPrice=123`, which would be and invalid identifier and a non-Rust style
88/// field name respectively. A field containing local data that won't be included in the query
89/// is also tagged as `#[query(exclude)]` to exclude it.
90///
91/// ```
92/// # use query_params_macro::QueryParams;
93/// # // trait defined here again since it can't be provided by macro crate
94/// # pub trait ToQueryParams {
95/// #    fn to_query_params(&self) -> Vec<(String, String)>;
96/// # }
97/// // Eq and PartialEq are just for assertions
98/// #[derive(QueryParams, Debug, PartialEq, Eq)]
99/// struct ProductRequest {
100///     #[query(required)]
101///     id: i32,
102///     #[query(rename = "type")]
103///     product_type: Option<String>,
104///     #[query(rename = "maxPrice")]
105///     max_price: Option<i32>,
106///     #[query(exclude)]
107///     private_data: i32,
108/// }
109///
110/// pub fn main() {
111///     let request = ProductRequest {
112///         id: 999,
113///         product_type: Some("accessory".into()),
114///         max_price: Some(100),
115///         private_data: 42, // will not be part of the output
116///     };
117///
118///     let expected = vec![
119///         ("id".into(), "999".into()),
120///         ("type".into(), "accessory".into()),
121///         ("maxPrice".into(), "100".into())
122///     ];
123///     
124///     let query_params = request.to_query_params();
125///
126///     assert_eq!(expected, query_params);
127/// }
128/// ```
129#[proc_macro_derive(QueryParams, attributes(query))]
130pub fn derive(input: TokenStream) -> TokenStream {
131    let ast: DeriveInput = parse_macro_input!(input);
132    let ident = ast.ident;
133
134    let fields: &Fields = match ast.data {
135        syn::Data::Struct(ref s) => &s.fields,
136        _ => panic!("Can only derive QueryParams for structs."),
137    };
138
139    let named_fields: Vec<&Field> = fields
140        .iter()
141        .filter_map(|field| field.ident.as_ref().map(|_ident| field))
142        .collect();
143
144    let field_descriptions = named_fields
145        .into_iter()
146        .map(map_field_to_description)
147        .filter(|field| !field.attributes.contains(&FieldAttributes::Excluded))
148        .collect::<Vec<FieldDescription>>();
149
150    let required_fields: Vec<&FieldDescription> = field_descriptions
151        .iter()
152        .filter(|desc| desc.attributes.contains(&FieldAttributes::Required))
153        .collect();
154
155    let req_names: Vec<String> = required_fields
156        .iter()
157        .map(|field| field.field_name.clone())
158        .collect();
159
160    let req_idents: Vec<&Ident> = required_fields.iter().map(|field| &field.ident).collect();
161
162    let vec_definition = quote! {
163        let mut query_params: ::std::vec::Vec<(String, String)> =
164        vec![#((
165            #req_names.to_string(),
166            self.#req_idents.to_string()
167        )),*];
168    };
169
170    let vec_encoded_definition = quote! {
171        let mut query_params: ::std::vec::Vec<(String, String)> =
172        vec![#(
173            (
174                ::to_query_params::urlencoding::encode(#req_names).into_owned(),
175                ::to_query_params::urlencoding::encode(&self.#req_idents.to_string()).into_owned()
176            )
177        ),*];
178    };
179
180    let optional_fields: Vec<&FieldDescription> = field_descriptions
181        .iter()
182        .filter(|desc| !desc.attributes.contains(&FieldAttributes::Required))
183        .collect();
184
185    optional_fields.iter().for_each(validate_optional_field);
186
187    let optional_assignments: TokenStream2 = optional_fields
188        .iter()
189        .map(|field| {
190            let ident = &field.ident;
191            let name = &field.field_name;
192            quote! {
193                if let Some(val) = &self.#ident {
194                    query_params.push((
195                        #name.to_string(),
196                        val.to_string()
197                    ));
198                }
199            }
200        })
201        .collect();
202
203    let optional_encoded_assignments: TokenStream2 = optional_fields
204        .iter()
205        .map(|field| {
206            let ident = &field.ident;
207            let name = &field.field_name;
208            quote! {
209                if let Some(val) = &self.#ident {
210                    query_params.push(
211                        (
212                            ::to_query_params::urlencoding::encode(#name).into_owned(),
213                            ::to_query_params::urlencoding::encode(&val.to_string()).into_owned()
214                        )
215                    );
216                }
217            }
218        })
219        .collect();
220
221    let trait_impl = quote! {
222        #[allow(dead_code)]
223        impl ToQueryParams for #ident {
224            fn to_query_params(&self) -> ::std::vec::Vec<(String, String)> {
225                #vec_definition
226                #optional_assignments
227                query_params
228            }
229
230            fn to_encoded_params(&self) -> ::std::vec::Vec<(String, String)> {
231                #vec_encoded_definition
232                #optional_encoded_assignments
233                query_params
234            }
235        }
236    };
237
238    trait_impl.into()
239}
240
241fn map_field_to_description(field: &Field) -> FieldDescription {
242    let attributes = field
243        .attrs
244        .iter()
245        .flat_map(parse_query_attributes)
246        .collect::<HashSet<FieldAttributes>>();
247
248    let mut desc = FieldDescription {
249        field,
250        field_name: field.ident.as_ref().unwrap().to_string(),
251        ident: field.ident.clone().unwrap(),
252        attributes,
253    };
254
255    let name = name_from_field_description(&desc);
256    desc.field_name = name;
257    desc
258}
259
260fn name_from_field_description(field: &FieldDescription) -> String {
261    let mut name = field.ident.to_string();
262    for attribute in field.attributes.iter() {
263        if let FieldAttributes::Rename(rename) = attribute {
264            name = (*rename).clone();
265        }
266    }
267
268    name
269}
270
271fn parse_query_attributes(attr: &Attribute) -> Vec<FieldAttributes> {
272    let mut attrs = Vec::new();
273
274    if attr.path().is_ident("query") {
275        attr.parse_nested_meta(|m| {
276            if m.path.is_ident("required") {
277                attrs.push(FieldAttributes::Required);
278            }
279
280            if m.path.is_ident("exclude") {
281                attrs.push(FieldAttributes::Excluded);
282            }
283
284            if m.path.is_ident("rename") {
285                let value = m.value().unwrap();
286                let rename: LitStr = value.parse().unwrap();
287
288                attrs.push(FieldAttributes::Rename(rename.value()));
289            }
290
291            Ok(())
292        })
293        .expect("Unsupported attribute found in #[query(...)] attribute");
294    }
295
296    attrs
297}
298
299fn validate_optional_field(field_desc: &&FieldDescription) {
300    if let Type::Path(type_path) = &field_desc.field.ty {
301        if !(type_path.qself.is_none() && path_is_option(&type_path.path)) {
302            panic!("Non-optional types must be marked with #[query(required)] attribute")
303        }
304    }
305}
306
307fn path_is_option(path: &Path) -> bool {
308    path.leading_colon.is_none()
309        && path.segments.len() == 1
310        && path.segments.iter().next().unwrap().ident == "Option"
311}