1use crate::{GeneratorError, generator::GeneratorConfig};
141use serde::{Deserialize, Serialize};
142use std::collections::BTreeMap;
143use std::path::{Path, PathBuf};
144use validator::Validate;
145
146#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
148pub struct ConfigFile {
149 #[validate(nested)]
150 pub generator: GeneratorSection,
151 #[validate(nested)]
152 pub features: FeaturesSection,
153 #[serde(default)]
154 #[validate(nested)]
155 pub http_client: Option<HttpClientSection>,
156 #[serde(default)]
157 #[validate(nested)]
158 pub streaming: Option<StreamingSection>,
159 #[serde(default)]
160 pub nullable_overrides: BTreeMap<String, bool>,
161 #[serde(default)]
162 pub type_mappings: BTreeMap<String, String>,
163}
164
165#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
166pub struct GeneratorSection {
167 #[validate(custom(function = "validate_spec_path_exists"))]
168 pub spec_path: PathBuf,
169 pub output_dir: PathBuf,
170 #[validate(length(min = 1, message = "module_name cannot be empty"))]
171 pub module_name: String,
172}
173
174#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
175pub struct FeaturesSection {
176 #[serde(default)]
177 pub enable_sse_client: bool,
178 #[serde(default)]
179 pub enable_async_client: bool,
180 #[serde(default)]
181 pub enable_specta: bool,
182}
183
184#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
185pub struct HttpClientSection {
186 #[validate(url(message = "base_url must be a valid URL"))]
187 pub base_url: Option<String>,
188 #[validate(custom(function = "validate_timeout_seconds"))]
189 pub timeout_seconds: Option<u64>,
190 #[validate(nested)]
191 pub auth: Option<AuthConfigSection>,
192 #[serde(default)]
193 #[validate(nested)]
194 pub headers: Vec<HeaderEntry>,
195 #[validate(nested)]
196 pub retry: Option<RetryConfigSection>,
197 #[validate(nested)]
198 pub tracing: Option<TracingConfigSection>,
199}
200
201#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
202pub struct TracingConfigSection {
203 #[serde(default = "default_tracing_enabled")]
204 pub enabled: bool,
205}
206
207fn default_tracing_enabled() -> bool {
208 true
209}
210
211#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
212pub struct RetryConfigSection {
213 #[serde(default = "default_max_retries")]
214 #[validate(custom(function = "validate_max_retries"))]
215 pub max_retries: u32,
216 #[serde(default = "default_initial_delay_ms")]
217 #[validate(custom(function = "validate_initial_delay_ms"))]
218 pub initial_delay_ms: u64,
219 #[serde(default = "default_max_delay_ms")]
220 #[validate(custom(function = "validate_max_delay_ms"))]
221 pub max_delay_ms: u64,
222}
223
224fn default_max_retries() -> u32 {
225 3
226}
227fn default_initial_delay_ms() -> u64 {
228 500
229}
230fn default_max_delay_ms() -> u64 {
231 16000
232}
233
234#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
235pub struct AuthConfigSection {
236 #[serde(rename = "type")]
237 #[validate(custom(function = "validate_auth_type"))]
238 pub auth_type: String,
239 #[validate(length(min = 1, message = "header_name cannot be empty"))]
240 pub header_name: String,
241}
242
243#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
244pub struct HeaderEntry {
245 #[validate(length(min = 1, message = "header name cannot be empty"))]
246 pub name: String,
247 pub value: String,
248}
249
250#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
251pub struct StreamingSection {
252 #[validate(nested)]
253 pub endpoints: Vec<StreamingEndpointSection>,
254}
255
256#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
257pub struct StreamingEndpointSection {
258 #[validate(length(min = 1))]
259 pub operation_id: String,
260 #[validate(length(min = 1))]
261 pub path: String,
262 #[serde(default)]
264 pub http_method: Option<String>,
265 #[serde(default)]
267 pub stream_parameter: String,
268 #[serde(default)]
270 pub query_parameters: Vec<QueryParameterSection>,
271 #[validate(length(min = 1))]
272 pub event_union_type: String,
273 pub content_type: Option<String>,
274 #[validate(nested)]
275 pub event_flow: Option<EventFlowSection>,
276}
277
278#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
279pub struct QueryParameterSection {
280 #[validate(length(min = 1))]
281 pub name: String,
282 #[serde(default)]
283 pub required: bool,
284}
285
286#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
287pub struct EventFlowSection {
288 #[serde(rename = "type")]
289 #[validate(custom(function = "validate_event_flow_type"))]
290 pub flow_type: String,
291 pub start_events: Option<Vec<String>>,
292 pub delta_events: Option<Vec<String>>,
293 pub stop_events: Option<Vec<String>>,
294}
295
296fn validate_spec_path_exists(path: &Path) -> Result<(), validator::ValidationError> {
298 if !path.exists() {
299 let mut error = validator::ValidationError::new("file_not_found");
300 error.message = Some(
301 format!(
302 "OpenAPI spec file not found: {}. Ensure spec_path points to a valid OpenAPI JSON or YAML file.",
303 path.display()
304 )
305 .into(),
306 );
307 return Err(error);
308 }
309 Ok(())
310}
311
312fn validate_auth_type(auth_type: &str) -> Result<(), validator::ValidationError> {
313 match auth_type {
314 "Bearer" | "ApiKey" | "Custom" => Ok(()),
315 _ => {
316 let mut error = validator::ValidationError::new("invalid_auth_type");
317 error.message = Some(
318 format!(
319 "Invalid auth type '{}'. Must be one of: Bearer, ApiKey, Custom",
320 auth_type
321 )
322 .into(),
323 );
324 Err(error)
325 }
326 }
327}
328
329fn validate_event_flow_type(flow_type: &str) -> Result<(), validator::ValidationError> {
330 match flow_type {
331 "StartDeltaStop" | "Continuous" => Ok(()),
332 _ => {
333 let mut error = validator::ValidationError::new("invalid_event_flow_type");
334 error.message = Some(
335 format!(
336 "Invalid event flow type '{}'. Must be one of: StartDeltaStop, Continuous",
337 flow_type
338 )
339 .into(),
340 );
341 Err(error)
342 }
343 }
344}
345
346fn validate_timeout_seconds(timeout: u64) -> Result<(), validator::ValidationError> {
347 if !(1..=3600).contains(&timeout) {
348 let mut error = validator::ValidationError::new("out_of_range");
349 error.message = Some("timeout_seconds must be between 1 and 3600".into());
350 return Err(error);
351 }
352 Ok(())
353}
354
355fn validate_max_retries(retries: u32) -> Result<(), validator::ValidationError> {
356 if retries > 10 {
357 let mut error = validator::ValidationError::new("out_of_range");
358 error.message = Some("max_retries must be between 0 and 10".into());
359 return Err(error);
360 }
361 Ok(())
362}
363
364fn validate_initial_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
365 if !(100..=10000).contains(&delay) {
366 let mut error = validator::ValidationError::new("out_of_range");
367 error.message = Some("initial_delay_ms must be between 100 and 10000".into());
368 return Err(error);
369 }
370 Ok(())
371}
372
373fn validate_max_delay_ms(delay: u64) -> Result<(), validator::ValidationError> {
374 if !(1000..=300000).contains(&delay) {
375 let mut error = validator::ValidationError::new("out_of_range");
376 error.message = Some("max_delay_ms must be between 1000 and 300000".into());
377 return Err(error);
378 }
379 Ok(())
380}
381
382impl ConfigFile {
383 pub fn load(path: &Path) -> Result<Self, GeneratorError> {
385 let content = std::fs::read_to_string(path).map_err(|e| GeneratorError::FileError {
386 message: format!("Failed to read config file '{}': {}", path.display(), e),
387 })?;
388
389 let config: ConfigFile =
390 toml::from_str(&content).map_err(|e| GeneratorError::FileError {
391 message: format!(
392 "Failed to parse TOML config: {}\n\nExample config:\n{}",
393 e, EXAMPLE_CONFIG
394 ),
395 })?;
396
397 config.validate().map_err(|e| {
399 GeneratorError::ValidationError(format!(
400 "Configuration validation failed:\n{}",
401 format_validation_errors(&e)
402 ))
403 })?;
404
405 Ok(config)
406 }
407
408 pub fn into_generator_config(self) -> GeneratorConfig {
410 use crate::http_config::{AuthConfig, HttpClientConfig, RetryConfig};
411
412 let http_client_config = self.http_client.as_ref().map(|http| HttpClientConfig {
414 base_url: http.base_url.clone(),
415 timeout_seconds: http.timeout_seconds,
416 default_headers: http
417 .headers
418 .iter()
419 .map(|h| (h.name.clone(), h.value.clone()))
420 .collect(),
421 });
422
423 let retry_config = self
425 .http_client
426 .as_ref()
427 .and_then(|http| http.retry.as_ref())
428 .map(|retry| RetryConfig {
429 max_retries: retry.max_retries,
430 initial_delay_ms: retry.initial_delay_ms,
431 max_delay_ms: retry.max_delay_ms,
432 });
433
434 let tracing_enabled = self
436 .http_client
437 .as_ref()
438 .and_then(|http| http.tracing.as_ref())
439 .map(|tracing| tracing.enabled)
440 .unwrap_or(true);
441
442 let auth_config = self
444 .http_client
445 .as_ref()
446 .and_then(|http| http.auth.as_ref())
447 .map(|auth| match auth.auth_type.as_str() {
448 "Bearer" => AuthConfig::Bearer {
449 header_name: auth.header_name.clone(),
450 },
451 "ApiKey" => AuthConfig::ApiKey {
452 header_name: auth.header_name.clone(),
453 },
454 "Custom" => AuthConfig::Custom {
455 header_name: auth.header_name.clone(),
456 header_value_prefix: None,
457 },
458 _ => AuthConfig::Bearer {
459 header_name: "Authorization".to_string(),
460 },
461 });
462
463 let streaming_config = self.streaming.map(|section| {
465 use crate::streaming::{
466 EventFlow, HttpMethod, QueryParameter, StreamingConfig, StreamingEndpoint,
467 };
468
469 let endpoints = section
470 .endpoints
471 .into_iter()
472 .map(|e| {
473 let event_flow = e
474 .event_flow
475 .map(|ef| match ef.flow_type.as_str() {
476 "start_delta_stop" => EventFlow::StartDeltaStop {
477 start_events: ef.start_events.unwrap_or_default(),
478 delta_events: ef.delta_events.unwrap_or_default(),
479 stop_events: ef.stop_events.unwrap_or_default(),
480 },
481 _ => EventFlow::Simple,
482 })
483 .unwrap_or(EventFlow::Simple);
484
485 let http_method = e
486 .http_method
487 .map(|m| match m.to_uppercase().as_str() {
488 "GET" => HttpMethod::Get,
489 _ => HttpMethod::Post,
490 })
491 .unwrap_or(HttpMethod::Post);
492
493 let query_parameters = e
494 .query_parameters
495 .into_iter()
496 .map(|qp| QueryParameter {
497 name: qp.name,
498 required: qp.required,
499 })
500 .collect();
501
502 StreamingEndpoint {
503 operation_id: e.operation_id,
504 path: e.path,
505 http_method,
506 stream_parameter: e.stream_parameter,
507 query_parameters,
508 event_union_type: e.event_union_type,
509 content_type: e.content_type,
510 event_flow,
511 ..Default::default()
512 }
513 })
514 .collect();
515
516 StreamingConfig {
517 endpoints,
518 ..Default::default()
519 }
520 });
521
522 GeneratorConfig {
523 spec_path: self.generator.spec_path,
524 output_dir: self.generator.output_dir,
525 module_name: self.generator.module_name,
526 enable_sse_client: self.features.enable_sse_client,
527 enable_async_client: self.features.enable_async_client,
528 enable_specta: self.features.enable_specta,
529 type_mappings: if self.type_mappings.is_empty() {
530 super::generator::default_type_mappings()
531 } else {
532 self.type_mappings
533 },
534 streaming_config,
535 nullable_field_overrides: self.nullable_overrides,
536 schema_extensions: vec![],
537 http_client_config,
538 retry_config,
539 tracing_enabled,
540 auth_config,
541 }
542 }
543}
544
545fn format_validation_errors(errors: &validator::ValidationErrors) -> String {
547 let mut messages = Vec::new();
548
549 for (field, field_errors) in errors.field_errors() {
551 for error in field_errors {
552 let msg = if let Some(message) = &error.message {
553 format!(" - {}: {}", field, message)
554 } else {
555 format!(" - {}: validation failed (code: {})", field, error.code)
556 };
557 messages.push(msg);
558 }
559 }
560
561 for (field, nested_errors) in errors.errors() {
563 if let validator::ValidationErrorsKind::Struct(struct_errors) = nested_errors {
564 let nested_msgs = format_validation_errors(struct_errors);
565 if !nested_msgs.is_empty() {
566 messages.push(format!(" - {} (nested):\n{}", field, nested_msgs));
567 }
568 }
569 }
570
571 messages.join("\n")
572}
573
574const EXAMPLE_CONFIG: &str = r#"[generator]
575spec_path = "openapi.json"
576output_dir = "src/generated"
577module_name = "types"
578
579[features]
580enable_async_client = true
581
582[http_client]
583base_url = "https://api.example.com"
584timeout_seconds = 30
585
586[http_client.retry]
587max_retries = 3
588
589[http_client.auth]
590type = "Bearer"
591header_name = "Authorization""#;