Skip to main content

openapi_to_rust/
client_generator.rs

1//! HTTP client generation for OpenAPI specifications.
2//!
3//! This module is part of the code generator that creates production-ready HTTP clients
4//! from OpenAPI specifications. It generates clients with middleware support including
5//! retry logic and request tracing.
6//!
7//! # Overview
8//!
9//! The client generator creates:
10//! - `HttpClient` struct with middleware stack (reqwest-middleware)
11//! - Retry logic with exponential backoff (reqwest-retry)
12//! - Request/response tracing (reqwest-tracing)
13//! - Direct methods for all API operations (GET, POST, PUT, DELETE, PATCH)
14//! - Comprehensive error handling with [`HttpError`](crate::http_error::HttpError)
15//! - Builder pattern for configuration
16//!
17//! # Generated Code Structure
18//!
19//! For each OpenAPI specification, the generator creates:
20//!
21//! ```rust,ignore
22//! // Generated client.rs file
23//!
24//! use crate::types::*;
25//! use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
26//! use std::collections::BTreeMap;
27//!
28//! pub struct HttpClient {
29//!     base_url: String,
30//!     api_key: Option<String>,
31//!     http_client: ClientWithMiddleware,
32//!     custom_headers: BTreeMap<String, String>,
33//! }
34//!
35//! impl HttpClient {
36//!     pub fn new() -> Self { /* ... */ }
37//!     pub fn with_config(retry_config: Option<RetryConfig>, enable_tracing: bool) -> Self { /* ... */ }
38//!     pub fn with_base_url(self, base_url: String) -> Self { /* ... */ }
39//!     pub fn with_api_key(self, api_key: String) -> Self { /* ... */ }
40//!     pub fn with_header(self, key: String, value: String) -> Self { /* ... */ }
41//!
42//!     // Generated operation methods
43//!     pub async fn list_items(&self) -> Result<ItemList, HttpError> { /* ... */ }
44//!     pub async fn create_item(&self, request: CreateItemRequest) -> Result<Item, HttpError> { /* ... */ }
45//!     pub async fn get_item(&self, id: impl AsRef<str>) -> Result<Item, HttpError> { /* ... */ }
46//! }
47//! ```
48//!
49//! # Middleware Stack
50//!
51//! The generated client uses `reqwest-middleware` to build a composable middleware stack:
52//!
53//! 1. **Tracing Middleware** (optional, enabled by default)
54//!    - Logs HTTP requests/responses
55//!    - Creates spans for distributed tracing
56//!    - Integrates with `tracing` ecosystem
57//!
58//! 2. **Retry Middleware** (optional, configured via TOML)
59//!    - Exponential backoff retry policy
60//!    - Automatically retries transient errors (429, 500, 502, 503, 504)
61//!    - Configurable max retries and delay bounds
62//!
63//! # Configuration
64//!
65//! ## Via TOML
66//!
67//! ```toml
68//! [http_client]
69//! base_url = "https://api.example.com"
70//! timeout_seconds = 30
71//!
72//! [http_client.retry]
73//! max_retries = 3
74//! initial_delay_ms = 500
75//! max_delay_ms = 16000
76//!
77//! [http_client.tracing]
78//! enabled = true
79//! ```
80//!
81//! ## Via Rust API
82//!
83//! ```no_run
84//! use openapi_to_rust::{GeneratorConfig, http_config::*};
85//! use std::path::PathBuf;
86//!
87//! let config = GeneratorConfig {
88//!     spec_path: PathBuf::from("openapi.json"),
89//!     enable_async_client: true,
90//!     retry_config: Some(RetryConfig {
91//!         max_retries: 3,
92//!         initial_delay_ms: 500,
93//!         max_delay_ms: 16000,
94//!     }),
95//!     tracing_enabled: true,
96//!     // ... other fields
97//!     ..Default::default()
98//! };
99//! ```
100//!
101//! # Generated Client Usage
102//!
103//! ```rust,ignore
104//! use crate::generated::client::HttpClient;
105//!
106//! #[tokio::main]
107//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//!     // Create client with retry and tracing
109//!     let client = HttpClient::new()
110//!         .with_base_url("https://api.example.com".to_string())
111//!         .with_api_key("your-api-key".to_string())
112//!         .with_header("X-Custom-Header".to_string(), "value".to_string());
113//!
114//!     // Make API calls - retries happen automatically
115//!     let items = client.list_items().await?;
116//!     println!("Found {} items", items.items.len());
117//!
118//!     Ok(())
119//! }
120//! ```
121//!
122//! # HTTP Method Support
123//!
124//! The generator supports all standard HTTP methods:
125//! - `GET` - List and retrieve operations
126//! - `POST` - Create operations
127//! - `PUT` - Full update operations
128//! - `PATCH` - Partial update operations
129//! - `DELETE` - Delete operations
130//!
131//! # Error Handling
132//!
133//! All generated methods return `Result<T, HttpError>` where `HttpError` provides:
134//! - Detailed error information
135//! - Retry detection via `is_retryable()`
136//! - Error categorization (client errors, server errors)
137//!
138//! See [`http_error`](crate::http_error) module for details.
139//!
140//! # Implementation Details
141//!
142//! The generator uses the following approach:
143//! 1. Analyzes OpenAPI operations to extract HTTP methods, paths, parameters
144//! 2. Generates typed request/response handling
145//! 3. Creates method signatures with proper parameter types
146//! 4. Generates path parameter substitution
147//! 5. Handles query parameters and request bodies
148//! 6. Configures middleware stack based on generator config
149
150use crate::analysis::{OperationInfo, SchemaAnalysis};
151use crate::generator::CodeGenerator;
152use heck::ToSnakeCase;
153use proc_macro2::TokenStream;
154use quote::quote;
155
156impl CodeGenerator {
157    /// Generate the HTTP client struct with middleware support
158    pub fn generate_http_client_struct(&self) -> TokenStream {
159        let has_retry = self.config().retry_config.is_some();
160        let has_tracing = self.config().tracing_enabled;
161
162        // Generate RetryConfig struct if needed
163        let retry_config_struct = if has_retry {
164            quote! {
165                /// Retry configuration for HTTP requests
166                #[derive(Debug, Clone)]
167                pub struct RetryConfig {
168                    pub max_retries: u32,
169                    pub initial_delay_ms: u64,
170                    pub max_delay_ms: u64,
171                }
172
173                impl Default for RetryConfig {
174                    fn default() -> Self {
175                        Self {
176                            max_retries: 3,
177                            initial_delay_ms: 500,
178                            max_delay_ms: 16000,
179                        }
180                    }
181                }
182            }
183        } else {
184            quote! {}
185        };
186
187        // Generate the main HttpClient struct
188        let client_struct = quote! {
189            use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
190            use std::collections::BTreeMap;
191
192            /// HTTP client for making API requests
193            #[derive(Clone)]
194            pub struct HttpClient {
195                base_url: String,
196                api_key: Option<String>,
197                http_client: ClientWithMiddleware,
198                custom_headers: BTreeMap<String, String>,
199            }
200        };
201
202        // Generate constructor
203        let constructor = self.generate_constructor(has_retry, has_tracing);
204
205        // Generate builder methods
206        let builder_methods = self.generate_builder_methods();
207
208        // Generate Default implementation
209        let default_impl = quote! {
210            impl Default for HttpClient {
211                fn default() -> Self {
212                    Self::new()
213                }
214            }
215        };
216
217        // Combine all parts
218        quote! {
219            #retry_config_struct
220            #client_struct
221
222            impl HttpClient {
223                #constructor
224                #builder_methods
225            }
226
227            #default_impl
228        }
229    }
230
231    /// Generate the constructor method
232    fn generate_constructor(&self, has_retry: bool, has_tracing: bool) -> TokenStream {
233        let retry_param = if has_retry {
234            quote! { retry_config: Option<RetryConfig>, }
235        } else {
236            quote! {}
237        };
238
239        let tracing_param = if has_tracing {
240            quote! { enable_tracing: bool, }
241        } else {
242            quote! {}
243        };
244
245        let retry_middleware = if has_retry {
246            quote! {
247                if let Some(config) = retry_config {
248                    use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
249
250                    let retry_policy = ExponentialBackoff::builder()
251                        .retry_bounds(
252                            std::time::Duration::from_millis(config.initial_delay_ms),
253                            std::time::Duration::from_millis(config.max_delay_ms),
254                        )
255                        .build_with_max_retries(config.max_retries);
256
257                    let retry_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
258                    client_builder = client_builder.with(retry_middleware);
259                }
260            }
261        } else {
262            quote! {}
263        };
264
265        let tracing_middleware = if has_tracing {
266            quote! {
267                if enable_tracing {
268                    use reqwest_tracing::TracingMiddleware;
269                    client_builder = client_builder.with(TracingMiddleware::default());
270                }
271            }
272        } else {
273            quote! {}
274        };
275
276        let default_constructor = if has_retry && has_tracing {
277            quote! {
278                /// Create a new HTTP client with default configuration
279                pub fn new() -> Self {
280                    Self::with_config(None, true)
281                }
282            }
283        } else if has_retry {
284            quote! {
285                /// Create a new HTTP client with default configuration
286                pub fn new() -> Self {
287                    Self::with_config(None)
288                }
289            }
290        } else if has_tracing {
291            quote! {
292                /// Create a new HTTP client with default configuration
293                pub fn new() -> Self {
294                    Self::with_config(true)
295                }
296            }
297        } else {
298            quote! {
299                /// Create a new HTTP client with default configuration
300                pub fn new() -> Self {
301                    let reqwest_client = reqwest::Client::new();
302                    let client_builder = ClientBuilder::new(reqwest_client);
303                    let http_client = client_builder.build();
304
305                    Self {
306                        base_url: String::new(),
307                        api_key: None,
308                        http_client,
309                        custom_headers: BTreeMap::new(),
310                    }
311                }
312            }
313        };
314
315        if has_retry || has_tracing {
316            quote! {
317                #default_constructor
318
319                /// Create a new HTTP client with custom configuration
320                pub fn with_config(#retry_param #tracing_param) -> Self {
321                    let reqwest_client = reqwest::Client::new();
322                    let mut client_builder = ClientBuilder::new(reqwest_client);
323
324                    #tracing_middleware
325                    #retry_middleware
326
327                    let http_client = client_builder.build();
328
329                    Self {
330                        base_url: String::new(),
331                        api_key: None,
332                        http_client,
333                        custom_headers: BTreeMap::new(),
334                    }
335                }
336            }
337        } else {
338            default_constructor
339        }
340    }
341
342    /// Generate builder methods for configuration
343    fn generate_builder_methods(&self) -> TokenStream {
344        quote! {
345            /// Set the base URL for all requests
346            pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
347                self.base_url = base_url.into();
348                self
349            }
350
351            /// Set the API key for authentication
352            pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
353                self.api_key = Some(api_key.into());
354                self
355            }
356
357            /// Add a custom header to all requests
358            pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
359                self.custom_headers.insert(name.into(), value.into());
360                self
361            }
362
363            /// Add multiple custom headers
364            pub fn with_headers(mut self, headers: BTreeMap<String, String>) -> Self {
365                self.custom_headers.extend(headers);
366                self
367            }
368        }
369    }
370
371    /// Generate HTTP operation methods for the client
372    pub fn generate_operation_methods(&self, analysis: &SchemaAnalysis) -> TokenStream {
373        let methods: Vec<TokenStream> = analysis
374            .operations
375            .values()
376            .map(|op| self.generate_single_operation_method(op))
377            .collect();
378
379        quote! {
380            impl HttpClient {
381                #(#methods)*
382            }
383        }
384    }
385
386    /// Generate a single operation method
387    fn generate_single_operation_method(&self, op: &OperationInfo) -> TokenStream {
388        let method_name = self.get_method_name(op);
389        let http_method = self.get_http_method(op);
390        let path = &op.path;
391        let request_param = self.generate_request_param(op);
392        let request_body = self.generate_request_body(op);
393        let query_params = self.generate_query_params(op);
394        let response_type = self.get_response_type(op);
395        let has_response_body = self.get_success_response_schema(op).is_some();
396        let error_handling = self.generate_error_handling(has_response_body);
397        let url_construction = self.generate_url_construction(path, op);
398        let doc_comment = self.generate_operation_doc_comment(op);
399
400        quote! {
401            #doc_comment
402            pub async fn #method_name(
403                &self,
404                #request_param
405            ) -> HttpResult<#response_type> {
406                #url_construction
407
408                let mut req = self.http_client
409                    .#http_method(request_url)
410                    #request_body;
411
412                #query_params
413
414                // Add API key if configured
415                if let Some(api_key) = &self.api_key {
416                    req = req.bearer_auth(api_key);
417                }
418
419                // Add custom headers
420                for (name, value) in &self.custom_headers {
421                    req = req.header(name, value);
422                }
423
424                let response = req.send().await?;
425                #error_handling
426            }
427        }
428    }
429
430    /// Generate query parameter handling
431    fn generate_query_params(&self, op: &OperationInfo) -> TokenStream {
432        let query_params: Vec<_> = op
433            .parameters
434            .iter()
435            .filter(|p| p.location == "query")
436            .collect();
437
438        if query_params.is_empty() {
439            return quote! {};
440        }
441
442        let mut param_building = Vec::new();
443
444        for param in query_params {
445            // Use snake_case for Rust variable name with keyword escaping
446            let param_name_snake = self.sanitize_param_name(&param.name);
447            let param_name = Self::to_field_ident(&param_name_snake);
448
449            // Use the original parameter name from OpenAPI spec as the query string key
450            let param_key = &param.name;
451
452            if param.required {
453                // Required parameters: always add
454                if param.rust_type == "String" {
455                    param_building.push(quote! {
456                        query_params.push((#param_key, #param_name.as_ref().to_string()));
457                    });
458                } else {
459                    param_building.push(quote! {
460                        query_params.push((#param_key, #param_name.to_string()));
461                    });
462                }
463            } else {
464                // Optional parameters: add only if Some
465                if param.rust_type == "String" {
466                    param_building.push(quote! {
467                        if let Some(v) = #param_name {
468                            query_params.push((#param_key, v.as_ref().to_string()));
469                        }
470                    });
471                } else {
472                    param_building.push(quote! {
473                        if let Some(v) = #param_name {
474                            query_params.push((#param_key, v.to_string()));
475                        }
476                    });
477                }
478            }
479        }
480
481        quote! {
482            // Add query parameters
483            {
484                let mut query_params: Vec<(&str, String)> = Vec::new();
485                #(#param_building)*
486                if !query_params.is_empty() {
487                    req = req.query(&query_params);
488                }
489            }
490        }
491    }
492
493    /// Generate documentation comment for the operation
494    fn generate_operation_doc_comment(&self, op: &OperationInfo) -> TokenStream {
495        let method = op.method.to_uppercase();
496        let path = &op.path;
497        let doc = format!("{} {}", method, path);
498
499        quote! {
500            #[doc = #doc]
501        }
502    }
503
504    /// Get the method name from the operation
505    fn get_method_name(&self, op: &OperationInfo) -> syn::Ident {
506        let name = if !op.operation_id.is_empty() {
507            op.operation_id.to_snake_case()
508        } else {
509            // Fallback: generate from HTTP method and path
510            format!(
511                "{}_{}",
512                op.method,
513                op.path.replace('/', "_").replace(['{', '}'], "")
514            )
515            .to_snake_case()
516        };
517
518        syn::Ident::new(&name, proc_macro2::Span::call_site())
519    }
520
521    /// Get the HTTP method
522    fn get_http_method(&self, op: &OperationInfo) -> syn::Ident {
523        let method = match op.method.to_uppercase().as_str() {
524            "GET" => "get",
525            "POST" => "post",
526            "PUT" => "put",
527            "DELETE" => "delete",
528            "PATCH" => "patch",
529            _ => "get", // Default fallback
530        };
531
532        syn::Ident::new(method, proc_macro2::Span::call_site())
533    }
534
535    /// Generate request parameters including path parameters, query parameters, and request body
536    fn generate_request_param(&self, op: &OperationInfo) -> TokenStream {
537        let mut params = Vec::new();
538
539        // Add path parameters
540        for param in &op.parameters {
541            if param.location == "path" {
542                let param_name_snake = self.sanitize_param_name(&param.name);
543                let param_name = Self::to_field_ident(&param_name_snake);
544                let param_type = self.get_param_rust_type(param);
545                params.push(quote! { #param_name: #param_type });
546            }
547        }
548
549        // Add query parameters (all as Option<T>)
550        for param in &op.parameters {
551            if param.location == "query" {
552                let param_name_snake = self.sanitize_param_name(&param.name);
553                let param_name = Self::to_field_ident(&param_name_snake);
554                let param_type = self.get_param_rust_type(param);
555
556                // Query parameters should be Option unless explicitly required
557                if param.required {
558                    params.push(quote! { #param_name: #param_type });
559                } else {
560                    params.push(quote! { #param_name: Option<#param_type> });
561                }
562            }
563        }
564
565        // Add request body parameter based on content type
566        if let Some(ref rb) = op.request_body {
567            use crate::analysis::RequestBodyContent;
568            match rb {
569                RequestBodyContent::Json { schema_name }
570                | RequestBodyContent::FormUrlEncoded { schema_name } => {
571                    let rust_type_name = self.to_rust_type_name(schema_name);
572                    let request_ident =
573                        syn::Ident::new(&rust_type_name, proc_macro2::Span::call_site());
574                    params.push(quote! { request: #request_ident });
575                }
576                RequestBodyContent::Multipart => {
577                    params.push(quote! { form: reqwest::multipart::Form });
578                }
579                RequestBodyContent::OctetStream => {
580                    params.push(quote! { body: Vec<u8> });
581                }
582                RequestBodyContent::TextPlain => {
583                    params.push(quote! { body: String });
584                }
585            }
586        }
587
588        if params.is_empty() {
589            quote! {}
590        } else {
591            quote! { #(#params),* }
592        }
593    }
594
595    /// Get the Rust type for a parameter
596    fn get_param_rust_type(&self, param: &crate::analysis::ParameterInfo) -> TokenStream {
597        let type_str = &param.rust_type;
598        match type_str.as_str() {
599            "String" => quote! { impl AsRef<str> },
600            "i64" => quote! { i64 },
601            "i32" => quote! { i32 },
602            "f64" => quote! { f64 },
603            "bool" => quote! { bool },
604            _ => {
605                let type_ident = syn::Ident::new(type_str, proc_macro2::Span::call_site());
606                quote! { #type_ident }
607            }
608        }
609    }
610
611    /// Generate request body serialization based on content type
612    fn generate_request_body(&self, op: &OperationInfo) -> TokenStream {
613        if let Some(ref rb) = op.request_body {
614            use crate::analysis::RequestBodyContent;
615            match rb {
616                RequestBodyContent::Json { .. } => {
617                    quote! {
618                        .body(serde_json::to_vec(&request).map_err(HttpError::serialization_error)?)
619                        .header("content-type", "application/json")
620                    }
621                }
622                RequestBodyContent::FormUrlEncoded { .. } => {
623                    quote! {
624                        .body(serde_urlencoded::to_string(&request).map_err(HttpError::serialization_error)?)
625                        .header("content-type", "application/x-www-form-urlencoded")
626                    }
627                }
628                RequestBodyContent::Multipart => {
629                    quote! {
630                        .multipart(form)
631                    }
632                }
633                RequestBodyContent::OctetStream => {
634                    quote! {
635                        .body(body)
636                        .header("content-type", "application/octet-stream")
637                    }
638                }
639                RequestBodyContent::TextPlain => {
640                    quote! {
641                        .body(body)
642                        .header("content-type", "text/plain")
643                    }
644                }
645            }
646        } else {
647            quote! {}
648        }
649    }
650
651    /// Find the success (2xx) response schema name, if any.
652    ///
653    /// Only considers 2xx status codes. Error schemas (4xx, 5xx) are ignored
654    /// so that endpoints like 204 No Content correctly return `()` instead of
655    /// accidentally picking up the error schema (e.g. `BadRequestError`).
656    fn get_success_response_schema<'a>(&self, op: &'a OperationInfo) -> Option<&'a String> {
657        op.response_schemas
658            .get("200")
659            .or_else(|| op.response_schemas.get("201"))
660            .or_else(|| {
661                op.response_schemas
662                    .iter()
663                    .find(|(code, _)| code.starts_with('2'))
664                    .map(|(_, v)| v)
665            })
666    }
667
668    /// Get response type
669    fn get_response_type(&self, op: &OperationInfo) -> TokenStream {
670        if let Some(response_type) = self.get_success_response_schema(op) {
671            // Convert schema name to Rust type name (handles underscores, etc.)
672            let rust_type_name = self.to_rust_type_name(response_type);
673            let response_ident = syn::Ident::new(&rust_type_name, proc_macro2::Span::call_site());
674            quote! { #response_ident }
675        } else {
676            quote! { () }
677        }
678    }
679
680    /// Generate error handling.
681    ///
682    /// When `has_response_body` is false the endpoint returns no JSON body
683    /// (e.g. 204 No Content) and we skip deserialization entirely.
684    fn generate_error_handling(&self, has_response_body: bool) -> TokenStream {
685        let success_branch = if has_response_body {
686            quote! {
687                let body = response.json().await
688                    .map_err(HttpError::deserialization_error)?;
689                Ok(body)
690            }
691        } else {
692            quote! {
693                Ok(())
694            }
695        };
696
697        quote! {
698            let status = response.status();
699
700            if status.is_success() {
701                #success_branch
702            } else {
703                let status_code = status.as_u16();
704                let message = status.canonical_reason().unwrap_or("Unknown error");
705                let body = response.text().await.ok();
706                Err(HttpError::from_status(status_code, message, body))
707            }
708        }
709    }
710
711    /// Generate URL construction with path parameter substitution
712    fn generate_url_construction(&self, path: &str, op: &OperationInfo) -> TokenStream {
713        // Check if path has parameters (contains {...})
714        if path.contains('{') {
715            self.generate_url_with_params(path, op)
716        } else {
717            quote! {
718                let request_url = format!("{}{}", self.base_url, #path);
719            }
720        }
721    }
722
723    /// Generate URL with path parameters
724    fn generate_url_with_params(&self, path: &str, op: &OperationInfo) -> TokenStream {
725        // Parse path to find all parameter placeholders
726        let mut format_string = path.to_string();
727        let mut format_args = Vec::new();
728
729        // Find all path parameters in the operation
730        let path_params: Vec<_> = op
731            .parameters
732            .iter()
733            .filter(|p| p.location == "path")
734            .collect();
735
736        // Replace {paramName} with {} and collect parameter names for format args
737        for param in &path_params {
738            let placeholder = format!("{{{}}}", param.name);
739            if format_string.contains(&placeholder) {
740                format_string = format_string.replace(&placeholder, "{}");
741
742                // Use snake_case for the Rust variable name with keyword escaping
743                let param_name_snake = self.sanitize_param_name(&param.name);
744                let param_ident = Self::to_field_ident(&param_name_snake);
745
746                // Use .as_ref() for string types to handle impl AsRef<str>
747                if param.rust_type == "String" {
748                    format_args.push(quote! { #param_ident.as_ref() });
749                } else {
750                    format_args.push(quote! { #param_ident });
751                }
752            }
753        }
754
755        if format_args.is_empty() {
756            // No path parameters found, use simple format
757            quote! {
758                let request_url = format!("{}{}", self.base_url, #path);
759            }
760        } else {
761            // Build format call with path parameters
762            quote! {
763                let request_url = format!("{}{}", self.base_url, format!(#format_string, #(#format_args),*));
764            }
765        }
766    }
767
768    /// Sanitize a parameter name by escaping Rust reserved keywords with raw identifiers
769    fn sanitize_param_name(&self, name: &str) -> String {
770        let snake_case = name.to_snake_case();
771        if Self::is_rust_keyword(&snake_case) {
772            format!("r#{snake_case}")
773        } else {
774            snake_case
775        }
776    }
777}