query_params_macro/
lib.rs

1//! QueryParams is a procedural macro for deriving a [`Hyper`]-centric representation
2//! of a struct as query parameters that can be easily appended as query parameters in the Hyper
3//! framework. *This crate is only meant to be tested and re-exported by the `QueryParams` crate,
4//! and is not meant for direct consumption.*
5//!
6//! [`Hyper`]: https://crates.io/crates/hyper
7use proc_macro::{self, TokenStream};
8use quote::quote;
9use std::collections::HashSet;
10use std::vec::Vec;
11use syn::__private::TokenStream2;
12use syn::{Attribute, DeriveInput, Field, Fields, Ident, LitStr, Path, Type, parse_macro_input};
13
14#[derive(Debug, Eq, PartialEq, Hash)]
15enum FieldAttributes {
16    Required,
17    Excluded,
18    Rename(String),
19}
20
21struct FieldDescription<'f> {
22    pub field: &'f Field,
23    pub field_name: String,
24    pub ident: Ident,
25    pub attributes: HashSet<FieldAttributes>,
26}
27
28/// [`QueryParams`] derives `fn to_query_params(&self) -> Vec<(String, String)>` for
29/// any struct with field values supporting `.to_string()`.
30///
31/// Optional values are only included if present,
32/// and fields marked `#[query(required)]` must be non-optional. Renaming and excluding of fields is
33/// also available, using `#[query(rename = "new_name")]` or `#[query(exclude)]` on the field.
34///
35/// # Example: Query Params
36/// QueryParams supports both required and optional fields, which won't be included in the output
37/// if their value is None.
38///
39/// ```
40/// # use query_params_macro::QueryParams;
41/// # // trait defined here again since it can't be provided by macro crate
42/// # pub trait ToQueryParams {
43/// #    fn to_query_params(&self) -> Vec<(String, String)>;
44/// # }
45/// // Eq and PartialEq are just for assertions
46/// #[derive(QueryParams, Debug, PartialEq, Eq)]
47/// struct ProductRequest {
48///     #[query(required)]
49///     id: i32,
50///     min_price: Option<i32>,
51///     max_price: Option<i32>,
52/// }
53///
54/// pub fn main() {
55///     let request = ProductRequest {
56///         id: 999, // will be included in output
57///         min_price: None, // will *not* be included in output
58///         max_price: Some(100), // will be included in output
59///     };
60///
61///     let expected = vec![
62///         ("id".into(), "999".into()),
63///         ("max_price".into(), "100".into())
64///     ];
65///     
66///     let query_params = request.to_query_params();
67///
68///     assert_eq!(expected, query_params);
69/// }
70/// ```
71///
72/// ## Attributes
73/// QueryParams supports attributes under `#[query(...)]` on individual fields to carry metadata.
74/// At this time, the available attributes are:
75/// - required -- marks a field as required, meaning it can be `T` instead of `Option<T>` on the struct and will always appear in the resulting `Vec`
76/// - rename -- marks a field to be renamed when it is output in the resulting Vec, e.g. `#[query(rename = "newName")]`
77/// - exclude -- marks a field to never be included in the output query params
78///
79/// # Example: Renaming and Excluding
80/// In some cases, names of query parameters are not valid identifiers, or don't adhere to Rust's
81/// default style of "snake_case". [`QueryParams`] can rename individual fields when creating the
82/// query parameters Vec if the attribute with the rename attribute: `#[query(rename = "new_name")]`.
83///
84/// In the below example, an API expects a type of product and a max price, given as
85/// `type=something&maxPrice=123`, which would be an invalid identifier and a non-Rust style
86/// field name respectively. A field containing local data that won't be included in the query
87/// is also tagged as `#[query(exclude)]` to exclude it.
88///
89/// ```
90/// # use query_params_macro::QueryParams;
91/// # // trait defined here again since it can't be provided by macro crate
92/// # pub trait ToQueryParams {
93/// #    fn to_query_params(&self) -> Vec<(String, String)>;
94/// # }
95/// // Eq and PartialEq are just for assertions
96/// #[derive(QueryParams, Debug, PartialEq, Eq)]
97/// struct ProductRequest {
98///     #[query(required)]
99///     id: i32,
100///     #[query(rename = "type")]
101///     product_type: Option<String>,
102///     #[query(rename = "maxPrice")]
103///     max_price: Option<i32>,
104///     #[query(exclude)]
105///     private_data: i32,
106/// }
107///
108/// pub fn main() {
109///     let request = ProductRequest {
110///         id: 999,
111///         product_type: Some("accessory".into()),
112///         max_price: Some(100),
113///         private_data: 42, // will not be part of the output
114///     };
115///
116///     let expected = vec![
117///         ("id".into(), "999".into()),
118///         ("type".into(), "accessory".into()),
119///         ("maxPrice".into(), "100".into())
120///     ];
121///     
122///     let query_params = request.to_query_params();
123///
124///     assert_eq!(expected, query_params);
125/// }
126/// ```
127#[proc_macro_derive(QueryParams, attributes(query))]
128pub fn derive(input: TokenStream) -> TokenStream {
129    let ast: DeriveInput = parse_macro_input!(input);
130    let ident = ast.ident;
131
132    let fields: &Fields = match ast.data {
133        syn::Data::Struct(ref s) => &s.fields,
134        _ => panic!("Can only derive QueryParams for structs."),
135    };
136
137    let named_fields: Vec<&Field> = fields
138        .iter()
139        .filter_map(|field| field.ident.as_ref().map(|_ident| field))
140        .collect();
141
142    let field_descriptions = named_fields
143        .into_iter()
144        .map(map_field_to_description)
145        .filter(|field| !field.attributes.contains(&FieldAttributes::Excluded))
146        .collect::<Vec<FieldDescription>>();
147
148    let required_fields: Vec<&FieldDescription> = field_descriptions
149        .iter()
150        .filter(|desc| desc.attributes.contains(&FieldAttributes::Required))
151        .collect();
152
153    let req_names: Vec<String> = required_fields
154        .iter()
155        .map(|field| field.field_name.clone())
156        .collect();
157
158    let req_idents: Vec<&Ident> = required_fields.iter().map(|field| &field.ident).collect();
159
160    let vec_definition = quote! {
161        let mut query_params: ::std::vec::Vec<(String, String)> =
162        vec![#((
163            #req_names.to_string(),
164            self.#req_idents.to_string()
165        )),*];
166    };
167
168    let optional_fields: Vec<&FieldDescription> = field_descriptions
169        .iter()
170        .filter(|desc| !desc.attributes.contains(&FieldAttributes::Required))
171        .collect();
172
173    optional_fields.iter().for_each(validate_optional_field);
174
175    let optional_assignments: TokenStream2 = optional_fields
176        .iter()
177        .map(|field| {
178            let ident = &field.ident;
179            let name = &field.field_name;
180            quote! {
181                if let Some(val) = &self.#ident {
182                    query_params.push((
183                        #name.to_string(),
184                        val.to_string()
185                    ));
186                }
187            }
188        })
189        .collect();
190
191    let trait_impl = quote! {
192        #[allow(dead_code)]
193        impl to_query_params::ToQueryParams for #ident {
194            fn to_query_params(&self) -> ::std::vec::Vec<(String, String)> {
195                #vec_definition
196                #optional_assignments
197                query_params
198            }
199        }
200    };
201
202    trait_impl.into()
203}
204
205fn map_field_to_description<'f>(field: &'f Field) -> FieldDescription<'f> {
206    let attributes = field
207        .attrs
208        .iter()
209        .flat_map(parse_query_attributes)
210        .collect::<HashSet<FieldAttributes>>();
211
212    let mut desc = FieldDescription {
213        field,
214        field_name: field.ident.as_ref().unwrap().to_string(),
215        ident: field.ident.clone().unwrap(),
216        attributes,
217    };
218
219    let name = name_from_field_description(&desc);
220    desc.field_name = name;
221    desc
222}
223
224fn name_from_field_description(field: &FieldDescription) -> String {
225    let mut name = field.ident.to_string();
226    for attribute in field.attributes.iter() {
227        if let FieldAttributes::Rename(rename) = attribute {
228            name = (*rename).clone();
229        }
230    }
231
232    name
233}
234
235fn parse_query_attributes(attr: &Attribute) -> Vec<FieldAttributes> {
236    let mut attrs = Vec::new();
237
238    if attr.path().is_ident("query") {
239        attr.parse_nested_meta(|m| {
240            if m.path.is_ident("required") {
241                attrs.push(FieldAttributes::Required);
242            }
243
244            if m.path.is_ident("exclude") {
245                attrs.push(FieldAttributes::Excluded);
246            }
247
248            if m.path.is_ident("rename") {
249                let value = m.value().unwrap();
250                let rename: LitStr = value.parse().unwrap();
251
252                attrs.push(FieldAttributes::Rename(rename.value()));
253            }
254
255            Ok(())
256        })
257        .expect("Unsupported attribute found in #[query(...)] attribute");
258    }
259
260    attrs
261}
262
263fn validate_optional_field(field_desc: &&FieldDescription) {
264    if let Type::Path(type_path) = &field_desc.field.ty
265        && !(type_path.qself.is_none() && path_is_option(&type_path.path))
266    {
267        panic!("Non-optional types must be marked with #[query(required)] attribute")
268    }
269}
270
271fn path_is_option(path: &Path) -> bool {
272    path.leading_colon.is_none()
273        && path.segments.len() == 1
274        && path.segments.iter().next().unwrap().ident == "Option"
275}