1use crate::{GeneratorError, generator::GeneratorConfig};
141use serde::{Deserialize, Serialize};
142use std::collections::BTreeMap;
143use std::path::{Path, PathBuf};
144use validator::Validate;
145
146#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
148pub struct ConfigFile {
149 #[validate(nested)]
150 pub generator: GeneratorSection,
151 #[validate(nested)]
152 pub features: FeaturesSection,
153 #[serde(default)]
154 #[validate(nested)]
155 pub http_client: Option<HttpClientSection>,
156 #[serde(default)]
157 #[validate(nested)]
158 pub streaming: Option<StreamingSection>,
159 #[serde(default)]
160 pub nullable_overrides: BTreeMap<String, bool>,
161 #[serde(default)]
162 pub type_mappings: BTreeMap<String, String>,
163}
164
165#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
166pub struct GeneratorSection {
167 #[validate(custom(function = "validate_spec_path_exists"))]
168 pub spec_path: PathBuf,
169 pub output_dir: PathBuf,
170 #[validate(length(min = 1, message = "module_name cannot be empty"))]
171 pub module_name: String,
172 #[serde(default)]
175 pub schema_extensions: Vec<PathBuf>,
176}
177
178#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
179pub struct FeaturesSection {
180 #[serde(default)]
181 pub enable_sse_client: bool,
182 #[serde(default)]
183 pub enable_async_client: bool,
184 #[serde(default)]
185 pub enable_specta: bool,
186 #[serde(default)]
188 pub enable_registry: bool,
189 #[serde(default)]
191 pub registry_only: bool,
192}
193
194#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
195pub struct HttpClientSection {
196 #[validate(url(message = "base_url must be a valid URL"))]
197 pub base_url: Option<String>,
198 #[validate(custom(function = "validate_timeout_seconds"))]
199 pub timeout_seconds: Option<u64>,
200 #[validate(nested)]
201 pub auth: Option<AuthConfigSection>,
202 #[serde(default)]
203 #[validate(nested)]
204 pub headers: Vec<HeaderEntry>,
205 #[validate(nested)]
206 pub retry: Option<RetryConfigSection>,
207 #[validate(nested)]
208 pub tracing: Option<TracingConfigSection>,
209}
210
211#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
212pub struct TracingConfigSection {
213 #[serde(default = "default_tracing_enabled")]
214 pub enabled: bool,
215}
216
217fn default_tracing_enabled() -> bool {
218 true
219}
220
221#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
222pub struct RetryConfigSection {
223 #[serde(default = "default_max_retries")]
224 #[validate(custom(function = "validate_max_retries"))]
225 pub max_retries: u32,
226 #[serde(default = "default_initial_delay_ms")]
227 #[validate(custom(function = "validate_initial_delay_ms"))]
228 pub initial_delay_ms: u64,
229 #[serde(default = "default_max_delay_ms")]
230 #[validate(custom(function = "validate_max_delay_ms"))]
231 pub max_delay_ms: u64,
232}
233
234fn default_max_retries() -> u32 {
235 3
236}
237fn default_initial_delay_ms() -> u64 {
238 500
239}
240fn default_max_delay_ms() -> u64 {
241 16000
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
245pub struct AuthConfigSection {
246 #[serde(rename = "type")]
247 #[validate(custom(function = "validate_auth_type"))]
248 pub auth_type: String,
249 #[validate(length(min = 1, message = "header_name cannot be empty"))]
250 pub header_name: String,
251}
252
253#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
254pub struct HeaderEntry {
255 #[validate(length(min = 1, message = "header name cannot be empty"))]
256 pub name: String,
257 pub value: String,
258}
259
260#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
261pub struct StreamingSection {
262 #[validate(nested)]
263 pub endpoints: Vec<StreamingEndpointSection>,
264}
265
266#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
267pub struct StreamingEndpointSection {
268 #[validate(length(min = 1))]
269 pub operation_id: String,
270 #[validate(length(min = 1))]
271 pub path: String,
272 #[serde(default)]
274 pub http_method: Option<String>,
275 #[serde(default)]
277 pub stream_parameter: String,
278 #[serde(default)]
280 pub query_parameters: Vec<QueryParameterSection>,
281 #[validate(length(min = 1))]
282 pub event_union_type: String,
283 pub content_type: Option<String>,
284 #[validate(nested)]
285 pub event_flow: Option<EventFlowSection>,
286}
287
288#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
289pub struct QueryParameterSection {
290 #[validate(length(min = 1))]
291 pub name: String,
292 #[serde(default)]
293 pub required: bool,
294}
295
296#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
297pub struct EventFlowSection {
298 #[serde(rename = "type")]
299 #[validate(custom(function = "validate_event_flow_type"))]
300 pub flow_type: String,
301 pub start_events: Option<Vec<String>>,
302 pub delta_events: Option<Vec<String>>,
303 pub stop_events: Option<Vec<String>>,
304}
305
306fn validate_spec_path_exists(path: &Path) -> Result<(), validator::ValidationError> {
308 if !path.exists() {
309 let mut error = validator::ValidationError::new("file_not_found");
310 error.message = Some(
311 format!(
312 "OpenAPI spec file not found: {}. Ensure spec_path points to a valid OpenAPI JSON or YAML file.",
313 path.display()
314 )
315 .into(),
316 );
317 return Err(error);
318 }
319 Ok(())
320}
321
322fn validate_auth_type(auth_type: &str) -> Result<(), validator::ValidationError> {
323 match auth_type {
324 "Bearer" | "ApiKey" | "Custom" => Ok(()),
325 _ => {
326 let mut error = validator::ValidationError::new("invalid_auth_type");
327 error.message = Some(
328 format!(
329 "Invalid auth type '{}'. Must be one of: Bearer, ApiKey, Custom",
330 auth_type
331 )
332 .into(),
333 );
334 Err(error)
335 }
336 }
337}
338
339fn validate_event_flow_type(flow_type: &str) -> Result<(), validator::ValidationError> {
340 match flow_type {
341 "StartDeltaStop" | "Continuous" => Ok(()),
342 _ => {
343 let mut error = validator::ValidationError::new("invalid_event_flow_type");
344 error.message = Some(
345 format!(
346 "Invalid event flow type '{}'. Must be one of: StartDeltaStop, Continuous",
347 flow_type
348 )
349 .into(),
350 );
351 Err(error)
352 }
353 }
354}
355
356fn validate_timeout_seconds(timeout: u64) -> Result<(), validator::ValidationError> {
357 if !(1..=3600).contains(&timeout) {
358 let mut error = validator::ValidationError::new("out_of_range");
359 error.message = Some("timeout_seconds must be between 1 and 3600".into());
360 return Err(error);
361 }
362 Ok(())
363}
364
365fn validate_max_retries(retries: u32) -> Result<(), validator::ValidationError> {
366 if retries > 10 {
367 let mut error = validator::ValidationError::new("out_of_range");
368 error.message = Some("max_retries must be between 0 and 10".into());
369 return Err(error);
370 }
371 Ok(())
372}
373
374fn validate_initial_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
375 if !(100..=10000).contains(&delay) {
376 let mut error = validator::ValidationError::new("out_of_range");
377 error.message = Some("initial_delay_ms must be between 100 and 10000".into());
378 return Err(error);
379 }
380 Ok(())
381}
382
383fn validate_max_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
384 if !(1000..=300000).contains(&delay) {
385 let mut error = validator::ValidationError::new("out_of_range");
386 error.message = Some("max_delay_ms must be between 1000 and 300000".into());
387 return Err(error);
388 }
389 Ok(())
390}
391
392impl ConfigFile {
393 pub fn load(path: &Path) -> Result<Self, GeneratorError> {
395 let content = std::fs::read_to_string(path).map_err(|e| GeneratorError::FileError {
396 message: format!("Failed to read config file '{}': {}", path.display(), e),
397 })?;
398
399 let config: ConfigFile =
400 toml::from_str(&content).map_err(|e| GeneratorError::FileError {
401 message: format!(
402 "Failed to parse TOML config: {}\n\nExample config:\n{}",
403 e, EXAMPLE_CONFIG
404 ),
405 })?;
406
407 config.validate().map_err(|e| {
409 GeneratorError::ValidationError(format!(
410 "Configuration validation failed:\n{}",
411 format_validation_errors(&e)
412 ))
413 })?;
414
415 Ok(config)
416 }
417
418 pub fn into_generator_config(self) -> GeneratorConfig {
420 use crate::http_config::{AuthConfig, HttpClientConfig, RetryConfig};
421
422 let http_client_config = self.http_client.as_ref().map(|http| HttpClientConfig {
424 base_url: http.base_url.clone(),
425 timeout_seconds: http.timeout_seconds,
426 default_headers: http
427 .headers
428 .iter()
429 .map(|h| (h.name.clone(), h.value.clone()))
430 .collect(),
431 });
432
433 let retry_config = self
435 .http_client
436 .as_ref()
437 .and_then(|http| http.retry.as_ref())
438 .map(|retry| RetryConfig {
439 max_retries: retry.max_retries,
440 initial_delay_ms: retry.initial_delay_ms,
441 max_delay_ms: retry.max_delay_ms,
442 });
443
444 let tracing_enabled = self
446 .http_client
447 .as_ref()
448 .and_then(|http| http.tracing.as_ref())
449 .map(|tracing| tracing.enabled)
450 .unwrap_or(true);
451
452 let auth_config = self
454 .http_client
455 .as_ref()
456 .and_then(|http| http.auth.as_ref())
457 .map(|auth| match auth.auth_type.as_str() {
458 "Bearer" => AuthConfig::Bearer {
459 header_name: auth.header_name.clone(),
460 },
461 "ApiKey" => AuthConfig::ApiKey {
462 header_name: auth.header_name.clone(),
463 },
464 "Custom" => AuthConfig::Custom {
465 header_name: auth.header_name.clone(),
466 header_value_prefix: None,
467 },
468 _ => AuthConfig::Bearer {
469 header_name: "Authorization".to_string(),
470 },
471 });
472
473 let streaming_config = self.streaming.map(|section| {
475 use crate::streaming::{
476 EventFlow, HttpMethod, QueryParameter, StreamingConfig, StreamingEndpoint,
477 };
478
479 let endpoints = section
480 .endpoints
481 .into_iter()
482 .map(|e| {
483 let event_flow = e
484 .event_flow
485 .map(|ef| match ef.flow_type.as_str() {
486 "start_delta_stop" => EventFlow::StartDeltaStop {
487 start_events: ef.start_events.unwrap_or_default(),
488 delta_events: ef.delta_events.unwrap_or_default(),
489 stop_events: ef.stop_events.unwrap_or_default(),
490 },
491 _ => EventFlow::Simple,
492 })
493 .unwrap_or(EventFlow::Simple);
494
495 let http_method = e
496 .http_method
497 .map(|m| match m.to_uppercase().as_str() {
498 "GET" => HttpMethod::Get,
499 _ => HttpMethod::Post,
500 })
501 .unwrap_or(HttpMethod::Post);
502
503 let query_parameters = e
504 .query_parameters
505 .into_iter()
506 .map(|qp| QueryParameter {
507 name: qp.name,
508 required: qp.required,
509 })
510 .collect();
511
512 StreamingEndpoint {
513 operation_id: e.operation_id,
514 path: e.path,
515 http_method,
516 stream_parameter: e.stream_parameter,
517 query_parameters,
518 event_union_type: e.event_union_type,
519 content_type: e.content_type,
520 event_flow,
521 ..Default::default()
522 }
523 })
524 .collect();
525
526 StreamingConfig {
527 endpoints,
528 ..Default::default()
529 }
530 });
531
532 GeneratorConfig {
533 spec_path: self.generator.spec_path,
534 output_dir: self.generator.output_dir,
535 module_name: self.generator.module_name,
536 enable_sse_client: self.features.enable_sse_client,
537 enable_async_client: self.features.enable_async_client,
538 enable_specta: self.features.enable_specta,
539 type_mappings: if self.type_mappings.is_empty() {
540 super::generator::default_type_mappings()
541 } else {
542 self.type_mappings
543 },
544 streaming_config,
545 nullable_field_overrides: self.nullable_overrides,
546 schema_extensions: self.generator.schema_extensions,
547 http_client_config,
548 retry_config,
549 tracing_enabled,
550 auth_config,
551 enable_registry: self.features.enable_registry,
552 registry_only: self.features.registry_only,
553 }
554 }
555}
556
557fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
559 let mut messages = Vec::new();
560
561 for (field, field_errors) in errors.field_errors() {
563 for error in field_errors {
564 let msg = if let Some(message) = &error.message {
565 format!(" - {}: {}", field, message)
566 } else {
567 format!(" - {}: validation failed (code: {})", field, error.code)
568 };
569 messages.push(msg);
570 }
571 }
572
573 for (field, nested_errors) in errors.errors() {
575 if let validator::ValidationErrorsKind::Struct(struct_errors) = nested_errors {
576 let nested_msgs = format_validation_errors(struct_errors);
577 if !nested_msgs.is_empty() {
578 messages.push(format!(" - {} (nested):\n{}", field, nested_msgs));
579 }
580 }
581 }
582
583 messages.join("\n")
584}
585
586const EXAMPLE_CONFIG: &str = r#"[generator]
587spec_path = "openapi.json"
588output_dir = "src/generated"
589module_name = "types"
590
591[features]
592enable_async_client = true
593
594[http_client]
595base_url = "https://api.example.com"
596timeout_seconds = 30
597
598[http_client.retry]
599max_retries = 3
600
601[http_client.auth]
602type = "Bearer"
603header_name = "Authorization""#;