Skip to main content

ploidy_codegen_rust/
query.rs

1use itertools::Itertools;
2use ploidy_core::{
3    codegen::UniqueNames,
4    ir::{OperationView, ParameterStyle, ParameterView, QueryParameter, View},
5};
6use proc_macro2::TokenStream;
7use quote::{ToTokens, TokenStreamExt, format_ident, quote};
8
9use super::{
10    derives::ExtraDerive,
11    ext::ParameterViewExt,
12    naming::{CodegenIdent, CodegenIdentScope, CodegenIdentUsage},
13    ref_::CodegenRef,
14};
15
16/// Generates a query parameter struct for an API operation.
17///
18/// The generated struct is named `{OperationId}Query`.
19/// It bundles all query parameters for that operation,
20/// derives `Serialize`, and has an associated `STYLES` table
21/// with per-parameter serialization style overrides.
22#[derive(Debug)]
23pub struct CodegenQueryParameters<'a> {
24    op: &'a OperationView<'a>,
25}
26
27impl<'a> CodegenQueryParameters<'a> {
28    /// Creates a new query parameter struct for the given operation.
29    #[inline]
30    pub fn new(op: &'a OperationView<'a>) -> Self {
31        Self { op }
32    }
33}
34
35impl ToTokens for CodegenQueryParameters<'_> {
36    fn to_tokens(&self, tokens: &mut TokenStream) {
37        let op_ident = CodegenIdent::new(self.op.id());
38        let query_name = format_ident!("{}Query", CodegenIdentUsage::Type(&op_ident));
39
40        let mut extra_derives = vec![];
41
42        // Derive `Eq` and `Hash` if all parameter types, and their
43        // transitively referenced types, are hashable.
44        if self.op.query().all(|param| param.hashable()) {
45            extra_derives.push(ExtraDerive::Eq);
46            extra_derives.push(ExtraDerive::Hash);
47        }
48
49        // Derive `Default` if all required parameters, and their
50        // transitively referenced types, are defaultable.
51        // Optional parameters become `Option<T>`, which is `Default`.
52        if self
53            .op
54            .query()
55            .all(|param| !param.required() || param.defaultable())
56        {
57            extra_derives.push(ExtraDerive::Default);
58        }
59
60        let unique = UniqueNames::new();
61        let mut scope = CodegenIdentScope::new(&unique);
62
63        let params = self
64            .op
65            .query()
66            .map(|param| (scope.uniquify(param.name()), param))
67            .collect_vec();
68
69        let fields = params.iter().map(|(ident, param)| {
70            let field_name = CodegenIdentUsage::Field(ident);
71            let serde_attr = SerdeQueryFieldAttr::new(field_name, param);
72
73            let ty = if param.optional() {
74                let view = param.ty();
75                let path = CodegenRef::new(&view);
76                quote! { ::std::option::Option<#path> }
77            } else {
78                let view = param.ty();
79                let path = CodegenRef::new(&view);
80                quote!(#path)
81            };
82
83            quote! {
84                #serde_attr
85                pub #field_name: #ty,
86            }
87        });
88
89        let styles = params
90            .iter()
91            .filter_map(|(_, param)| Some((param.name(), param.style()?)))
92            .map(|(name, style)| {
93                let style = match style {
94                    ParameterStyle::DeepObject => {
95                        quote!(::ploidy_util::QueryStyle::DeepObject)
96                    }
97                    ParameterStyle::SpaceDelimited => {
98                        quote!(::ploidy_util::QueryStyle::SpaceDelimited)
99                    }
100                    ParameterStyle::PipeDelimited => {
101                        quote!(::ploidy_util::QueryStyle::PipeDelimited)
102                    }
103                    ParameterStyle::Form { exploded } => {
104                        quote!(::ploidy_util::QueryStyle::Form { exploded: #exploded })
105                    }
106                };
107                quote!((#name, #style))
108            });
109
110        tokens.append_all(quote! {
111            #[derive(Debug, Clone, PartialEq, #(#extra_derives,)* ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
112            #[serde(crate = "::ploidy_util::serde")]
113            pub struct #query_name {
114                #(#fields)*
115            }
116
117            impl #query_name {
118                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[#(#styles,)*];
119            }
120        });
121    }
122}
123
124/// Generates a `#[serde(...)]` attribute for a query parameter struct field.
125#[derive(Debug)]
126struct SerdeQueryFieldAttr<'param, 'a> {
127    field_name: CodegenIdentUsage<'param>,
128    param: &'param ParameterView<'param, 'a, QueryParameter>,
129}
130
131impl<'param, 'a> SerdeQueryFieldAttr<'param, 'a> {
132    fn new(
133        field_name: CodegenIdentUsage<'param>,
134        param: &'param ParameterView<'param, 'a, QueryParameter>,
135    ) -> Self {
136        Self { field_name, param }
137    }
138}
139
140impl ToTokens for SerdeQueryFieldAttr<'_, '_> {
141    fn to_tokens(&self, tokens: &mut TokenStream) {
142        let mut attrs = vec![];
143
144        let param_name = self.param.name();
145        if self.field_name.display().to_string() != param_name {
146            attrs.push(quote! { rename = #param_name });
147        }
148
149        if self.param.optional() {
150            attrs.push(quote! { default });
151            attrs.push(quote! { skip_serializing_if = "Option::is_none" });
152        }
153
154        if !attrs.is_empty() {
155            tokens.append_all(quote! { #[serde(#(#attrs),*)] });
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    use ploidy_core::{
165        arena::Arena,
166        ir::{RawGraph, Spec},
167        parse::Document,
168    };
169    use pretty_assertions::assert_eq;
170    use syn::parse_quote;
171
172    use crate::CodegenGraph;
173
174    #[test]
175    fn test_all_optional_query_params() {
176        let doc = Document::from_yaml(indoc::indoc! {"
177            openapi: 3.0.0
178            info:
179              title: Test API
180              version: 1.0.0
181            paths:
182              /charts/{chart_id}:
183                get:
184                  operationId: fetchChart
185                  parameters:
186                    - name: chart_id
187                      in: path
188                      required: true
189                      schema:
190                        type: string
191                    - name: refresh
192                      in: query
193                      schema:
194                        type: boolean
195                    - name: client_job_id
196                      in: query
197                      schema:
198                        type: string
199                    - name: partition_idx
200                      in: query
201                      schema:
202                        type: integer
203                        format: int32
204                  responses:
205                    '200':
206                      description: OK
207        "})
208        .unwrap();
209
210        let arena = Arena::new();
211        let spec = Spec::from_doc(&arena, &doc).unwrap();
212        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
213
214        let op = graph.operations().next().unwrap();
215        let codegen = CodegenQueryParameters::new(&op);
216
217        let actual: syn::File = parse_quote!(#codegen);
218        let expected: syn::File = parse_quote! {
219            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
220            #[serde(crate = "::ploidy_util::serde")]
221            pub struct FetchChartQuery {
222                #[serde(default, skip_serializing_if = "Option::is_none")]
223                pub refresh: ::std::option::Option<bool>,
224                #[serde(default, skip_serializing_if = "Option::is_none")]
225                pub client_job_id: ::std::option::Option<::std::string::String>,
226                #[serde(default, skip_serializing_if = "Option::is_none")]
227                pub partition_idx: ::std::option::Option<i32>,
228            }
229
230            impl FetchChartQuery {
231                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
232            }
233        };
234        assert_eq!(actual, expected);
235    }
236
237    #[test]
238    fn test_required_and_optional_query_params() {
239        let doc = Document::from_yaml(indoc::indoc! {"
240            openapi: 3.0.0
241            info:
242              title: Test API
243              version: 1.0.0
244            paths:
245              /items:
246                get:
247                  operationId: listItems
248                  parameters:
249                    - name: page
250                      in: query
251                      required: true
252                      schema:
253                        type: integer
254                        format: int32
255                    - name: perPage
256                      in: query
257                      style: pipeDelimited
258                      schema:
259                        type: array
260                        items:
261                          type: integer
262                          format: int32
263                  responses:
264                    '200':
265                      description: OK
266        "})
267        .unwrap();
268
269        let arena = Arena::new();
270        let spec = Spec::from_doc(&arena, &doc).unwrap();
271        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
272
273        let op = graph.operations().next().unwrap();
274        let codegen = CodegenQueryParameters::new(&op);
275
276        let actual: syn::File = parse_quote!(#codegen);
277        let expected: syn::File = parse_quote! {
278            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
279            #[serde(crate = "::ploidy_util::serde")]
280            pub struct ListItemsQuery {
281                pub page: i32,
282                #[serde(rename = "perPage", default, skip_serializing_if = "Option::is_none")]
283                pub per_page: ::std::option::Option<::std::vec::Vec<i32>>,
284            }
285
286            impl ListItemsQuery {
287                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[
288                    ("perPage", ::ploidy_util::QueryStyle::PipeDelimited),
289                ];
290            }
291        };
292        assert_eq!(actual, expected);
293    }
294
295    #[test]
296    fn test_excludes_eq_hash_for_float_params() {
297        let doc = Document::from_yaml(indoc::indoc! {"
298            openapi: 3.0.0
299            info:
300              title: Test API
301              version: 1.0.0
302            paths:
303              /items:
304                get:
305                  operationId: getItems
306                  parameters:
307                    - name: threshold
308                      in: query
309                      schema:
310                        type: number
311                        format: double
312                  responses:
313                    '200':
314                      description: OK
315        "})
316        .unwrap();
317
318        let arena = Arena::new();
319        let spec = Spec::from_doc(&arena, &doc).unwrap();
320        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
321
322        let op = graph.operations().next().unwrap();
323        let codegen = CodegenQueryParameters::new(&op);
324
325        let actual: syn::File = parse_quote!(#codegen);
326        let expected: syn::File = parse_quote! {
327            #[derive(Debug, Clone, PartialEq, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
328            #[serde(crate = "::ploidy_util::serde")]
329            pub struct GetItemsQuery {
330                #[serde(default, skip_serializing_if = "Option::is_none")]
331                pub threshold: ::std::option::Option<f64>,
332            }
333
334            impl GetItemsQuery {
335                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
336            }
337        };
338        assert_eq!(actual, expected);
339    }
340
341    #[test]
342    fn test_excludes_default_for_non_defaultable_required_param() {
343        let doc = Document::from_yaml(indoc::indoc! {"
344            openapi: 3.0.0
345            info:
346              title: Test API
347              version: 1.0.0
348            paths:
349              /items:
350                get:
351                  operationId: getItems
352                  parameters:
353                    - name: callback
354                      in: query
355                      required: true
356                      schema:
357                        type: string
358                        format: uri
359                  responses:
360                    '200':
361                      description: OK
362        "})
363        .unwrap();
364
365        let arena = Arena::new();
366        let spec = Spec::from_doc(&arena, &doc).unwrap();
367        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
368
369        let op = graph.operations().next().unwrap();
370        let codegen = CodegenQueryParameters::new(&op);
371
372        let actual: syn::File = parse_quote!(#codegen);
373        let expected: syn::File = parse_quote! {
374            #[derive(Debug, Clone, PartialEq, Eq, Hash, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
375            #[serde(crate = "::ploidy_util::serde")]
376            pub struct GetItemsQuery {
377                pub callback: ::ploidy_util::url::Url,
378            }
379
380            impl GetItemsQuery {
381                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
382            }
383        };
384        assert_eq!(actual, expected);
385    }
386
387    #[test]
388    fn test_query_parameter_styles() {
389        let doc = Document::from_yaml(indoc::indoc! {"
390            openapi: 3.0.0
391            info:
392              title: Test API
393              version: 1.0.0
394            paths:
395              /items:
396                get:
397                  operationId: listItems
398                  parameters:
399                    - name: filter
400                      in: query
401                      style: deepObject
402                      schema:
403                        type: object
404                        properties:
405                          status:
406                            type: string
407                    - name: tags
408                      in: query
409                      style: pipeDelimited
410                      schema:
411                        type: array
412                        items:
413                          type: string
414                    - name: ids
415                      in: query
416                      style: spaceDelimited
417                      schema:
418                        type: array
419                        items:
420                          type: integer
421                          format: int32
422                    - name: colors
423                      in: query
424                      style: form
425                      explode: false
426                      schema:
427                        type: array
428                        items:
429                          type: string
430                  responses:
431                    '200':
432                      description: OK
433        "})
434        .unwrap();
435
436        let arena = Arena::new();
437        let spec = Spec::from_doc(&arena, &doc).unwrap();
438        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
439
440        let op = graph.operations().next().unwrap();
441        let codegen = CodegenQueryParameters::new(&op);
442
443        let actual: syn::File = parse_quote!(#codegen);
444        let expected: syn::File = parse_quote! {
445            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
446            #[serde(crate = "::ploidy_util::serde")]
447            pub struct ListItemsQuery {
448                #[serde(default, skip_serializing_if = "Option::is_none")]
449                pub filter: ::std::option::Option<crate::client::default::types::ListItemsFilter>,
450                #[serde(default, skip_serializing_if = "Option::is_none")]
451                pub tags: ::std::option::Option<::std::vec::Vec<::std::string::String>>,
452                #[serde(default, skip_serializing_if = "Option::is_none")]
453                pub ids: ::std::option::Option<::std::vec::Vec<i32>>,
454                #[serde(default, skip_serializing_if = "Option::is_none")]
455                pub colors: ::std::option::Option<::std::vec::Vec<::std::string::String>>,
456            }
457
458            impl ListItemsQuery {
459                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[
460                    ("filter", ::ploidy_util::QueryStyle::DeepObject),
461                    ("tags", ::ploidy_util::QueryStyle::PipeDelimited),
462                    ("ids", ::ploidy_util::QueryStyle::SpaceDelimited),
463                    ("colors", ::ploidy_util::QueryStyle::Form { exploded: false }),
464                ];
465            }
466        };
467        assert_eq!(actual, expected);
468    }
469
470    #[test]
471    fn test_ref_query_parameter() {
472        let doc = Document::from_yaml(indoc::indoc! {"
473            openapi: 3.0.0
474            info:
475              title: Test API
476              version: 1.0.0
477            paths:
478              /items:
479                get:
480                  operationId: listItems
481                  parameters:
482                    - name: sort
483                      in: query
484                      schema:
485                        $ref: '#/components/schemas/SortOrder'
486                  responses:
487                    '200':
488                      description: OK
489            components:
490              schemas:
491                SortOrder:
492                  type: string
493                  enum:
494                    - asc
495                    - desc
496        "})
497        .unwrap();
498
499        let arena = Arena::new();
500        let spec = Spec::from_doc(&arena, &doc).unwrap();
501        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
502
503        let op = graph.operations().next().unwrap();
504        let codegen = CodegenQueryParameters::new(&op);
505
506        let actual: syn::File = parse_quote!(#codegen);
507        let expected: syn::File = parse_quote! {
508            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
509            #[serde(crate = "::ploidy_util::serde")]
510            pub struct ListItemsQuery {
511                #[serde(default, skip_serializing_if = "Option::is_none")]
512                pub sort: ::std::option::Option<crate::types::SortOrder>,
513            }
514
515            impl ListItemsQuery {
516                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
517            }
518        };
519        assert_eq!(actual, expected);
520    }
521}