Skip to main content

openapi_to_rust/
config.rs

1//! TOML configuration file support for OpenAPI code generation.
2//!
3//! This module provides TOML-based configuration as an alternative to the Rust API.
4//! It enables CLI-based code generation without requiring the generator as a build dependency.
5//!
6//! # Overview
7//!
8//! The TOML configuration system provides:
9//! - Declarative configuration in `openapi-to-rust.toml` files
10//! - Comprehensive validation with helpful error messages
11//! - Support for all generator features (HTTP client, retry, tracing, Specta)
12//! - Conversion to internal [`GeneratorConfig`] for code generation
13//!
14//! # Quick Start
15//!
16//! Create an `openapi-to-rust.toml` file:
17//!
18//! ```toml
19//! [generator]
20//! spec_path = "openapi.json"
21//! output_dir = "src/generated"
22//! module_name = "api"
23//!
24//! [features]
25//! enable_async_client = true
26//!
27//! [http_client]
28//! base_url = "https://api.example.com"
29//! timeout_seconds = 30
30//!
31//! [http_client.retry]
32//! max_retries = 3
33//! initial_delay_ms = 500
34//! max_delay_ms = 16000
35//! ```
36//!
37//! Load and use the configuration:
38//!
39//! ```no_run
40//! use openapi_to_rust::config::ConfigFile;
41//! use std::path::Path;
42//!
43//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
44//! // Load configuration from TOML file
45//! let config_file = ConfigFile::load(Path::new("openapi-to-rust.toml"))?;
46//!
47//! // Convert to internal GeneratorConfig
48//! let generator_config = config_file.into_generator_config();
49//!
50//! // Use with CodeGenerator...
51//! # Ok(())
52//! # }
53//! ```
54//!
55//! # Configuration Sections
56//!
57//! ## Generator Section (Required)
58//!
59//! ```toml
60//! [generator]
61//! spec_path = "openapi.json"       # Path to OpenAPI spec
62//! output_dir = "src/generated"     # Output directory
63//! module_name = "api"              # Module name
64//! ```
65//!
66//! ## Features Section (Optional)
67//!
68//! ```toml
69//! [features]
70//! enable_sse_client = true         # Generate SSE streaming client
71//! enable_async_client = true       # Generate HTTP REST client
72//! enable_specta = false            # Add specta::Type derives
73//! ```
74//!
75//! ## HTTP Client Section (Optional)
76//!
77//! ```toml
78//! [http_client]
79//! base_url = "https://api.example.com"
80//! timeout_seconds = 30
81//!
82//! [http_client.retry]
83//! max_retries = 3                  # 0-10 retries
84//! initial_delay_ms = 500           # 100-10000ms
85//! max_delay_ms = 16000             # 1000-300000ms
86//!
87//! [http_client.tracing]
88//! enabled = true                   # Enable request tracing (default: true)
89//!
90//! [http_client.auth]
91//! type = "Bearer"                  # Bearer, ApiKey, or Custom
92//! header_name = "Authorization"
93//!
94//! [[http_client.headers]]
95//! name = "content-type"
96//! value = "application/json"
97//! ```
98//!
99//! # Validation
100//!
101//! The configuration is validated on load using the `validator` crate:
102//! - File paths are checked for existence
103//! - Numeric ranges are enforced (timeout, retry counts, delays)
104//! - Enum values are validated (auth types, event flow types)
105//! - Required fields are checked
106//!
107//! Invalid configurations produce helpful error messages:
108//!
109//! ```text
110//! Configuration validation failed:
111//!   - generator.spec_path: OpenAPI spec file not found: missing.json
112//!   - http_client.retry.max_retries: max_retries must be between 0 and 10
113//! ```
114//!
115//! # Examples
116//!
117//! See the [examples](https://github.com/your-repo/examples) directory for complete examples:
118//! - `toml_config_example.rs` - Various configuration patterns
119//! - `complete_workflow.rs` - Full generation workflow with TOML
120//!
121//! # Backward Compatibility
122//!
123//! The TOML configuration is fully optional. The existing Rust API continues to work:
124//!
125//! ```no_run
126//! use openapi_to_rust::{GeneratorConfig, CodeGenerator};
127//! use std::path::PathBuf;
128//!
129//! let config = GeneratorConfig {
130//!     spec_path: PathBuf::from("openapi.json"),
131//!     enable_async_client: true,
132//!     // ... other fields
133//!     ..Default::default()
134//! };
135//!
136//! let generator = CodeGenerator::new(config);
137//! // ... generate code
138//! ```
139
140use crate::{GeneratorError, generator::GeneratorConfig};
141use serde::{Deserialize, Serialize};
142use std::collections::BTreeMap;
143use std::path::{Path, PathBuf};
144use validator::Validate;
145
146/// Root configuration loaded from TOML file
147#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
148pub struct ConfigFile {
149    #[validate(nested)]
150    pub generator: GeneratorSection,
151    #[validate(nested)]
152    pub features: FeaturesSection,
153    #[serde(default)]
154    #[validate(nested)]
155    pub http_client: Option<HttpClientSection>,
156    #[serde(default)]
157    #[validate(nested)]
158    pub streaming: Option<StreamingSection>,
159    #[serde(default)]
160    pub nullable_overrides: BTreeMap<String, bool>,
161    #[serde(default)]
162    pub type_mappings: BTreeMap<String, String>,
163}
164
165#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
166pub struct GeneratorSection {
167    #[validate(custom(function = "validate_spec_path_exists"))]
168    pub spec_path: PathBuf,
169    pub output_dir: PathBuf,
170    #[validate(length(min = 1, message = "module_name cannot be empty"))]
171    pub module_name: String,
172    /// Schema extension files to merge into the main spec before codegen.
173    /// Paths are relative to the working directory (same as spec_path).
174    #[serde(default)]
175    pub schema_extensions: Vec<PathBuf>,
176}
177
178#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
179pub struct FeaturesSection {
180    #[serde(default)]
181    pub enable_sse_client: bool,
182    #[serde(default)]
183    pub enable_async_client: bool,
184    #[serde(default)]
185    pub enable_specta: bool,
186    /// Generate a static operation registry with metadata for CLI/proxy routing
187    #[serde(default)]
188    pub enable_registry: bool,
189    /// Generate only the operation registry (skip types, client, streaming)
190    #[serde(default)]
191    pub registry_only: bool,
192}
193
194#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
195pub struct HttpClientSection {
196    #[validate(url(message = "base_url must be a valid URL"))]
197    pub base_url: Option<String>,
198    #[validate(custom(function = "validate_timeout_seconds"))]
199    pub timeout_seconds: Option<u64>,
200    #[validate(nested)]
201    pub auth: Option<AuthConfigSection>,
202    #[serde(default)]
203    #[validate(nested)]
204    pub headers: Vec<HeaderEntry>,
205    #[validate(nested)]
206    pub retry: Option<RetryConfigSection>,
207    #[validate(nested)]
208    pub tracing: Option<TracingConfigSection>,
209}
210
211#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
212pub struct TracingConfigSection {
213    #[serde(default = "default_tracing_enabled")]
214    pub enabled: bool,
215}
216
217fn default_tracing_enabled() -> bool {
218    true
219}
220
221#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
222pub struct RetryConfigSection {
223    #[serde(default = "default_max_retries")]
224    #[validate(custom(function = "validate_max_retries"))]
225    pub max_retries: u32,
226    #[serde(default = "default_initial_delay_ms")]
227    #[validate(custom(function = "validate_initial_delay_ms"))]
228    pub initial_delay_ms: u64,
229    #[serde(default = "default_max_delay_ms")]
230    #[validate(custom(function = "validate_max_delay_ms"))]
231    pub max_delay_ms: u64,
232}
233
234fn default_max_retries() -> u32 {
235    3
236}
237fn default_initial_delay_ms() -> u64 {
238    500
239}
240fn default_max_delay_ms() -> u64 {
241    16000
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
245pub struct AuthConfigSection {
246    #[serde(rename = "type")]
247    #[validate(custom(function = "validate_auth_type"))]
248    pub auth_type: String,
249    #[validate(length(min = 1, message = "header_name cannot be empty"))]
250    pub header_name: String,
251}
252
253#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
254pub struct HeaderEntry {
255    #[validate(length(min = 1, message = "header name cannot be empty"))]
256    pub name: String,
257    pub value: String,
258}
259
260#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
261pub struct StreamingSection {
262    #[validate(nested)]
263    pub endpoints: Vec<StreamingEndpointSection>,
264}
265
266#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
267pub struct StreamingEndpointSection {
268    #[validate(length(min = 1))]
269    pub operation_id: String,
270    #[validate(length(min = 1))]
271    pub path: String,
272    /// HTTP method: "GET" or "POST" (default: POST)
273    #[serde(default)]
274    pub http_method: Option<String>,
275    /// Parameter name that controls streaming (only for POST requests)
276    #[serde(default)]
277    pub stream_parameter: String,
278    /// Query parameters for GET requests
279    #[serde(default)]
280    pub query_parameters: Vec<QueryParameterSection>,
281    #[validate(length(min = 1))]
282    pub event_union_type: String,
283    pub content_type: Option<String>,
284    #[validate(nested)]
285    pub event_flow: Option<EventFlowSection>,
286}
287
288#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
289pub struct QueryParameterSection {
290    #[validate(length(min = 1))]
291    pub name: String,
292    #[serde(default)]
293    pub required: bool,
294}
295
296#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
297pub struct EventFlowSection {
298    #[serde(rename = "type")]
299    #[validate(custom(function = "validate_event_flow_type"))]
300    pub flow_type: String,
301    pub start_events: Option<Vec<String>>,
302    pub delta_events: Option<Vec<String>>,
303    pub stop_events: Option<Vec<String>>,
304}
305
306// Custom validators
307fn validate_spec_path_exists(path: &Path) -> Result<(), validator::ValidationError> {
308    if !path.exists() {
309        let mut error = validator::ValidationError::new("file_not_found");
310        error.message = Some(
311            format!(
312                "OpenAPI spec file not found: {}. Ensure spec_path points to a valid OpenAPI JSON or YAML file.",
313                path.display()
314            )
315            .into(),
316        );
317        return Err(error);
318    }
319    Ok(())
320}
321
322fn validate_auth_type(auth_type: &str) -> Result<(), validator::ValidationError> {
323    match auth_type {
324        "Bearer" | "ApiKey" | "Custom" => Ok(()),
325        _ => {
326            let mut error = validator::ValidationError::new("invalid_auth_type");
327            error.message = Some(
328                format!(
329                    "Invalid auth type '{}'. Must be one of: Bearer, ApiKey, Custom",
330                    auth_type
331                )
332                .into(),
333            );
334            Err(error)
335        }
336    }
337}
338
339fn validate_event_flow_type(flow_type: &str) -> Result<(), validator::ValidationError> {
340    match flow_type {
341        "StartDeltaStop" | "Continuous" => Ok(()),
342        _ => {
343            let mut error = validator::ValidationError::new("invalid_event_flow_type");
344            error.message = Some(
345                format!(
346                    "Invalid event flow type '{}'. Must be one of: StartDeltaStop, Continuous",
347                    flow_type
348                )
349                .into(),
350            );
351            Err(error)
352        }
353    }
354}
355
356fn validate_timeout_seconds(timeout: u64) -> Result<(), validator::ValidationError> {
357    if !(1..=3600).contains(&timeout) {
358        let mut error = validator::ValidationError::new("out_of_range");
359        error.message = Some("timeout_seconds must be between 1 and 3600".into());
360        return Err(error);
361    }
362    Ok(())
363}
364
365fn validate_max_retries(retries: u32) -> Result<(), validator::ValidationError> {
366    if retries > 10 {
367        let mut error = validator::ValidationError::new("out_of_range");
368        error.message = Some("max_retries must be between 0 and 10".into());
369        return Err(error);
370    }
371    Ok(())
372}
373
374fn validate_initial_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
375    if !(100..=10000).contains(&delay) {
376        let mut error = validator::ValidationError::new("out_of_range");
377        error.message = Some("initial_delay_ms must be between 100 and 10000".into());
378        return Err(error);
379    }
380    Ok(())
381}
382
383fn validate_max_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
384    if !(1000..=300000).contains(&delay) {
385        let mut error = validator::ValidationError::new("out_of_range");
386        error.message = Some("max_delay_ms must be between 1000 and 300000".into());
387        return Err(error);
388    }
389    Ok(())
390}
391
392impl ConfigFile {
393    /// Load and validate configuration from TOML file
394    pub fn load(path: &Path) -> Result<Self, GeneratorError> {
395        let content = std::fs::read_to_string(path).map_err(|e| GeneratorError::FileError {
396            message: format!("Failed to read config file '{}': {}", path.display(), e),
397        })?;
398
399        let config: ConfigFile =
400            toml::from_str(&content).map_err(|e| GeneratorError::FileError {
401                message: format!(
402                    "Failed to parse TOML config: {}\n\nExample config:\n{}",
403                    e, EXAMPLE_CONFIG
404                ),
405            })?;
406
407        // Validate the configuration
408        config.validate().map_err(|e| {
409            GeneratorError::ValidationError(format!(
410                "Configuration validation failed:\n{}",
411                format_validation_errors(&e)
412            ))
413        })?;
414
415        Ok(config)
416    }
417
418    /// Convert to internal GeneratorConfig
419    pub fn into_generator_config(self) -> GeneratorConfig {
420        use crate::http_config::{AuthConfig, HttpClientConfig, RetryConfig};
421
422        // Convert HTTP client config
423        let http_client_config = self.http_client.as_ref().map(|http| HttpClientConfig {
424            base_url: http.base_url.clone(),
425            timeout_seconds: http.timeout_seconds,
426            default_headers: http
427                .headers
428                .iter()
429                .map(|h| (h.name.clone(), h.value.clone()))
430                .collect(),
431        });
432
433        // Convert retry config
434        let retry_config = self
435            .http_client
436            .as_ref()
437            .and_then(|http| http.retry.as_ref())
438            .map(|retry| RetryConfig {
439                max_retries: retry.max_retries,
440                initial_delay_ms: retry.initial_delay_ms,
441                max_delay_ms: retry.max_delay_ms,
442            });
443
444        // Convert tracing config
445        let tracing_enabled = self
446            .http_client
447            .as_ref()
448            .and_then(|http| http.tracing.as_ref())
449            .map(|tracing| tracing.enabled)
450            .unwrap_or(true);
451
452        // Convert auth config
453        let auth_config = self
454            .http_client
455            .as_ref()
456            .and_then(|http| http.auth.as_ref())
457            .map(|auth| match auth.auth_type.as_str() {
458                "Bearer" => AuthConfig::Bearer {
459                    header_name: auth.header_name.clone(),
460                },
461                "ApiKey" => AuthConfig::ApiKey {
462                    header_name: auth.header_name.clone(),
463                },
464                "Custom" => AuthConfig::Custom {
465                    header_name: auth.header_name.clone(),
466                    header_value_prefix: None,
467                },
468                _ => AuthConfig::Bearer {
469                    header_name: "Authorization".to_string(),
470                },
471            });
472
473        // Convert streaming section to StreamingConfig
474        let streaming_config = self.streaming.map(|section| {
475            use crate::streaming::{
476                EventFlow, HttpMethod, QueryParameter, StreamingConfig, StreamingEndpoint,
477            };
478
479            let endpoints = section
480                .endpoints
481                .into_iter()
482                .map(|e| {
483                    let event_flow = e
484                        .event_flow
485                        .map(|ef| match ef.flow_type.as_str() {
486                            "start_delta_stop" => EventFlow::StartDeltaStop {
487                                start_events: ef.start_events.unwrap_or_default(),
488                                delta_events: ef.delta_events.unwrap_or_default(),
489                                stop_events: ef.stop_events.unwrap_or_default(),
490                            },
491                            _ => EventFlow::Simple,
492                        })
493                        .unwrap_or(EventFlow::Simple);
494
495                    let http_method = e
496                        .http_method
497                        .map(|m| match m.to_uppercase().as_str() {
498                            "GET" => HttpMethod::Get,
499                            _ => HttpMethod::Post,
500                        })
501                        .unwrap_or(HttpMethod::Post);
502
503                    let query_parameters = e
504                        .query_parameters
505                        .into_iter()
506                        .map(|qp| QueryParameter {
507                            name: qp.name,
508                            required: qp.required,
509                        })
510                        .collect();
511
512                    StreamingEndpoint {
513                        operation_id: e.operation_id,
514                        path: e.path,
515                        http_method,
516                        stream_parameter: e.stream_parameter,
517                        query_parameters,
518                        event_union_type: e.event_union_type,
519                        content_type: e.content_type,
520                        event_flow,
521                        ..Default::default()
522                    }
523                })
524                .collect();
525
526            StreamingConfig {
527                endpoints,
528                ..Default::default()
529            }
530        });
531
532        GeneratorConfig {
533            spec_path: self.generator.spec_path,
534            output_dir: self.generator.output_dir,
535            module_name: self.generator.module_name,
536            enable_sse_client: self.features.enable_sse_client,
537            enable_async_client: self.features.enable_async_client,
538            enable_specta: self.features.enable_specta,
539            type_mappings: if self.type_mappings.is_empty() {
540                super::generator::default_type_mappings()
541            } else {
542                self.type_mappings
543            },
544            streaming_config,
545            nullable_field_overrides: self.nullable_overrides,
546            schema_extensions: self.generator.schema_extensions,
547            http_client_config,
548            retry_config,
549            tracing_enabled,
550            auth_config,
551            enable_registry: self.features.enable_registry,
552            registry_only: self.features.registry_only,
553        }
554    }
555}
556
557/// Format validator errors into a readable message
558fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
559    let mut messages = Vec::new();
560
561    // Handle direct field errors
562    for (field, field_errors) in errors.field_errors() {
563        for error in field_errors {
564            let msg = if let Some(message) = &error.message {
565                format!("  - {}: {}", field, message)
566            } else {
567                format!("  - {}: validation failed (code: {})", field, error.code)
568            };
569            messages.push(msg);
570        }
571    }
572
573    // Handle nested errors
574    for (field, nested_errors) in errors.errors() {
575        if let validator::ValidationErrorsKind::Struct(struct_errors) = nested_errors {
576            let nested_msgs = format_validation_errors(struct_errors);
577            if !nested_msgs.is_empty() {
578                messages.push(format!("  - {} (nested):\n{}", field, nested_msgs));
579            }
580        }
581    }
582
583    messages.join("\n")
584}
585
586const EXAMPLE_CONFIG: &str = r#"[generator]
587spec_path = "openapi.json"
588output_dir = "src/generated"
589module_name = "types"
590
591[features]
592enable_async_client = true
593
594[http_client]
595base_url = "https://api.example.com"
596timeout_seconds = 30
597
598[http_client.retry]
599max_retries = 3
600
601[http_client.auth]
602type = "Bearer"
603header_name = "Authorization""#;