Skip to main content

ploidy_codegen_rust/
query.rs

1use itertools::Itertools;
2use ploidy_core::{
3    codegen::UniqueNames,
4    ir::{OperationView, ParameterStyle, ParameterView, QueryParameter},
5};
6use proc_macro2::TokenStream;
7use quote::{ToTokens, TokenStreamExt, format_ident, quote};
8
9use super::{
10    derives::ExtraDerive,
11    ext::ParameterViewExt,
12    ext::ViewExt,
13    naming::{CodegenIdent, CodegenIdentScope, CodegenIdentUsage},
14    ref_::CodegenRef,
15};
16
17/// Generates a query parameter struct for an API operation.
18///
19/// The generated struct is named `{OperationId}Query`.
20/// It bundles all query parameters for that operation,
21/// derives `Serialize`, and has an associated `STYLES` table
22/// with per-parameter serialization style overrides.
23#[derive(Debug)]
24pub struct CodegenQueryParameters<'a> {
25    op: &'a OperationView<'a>,
26}
27
28impl<'a> CodegenQueryParameters<'a> {
29    /// Creates a new query parameter struct for the given operation.
30    #[inline]
31    pub fn new(op: &'a OperationView<'a>) -> Self {
32        Self { op }
33    }
34}
35
36impl ToTokens for CodegenQueryParameters<'_> {
37    fn to_tokens(&self, tokens: &mut TokenStream) {
38        let op_ident = CodegenIdent::new(self.op.id());
39        let query_name = format_ident!("{}Query", CodegenIdentUsage::Type(&op_ident));
40
41        let mut extra_derives = vec![];
42
43        // Derive `Eq` and `Hash` if all parameter types, and their
44        // transitively referenced types, are hashable.
45        if self.op.query().all(|param| param.hashable()) {
46            extra_derives.push(ExtraDerive::Eq);
47            extra_derives.push(ExtraDerive::Hash);
48        }
49
50        // Derive `Default` if all required parameters, and their
51        // transitively referenced types, are defaultable.
52        // Optional parameters become `Option<T>`, which is `Default`.
53        if self
54            .op
55            .query()
56            .all(|param| !param.required() || param.defaultable())
57        {
58            extra_derives.push(ExtraDerive::Default);
59        }
60
61        let unique = UniqueNames::new();
62        let mut scope = CodegenIdentScope::new(&unique);
63
64        let params = self
65            .op
66            .query()
67            .map(|param| (scope.uniquify(param.name()), param))
68            .collect_vec();
69
70        let fields = params.iter().map(|(ident, param)| {
71            let field_name = CodegenIdentUsage::Field(ident);
72            let serde_attr = SerdeQueryFieldAttr::new(field_name, param);
73
74            let ty = if param.optional() {
75                let view = param.ty();
76                let path = CodegenRef::new(&view);
77                quote! { ::std::option::Option<#path> }
78            } else {
79                let view = param.ty();
80                let path = CodegenRef::new(&view);
81                quote!(#path)
82            };
83
84            quote! {
85                #serde_attr
86                pub #field_name: #ty,
87            }
88        });
89
90        let styles = params
91            .iter()
92            .filter_map(|(_, param)| Some((param.name(), param.style()?)))
93            .map(|(name, style)| {
94                let style = match style {
95                    ParameterStyle::DeepObject => {
96                        quote!(::ploidy_util::QueryStyle::DeepObject)
97                    }
98                    ParameterStyle::SpaceDelimited => {
99                        quote!(::ploidy_util::QueryStyle::SpaceDelimited)
100                    }
101                    ParameterStyle::PipeDelimited => {
102                        quote!(::ploidy_util::QueryStyle::PipeDelimited)
103                    }
104                    ParameterStyle::Form { exploded } => {
105                        quote!(::ploidy_util::QueryStyle::Form { exploded: #exploded })
106                    }
107                };
108                quote!((#name, #style))
109            });
110
111        tokens.append_all(quote! {
112            #[derive(Debug, Clone, PartialEq, #(#extra_derives,)* ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
113            #[serde(crate = "::ploidy_util::serde")]
114            pub struct #query_name {
115                #(#fields)*
116            }
117
118            impl #query_name {
119                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[#(#styles,)*];
120            }
121        });
122    }
123}
124
125/// Generates a `#[serde(...)]` attribute for a query parameter struct field.
126#[derive(Debug)]
127struct SerdeQueryFieldAttr<'param, 'a> {
128    field_name: CodegenIdentUsage<'param>,
129    param: &'param ParameterView<'param, 'a, QueryParameter>,
130}
131
132impl<'param, 'a> SerdeQueryFieldAttr<'param, 'a> {
133    fn new(
134        field_name: CodegenIdentUsage<'param>,
135        param: &'param ParameterView<'param, 'a, QueryParameter>,
136    ) -> Self {
137        Self { field_name, param }
138    }
139}
140
141impl ToTokens for SerdeQueryFieldAttr<'_, '_> {
142    fn to_tokens(&self, tokens: &mut TokenStream) {
143        let mut attrs = vec![];
144
145        let param_name = self.param.name();
146        if self.field_name.display().to_string() != param_name {
147            attrs.push(quote! { rename = #param_name });
148        }
149
150        if self.param.optional() {
151            attrs.push(quote! { default });
152            attrs.push(quote! { skip_serializing_if = "Option::is_none" });
153        }
154
155        if !attrs.is_empty() {
156            tokens.append_all(quote! { #[serde(#(#attrs),*)] });
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    use ploidy_core::{
166        arena::Arena,
167        ir::{RawGraph, Spec},
168        parse::Document,
169    };
170    use pretty_assertions::assert_eq;
171    use syn::parse_quote;
172
173    use crate::CodegenGraph;
174
175    #[test]
176    fn test_all_optional_query_params() {
177        let doc = Document::from_yaml(indoc::indoc! {"
178            openapi: 3.0.0
179            info:
180              title: Test API
181              version: 1.0.0
182            paths:
183              /charts/{chart_id}:
184                get:
185                  operationId: fetchChart
186                  parameters:
187                    - name: chart_id
188                      in: path
189                      required: true
190                      schema:
191                        type: string
192                    - name: refresh
193                      in: query
194                      schema:
195                        type: boolean
196                    - name: client_job_id
197                      in: query
198                      schema:
199                        type: string
200                    - name: partition_idx
201                      in: query
202                      schema:
203                        type: integer
204                        format: int32
205                  responses:
206                    '200':
207                      description: OK
208        "})
209        .unwrap();
210
211        let arena = Arena::new();
212        let spec = Spec::from_doc(&arena, &doc).unwrap();
213        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
214
215        let op = graph.operations().next().unwrap();
216        let codegen = CodegenQueryParameters::new(&op);
217
218        let actual: syn::File = parse_quote!(#codegen);
219        let expected: syn::File = parse_quote! {
220            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
221            #[serde(crate = "::ploidy_util::serde")]
222            pub struct FetchChartQuery {
223                #[serde(default, skip_serializing_if = "Option::is_none")]
224                pub refresh: ::std::option::Option<bool>,
225                #[serde(default, skip_serializing_if = "Option::is_none")]
226                pub client_job_id: ::std::option::Option<::std::string::String>,
227                #[serde(default, skip_serializing_if = "Option::is_none")]
228                pub partition_idx: ::std::option::Option<i32>,
229            }
230
231            impl FetchChartQuery {
232                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
233            }
234        };
235        assert_eq!(actual, expected);
236    }
237
238    #[test]
239    fn test_required_and_optional_query_params() {
240        let doc = Document::from_yaml(indoc::indoc! {"
241            openapi: 3.0.0
242            info:
243              title: Test API
244              version: 1.0.0
245            paths:
246              /items:
247                get:
248                  operationId: listItems
249                  parameters:
250                    - name: page
251                      in: query
252                      required: true
253                      schema:
254                        type: integer
255                        format: int32
256                    - name: perPage
257                      in: query
258                      style: pipeDelimited
259                      schema:
260                        type: array
261                        items:
262                          type: integer
263                          format: int32
264                  responses:
265                    '200':
266                      description: OK
267        "})
268        .unwrap();
269
270        let arena = Arena::new();
271        let spec = Spec::from_doc(&arena, &doc).unwrap();
272        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
273
274        let op = graph.operations().next().unwrap();
275        let codegen = CodegenQueryParameters::new(&op);
276
277        let actual: syn::File = parse_quote!(#codegen);
278        let expected: syn::File = parse_quote! {
279            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
280            #[serde(crate = "::ploidy_util::serde")]
281            pub struct ListItemsQuery {
282                pub page: i32,
283                #[serde(rename = "perPage", default, skip_serializing_if = "Option::is_none")]
284                pub per_page: ::std::option::Option<::std::vec::Vec<i32>>,
285            }
286
287            impl ListItemsQuery {
288                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[
289                    ("perPage", ::ploidy_util::QueryStyle::PipeDelimited),
290                ];
291            }
292        };
293        assert_eq!(actual, expected);
294    }
295
296    #[test]
297    fn test_excludes_eq_hash_for_float_params() {
298        let doc = Document::from_yaml(indoc::indoc! {"
299            openapi: 3.0.0
300            info:
301              title: Test API
302              version: 1.0.0
303            paths:
304              /items:
305                get:
306                  operationId: getItems
307                  parameters:
308                    - name: threshold
309                      in: query
310                      schema:
311                        type: number
312                        format: double
313                  responses:
314                    '200':
315                      description: OK
316        "})
317        .unwrap();
318
319        let arena = Arena::new();
320        let spec = Spec::from_doc(&arena, &doc).unwrap();
321        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
322
323        let op = graph.operations().next().unwrap();
324        let codegen = CodegenQueryParameters::new(&op);
325
326        let actual: syn::File = parse_quote!(#codegen);
327        let expected: syn::File = parse_quote! {
328            #[derive(Debug, Clone, PartialEq, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
329            #[serde(crate = "::ploidy_util::serde")]
330            pub struct GetItemsQuery {
331                #[serde(default, skip_serializing_if = "Option::is_none")]
332                pub threshold: ::std::option::Option<f64>,
333            }
334
335            impl GetItemsQuery {
336                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
337            }
338        };
339        assert_eq!(actual, expected);
340    }
341
342    #[test]
343    fn test_excludes_default_for_non_defaultable_required_param() {
344        let doc = Document::from_yaml(indoc::indoc! {"
345            openapi: 3.0.0
346            info:
347              title: Test API
348              version: 1.0.0
349            paths:
350              /items:
351                get:
352                  operationId: getItems
353                  parameters:
354                    - name: callback
355                      in: query
356                      required: true
357                      schema:
358                        type: string
359                        format: uri
360                  responses:
361                    '200':
362                      description: OK
363        "})
364        .unwrap();
365
366        let arena = Arena::new();
367        let spec = Spec::from_doc(&arena, &doc).unwrap();
368        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
369
370        let op = graph.operations().next().unwrap();
371        let codegen = CodegenQueryParameters::new(&op);
372
373        let actual: syn::File = parse_quote!(#codegen);
374        let expected: syn::File = parse_quote! {
375            #[derive(Debug, Clone, PartialEq, Eq, Hash, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
376            #[serde(crate = "::ploidy_util::serde")]
377            pub struct GetItemsQuery {
378                pub callback: ::ploidy_util::url::Url,
379            }
380
381            impl GetItemsQuery {
382                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
383            }
384        };
385        assert_eq!(actual, expected);
386    }
387
388    #[test]
389    fn test_query_parameter_styles() {
390        let doc = Document::from_yaml(indoc::indoc! {"
391            openapi: 3.0.0
392            info:
393              title: Test API
394              version: 1.0.0
395            paths:
396              /items:
397                get:
398                  operationId: listItems
399                  parameters:
400                    - name: filter
401                      in: query
402                      style: deepObject
403                      schema:
404                        type: object
405                        properties:
406                          status:
407                            type: string
408                    - name: tags
409                      in: query
410                      style: pipeDelimited
411                      schema:
412                        type: array
413                        items:
414                          type: string
415                    - name: ids
416                      in: query
417                      style: spaceDelimited
418                      schema:
419                        type: array
420                        items:
421                          type: integer
422                          format: int32
423                    - name: colors
424                      in: query
425                      style: form
426                      explode: false
427                      schema:
428                        type: array
429                        items:
430                          type: string
431                  responses:
432                    '200':
433                      description: OK
434        "})
435        .unwrap();
436
437        let arena = Arena::new();
438        let spec = Spec::from_doc(&arena, &doc).unwrap();
439        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
440
441        let op = graph.operations().next().unwrap();
442        let codegen = CodegenQueryParameters::new(&op);
443
444        let actual: syn::File = parse_quote!(#codegen);
445        let expected: syn::File = parse_quote! {
446            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
447            #[serde(crate = "::ploidy_util::serde")]
448            pub struct ListItemsQuery {
449                #[serde(default, skip_serializing_if = "Option::is_none")]
450                pub filter: ::std::option::Option<crate::client::default::types::ListItemsFilter>,
451                #[serde(default, skip_serializing_if = "Option::is_none")]
452                pub tags: ::std::option::Option<::std::vec::Vec<::std::string::String>>,
453                #[serde(default, skip_serializing_if = "Option::is_none")]
454                pub ids: ::std::option::Option<::std::vec::Vec<i32>>,
455                #[serde(default, skip_serializing_if = "Option::is_none")]
456                pub colors: ::std::option::Option<::std::vec::Vec<::std::string::String>>,
457            }
458
459            impl ListItemsQuery {
460                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[
461                    ("filter", ::ploidy_util::QueryStyle::DeepObject),
462                    ("tags", ::ploidy_util::QueryStyle::PipeDelimited),
463                    ("ids", ::ploidy_util::QueryStyle::SpaceDelimited),
464                    ("colors", ::ploidy_util::QueryStyle::Form { exploded: false }),
465                ];
466            }
467        };
468        assert_eq!(actual, expected);
469    }
470
471    #[test]
472    fn test_ref_query_parameter() {
473        let doc = Document::from_yaml(indoc::indoc! {"
474            openapi: 3.0.0
475            info:
476              title: Test API
477              version: 1.0.0
478            paths:
479              /items:
480                get:
481                  operationId: listItems
482                  parameters:
483                    - name: sort
484                      in: query
485                      schema:
486                        $ref: '#/components/schemas/SortOrder'
487                  responses:
488                    '200':
489                      description: OK
490            components:
491              schemas:
492                SortOrder:
493                  type: string
494                  enum:
495                    - asc
496                    - desc
497        "})
498        .unwrap();
499
500        let arena = Arena::new();
501        let spec = Spec::from_doc(&arena, &doc).unwrap();
502        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
503
504        let op = graph.operations().next().unwrap();
505        let codegen = CodegenQueryParameters::new(&op);
506
507        let actual: syn::File = parse_quote!(#codegen);
508        let expected: syn::File = parse_quote! {
509            #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, ::ploidy_util::serde::Serialize, ::ploidy_util::serde::Deserialize)]
510            #[serde(crate = "::ploidy_util::serde")]
511            pub struct ListItemsQuery {
512                #[serde(default, skip_serializing_if = "Option::is_none")]
513                pub sort: ::std::option::Option<crate::types::SortOrder>,
514            }
515
516            impl ListItemsQuery {
517                pub const STYLES: &[(&str, ::ploidy_util::QueryStyle)] = &[];
518            }
519        };
520        assert_eq!(actual, expected);
521    }
522}