Skip to main content

ploidy_codegen_rust/
operation.rs

1use itertools::Itertools;
2use ploidy_core::{
3    codegen::UniqueNames,
4    ir::{OperationView, ParameterView, PathParameter, RequestView, ResponseView},
5    parse::{Method, path::PathFragment},
6};
7use proc_macro2::{Span, TokenStream};
8use quote::{ToTokens, TokenStreamExt, format_ident, quote};
9use syn::Ident;
10
11use super::{
12    doc_attrs,
13    naming::{CodegenIdent, CodegenIdentScope, CodegenIdentUsage},
14    ref_::CodegenRef,
15};
16
17/// Generates a single client method for an API operation.
18pub struct CodegenOperation<'a> {
19    op: &'a OperationView<'a>,
20}
21
22impl<'a> CodegenOperation<'a> {
23    pub fn new(op: &'a OperationView<'a>) -> Self {
24        Self { op }
25    }
26
27    /// Generates code to build and interpolate path parameters into
28    /// the request URL.
29    fn url(&self, params: &[(CodegenIdent, ParameterView<'_, '_, PathParameter>)]) -> TokenStream {
30        let segments = self
31            .op
32            .path()
33            .segments()
34            .map(|segment| match segment.fragments() {
35                [] => quote! { "" },
36                [PathFragment::Literal(text)] => quote! { #text },
37                [PathFragment::Param(name)] => {
38                    let (ident, _) = params
39                        .iter()
40                        .find(|(_, param)| param.name() == *name)
41                        .unwrap();
42                    let usage = CodegenIdentUsage::Param(ident);
43                    quote!(#usage)
44                }
45                fragments => {
46                    // Build a format string, with placeholders for parameter fragments.
47                    let format = fragments.iter().fold(String::new(), |mut f, fragment| {
48                        match fragment {
49                            PathFragment::Literal(text) => {
50                                f.push_str(&text.replace('{', "{{").replace('}', "}}"))
51                            }
52                            PathFragment::Param(_) => f.push_str("{}"),
53                        }
54                        f
55                    });
56                    let args = fragments
57                        .iter()
58                        .filter_map(|fragment| match fragment {
59                            PathFragment::Param(name) => Some(name),
60                            PathFragment::Literal(_) => None,
61                        })
62                        .map(|name| {
63                            // `url::PathSegmentsMut::push` percent-encodes the
64                            // full segment, so we can interpolate fragments
65                            // directly.
66                            let (ident, _) = params
67                                .iter()
68                                .find(|(_, param)| param.name() == *name)
69                                .unwrap();
70                            CodegenIdentUsage::Param(ident)
71                        });
72                    quote! { &format!(#format, #(#args),*) }
73                }
74            });
75        quote! {
76            let url = {
77                let mut url = self.base_url.clone();
78                let _ = url
79                    .path_segments_mut()
80                    .map(|mut segments| {
81                        segments.pop_if_empty()
82                            #(.push(#segments))*;
83                    });
84                url
85            };
86        }
87    }
88
89    /// Generates code to serialize query parameters into the URL.
90    fn query(&self) -> Option<TokenStream> {
91        self.op.query().next().is_some().then(|| {
92            let op_ident = CodegenIdent::new(self.op.id());
93            let query_name = format_ident!("{}Query", CodegenIdentUsage::Type(&op_ident));
94            quote! {
95                let url = ::ploidy_util::serde::Serialize::serialize(
96                    query,
97                    ::ploidy_util::QuerySerializer::new(
98                        url,
99                        parameters::#query_name::STYLES,
100                    ),
101                )?;
102            }
103        })
104    }
105}
106
107impl ToTokens for CodegenOperation<'_> {
108    fn to_tokens(&self, tokens: &mut TokenStream) {
109        let operation_id = CodegenIdent::new(self.op.id());
110        let method_name = CodegenIdentUsage::Method(&operation_id);
111
112        let unique = UniqueNames::new();
113        let mut scope = CodegenIdentScope::with_reserved(
114            &unique,
115            // `query`, `request`, and `form` are argument names;
116            // `url` and `response` are local variables.
117            &["query", "request", "form", "url", "response"],
118        );
119        let mut params = vec![];
120
121        let paths = self
122            .op
123            .path()
124            .params()
125            .map(|param| (scope.uniquify(param.name()), param))
126            .collect_vec();
127        for (ident, _) in &paths {
128            let usage = CodegenIdentUsage::Param(ident);
129            params.push(quote! { #usage: &str });
130        }
131
132        if self.op.query().next().is_some() {
133            // Include the `query` argument if we have
134            // at least one query parameter.
135            let op_ident = CodegenIdent::new(self.op.id());
136            let query_type_name = format_ident!("{}Query", CodegenIdentUsage::Type(&op_ident));
137            params.push(quote! { query: &parameters::#query_type_name });
138        }
139
140        if let Some(request) = self.op.request() {
141            match request {
142                RequestView::Json(view) => {
143                    let param_type = CodegenRef::new(&view);
144                    params.push(quote! { request: impl Into<#param_type> });
145                }
146                RequestView::Multipart => {
147                    params.push(quote! { form: crate::util::reqwest::multipart::Form });
148                }
149            }
150        }
151
152        let return_type = match self.op.response() {
153            Some(response) => match response {
154                ResponseView::Json(view) => CodegenRef::new(&view).into_token_stream(),
155            },
156            None => quote! { () },
157        };
158
159        let build_url = self.url(&paths);
160
161        let build_query = self.query();
162
163        let http_method = CodegenMethod(self.op.method());
164
165        let build_request = match self.op.request() {
166            Some(RequestView::Json(_)) => quote! {
167                let response = self.client
168                    .#http_method(url)
169                    .headers(self.headers.clone())
170                    .json(&request.into())
171                    .send()
172                    .await?
173                    .error_for_status()?;
174            },
175            Some(RequestView::Multipart) => quote! {
176                let response = self.client
177                    .#http_method(url)
178                    .headers(self.headers.clone())
179                    .multipart(form)
180                    .send()
181                    .await?
182                    .error_for_status()?;
183            },
184            None => quote! {
185                let response = self.client
186                    .#http_method(url)
187                    .headers(self.headers.clone())
188                    .send()
189                    .await?
190                    .error_for_status()?;
191            },
192        };
193
194        let parse_response = if self.op.response().is_some() {
195            quote! {
196                let body = response.bytes().await?;
197                let deserializer = &mut ::ploidy_util::serde_json::Deserializer::from_slice(&body);
198                let result = ::ploidy_util::serde_path_to_error::deserialize(deserializer)
199                    .map_err(crate::error::JsonError::from)?;
200                Ok(result)
201            }
202        } else {
203            quote! {
204                let _ = response;
205                Ok(())
206            }
207        };
208
209        let doc = self.op.description().map(doc_attrs);
210
211        tokens.append_all(quote! {
212            #doc
213            pub async fn #method_name(
214                &self,
215                #(#params),*
216            ) -> Result<#return_type, crate::error::Error> {
217                #build_url
218                #build_query
219                #build_request
220                #parse_response
221            }
222        });
223    }
224}
225
226#[derive(Clone, Copy, Debug)]
227pub struct CodegenMethod(pub Method);
228
229impl ToTokens for CodegenMethod {
230    fn to_tokens(&self, tokens: &mut TokenStream) {
231        tokens.append(match self.0 {
232            Method::Get => Ident::new("get", Span::call_site()),
233            Method::Post => Ident::new("post", Span::call_site()),
234            Method::Put => Ident::new("put", Span::call_site()),
235            Method::Patch => Ident::new("patch", Span::call_site()),
236            Method::Delete => Ident::new("delete", Span::call_site()),
237        });
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    use ploidy_core::{
246        arena::Arena,
247        ir::{RawGraph, Spec},
248        parse::Document,
249    };
250    use pretty_assertions::assert_eq;
251    use syn::parse_quote;
252
253    use crate::CodegenGraph;
254
255    // MARK: With query params
256
257    #[test]
258    fn test_operation_with_path_and_query_params() {
259        let doc = Document::from_yaml(indoc::indoc! {"
260            openapi: 3.0.0
261            info:
262              title: Test API
263              version: 1.0.0
264            paths:
265              /items/{item_id}:
266                get:
267                  operationId: getItem
268                  parameters:
269                    - name: item_id
270                      in: path
271                      required: true
272                      schema:
273                        type: string
274                    - name: expand
275                      in: query
276                      schema:
277                        type: boolean
278                  responses:
279                    '200':
280                      description: OK
281        "})
282        .unwrap();
283
284        let arena = Arena::new();
285        let spec = Spec::from_doc(&arena, &doc).unwrap();
286        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
287
288        let op = graph.operations().next().unwrap();
289        let codegen = CodegenOperation::new(&op);
290
291        let actual: syn::ImplItemFn = parse_quote!(#codegen);
292        let expected: syn::ImplItemFn = parse_quote! {
293            pub async fn get_item(
294                &self,
295                item_id: &str,
296                query: &parameters::GetItemQuery
297            ) -> Result<(), crate::error::Error> {
298                let url = {
299                    let mut url = self.base_url.clone();
300                    let _ = url
301                        .path_segments_mut()
302                        .map(|mut segments| {
303                            segments.pop_if_empty()
304                                .push("items")
305                                .push(item_id);
306                        });
307                    url
308                };
309                let url = ::ploidy_util::serde::Serialize::serialize(
310                    query,
311                    ::ploidy_util::QuerySerializer::new(
312                        url,
313                        parameters::GetItemQuery::STYLES,
314                    ),
315                )?;
316                let response = self
317                    .client
318                    .get(url)
319                    .headers(self.headers.clone())
320                    .send()
321                    .await?
322                    .error_for_status()?;
323                let _ = response;
324                Ok(())
325            }
326        };
327        assert_eq!(actual, expected);
328    }
329
330    #[test]
331    fn test_operation_with_query_params_only() {
332        let doc = Document::from_yaml(indoc::indoc! {"
333            openapi: 3.0.0
334            info:
335              title: Test API
336              version: 1.0.0
337            paths:
338              /items:
339                get:
340                  operationId: getItems
341                  parameters:
342                    - name: limit
343                      in: query
344                      schema:
345                        type: integer
346                        format: int32
347                  responses:
348                    '200':
349                      description: OK
350        "})
351        .unwrap();
352
353        let arena = Arena::new();
354        let spec = Spec::from_doc(&arena, &doc).unwrap();
355        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
356
357        let op = graph.operations().next().unwrap();
358        let codegen = CodegenOperation::new(&op);
359
360        let actual: syn::ImplItemFn = parse_quote!(#codegen);
361        let expected: syn::ImplItemFn = parse_quote! {
362            pub async fn get_items(
363                &self,
364                query: &parameters::GetItemsQuery
365            ) -> Result<(), crate::error::Error> {
366                let url = {
367                    let mut url = self.base_url.clone();
368                    let _ = url
369                        .path_segments_mut()
370                        .map(|mut segments| {
371                            segments.pop_if_empty()
372                                .push("items");
373                        });
374                    url
375                };
376                let url = ::ploidy_util::serde::Serialize::serialize(
377                    query,
378                    ::ploidy_util::QuerySerializer::new(
379                        url,
380                        parameters::GetItemsQuery::STYLES,
381                    ),
382                )?;
383                let response = self
384                    .client
385                    .get(url)
386                    .headers(self.headers.clone())
387                    .send()
388                    .await?
389                    .error_for_status()?;
390                let _ = response;
391                Ok(())
392            }
393        };
394        assert_eq!(actual, expected);
395    }
396
397    #[test]
398    fn test_path_param_named_query_does_not_shadow() {
399        let doc = Document::from_yaml(indoc::indoc! {"
400            openapi: 3.0.0
401            info:
402              title: Test API
403              version: 1.0.0
404            paths:
405              /search/{query}:
406                get:
407                  operationId: search
408                  parameters:
409                    - name: query
410                      in: path
411                      required: true
412                      schema:
413                        type: string
414                    - name: limit
415                      in: query
416                      schema:
417                        type: integer
418                        format: int32
419                  responses:
420                    '200':
421                      description: OK
422        "})
423        .unwrap();
424
425        let arena = Arena::new();
426        let spec = Spec::from_doc(&arena, &doc).unwrap();
427        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
428
429        let op = graph.operations().next().unwrap();
430        let codegen = CodegenOperation::new(&op);
431
432        let actual: syn::ImplItemFn = parse_quote!(#codegen);
433        let expected: syn::ImplItemFn = parse_quote! {
434            pub async fn search(
435                &self,
436                query2: &str,
437                query: &parameters::SearchQuery
438            ) -> Result<(), crate::error::Error> {
439                let url = {
440                    let mut url = self.base_url.clone();
441                    let _ = url
442                        .path_segments_mut()
443                        .map(|mut segments| {
444                            segments.pop_if_empty()
445                                .push("search")
446                                .push(query2);
447                        });
448                    url
449                };
450                let url = ::ploidy_util::serde::Serialize::serialize(
451                    query,
452                    ::ploidy_util::QuerySerializer::new(
453                        url,
454                        parameters::SearchQuery::STYLES,
455                    ),
456                )?;
457                let response = self
458                    .client
459                    .get(url)
460                    .headers(self.headers.clone())
461                    .send()
462                    .await?
463                    .error_for_status()?;
464                let _ = response;
465                Ok(())
466            }
467        };
468        assert_eq!(actual, expected);
469    }
470
471    // MARK: With query params and request body
472
473    #[test]
474    fn test_operation_with_query_params_and_request_body() {
475        let doc = Document::from_yaml(indoc::indoc! {"
476            openapi: 3.0.0
477            info:
478              title: Test API
479              version: 1.0.0
480            paths:
481              /items/{item_id}:
482                put:
483                  operationId: updateItem
484                  parameters:
485                    - name: item_id
486                      in: path
487                      required: true
488                      schema:
489                        type: string
490                    - name: dry_run
491                      in: query
492                      schema:
493                        type: boolean
494                  requestBody:
495                    content:
496                      application/json:
497                        schema:
498                          $ref: '#/components/schemas/Item'
499                  responses:
500                    '200':
501                      description: OK
502                      content:
503                        application/json:
504                          schema:
505                            $ref: '#/components/schemas/Item'
506            components:
507              schemas:
508                Item:
509                  type: object
510                  properties:
511                    name:
512                      type: string
513        "})
514        .unwrap();
515
516        let arena = Arena::new();
517        let spec = Spec::from_doc(&arena, &doc).unwrap();
518        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
519
520        let op = graph.operations().next().unwrap();
521        let codegen = CodegenOperation::new(&op);
522
523        let actual: syn::ImplItemFn = parse_quote!(#codegen);
524        let expected: syn::ImplItemFn = parse_quote! {
525            pub async fn update_item(
526                &self,
527                item_id: &str,
528                query: &parameters::UpdateItemQuery,
529                request: impl Into<crate::types::Item>
530            ) -> Result<crate::types::Item, crate::error::Error> {
531                let url = {
532                    let mut url = self.base_url.clone();
533                    let _ = url
534                        .path_segments_mut()
535                        .map(|mut segments| {
536                            segments.pop_if_empty()
537                                .push("items")
538                                .push(item_id);
539                        });
540                    url
541                };
542                let url = ::ploidy_util::serde::Serialize::serialize(
543                    query,
544                    ::ploidy_util::QuerySerializer::new(
545                        url,
546                        parameters::UpdateItemQuery::STYLES,
547                    ),
548                )?;
549                let response = self
550                    .client
551                    .put(url)
552                    .headers(self.headers.clone())
553                    .json(&request.into())
554                    .send()
555                    .await?
556                    .error_for_status()?;
557                let body = response.bytes().await?;
558                let deserializer = &mut ::ploidy_util::serde_json::Deserializer::from_slice(&body);
559                let result = ::ploidy_util::serde_path_to_error::deserialize(deserializer)
560                    .map_err(crate::error::JsonError::from)?;
561                Ok(result)
562            }
563        };
564        assert_eq!(actual, expected);
565    }
566
567    // MARK: Without query params
568
569    #[test]
570    fn test_operation_without_query_params() {
571        let doc = Document::from_yaml(indoc::indoc! {"
572            openapi: 3.0.0
573            info:
574              title: Test API
575              version: 1.0.0
576            paths:
577              /items/{item_id}:
578                get:
579                  operationId: getItem
580                  parameters:
581                    - name: item_id
582                      in: path
583                      required: true
584                      schema:
585                        type: string
586                  responses:
587                    '200':
588                      description: OK
589        "})
590        .unwrap();
591
592        let arena = Arena::new();
593        let spec = Spec::from_doc(&arena, &doc).unwrap();
594        let graph = CodegenGraph::new(RawGraph::new(&arena, &spec).cook());
595
596        let op = graph.operations().next().unwrap();
597        let codegen = CodegenOperation::new(&op);
598
599        let actual: syn::ImplItemFn = parse_quote!(#codegen);
600        let expected: syn::ImplItemFn = parse_quote! {
601            pub async fn get_item(
602                &self,
603                item_id: &str
604            ) -> Result<(), crate::error::Error> {
605                let url = {
606                    let mut url = self.base_url.clone();
607                    let _ = url
608                        .path_segments_mut()
609                        .map(|mut segments| {
610                            segments.pop_if_empty()
611                                .push("items")
612                                .push(item_id);
613                        });
614                    url
615                };
616                let response = self
617                    .client
618                    .get(url)
619                    .headers(self.headers.clone())
620                    .send()
621                    .await?
622                    .error_for_status()?;
623                let _ = response;
624                Ok(())
625            }
626        };
627        assert_eq!(actual, expected);
628    }
629}