Skip to main content

openapi_to_rust/
config.rs

1//! TOML configuration file support for OpenAPI code generation.
2//!
3//! This module provides TOML-based configuration as an alternative to the Rust API.
4//! It enables CLI-based code generation without requiring the generator as a build dependency.
5//!
6//! # Overview
7//!
8//! The TOML configuration system provides:
9//! - Declarative configuration in `openapi-to-rust.toml` files
10//! - Comprehensive validation with helpful error messages
11//! - Support for all generator features (HTTP client, retry, tracing, Specta)
12//! - Conversion to internal [`GeneratorConfig`] for code generation
13//!
14//! # Quick Start
15//!
16//! Create an `openapi-to-rust.toml` file:
17//!
18//! ```toml
19//! [generator]
20//! spec_path = "openapi.json"
21//! output_dir = "src/generated"
22//! module_name = "api"
23//!
24//! [features]
25//! enable_async_client = true
26//!
27//! [http_client]
28//! base_url = "https://api.example.com"
29//! timeout_seconds = 30
30//!
31//! [http_client.retry]
32//! max_retries = 3
33//! initial_delay_ms = 500
34//! max_delay_ms = 16000
35//! ```
36//!
37//! Load and use the configuration:
38//!
39//! ```no_run
40//! use openapi_to_rust::config::ConfigFile;
41//! use std::path::Path;
42//!
43//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
44//! // Load configuration from TOML file
45//! let config_file = ConfigFile::load(Path::new("openapi-to-rust.toml"))?;
46//!
47//! // Convert to internal GeneratorConfig
48//! let generator_config = config_file.into_generator_config();
49//!
50//! // Use with CodeGenerator...
51//! # Ok(())
52//! # }
53//! ```
54//!
55//! # Configuration Sections
56//!
57//! ## Generator Section (Required)
58//!
59//! ```toml
60//! [generator]
61//! spec_path = "openapi.json"       # Path to OpenAPI spec
62//! output_dir = "src/generated"     # Output directory
63//! module_name = "api"              # Module name
64//! ```
65//!
66//! ## Features Section (Optional)
67//!
68//! ```toml
69//! [features]
70//! enable_sse_client = true         # Generate SSE streaming client
71//! enable_async_client = true       # Generate HTTP REST client
72//! enable_specta = false            # Add specta::Type derives
73//! ```
74//!
75//! ## HTTP Client Section (Optional)
76//!
77//! ```toml
78//! [http_client]
79//! base_url = "https://api.example.com"
80//! timeout_seconds = 30
81//!
82//! [http_client.retry]
83//! max_retries = 3                  # 0-10 retries
84//! initial_delay_ms = 500           # 100-10000ms
85//! max_delay_ms = 16000             # 1000-300000ms
86//!
87//! [http_client.tracing]
88//! enabled = true                   # Enable request tracing (default: true)
89//!
90//! [http_client.auth]
91//! type = "Bearer"                  # Bearer, ApiKey, or Custom
92//! header_name = "Authorization"
93//!
94//! [[http_client.headers]]
95//! name = "content-type"
96//! value = "application/json"
97//! ```
98//!
99//! # Validation
100//!
101//! The configuration is validated on load using the `validator` crate:
102//! - File paths are checked for existence
103//! - Numeric ranges are enforced (timeout, retry counts, delays)
104//! - Enum values are validated (auth types, event flow types)
105//! - Required fields are checked
106//!
107//! Invalid configurations produce helpful error messages:
108//!
109//! ```text
110//! Configuration validation failed:
111//!   - generator.spec_path: OpenAPI spec file not found: missing.json
112//!   - http_client.retry.max_retries: max_retries must be between 0 and 10
113//! ```
114//!
115//! # Examples
116//!
117//! See the [examples](https://github.com/your-repo/examples) directory for complete examples:
118//! - `toml_config_example.rs` - Various configuration patterns
119//! - `complete_workflow.rs` - Full generation workflow with TOML
120//!
121//! # Backward Compatibility
122//!
123//! The TOML configuration is fully optional. The existing Rust API continues to work:
124//!
125//! ```no_run
126//! use openapi_to_rust::{GeneratorConfig, CodeGenerator};
127//! use std::path::PathBuf;
128//!
129//! let config = GeneratorConfig {
130//!     spec_path: PathBuf::from("openapi.json"),
131//!     enable_async_client: true,
132//!     // ... other fields
133//!     ..Default::default()
134//! };
135//!
136//! let generator = CodeGenerator::new(config);
137//! // ... generate code
138//! ```
139
140use crate::{GeneratorError, generator::GeneratorConfig};
141use serde::{Deserialize, Serialize};
142use std::collections::BTreeMap;
143use std::path::{Path, PathBuf};
144use validator::Validate;
145
146/// Root configuration loaded from TOML file
147#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
148pub struct ConfigFile {
149    #[validate(nested)]
150    pub generator: GeneratorSection,
151    #[validate(nested)]
152    pub features: FeaturesSection,
153    #[serde(default)]
154    #[validate(nested)]
155    pub http_client: Option<HttpClientSection>,
156    #[serde(default)]
157    #[validate(nested)]
158    pub streaming: Option<StreamingSection>,
159    #[serde(default)]
160    pub nullable_overrides: BTreeMap<String, bool>,
161    #[serde(default)]
162    pub type_mappings: BTreeMap<String, String>,
163}
164
165#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
166pub struct GeneratorSection {
167    #[validate(custom(function = "validate_spec_path_exists"))]
168    pub spec_path: PathBuf,
169    pub output_dir: PathBuf,
170    #[validate(length(min = 1, message = "module_name cannot be empty"))]
171    pub module_name: String,
172}
173
174#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
175pub struct FeaturesSection {
176    #[serde(default)]
177    pub enable_sse_client: bool,
178    #[serde(default)]
179    pub enable_async_client: bool,
180    #[serde(default)]
181    pub enable_specta: bool,
182}
183
184#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
185pub struct HttpClientSection {
186    #[validate(url(message = "base_url must be a valid URL"))]
187    pub base_url: Option<String>,
188    #[validate(custom(function = "validate_timeout_seconds"))]
189    pub timeout_seconds: Option<u64>,
190    #[validate(nested)]
191    pub auth: Option<AuthConfigSection>,
192    #[serde(default)]
193    #[validate(nested)]
194    pub headers: Vec<HeaderEntry>,
195    #[validate(nested)]
196    pub retry: Option<RetryConfigSection>,
197    #[validate(nested)]
198    pub tracing: Option<TracingConfigSection>,
199}
200
201#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
202pub struct TracingConfigSection {
203    #[serde(default = "default_tracing_enabled")]
204    pub enabled: bool,
205}
206
207fn default_tracing_enabled() -> bool {
208    true
209}
210
211#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
212pub struct RetryConfigSection {
213    #[serde(default = "default_max_retries")]
214    #[validate(custom(function = "validate_max_retries"))]
215    pub max_retries: u32,
216    #[serde(default = "default_initial_delay_ms")]
217    #[validate(custom(function = "validate_initial_delay_ms"))]
218    pub initial_delay_ms: u64,
219    #[serde(default = "default_max_delay_ms")]
220    #[validate(custom(function = "validate_max_delay_ms"))]
221    pub max_delay_ms: u64,
222}
223
224fn default_max_retries() -> u32 {
225    3
226}
227fn default_initial_delay_ms() -> u64 {
228    500
229}
230fn default_max_delay_ms() -> u64 {
231    16000
232}
233
234#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
235pub struct AuthConfigSection {
236    #[serde(rename = "type")]
237    #[validate(custom(function = "validate_auth_type"))]
238    pub auth_type: String,
239    #[validate(length(min = 1, message = "header_name cannot be empty"))]
240    pub header_name: String,
241}
242
243#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
244pub struct HeaderEntry {
245    #[validate(length(min = 1, message = "header name cannot be empty"))]
246    pub name: String,
247    pub value: String,
248}
249
250#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
251pub struct StreamingSection {
252    #[validate(nested)]
253    pub endpoints: Vec<StreamingEndpointSection>,
254}
255
256#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
257pub struct StreamingEndpointSection {
258    #[validate(length(min = 1))]
259    pub operation_id: String,
260    #[validate(length(min = 1))]
261    pub path: String,
262    /// HTTP method: "GET" or "POST" (default: POST)
263    #[serde(default)]
264    pub http_method: Option<String>,
265    /// Parameter name that controls streaming (only for POST requests)
266    #[serde(default)]
267    pub stream_parameter: String,
268    /// Query parameters for GET requests
269    #[serde(default)]
270    pub query_parameters: Vec<QueryParameterSection>,
271    #[validate(length(min = 1))]
272    pub event_union_type: String,
273    pub content_type: Option<String>,
274    #[validate(nested)]
275    pub event_flow: Option<EventFlowSection>,
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
279pub struct QueryParameterSection {
280    #[validate(length(min = 1))]
281    pub name: String,
282    #[serde(default)]
283    pub required: bool,
284}
285
286#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
287pub struct EventFlowSection {
288    #[serde(rename = "type")]
289    #[validate(custom(function = "validate_event_flow_type"))]
290    pub flow_type: String,
291    pub start_events: Option<Vec<String>>,
292    pub delta_events: Option<Vec<String>>,
293    pub stop_events: Option<Vec<String>>,
294}
295
296// Custom validators
297fn validate_spec_path_exists(path: &Path) -> Result<(), validator::ValidationError> {
298    if !path.exists() {
299        let mut error = validator::ValidationError::new("file_not_found");
300        error.message = Some(
301            format!(
302                "OpenAPI spec file not found: {}. Ensure spec_path points to a valid OpenAPI JSON or YAML file.",
303                path.display()
304            )
305            .into(),
306        );
307        return Err(error);
308    }
309    Ok(())
310}
311
312fn validate_auth_type(auth_type: &str) -> Result<(), validator::ValidationError> {
313    match auth_type {
314        "Bearer" | "ApiKey" | "Custom" => Ok(()),
315        _ => {
316            let mut error = validator::ValidationError::new("invalid_auth_type");
317            error.message = Some(
318                format!(
319                    "Invalid auth type '{}'. Must be one of: Bearer, ApiKey, Custom",
320                    auth_type
321                )
322                .into(),
323            );
324            Err(error)
325        }
326    }
327}
328
329fn validate_event_flow_type(flow_type: &str) -> Result<(), validator::ValidationError> {
330    match flow_type {
331        "StartDeltaStop" | "Continuous" => Ok(()),
332        _ => {
333            let mut error = validator::ValidationError::new("invalid_event_flow_type");
334            error.message = Some(
335                format!(
336                    "Invalid event flow type '{}'. Must be one of: StartDeltaStop, Continuous",
337                    flow_type
338                )
339                .into(),
340            );
341            Err(error)
342        }
343    }
344}
345
346fn validate_timeout_seconds(timeout: u64) -> Result<(), validator::ValidationError> {
347    if !(1..=3600).contains(&timeout) {
348        let mut error = validator::ValidationError::new("out_of_range");
349        error.message = Some("timeout_seconds must be between 1 and 3600".into());
350        return Err(error);
351    }
352    Ok(())
353}
354
355fn validate_max_retries(retries: u32) -> Result<(), validator::ValidationError> {
356    if retries > 10 {
357        let mut error = validator::ValidationError::new("out_of_range");
358        error.message = Some("max_retries must be between 0 and 10".into());
359        return Err(error);
360    }
361    Ok(())
362}
363
364fn validate_initial_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
365    if !(100..=10000).contains(&delay) {
366        let mut error = validator::ValidationError::new("out_of_range");
367        error.message = Some("initial_delay_ms must be between 100 and 10000".into());
368        return Err(error);
369    }
370    Ok(())
371}
372
373fn validate_max_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
374    if !(1000..=300000).contains(&delay) {
375        let mut error = validator::ValidationError::new("out_of_range");
376        error.message = Some("max_delay_ms must be between 1000 and 300000".into());
377        return Err(error);
378    }
379    Ok(())
380}
381
382impl ConfigFile {
383    /// Load and validate configuration from TOML file
384    pub fn load(path: &Path) -> Result<Self, GeneratorError> {
385        let content = std::fs::read_to_string(path).map_err(|e| GeneratorError::FileError {
386            message: format!("Failed to read config file '{}': {}", path.display(), e),
387        })?;
388
389        let config: ConfigFile =
390            toml::from_str(&content).map_err(|e| GeneratorError::FileError {
391                message: format!(
392                    "Failed to parse TOML config: {}\n\nExample config:\n{}",
393                    e, EXAMPLE_CONFIG
394                ),
395            })?;
396
397        // Validate the configuration
398        config.validate().map_err(|e| {
399            GeneratorError::ValidationError(format!(
400                "Configuration validation failed:\n{}",
401                format_validation_errors(&e)
402            ))
403        })?;
404
405        Ok(config)
406    }
407
408    /// Convert to internal GeneratorConfig
409    pub fn into_generator_config(self) -> GeneratorConfig {
410        use crate::http_config::{AuthConfig, HttpClientConfig, RetryConfig};
411
412        // Convert HTTP client config
413        let http_client_config = self.http_client.as_ref().map(|http| HttpClientConfig {
414            base_url: http.base_url.clone(),
415            timeout_seconds: http.timeout_seconds,
416            default_headers: http
417                .headers
418                .iter()
419                .map(|h| (h.name.clone(), h.value.clone()))
420                .collect(),
421        });
422
423        // Convert retry config
424        let retry_config = self
425            .http_client
426            .as_ref()
427            .and_then(|http| http.retry.as_ref())
428            .map(|retry| RetryConfig {
429                max_retries: retry.max_retries,
430                initial_delay_ms: retry.initial_delay_ms,
431                max_delay_ms: retry.max_delay_ms,
432            });
433
434        // Convert tracing config
435        let tracing_enabled = self
436            .http_client
437            .as_ref()
438            .and_then(|http| http.tracing.as_ref())
439            .map(|tracing| tracing.enabled)
440            .unwrap_or(true);
441
442        // Convert auth config
443        let auth_config = self
444            .http_client
445            .as_ref()
446            .and_then(|http| http.auth.as_ref())
447            .map(|auth| match auth.auth_type.as_str() {
448                "Bearer" => AuthConfig::Bearer {
449                    header_name: auth.header_name.clone(),
450                },
451                "ApiKey" => AuthConfig::ApiKey {
452                    header_name: auth.header_name.clone(),
453                },
454                "Custom" => AuthConfig::Custom {
455                    header_name: auth.header_name.clone(),
456                    header_value_prefix: None,
457                },
458                _ => AuthConfig::Bearer {
459                    header_name: "Authorization".to_string(),
460                },
461            });
462
463        // Convert streaming section to StreamingConfig
464        let streaming_config = self.streaming.map(|section| {
465            use crate::streaming::{
466                EventFlow, HttpMethod, QueryParameter, StreamingConfig, StreamingEndpoint,
467            };
468
469            let endpoints = section
470                .endpoints
471                .into_iter()
472                .map(|e| {
473                    let event_flow = e
474                        .event_flow
475                        .map(|ef| match ef.flow_type.as_str() {
476                            "start_delta_stop" => EventFlow::StartDeltaStop {
477                                start_events: ef.start_events.unwrap_or_default(),
478                                delta_events: ef.delta_events.unwrap_or_default(),
479                                stop_events: ef.stop_events.unwrap_or_default(),
480                            },
481                            _ => EventFlow::Simple,
482                        })
483                        .unwrap_or(EventFlow::Simple);
484
485                    let http_method = e
486                        .http_method
487                        .map(|m| match m.to_uppercase().as_str() {
488                            "GET" => HttpMethod::Get,
489                            _ => HttpMethod::Post,
490                        })
491                        .unwrap_or(HttpMethod::Post);
492
493                    let query_parameters = e
494                        .query_parameters
495                        .into_iter()
496                        .map(|qp| QueryParameter {
497                            name: qp.name,
498                            required: qp.required,
499                        })
500                        .collect();
501
502                    StreamingEndpoint {
503                        operation_id: e.operation_id,
504                        path: e.path,
505                        http_method,
506                        stream_parameter: e.stream_parameter,
507                        query_parameters,
508                        event_union_type: e.event_union_type,
509                        content_type: e.content_type,
510                        event_flow,
511                        ..Default::default()
512                    }
513                })
514                .collect();
515
516            StreamingConfig {
517                endpoints,
518                ..Default::default()
519            }
520        });
521
522        GeneratorConfig {
523            spec_path: self.generator.spec_path,
524            output_dir: self.generator.output_dir,
525            module_name: self.generator.module_name,
526            enable_sse_client: self.features.enable_sse_client,
527            enable_async_client: self.features.enable_async_client,
528            enable_specta: self.features.enable_specta,
529            type_mappings: if self.type_mappings.is_empty() {
530                super::generator::default_type_mappings()
531            } else {
532                self.type_mappings
533            },
534            streaming_config,
535            nullable_field_overrides: self.nullable_overrides,
536            schema_extensions: vec![],
537            http_client_config,
538            retry_config,
539            tracing_enabled,
540            auth_config,
541        }
542    }
543}
544
545/// Format validator errors into a readable message
546fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
547    let mut messages = Vec::new();
548
549    // Handle direct field errors
550    for (field, field_errors) in errors.field_errors() {
551        for error in field_errors {
552            let msg = if let Some(message) = &error.message {
553                format!("  - {}: {}", field, message)
554            } else {
555                format!("  - {}: validation failed (code: {})", field, error.code)
556            };
557            messages.push(msg);
558        }
559    }
560
561    // Handle nested errors
562    for (field, nested_errors) in errors.errors() {
563        if let validator::ValidationErrorsKind::Struct(struct_errors) = nested_errors {
564            let nested_msgs = format_validation_errors(struct_errors);
565            if !nested_msgs.is_empty() {
566                messages.push(format!("  - {} (nested):\n{}", field, nested_msgs));
567            }
568        }
569    }
570
571    messages.join("\n")
572}
573
574const EXAMPLE_CONFIG: &str = r#"[generator]
575spec_path = "openapi.json"
576output_dir = "src/generated"
577module_name = "types"
578
579[features]
580enable_async_client = true
581
582[http_client]
583base_url = "https://api.example.com"
584timeout_seconds = 30
585
586[http_client.retry]
587max_retries = 3
588
589[http_client.auth]
590type = "Bearer"
591header_name = "Authorization""#;