openapi_to_rust/client_generator.rs
1//! HTTP client generation for OpenAPI specifications.
2//!
3//! This module is part of the code generator that creates production-ready HTTP clients
4//! from OpenAPI specifications. It generates clients with middleware support including
5//! retry logic and request tracing.
6//!
7//! # Overview
8//!
9//! The client generator creates:
10//! - `HttpClient` struct with middleware stack (reqwest-middleware)
11//! - Retry logic with exponential backoff (reqwest-retry)
12//! - Request/response tracing (reqwest-tracing)
13//! - Direct methods for all API operations (GET, POST, PUT, DELETE, PATCH)
14//! - Comprehensive error handling with [`HttpError`](crate::http_error::HttpError)
15//! - Builder pattern for configuration
16//!
17//! # Generated Code Structure
18//!
19//! For each OpenAPI specification, the generator creates:
20//!
21//! ```rust,ignore
22//! // Generated client.rs file
23//!
24//! use crate::types::*;
25//! use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
26//! use std::collections::BTreeMap;
27//!
28//! pub struct HttpClient {
29//! base_url: String,
30//! api_key: Option<String>,
31//! http_client: ClientWithMiddleware,
32//! custom_headers: BTreeMap<String, String>,
33//! }
34//!
35//! impl HttpClient {
36//! pub fn new() -> Self { /* ... */ }
37//! pub fn with_config(retry_config: Option<RetryConfig>, enable_tracing: bool) -> Self { /* ... */ }
38//! pub fn with_base_url(self, base_url: String) -> Self { /* ... */ }
39//! pub fn with_api_key(self, api_key: String) -> Self { /* ... */ }
40//! pub fn with_header(self, key: String, value: String) -> Self { /* ... */ }
41//!
42//! // Generated operation methods
43//! pub async fn list_items(&self) -> Result<ItemList, HttpError> { /* ... */ }
44//! pub async fn create_item(&self, request: CreateItemRequest) -> Result<Item, HttpError> { /* ... */ }
45//! pub async fn get_item(&self, id: impl AsRef<str>) -> Result<Item, HttpError> { /* ... */ }
46//! }
47//! ```
48//!
49//! # Middleware Stack
50//!
51//! The generated client uses `reqwest-middleware` to build a composable middleware stack:
52//!
53//! 1. **Tracing Middleware** (optional, enabled by default)
54//! - Logs HTTP requests/responses
55//! - Creates spans for distributed tracing
56//! - Integrates with `tracing` ecosystem
57//!
58//! 2. **Retry Middleware** (optional, configured via TOML)
59//! - Exponential backoff retry policy
60//! - Automatically retries transient errors (429, 500, 502, 503, 504)
61//! - Configurable max retries and delay bounds
62//!
63//! # Configuration
64//!
65//! ## Via TOML
66//!
67//! ```toml
68//! [http_client]
69//! base_url = "https://api.example.com"
70//! timeout_seconds = 30
71//!
72//! [http_client.retry]
73//! max_retries = 3
74//! initial_delay_ms = 500
75//! max_delay_ms = 16000
76//!
77//! [http_client.tracing]
78//! enabled = true
79//! ```
80//!
81//! ## Via Rust API
82//!
83//! ```no_run
84//! use openapi_to_rust::{GeneratorConfig, http_config::*};
85//! use std::path::PathBuf;
86//!
87//! let config = GeneratorConfig {
88//! spec_path: PathBuf::from("openapi.json"),
89//! enable_async_client: true,
90//! retry_config: Some(RetryConfig {
91//! max_retries: 3,
92//! initial_delay_ms: 500,
93//! max_delay_ms: 16000,
94//! }),
95//! tracing_enabled: true,
96//! // ... other fields
97//! ..Default::default()
98//! };
99//! ```
100//!
101//! # Generated Client Usage
102//!
103//! ```rust,ignore
104//! use crate::generated::client::HttpClient;
105//!
106//! #[tokio::main]
107//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
108//! // Create client with retry and tracing
109//! let client = HttpClient::new()
110//! .with_base_url("https://api.example.com".to_string())
111//! .with_api_key("your-api-key".to_string())
112//! .with_header("X-Custom-Header".to_string(), "value".to_string());
113//!
114//! // Make API calls - retries happen automatically
115//! let items = client.list_items().await?;
116//! println!("Found {} items", items.items.len());
117//!
118//! Ok(())
119//! }
120//! ```
121//!
122//! # HTTP Method Support
123//!
124//! The generator supports all standard HTTP methods:
125//! - `GET` - List and retrieve operations
126//! - `POST` - Create operations
127//! - `PUT` - Full update operations
128//! - `PATCH` - Partial update operations
129//! - `DELETE` - Delete operations
130//!
131//! # Error Handling
132//!
133//! All generated methods return `Result<T, HttpError>` where `HttpError` provides:
134//! - Detailed error information
135//! - Retry detection via `is_retryable()`
136//! - Error categorization (client errors, server errors)
137//!
138//! See [`http_error`](crate::http_error) module for details.
139//!
140//! # Implementation Details
141//!
142//! The generator uses the following approach:
143//! 1. Analyzes OpenAPI operations to extract HTTP methods, paths, parameters
144//! 2. Generates typed request/response handling
145//! 3. Creates method signatures with proper parameter types
146//! 4. Generates path parameter substitution
147//! 5. Handles query parameters and request bodies
148//! 6. Configures middleware stack based on generator config
149
150use crate::analysis::{OperationInfo, SchemaAnalysis};
151use crate::generator::CodeGenerator;
152use heck::ToSnakeCase;
153use proc_macro2::TokenStream;
154use quote::quote;
155
156impl CodeGenerator {
157 /// Generate the HTTP client struct with middleware support
158 pub fn generate_http_client_struct(&self) -> TokenStream {
159 let has_retry = self.config().retry_config.is_some();
160 let has_tracing = self.config().tracing_enabled;
161
162 // Generate RetryConfig struct if needed
163 let retry_config_struct = if has_retry {
164 quote! {
165 /// Retry configuration for HTTP requests
166 #[derive(Debug, Clone)]
167 pub struct RetryConfig {
168 pub max_retries: u32,
169 pub initial_delay_ms: u64,
170 pub max_delay_ms: u64,
171 }
172
173 impl Default for RetryConfig {
174 fn default() -> Self {
175 Self {
176 max_retries: 3,
177 initial_delay_ms: 500,
178 max_delay_ms: 16000,
179 }
180 }
181 }
182 }
183 } else {
184 quote! {}
185 };
186
187 // Generate the main HttpClient struct
188 let client_struct = quote! {
189 use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
190 use std::collections::BTreeMap;
191
192 /// HTTP client for making API requests
193 #[derive(Clone)]
194 pub struct HttpClient {
195 base_url: String,
196 api_key: Option<String>,
197 http_client: ClientWithMiddleware,
198 custom_headers: BTreeMap<String, String>,
199 }
200 };
201
202 // Generate constructor
203 let constructor = self.generate_constructor(has_retry, has_tracing);
204
205 // Generate builder methods
206 let builder_methods = self.generate_builder_methods();
207
208 // Generate Default implementation
209 let default_impl = quote! {
210 impl Default for HttpClient {
211 fn default() -> Self {
212 Self::new()
213 }
214 }
215 };
216
217 // Combine all parts
218 quote! {
219 #retry_config_struct
220 #client_struct
221
222 impl HttpClient {
223 #constructor
224 #builder_methods
225 }
226
227 #default_impl
228 }
229 }
230
231 /// Generate the constructor method
232 fn generate_constructor(&self, has_retry: bool, has_tracing: bool) -> TokenStream {
233 let retry_param = if has_retry {
234 quote! { retry_config: Option<RetryConfig>, }
235 } else {
236 quote! {}
237 };
238
239 let tracing_param = if has_tracing {
240 quote! { enable_tracing: bool, }
241 } else {
242 quote! {}
243 };
244
245 let retry_middleware = if has_retry {
246 quote! {
247 if let Some(config) = retry_config {
248 use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
249
250 let retry_policy = ExponentialBackoff::builder()
251 .retry_bounds(
252 std::time::Duration::from_millis(config.initial_delay_ms),
253 std::time::Duration::from_millis(config.max_delay_ms),
254 )
255 .build_with_max_retries(config.max_retries);
256
257 let retry_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
258 client_builder = client_builder.with(retry_middleware);
259 }
260 }
261 } else {
262 quote! {}
263 };
264
265 let tracing_middleware = if has_tracing {
266 quote! {
267 if enable_tracing {
268 use reqwest_tracing::TracingMiddleware;
269 client_builder = client_builder.with(TracingMiddleware::default());
270 }
271 }
272 } else {
273 quote! {}
274 };
275
276 let default_constructor = if has_retry && has_tracing {
277 quote! {
278 /// Create a new HTTP client with default configuration
279 pub fn new() -> Self {
280 Self::with_config(None, true)
281 }
282 }
283 } else if has_retry {
284 quote! {
285 /// Create a new HTTP client with default configuration
286 pub fn new() -> Self {
287 Self::with_config(None)
288 }
289 }
290 } else if has_tracing {
291 quote! {
292 /// Create a new HTTP client with default configuration
293 pub fn new() -> Self {
294 Self::with_config(true)
295 }
296 }
297 } else {
298 quote! {
299 /// Create a new HTTP client with default configuration
300 pub fn new() -> Self {
301 let reqwest_client = reqwest::Client::new();
302 let client_builder = ClientBuilder::new(reqwest_client);
303 let http_client = client_builder.build();
304
305 Self {
306 base_url: String::new(),
307 api_key: None,
308 http_client,
309 custom_headers: BTreeMap::new(),
310 }
311 }
312 }
313 };
314
315 if has_retry || has_tracing {
316 quote! {
317 #default_constructor
318
319 /// Create a new HTTP client with custom configuration
320 pub fn with_config(#retry_param #tracing_param) -> Self {
321 let reqwest_client = reqwest::Client::new();
322 let mut client_builder = ClientBuilder::new(reqwest_client);
323
324 #tracing_middleware
325 #retry_middleware
326
327 let http_client = client_builder.build();
328
329 Self {
330 base_url: String::new(),
331 api_key: None,
332 http_client,
333 custom_headers: BTreeMap::new(),
334 }
335 }
336 }
337 } else {
338 default_constructor
339 }
340 }
341
342 /// Generate builder methods for configuration
343 fn generate_builder_methods(&self) -> TokenStream {
344 quote! {
345 /// Set the base URL for all requests
346 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
347 self.base_url = base_url.into();
348 self
349 }
350
351 /// Set the API key for authentication
352 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
353 self.api_key = Some(api_key.into());
354 self
355 }
356
357 /// Add a custom header to all requests
358 pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
359 self.custom_headers.insert(name.into(), value.into());
360 self
361 }
362
363 /// Add multiple custom headers
364 pub fn with_headers(mut self, headers: BTreeMap<String, String>) -> Self {
365 self.custom_headers.extend(headers);
366 self
367 }
368 }
369 }
370
371 /// Generate HTTP operation methods for the client
372 pub fn generate_operation_methods(&self, analysis: &SchemaAnalysis) -> TokenStream {
373 let methods: Vec<TokenStream> = analysis
374 .operations
375 .values()
376 .map(|op| self.generate_single_operation_method(op))
377 .collect();
378
379 quote! {
380 impl HttpClient {
381 #(#methods)*
382 }
383 }
384 }
385
386 /// Generate a single operation method
387 fn generate_single_operation_method(&self, op: &OperationInfo) -> TokenStream {
388 let method_name = self.get_method_name(op);
389 let http_method = self.get_http_method(op);
390 let path = &op.path;
391 let request_param = self.generate_request_param(op);
392 let request_body = self.generate_request_body(op);
393 let query_params = self.generate_query_params(op);
394 let response_type = self.get_response_type(op);
395 let has_response_body = self.get_success_response_schema(op).is_some();
396 let error_handling = self.generate_error_handling(has_response_body);
397 let url_construction = self.generate_url_construction(path, op);
398 let doc_comment = self.generate_operation_doc_comment(op);
399
400 quote! {
401 #doc_comment
402 pub async fn #method_name(
403 &self,
404 #request_param
405 ) -> HttpResult<#response_type> {
406 #url_construction
407
408 let mut req = self.http_client
409 .#http_method(request_url)
410 #request_body;
411
412 #query_params
413
414 // Add API key if configured
415 if let Some(api_key) = &self.api_key {
416 req = req.bearer_auth(api_key);
417 }
418
419 // Add custom headers
420 for (name, value) in &self.custom_headers {
421 req = req.header(name, value);
422 }
423
424 let response = req.send().await?;
425 #error_handling
426 }
427 }
428 }
429
430 /// Generate query parameter handling
431 fn generate_query_params(&self, op: &OperationInfo) -> TokenStream {
432 let query_params: Vec<_> = op
433 .parameters
434 .iter()
435 .filter(|p| p.location == "query")
436 .collect();
437
438 if query_params.is_empty() {
439 return quote! {};
440 }
441
442 let mut param_building = Vec::new();
443
444 for param in query_params {
445 // Use snake_case for Rust variable name with keyword escaping
446 let param_name_snake = self.sanitize_param_name(¶m.name);
447 let param_name = Self::to_field_ident(¶m_name_snake);
448
449 // Use the original parameter name from OpenAPI spec as the query string key
450 let param_key = ¶m.name;
451
452 if param.required {
453 // Required parameters: always add
454 if param.rust_type == "String" {
455 param_building.push(quote! {
456 query_params.push((#param_key, #param_name.as_ref().to_string()));
457 });
458 } else {
459 param_building.push(quote! {
460 query_params.push((#param_key, #param_name.to_string()));
461 });
462 }
463 } else {
464 // Optional parameters: add only if Some
465 if param.rust_type == "String" {
466 param_building.push(quote! {
467 if let Some(v) = #param_name {
468 query_params.push((#param_key, v.as_ref().to_string()));
469 }
470 });
471 } else {
472 param_building.push(quote! {
473 if let Some(v) = #param_name {
474 query_params.push((#param_key, v.to_string()));
475 }
476 });
477 }
478 }
479 }
480
481 quote! {
482 // Add query parameters
483 {
484 let mut query_params: Vec<(&str, String)> = Vec::new();
485 #(#param_building)*
486 if !query_params.is_empty() {
487 req = req.query(&query_params);
488 }
489 }
490 }
491 }
492
493 /// Generate documentation comment for the operation
494 fn generate_operation_doc_comment(&self, op: &OperationInfo) -> TokenStream {
495 let method = op.method.to_uppercase();
496 let path = &op.path;
497 let doc = format!("{} {}", method, path);
498
499 quote! {
500 #[doc = #doc]
501 }
502 }
503
504 /// Get the method name from the operation
505 fn get_method_name(&self, op: &OperationInfo) -> syn::Ident {
506 let name = if !op.operation_id.is_empty() {
507 op.operation_id.to_snake_case()
508 } else {
509 // Fallback: generate from HTTP method and path
510 format!(
511 "{}_{}",
512 op.method,
513 op.path.replace('/', "_").replace(['{', '}'], "")
514 )
515 .to_snake_case()
516 };
517
518 syn::Ident::new(&name, proc_macro2::Span::call_site())
519 }
520
521 /// Get the HTTP method
522 fn get_http_method(&self, op: &OperationInfo) -> syn::Ident {
523 let method = match op.method.to_uppercase().as_str() {
524 "GET" => "get",
525 "POST" => "post",
526 "PUT" => "put",
527 "DELETE" => "delete",
528 "PATCH" => "patch",
529 _ => "get", // Default fallback
530 };
531
532 syn::Ident::new(method, proc_macro2::Span::call_site())
533 }
534
535 /// Generate request parameters including path parameters, query parameters, and request body
536 fn generate_request_param(&self, op: &OperationInfo) -> TokenStream {
537 let mut params = Vec::new();
538
539 // Add path parameters
540 for param in &op.parameters {
541 if param.location == "path" {
542 let param_name_snake = self.sanitize_param_name(¶m.name);
543 let param_name = Self::to_field_ident(¶m_name_snake);
544 let param_type = self.get_param_rust_type(param);
545 params.push(quote! { #param_name: #param_type });
546 }
547 }
548
549 // Add query parameters (all as Option<T>)
550 for param in &op.parameters {
551 if param.location == "query" {
552 let param_name_snake = self.sanitize_param_name(¶m.name);
553 let param_name = Self::to_field_ident(¶m_name_snake);
554 let param_type = self.get_param_rust_type(param);
555
556 // Query parameters should be Option unless explicitly required
557 if param.required {
558 params.push(quote! { #param_name: #param_type });
559 } else {
560 params.push(quote! { #param_name: Option<#param_type> });
561 }
562 }
563 }
564
565 // Add request body parameter based on content type
566 if let Some(ref rb) = op.request_body {
567 use crate::analysis::RequestBodyContent;
568 match rb {
569 RequestBodyContent::Json { schema_name }
570 | RequestBodyContent::FormUrlEncoded { schema_name } => {
571 let request_ident =
572 syn::Ident::new(schema_name, proc_macro2::Span::call_site());
573 params.push(quote! { request: #request_ident });
574 }
575 RequestBodyContent::Multipart => {
576 params.push(quote! { form: reqwest::multipart::Form });
577 }
578 RequestBodyContent::OctetStream => {
579 params.push(quote! { body: Vec<u8> });
580 }
581 RequestBodyContent::TextPlain => {
582 params.push(quote! { body: String });
583 }
584 }
585 }
586
587 if params.is_empty() {
588 quote! {}
589 } else {
590 quote! { #(#params),* }
591 }
592 }
593
594 /// Get the Rust type for a parameter
595 fn get_param_rust_type(&self, param: &crate::analysis::ParameterInfo) -> TokenStream {
596 let type_str = ¶m.rust_type;
597 match type_str.as_str() {
598 "String" => quote! { impl AsRef<str> },
599 "i64" => quote! { i64 },
600 "i32" => quote! { i32 },
601 "f64" => quote! { f64 },
602 "bool" => quote! { bool },
603 _ => {
604 let type_ident = syn::Ident::new(type_str, proc_macro2::Span::call_site());
605 quote! { #type_ident }
606 }
607 }
608 }
609
610 /// Generate request body serialization based on content type
611 fn generate_request_body(&self, op: &OperationInfo) -> TokenStream {
612 if let Some(ref rb) = op.request_body {
613 use crate::analysis::RequestBodyContent;
614 match rb {
615 RequestBodyContent::Json { .. } => {
616 quote! {
617 .body(serde_json::to_vec(&request).map_err(HttpError::serialization_error)?)
618 .header("content-type", "application/json")
619 }
620 }
621 RequestBodyContent::FormUrlEncoded { .. } => {
622 quote! {
623 .body(serde_urlencoded::to_string(&request).map_err(HttpError::serialization_error)?)
624 .header("content-type", "application/x-www-form-urlencoded")
625 }
626 }
627 RequestBodyContent::Multipart => {
628 quote! {
629 .multipart(form)
630 }
631 }
632 RequestBodyContent::OctetStream => {
633 quote! {
634 .body(body)
635 .header("content-type", "application/octet-stream")
636 }
637 }
638 RequestBodyContent::TextPlain => {
639 quote! {
640 .body(body)
641 .header("content-type", "text/plain")
642 }
643 }
644 }
645 } else {
646 quote! {}
647 }
648 }
649
650 /// Find the success (2xx) response schema name, if any.
651 ///
652 /// Only considers 2xx status codes. Error schemas (4xx, 5xx) are ignored
653 /// so that endpoints like 204 No Content correctly return `()` instead of
654 /// accidentally picking up the error schema (e.g. `BadRequestError`).
655 fn get_success_response_schema<'a>(&self, op: &'a OperationInfo) -> Option<&'a String> {
656 op.response_schemas
657 .get("200")
658 .or_else(|| op.response_schemas.get("201"))
659 .or_else(|| {
660 op.response_schemas
661 .iter()
662 .find(|(code, _)| code.starts_with('2'))
663 .map(|(_, v)| v)
664 })
665 }
666
667 /// Get response type
668 fn get_response_type(&self, op: &OperationInfo) -> TokenStream {
669 if let Some(response_type) = self.get_success_response_schema(op) {
670 // Convert schema name to Rust type name (handles underscores, etc.)
671 let rust_type_name = self.to_rust_type_name(response_type);
672 let response_ident = syn::Ident::new(&rust_type_name, proc_macro2::Span::call_site());
673 quote! { #response_ident }
674 } else {
675 quote! { () }
676 }
677 }
678
679 /// Generate error handling.
680 ///
681 /// When `has_response_body` is false the endpoint returns no JSON body
682 /// (e.g. 204 No Content) and we skip deserialization entirely.
683 fn generate_error_handling(&self, has_response_body: bool) -> TokenStream {
684 let success_branch = if has_response_body {
685 quote! {
686 let body = response.json().await
687 .map_err(HttpError::deserialization_error)?;
688 Ok(body)
689 }
690 } else {
691 quote! {
692 Ok(())
693 }
694 };
695
696 quote! {
697 let status = response.status();
698
699 if status.is_success() {
700 #success_branch
701 } else {
702 let status_code = status.as_u16();
703 let message = status.canonical_reason().unwrap_or("Unknown error");
704 let body = response.text().await.ok();
705 Err(HttpError::from_status(status_code, message, body))
706 }
707 }
708 }
709
710 /// Generate URL construction with path parameter substitution
711 fn generate_url_construction(&self, path: &str, op: &OperationInfo) -> TokenStream {
712 // Check if path has parameters (contains {...})
713 if path.contains('{') {
714 self.generate_url_with_params(path, op)
715 } else {
716 quote! {
717 let request_url = format!("{}{}", self.base_url, #path);
718 }
719 }
720 }
721
722 /// Generate URL with path parameters
723 fn generate_url_with_params(&self, path: &str, op: &OperationInfo) -> TokenStream {
724 // Parse path to find all parameter placeholders
725 let mut format_string = path.to_string();
726 let mut format_args = Vec::new();
727
728 // Find all path parameters in the operation
729 let path_params: Vec<_> = op
730 .parameters
731 .iter()
732 .filter(|p| p.location == "path")
733 .collect();
734
735 // Replace {paramName} with {} and collect parameter names for format args
736 for param in &path_params {
737 let placeholder = format!("{{{}}}", param.name);
738 if format_string.contains(&placeholder) {
739 format_string = format_string.replace(&placeholder, "{}");
740
741 // Use snake_case for the Rust variable name with keyword escaping
742 let param_name_snake = self.sanitize_param_name(¶m.name);
743 let param_ident = Self::to_field_ident(¶m_name_snake);
744
745 // Use .as_ref() for string types to handle impl AsRef<str>
746 if param.rust_type == "String" {
747 format_args.push(quote! { #param_ident.as_ref() });
748 } else {
749 format_args.push(quote! { #param_ident });
750 }
751 }
752 }
753
754 if format_args.is_empty() {
755 // No path parameters found, use simple format
756 quote! {
757 let request_url = format!("{}{}", self.base_url, #path);
758 }
759 } else {
760 // Build format call with path parameters
761 quote! {
762 let request_url = format!("{}{}", self.base_url, format!(#format_string, #(#format_args),*));
763 }
764 }
765 }
766
767 /// Sanitize a parameter name by escaping Rust reserved keywords with raw identifiers
768 fn sanitize_param_name(&self, name: &str) -> String {
769 let snake_case = name.to_snake_case();
770 if Self::is_rust_keyword(&snake_case) {
771 format!("r#{snake_case}")
772 } else {
773 snake_case
774 }
775 }
776}