Skip to main content

datasynth_server/rest/
routes.rs

1//! REST API routes.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7    extract::{State, WebSocketUpgrade},
8    http::{header, Method, StatusCode},
9    response::IntoResponse,
10    routing::{get, post},
11    Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24/// Application state shared across handlers.
25#[derive(Clone)]
26pub struct AppState {
27    pub server_state: Arc<ServerState>,
28    pub job_queue: Option<Arc<JobQueue>>,
29}
30
31/// Timeout configuration for the REST API.
32#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34    /// Request timeout in seconds.
35    pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39    fn default() -> Self {
40        Self {
41            // 5 minutes default - bulk generation can take a while
42            request_timeout_secs: 300,
43        }
44    }
45}
46
47impl TimeoutConfig {
48    /// Create a new timeout config.
49    pub fn new(timeout_secs: u64) -> Self {
50        Self {
51            request_timeout_secs: timeout_secs,
52        }
53    }
54}
55
56/// CORS configuration for the REST API.
57#[derive(Clone)]
58pub struct CorsConfig {
59    /// Allowed origins. If empty, only localhost is allowed.
60    pub allowed_origins: Vec<String>,
61    /// Allow any origin (development mode only - NOT recommended for production).
62    pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66    fn default() -> Self {
67        Self {
68            allowed_origins: vec![
69                "http://localhost:5173".to_string(), // Vite dev server
70                "http://localhost:3000".to_string(), // Common dev server
71                "http://127.0.0.1:5173".to_string(),
72                "http://127.0.0.1:3000".to_string(),
73                "tauri://localhost".to_string(), // Tauri app
74            ],
75            allow_any_origin: false,
76        }
77    }
78}
79
80/// Add API version header to responses.
81async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82    let (mut parts, body) = response.into_parts();
83    parts.headers.insert(
84        axum::http::HeaderName::from_static("x-api-version"),
85        axum::http::HeaderValue::from_static("v1"),
86    );
87    axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97/// Create the REST API router with default CORS settings.
98pub fn create_router(service: SynthService) -> Router {
99    create_router_with_cors(service, CorsConfig::default())
100}
101
102/// Create the REST API router with full configuration (CORS, auth, rate limiting, and timeout).
103///
104/// Uses in-memory rate limiting by default. For distributed rate limiting
105/// with Redis, use [`create_router_full_with_backend`] instead.
106pub fn create_router_full(
107    service: SynthService,
108    cors_config: CorsConfig,
109    auth_config: AuthConfig,
110    rate_limit_config: RateLimitConfig,
111    timeout_config: TimeoutConfig,
112) -> Router {
113    let backend = RateLimitBackend::in_memory(rate_limit_config);
114    create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117/// Create the REST API router with full configuration and a specific rate limiting backend.
118///
119/// This allows using either in-memory or Redis-backed rate limiting.
120///
121/// # Example (in-memory)
122/// ```rust,ignore
123/// let backend = RateLimitBackend::in_memory(rate_limit_config);
124/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
125/// ```
126///
127/// # Example (Redis)
128/// ```rust,ignore
129/// let backend = RateLimitBackend::redis("redis://127.0.0.1:6379", rate_limit_config).await?;
130/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
131/// ```
132pub fn create_router_full_with_backend(
133    service: SynthService,
134    cors_config: CorsConfig,
135    auth_config: AuthConfig,
136    rate_limit_backend: RateLimitBackend,
137    timeout_config: TimeoutConfig,
138) -> Router {
139    let server_state = service.state.clone();
140    let state = AppState {
141        server_state,
142        job_queue: None,
143    };
144
145    let cors = if cors_config.allow_any_origin {
146        CorsLayer::permissive()
147    } else {
148        let origins: Vec<_> = cors_config
149            .allowed_origins
150            .iter()
151            .filter_map(|o| o.parse().ok())
152            .collect();
153
154        CorsLayer::new()
155            .allow_origin(AllowOrigin::list(origins))
156            .allow_methods([
157                Method::GET,
158                Method::POST,
159                Method::PUT,
160                Method::DELETE,
161                Method::OPTIONS,
162            ])
163            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164    };
165
166    Router::new()
167        // Health and metrics (exempt from auth and rate limiting by default)
168        .route("/health", get(health_check))
169        .route("/ready", get(readiness_check))
170        .route("/live", get(liveness_check))
171        .route("/api/metrics", get(get_metrics))
172        .route("/metrics", get(prometheus_metrics))
173        // Configuration
174        .route("/api/config", get(get_config))
175        .route("/api/config", post(set_config))
176        .route("/api/config/reload", post(reload_config))
177        // Generation
178        .route("/api/generate/bulk", post(bulk_generate))
179        .route("/api/stream/start", post(start_stream))
180        .route("/api/stream/stop", post(stop_stream))
181        .route("/api/stream/pause", post(pause_stream))
182        .route("/api/stream/resume", post(resume_stream))
183        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184        .route("/api/stream/ndjson", get(stream_ndjson))
185        // Jobs
186        .route("/api/jobs/submit", post(submit_job))
187        .route("/api/jobs", get(list_jobs))
188        .route("/api/jobs/{id}", get(get_job))
189        .route("/api/jobs/{id}/cancel", post(cancel_job))
190        // Scenario templates (sector-specific DAG catalog).
191        // Expose on both /v1/ (SDK canonical path) and /api/ for consistency.
192        .route("/v1/scenarios/templates", get(list_scenario_templates))
193        .route("/api/scenarios/templates", get(list_scenario_templates))
194        // WebSocket
195        .route("/ws/metrics", get(websocket_metrics))
196        .route("/ws/events", get(websocket_events))
197        // Middleware stack (outermost applied first, innermost last)
198        // Order: Timeout -> RateLimit -> RequestValidation -> Auth -> RequestId -> CORS -> SecurityHeaders -> APIVersion -> Router
199        .layer(axum::middleware::from_fn(security_headers_middleware))
200        .layer(axum::middleware::map_response(api_version_header))
201        .layer(cors)
202        .layer(axum::middleware::from_fn(request_id_middleware))
203        .layer(axum::middleware::from_fn(auth_middleware))
204        .layer(axum::Extension(auth_config))
205        .layer(axum::middleware::from_fn(request_validation_middleware))
206        .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
207        .layer(axum::Extension(rate_limit_backend))
208        .layer(TimeoutLayer::with_status_code(
209            StatusCode::REQUEST_TIMEOUT,
210            Duration::from_secs(timeout_config.request_timeout_secs),
211        ))
212        .with_state(state)
213}
214
215/// Create the REST API router with custom CORS and authentication settings.
216pub fn create_router_with_auth(
217    service: SynthService,
218    cors_config: CorsConfig,
219    auth_config: AuthConfig,
220) -> Router {
221    let server_state = service.state.clone();
222    let state = AppState {
223        server_state,
224        job_queue: None,
225    };
226
227    let cors = if cors_config.allow_any_origin {
228        CorsLayer::permissive()
229    } else {
230        let origins: Vec<_> = cors_config
231            .allowed_origins
232            .iter()
233            .filter_map(|o| o.parse().ok())
234            .collect();
235
236        CorsLayer::new()
237            .allow_origin(AllowOrigin::list(origins))
238            .allow_methods([
239                Method::GET,
240                Method::POST,
241                Method::PUT,
242                Method::DELETE,
243                Method::OPTIONS,
244            ])
245            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
246    };
247
248    Router::new()
249        // Health and metrics (exempt from auth by default)
250        .route("/health", get(health_check))
251        .route("/ready", get(readiness_check))
252        .route("/live", get(liveness_check))
253        .route("/api/metrics", get(get_metrics))
254        .route("/metrics", get(prometheus_metrics))
255        // Configuration
256        .route("/api/config", get(get_config))
257        .route("/api/config", post(set_config))
258        .route("/api/config/reload", post(reload_config))
259        // Generation
260        .route("/api/generate/bulk", post(bulk_generate))
261        .route("/api/stream/start", post(start_stream))
262        .route("/api/stream/stop", post(stop_stream))
263        .route("/api/stream/pause", post(pause_stream))
264        .route("/api/stream/resume", post(resume_stream))
265        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
266        .route("/api/stream/ndjson", get(stream_ndjson))
267        // Jobs
268        .route("/api/jobs/submit", post(submit_job))
269        .route("/api/jobs", get(list_jobs))
270        .route("/api/jobs/{id}", get(get_job))
271        .route("/api/jobs/{id}/cancel", post(cancel_job))
272        // Scenario templates (sector-specific DAG catalog).
273        // Expose on both /v1/ (SDK canonical path) and /api/ for consistency.
274        .route("/v1/scenarios/templates", get(list_scenario_templates))
275        .route("/api/scenarios/templates", get(list_scenario_templates))
276        // WebSocket
277        .route("/ws/metrics", get(websocket_metrics))
278        .route("/ws/events", get(websocket_events))
279        .layer(axum::middleware::from_fn(auth_middleware))
280        .layer(axum::Extension(auth_config))
281        .layer(cors)
282        .with_state(state)
283}
284
285/// Create the REST API router with custom CORS settings.
286pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
287    let server_state = service.state.clone();
288    let state = AppState {
289        server_state,
290        job_queue: None,
291    };
292
293    let cors = if cors_config.allow_any_origin {
294        // Development mode - allow any origin (use with caution)
295        CorsLayer::permissive()
296    } else {
297        // Production mode - restricted origins
298        let origins: Vec<_> = cors_config
299            .allowed_origins
300            .iter()
301            .filter_map(|o| o.parse().ok())
302            .collect();
303
304        CorsLayer::new()
305            .allow_origin(AllowOrigin::list(origins))
306            .allow_methods([
307                Method::GET,
308                Method::POST,
309                Method::PUT,
310                Method::DELETE,
311                Method::OPTIONS,
312            ])
313            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
314    };
315
316    Router::new()
317        // Health and metrics
318        .route("/health", get(health_check))
319        .route("/ready", get(readiness_check))
320        .route("/live", get(liveness_check))
321        .route("/api/metrics", get(get_metrics))
322        .route("/metrics", get(prometheus_metrics))
323        // Configuration
324        .route("/api/config", get(get_config))
325        .route("/api/config", post(set_config))
326        .route("/api/config/reload", post(reload_config))
327        // Generation
328        .route("/api/generate/bulk", post(bulk_generate))
329        .route("/api/stream/start", post(start_stream))
330        .route("/api/stream/stop", post(stop_stream))
331        .route("/api/stream/pause", post(pause_stream))
332        .route("/api/stream/resume", post(resume_stream))
333        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
334        .route("/api/stream/ndjson", get(stream_ndjson))
335        // Jobs
336        .route("/api/jobs/submit", post(submit_job))
337        .route("/api/jobs", get(list_jobs))
338        .route("/api/jobs/{id}", get(get_job))
339        .route("/api/jobs/{id}/cancel", post(cancel_job))
340        // Scenario templates (sector-specific DAG catalog).
341        // Expose on both /v1/ (SDK canonical path) and /api/ for consistency.
342        .route("/v1/scenarios/templates", get(list_scenario_templates))
343        .route("/api/scenarios/templates", get(list_scenario_templates))
344        // WebSocket
345        .route("/ws/metrics", get(websocket_metrics))
346        .route("/ws/events", get(websocket_events))
347        .layer(cors)
348        .with_state(state)
349}
350
351// ===========================================================================
352// Request/Response types
353// ===========================================================================
354
355#[derive(Debug, Serialize, Deserialize)]
356pub struct HealthResponse {
357    pub healthy: bool,
358    pub version: String,
359    pub uptime_seconds: u64,
360}
361
362/// Readiness check response for Kubernetes.
363#[derive(Debug, Serialize, Deserialize)]
364pub struct ReadinessResponse {
365    pub ready: bool,
366    pub message: String,
367    pub checks: Vec<HealthCheck>,
368}
369
370/// Individual health check result.
371#[derive(Debug, Serialize, Deserialize)]
372pub struct HealthCheck {
373    pub name: String,
374    pub status: String,
375}
376
377/// Liveness check response for Kubernetes.
378#[derive(Debug, Serialize, Deserialize)]
379pub struct LivenessResponse {
380    pub alive: bool,
381    pub timestamp: String,
382}
383
384#[derive(Debug, Serialize, Deserialize)]
385pub struct MetricsResponse {
386    pub total_entries_generated: u64,
387    pub total_anomalies_injected: u64,
388    pub uptime_seconds: u64,
389    pub session_entries: u64,
390    pub session_entries_per_second: f64,
391    pub active_streams: u32,
392    pub total_stream_events: u64,
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct ConfigResponse {
397    pub success: bool,
398    pub message: String,
399    pub config: Option<GenerationConfigDto>,
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct GenerationConfigDto {
404    pub industry: String,
405    pub start_date: String,
406    pub period_months: u32,
407    pub seed: Option<u64>,
408    pub coa_complexity: String,
409    pub companies: Vec<CompanyConfigDto>,
410    pub fraud_enabled: bool,
411    pub fraud_rate: f32,
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
415pub struct CompanyConfigDto {
416    pub code: String,
417    pub name: String,
418    pub currency: String,
419    pub country: String,
420    pub annual_transaction_volume: u64,
421    pub volume_weight: f32,
422}
423
424#[derive(Debug, Deserialize)]
425pub struct BulkGenerateRequest {
426    pub entry_count: Option<u64>,
427    pub include_master_data: Option<bool>,
428    pub inject_anomalies: Option<bool>,
429}
430
431#[derive(Debug, Serialize)]
432pub struct BulkGenerateResponse {
433    pub success: bool,
434    pub entries_generated: u64,
435    pub duration_ms: u64,
436    pub anomaly_count: u64,
437}
438
439#[derive(Debug, Deserialize)]
440#[allow(dead_code)] // Fields deserialized from request, reserved for future use
441pub struct StreamRequest {
442    pub events_per_second: Option<u32>,
443    pub max_events: Option<u64>,
444    pub inject_anomalies: Option<bool>,
445}
446
447#[derive(Debug, Serialize)]
448pub struct StreamResponse {
449    pub success: bool,
450    pub message: String,
451}
452
453// ===========================================================================
454// Handlers
455// ===========================================================================
456
457/// Health check endpoint - returns overall health status.
458async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
459    Json(HealthResponse {
460        healthy: true,
461        version: env!("CARGO_PKG_VERSION").to_string(),
462        uptime_seconds: state.server_state.uptime_seconds(),
463    })
464}
465
466/// Readiness probe - indicates the service is ready to accept traffic.
467/// Use for Kubernetes readiness probes.
468async fn readiness_check(
469    State(state): State<AppState>,
470) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
471    let mut checks = Vec::new();
472    let mut any_fail = false;
473
474    // Check if configuration is loaded and valid
475    let config = state.server_state.config.read().await;
476    let config_valid = !config.companies.is_empty();
477    checks.push(HealthCheck {
478        name: "config".to_string(),
479        status: if config_valid { "ok" } else { "fail" }.to_string(),
480    });
481    if !config_valid {
482        any_fail = true;
483    }
484    drop(config);
485
486    // Check resource guard (memory)
487    let resource_status = state.server_state.resource_status();
488    let memory_status = if resource_status.degradation_level == "Emergency" {
489        any_fail = true;
490        "fail"
491    } else if resource_status.degradation_level != "Normal" {
492        "degraded"
493    } else {
494        "ok"
495    };
496    checks.push(HealthCheck {
497        name: "memory".to_string(),
498        status: memory_status.to_string(),
499    });
500
501    // Check disk (>100MB free)
502    let disk_ok = resource_status.disk_available_mb > 100;
503    checks.push(HealthCheck {
504        name: "disk".to_string(),
505        status: if disk_ok { "ok" } else { "fail" }.to_string(),
506    });
507    if !disk_ok {
508        any_fail = true;
509    }
510
511    let response = ReadinessResponse {
512        ready: !any_fail,
513        message: if any_fail {
514            "Service is not ready".to_string()
515        } else {
516            "Service is ready".to_string()
517        },
518        checks,
519    };
520
521    if any_fail {
522        Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
523    } else {
524        Ok(Json(response))
525    }
526}
527
528/// Liveness probe - indicates the service is alive.
529/// Use for Kubernetes liveness probes.
530async fn liveness_check() -> Json<LivenessResponse> {
531    Json(LivenessResponse {
532        alive: true,
533        timestamp: chrono::Utc::now().to_rfc3339(),
534    })
535}
536
537/// Prometheus-compatible metrics endpoint.
538/// Returns metrics in Prometheus text exposition format.
539async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
540    use std::sync::atomic::Ordering;
541
542    let uptime = state.server_state.uptime_seconds();
543    let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
544    let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
545    let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
546    let total_stream_events = state
547        .server_state
548        .total_stream_events
549        .load(Ordering::Relaxed);
550
551    let entries_per_second = if uptime > 0 {
552        total_entries as f64 / uptime as f64
553    } else {
554        0.0
555    };
556
557    let metrics = format!(
558        r#"# HELP synth_entries_generated_total Total number of journal entries generated
559# TYPE synth_entries_generated_total counter
560synth_entries_generated_total {}
561
562# HELP synth_anomalies_injected_total Total number of anomalies injected
563# TYPE synth_anomalies_injected_total counter
564synth_anomalies_injected_total {}
565
566# HELP synth_uptime_seconds Server uptime in seconds
567# TYPE synth_uptime_seconds gauge
568synth_uptime_seconds {}
569
570# HELP synth_entries_per_second Rate of entry generation
571# TYPE synth_entries_per_second gauge
572synth_entries_per_second {:.2}
573
574# HELP synth_active_streams Number of active streaming connections
575# TYPE synth_active_streams gauge
576synth_active_streams {}
577
578# HELP synth_stream_events_total Total events sent through streams
579# TYPE synth_stream_events_total counter
580synth_stream_events_total {}
581
582# HELP synth_info Server version information
583# TYPE synth_info gauge
584synth_info{{version="{}"}} 1
585"#,
586        total_entries,
587        total_anomalies,
588        uptime,
589        entries_per_second,
590        active_streams,
591        total_stream_events,
592        env!("CARGO_PKG_VERSION")
593    );
594
595    (
596        StatusCode::OK,
597        [(
598            header::CONTENT_TYPE,
599            "text/plain; version=0.0.4; charset=utf-8",
600        )],
601        metrics,
602    )
603}
604
605/// Get server metrics.
606async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
607    let uptime = state.server_state.uptime_seconds();
608    let total_entries = state
609        .server_state
610        .total_entries
611        .load(std::sync::atomic::Ordering::Relaxed);
612
613    let entries_per_second = if uptime > 0 {
614        total_entries as f64 / uptime as f64
615    } else {
616        0.0
617    };
618
619    Json(MetricsResponse {
620        total_entries_generated: total_entries,
621        total_anomalies_injected: state
622            .server_state
623            .total_anomalies
624            .load(std::sync::atomic::Ordering::Relaxed),
625        uptime_seconds: uptime,
626        session_entries: total_entries,
627        session_entries_per_second: entries_per_second,
628        active_streams: state
629            .server_state
630            .active_streams
631            .load(std::sync::atomic::Ordering::Relaxed) as u32,
632        total_stream_events: state
633            .server_state
634            .total_stream_events
635            .load(std::sync::atomic::Ordering::Relaxed),
636    })
637}
638
639/// Get current configuration.
640async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
641    let config = state.server_state.config.read().await;
642
643    Json(ConfigResponse {
644        success: true,
645        message: "Current configuration".to_string(),
646        config: Some(GenerationConfigDto {
647            industry: format!("{:?}", config.global.industry),
648            start_date: config.global.start_date.clone(),
649            period_months: config.global.period_months,
650            seed: config.global.seed,
651            coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
652            companies: config
653                .companies
654                .iter()
655                .map(|c| CompanyConfigDto {
656                    code: c.code.clone(),
657                    name: c.name.clone(),
658                    currency: c.currency.clone(),
659                    country: c.country.clone(),
660                    annual_transaction_volume: c.annual_transaction_volume.count(),
661                    volume_weight: c.volume_weight as f32,
662                })
663                .collect(),
664            fraud_enabled: config.fraud.enabled,
665            fraud_rate: config.fraud.fraud_rate as f32,
666        }),
667    })
668}
669
670/// Set configuration.
671async fn set_config(
672    State(state): State<AppState>,
673    Json(new_config): Json<GenerationConfigDto>,
674) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
675    use datasynth_config::schema::{CompanyConfig, TransactionVolume};
676    use datasynth_core::models::{CoAComplexity, IndustrySector};
677
678    info!(
679        "Configuration update requested: industry={}, period_months={}",
680        new_config.industry, new_config.period_months
681    );
682
683    // Parse industry from string
684    let industry = match new_config.industry.to_lowercase().as_str() {
685        "manufacturing" => IndustrySector::Manufacturing,
686        "retail" => IndustrySector::Retail,
687        "financial_services" | "financialservices" => IndustrySector::FinancialServices,
688        "healthcare" => IndustrySector::Healthcare,
689        "technology" => IndustrySector::Technology,
690        "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
691        "energy" => IndustrySector::Energy,
692        "transportation" => IndustrySector::Transportation,
693        "real_estate" | "realestate" => IndustrySector::RealEstate,
694        "telecommunications" => IndustrySector::Telecommunications,
695        _ => {
696            return Err((
697                StatusCode::BAD_REQUEST,
698                Json(ConfigResponse {
699                    success: false,
700                    message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
701                    config: None,
702                }),
703            ));
704        }
705    };
706
707    // Parse CoA complexity from string
708    let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
709        "small" => CoAComplexity::Small,
710        "medium" => CoAComplexity::Medium,
711        "large" => CoAComplexity::Large,
712        _ => {
713            return Err((
714                StatusCode::BAD_REQUEST,
715                Json(ConfigResponse {
716                    success: false,
717                    message: format!(
718                        "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
719                        new_config.coa_complexity
720                    ),
721                    config: None,
722                }),
723            ));
724        }
725    };
726
727    // Convert CompanyConfigDto to CompanyConfig
728    let companies: Vec<CompanyConfig> = new_config
729        .companies
730        .iter()
731        .map(|c| CompanyConfig {
732            code: c.code.clone(),
733            name: c.name.clone(),
734            currency: c.currency.clone(),
735            functional_currency: None,
736            country: c.country.clone(),
737            fiscal_year_variant: "K4".to_string(),
738            annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
739            volume_weight: c.volume_weight as f64,
740        })
741        .collect();
742
743    // Update the configuration
744    let mut config = state.server_state.config.write().await;
745    config.global.industry = industry;
746    config.global.start_date = new_config.start_date.clone();
747    config.global.period_months = new_config.period_months;
748    config.global.seed = new_config.seed;
749    config.chart_of_accounts.complexity = complexity;
750    config.fraud.enabled = new_config.fraud_enabled;
751    config.fraud.fraud_rate = new_config.fraud_rate as f64;
752
753    // Only update companies if provided
754    if !companies.is_empty() {
755        config.companies = companies;
756    }
757
758    info!("Configuration updated successfully");
759
760    Ok(Json(ConfigResponse {
761        success: true,
762        message: "Configuration updated and applied".to_string(),
763        config: Some(new_config),
764    }))
765}
766
767/// Bulk generation endpoint.
768async fn bulk_generate(
769    State(state): State<AppState>,
770    Json(req): Json<BulkGenerateRequest>,
771) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
772    // Validate entry_count bounds
773    const MAX_ENTRY_COUNT: u64 = 1_000_000;
774    if let Some(count) = req.entry_count {
775        if count > MAX_ENTRY_COUNT {
776            return Err((
777                StatusCode::BAD_REQUEST,
778                format!("entry_count ({count}) exceeds maximum allowed value ({MAX_ENTRY_COUNT})"),
779            ));
780        }
781    }
782
783    let config = state.server_state.config.read().await.clone();
784    let start_time = std::time::Instant::now();
785
786    let phase_config = {
787        let mut pc = PhaseConfig::from_config(&config);
788        pc.generate_master_data = req.include_master_data.unwrap_or(false);
789        pc.generate_document_flows = false;
790        pc.generate_journal_entries = true;
791        pc.inject_anomalies = req.inject_anomalies.unwrap_or(false);
792        pc.show_progress = false;
793        pc
794    };
795
796    let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
797        (
798            StatusCode::INTERNAL_SERVER_ERROR,
799            format!("Failed to create orchestrator: {e}"),
800        )
801    })?;
802
803    let result = orchestrator.generate().map_err(|e| {
804        (
805            StatusCode::INTERNAL_SERVER_ERROR,
806            format!("Generation failed: {e}"),
807        )
808    })?;
809
810    let duration_ms = start_time.elapsed().as_millis() as u64;
811    let entries_count = result.journal_entries.len() as u64;
812    let anomaly_count = result.anomaly_labels.labels.len() as u64;
813
814    // Update metrics
815    state
816        .server_state
817        .total_entries
818        .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
819    state
820        .server_state
821        .total_anomalies
822        .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
823
824    Ok(Json(BulkGenerateResponse {
825        success: true,
826        entries_generated: entries_count,
827        duration_ms,
828        anomaly_count,
829    }))
830}
831
832/// Start streaming.
833async fn start_stream(
834    State(state): State<AppState>,
835    Json(req): Json<StreamRequest>,
836) -> Json<StreamResponse> {
837    // Apply stream request parameters to server state
838    if let Some(eps) = req.events_per_second {
839        info!("Stream configured: events_per_second={}", eps);
840        state
841            .server_state
842            .stream_events_per_second
843            .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
844    }
845    if let Some(max) = req.max_events {
846        info!("Stream configured: max_events={}", max);
847        state
848            .server_state
849            .stream_max_events
850            .store(max, std::sync::atomic::Ordering::Relaxed);
851    }
852    if let Some(inject) = req.inject_anomalies {
853        info!("Stream configured: inject_anomalies={}", inject);
854        state
855            .server_state
856            .stream_inject_anomalies
857            .store(inject, std::sync::atomic::Ordering::Relaxed);
858    }
859
860    state
861        .server_state
862        .stream_stopped
863        .store(false, std::sync::atomic::Ordering::Relaxed);
864    state
865        .server_state
866        .stream_paused
867        .store(false, std::sync::atomic::Ordering::Relaxed);
868
869    Json(StreamResponse {
870        success: true,
871        message: "Stream started".to_string(),
872    })
873}
874
875/// Stop streaming.
876async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
877    state
878        .server_state
879        .stream_stopped
880        .store(true, std::sync::atomic::Ordering::Relaxed);
881
882    Json(StreamResponse {
883        success: true,
884        message: "Stream stopped".to_string(),
885    })
886}
887
888/// Pause streaming.
889async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
890    state
891        .server_state
892        .stream_paused
893        .store(true, std::sync::atomic::Ordering::Relaxed);
894
895    Json(StreamResponse {
896        success: true,
897        message: "Stream paused".to_string(),
898    })
899}
900
901/// Resume streaming.
902async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
903    state
904        .server_state
905        .stream_paused
906        .store(false, std::sync::atomic::Ordering::Relaxed);
907
908    Json(StreamResponse {
909        success: true,
910        message: "Stream resumed".to_string(),
911    })
912}
913
914/// Trigger a specific pattern.
915///
916/// Valid patterns: year_end_spike, period_end_spike, holiday_cluster,
917/// fraud_cluster, error_cluster, uniform, or custom:* patterns.
918async fn trigger_pattern(
919    State(state): State<AppState>,
920    axum::extract::Path(pattern): axum::extract::Path<String>,
921) -> Json<StreamResponse> {
922    info!("Pattern trigger requested: {}", pattern);
923
924    // Validate pattern name
925    let valid_patterns = [
926        "year_end_spike",
927        "period_end_spike",
928        "holiday_cluster",
929        "fraud_cluster",
930        "error_cluster",
931        "uniform",
932    ];
933
934    let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
935
936    if !is_valid {
937        return Json(StreamResponse {
938            success: false,
939            message: format!(
940                "Unknown pattern '{pattern}'. Valid patterns: {valid_patterns:?}, or use 'custom:name' for custom patterns"
941            ),
942        });
943    }
944
945    // Store the pattern for the stream generator to pick up
946    match state.server_state.triggered_pattern.try_write() {
947        Ok(mut triggered) => {
948            *triggered = Some(pattern.clone());
949            Json(StreamResponse {
950                success: true,
951                message: format!("Pattern '{pattern}' will be applied to upcoming entries"),
952            })
953        }
954        Err(_) => Json(StreamResponse {
955            success: false,
956            message: "Failed to acquire lock for pattern trigger".to_string(),
957        }),
958    }
959}
960
961/// A [`PhaseSink`](datasynth_runtime::stream_pipeline::PhaseSink) that sends
962/// NDJSON lines through a `tokio::sync::mpsc::Sender`. Bridges the synchronous
963/// generation pipeline to an async HTTP streaming response.
964struct ChannelPhaseSink {
965    tx: tokio::sync::mpsc::Sender<String>,
966    stats: Arc<std::sync::Mutex<datasynth_runtime::stream_pipeline::StreamStats>>,
967}
968
969impl ChannelPhaseSink {
970    fn new(tx: tokio::sync::mpsc::Sender<String>) -> Self {
971        Self {
972            tx,
973            stats: Arc::new(std::sync::Mutex::new(
974                datasynth_runtime::stream_pipeline::StreamStats::default(),
975            )),
976        }
977    }
978}
979
980impl datasynth_runtime::stream_pipeline::PhaseSink for ChannelPhaseSink {
981    fn emit(
982        &self,
983        phase: &str,
984        item_type: &str,
985        item: &serde_json::Value,
986    ) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
987        let envelope = serde_json::json!({
988            "phase": phase,
989            "item_type": item_type,
990            "data": item,
991        });
992        let json = serde_json::to_string(&envelope).map_err(|e| {
993            datasynth_runtime::stream_pipeline::StreamError::Serialization(e.to_string())
994        })?;
995
996        // blocking_send: we're on a spawn_blocking thread
997        self.tx.blocking_send(json).map_err(|_| {
998            datasynth_runtime::stream_pipeline::StreamError::Connection(
999                "channel closed".to_string(),
1000            )
1001        })?;
1002
1003        if let Ok(mut stats) = self.stats.lock() {
1004            stats.items_emitted += 1;
1005        }
1006        Ok(())
1007    }
1008
1009    fn phase_complete(
1010        &self,
1011        _phase: &str,
1012    ) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
1013        if let Ok(mut stats) = self.stats.lock() {
1014            stats.phases_completed += 1;
1015        }
1016        Ok(())
1017    }
1018
1019    fn flush(&self) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
1020        Ok(())
1021    }
1022
1023    fn stats(&self) -> datasynth_runtime::stream_pipeline::StreamStats {
1024        self.stats.lock().map(|s| s.clone()).unwrap_or_default()
1025    }
1026}
1027
1028/// Query parameters for the NDJSON streaming endpoint.
1029#[derive(Debug, Deserialize)]
1030struct NdjsonStreamQuery {
1031    /// Target events per second (0 or absent = unlimited).
1032    #[serde(default)]
1033    rate: Option<f64>,
1034    /// Token bucket burst size (default 100).
1035    #[serde(default)]
1036    burst: Option<u32>,
1037    /// Emit a _progress event every N items (default 100, 0 = disabled).
1038    #[serde(default)]
1039    progress_interval: Option<u64>,
1040}
1041
1042/// NDJSON streaming endpoint.
1043///
1044/// Runs a full generation and streams every phase (master data, document flows,
1045/// journal entries, anomalies, OCPM, etc.) as newline-delimited JSON.
1046///
1047/// Each line is a self-describing NDJSON envelope:
1048/// ```json
1049/// {"phase":"journal_entries","item_type":"JournalEntry","data":{...}}
1050/// ```
1051/// Progress events: `{"phase":"_progress","item_type":"StreamProgress","data":{...}}`
1052/// Completion: `{"type":"_complete","summary":{...}}`
1053///
1054/// Rate-controlled via the `rate` query parameter (events/sec, 0 = unlimited).
1055///
1056/// Example: `GET /api/stream/ndjson?rate=100&progress_interval=50`
1057async fn stream_ndjson(
1058    State(state): State<AppState>,
1059    axum::extract::Query(params): axum::extract::Query<NdjsonStreamQuery>,
1060) -> impl IntoResponse {
1061    let config = state.server_state.config.read().await.clone();
1062    let rate = params.rate.unwrap_or(0.0);
1063    let burst = params.burst.unwrap_or(100);
1064    let progress_interval = params.progress_interval.unwrap_or(100);
1065
1066    // Channel: generation thread sends NDJSON lines, HTTP response reads them
1067    let (tx, rx) = tokio::sync::mpsc::channel::<String>(1024);
1068
1069    // Spawn generation on a blocking thread
1070    tokio::task::spawn_blocking(move || {
1071        use datasynth_runtime::stream_pipeline::*;
1072
1073        // Create a PhaseSink that sends NDJSON through the channel
1074        let channel_sink = ChannelPhaseSink::new(tx.clone());
1075
1076        // Wrap with rate limiting
1077        let pipeline: Box<dyn PhaseSink> = Box::new(RateLimitedPipeline::new(
1078            Box::new(channel_sink),
1079            rate,
1080            burst,
1081            progress_interval,
1082        ));
1083
1084        // Configure generation with all phases
1085        let mut phase_config = PhaseConfig::from_config(&config);
1086        phase_config.show_progress = false;
1087
1088        match EnhancedOrchestrator::new(config, phase_config) {
1089            Ok(mut orchestrator) => {
1090                orchestrator.set_phase_sink(pipeline);
1091                match orchestrator.generate() {
1092                    Ok(result) => {
1093                        // Send completion summary
1094                        let summary = serde_json::json!({
1095                            "type": "_complete",
1096                            "summary": {
1097                                "total_entries": result.statistics.total_entries,
1098                                "total_line_items": result.statistics.total_line_items,
1099                                "anomaly_count": result.anomaly_labels.labels.len(),
1100                            }
1101                        });
1102                        let _ =
1103                            tx.blocking_send(serde_json::to_string(&summary).unwrap_or_default());
1104                    }
1105                    Err(e) => {
1106                        let err = serde_json::json!({
1107                            "type": "_error",
1108                            "message": format!("Generation failed: {e}"),
1109                        });
1110                        let _ = tx.blocking_send(serde_json::to_string(&err).unwrap_or_default());
1111                    }
1112                }
1113            }
1114            Err(e) => {
1115                let err = serde_json::json!({
1116                    "type": "_error",
1117                    "message": format!("Failed to create orchestrator: {e}"),
1118                });
1119                let _ = tx.blocking_send(serde_json::to_string(&err).unwrap_or_default());
1120            }
1121        }
1122        // tx is dropped here, closing the channel → stream ends
1123    });
1124
1125    // Convert the receiver into an axum streaming response
1126    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
1127    let body = axum::body::Body::from_stream(tokio_stream::StreamExt::map(stream, |mut line| {
1128        line.push('\n');
1129        Ok::<_, std::convert::Infallible>(line)
1130    }));
1131
1132    axum::response::Response::builder()
1133        .header("Content-Type", "application/x-ndjson")
1134        .header("Transfer-Encoding", "chunked")
1135        .header("Cache-Control", "no-cache")
1136        .header("X-Content-Type-Options", "nosniff")
1137        .body(body)
1138        .unwrap_or_else(|_| {
1139            axum::response::Response::builder()
1140                .status(StatusCode::INTERNAL_SERVER_ERROR)
1141                .body(axum::body::Body::empty())
1142                .expect("fallback response")
1143        })
1144}
1145
1146/// WebSocket endpoint for metrics stream.
1147async fn websocket_metrics(
1148    ws: WebSocketUpgrade,
1149    State(state): State<AppState>,
1150) -> impl IntoResponse {
1151    ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
1152}
1153
1154/// WebSocket endpoint for event stream.
1155async fn websocket_events(
1156    ws: WebSocketUpgrade,
1157    State(state): State<AppState>,
1158) -> impl IntoResponse {
1159    ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
1160}
1161
1162// ===========================================================================
1163// Scenario Template Handlers
1164// ===========================================================================
1165
1166/// Returns the catalog of scenario DAG templates the server can run.
1167///
1168/// Includes both sector-specific canonical templates (manufacturing,
1169/// retail, financial_services) and the ISA 315 financial process default
1170/// (`tpl_financial_process_17`). Before v3.1.1 this endpoint did not
1171/// exist on the server, so SDK clients fell back to a hard-coded single
1172/// template id — making sector-specific scenario DAGs invisible.
1173async fn list_scenario_templates() -> Json<serde_json::Value> {
1174    // Each entry lists the template id, human-readable name,
1175    // description, target industry tag(s), and a summary of the
1176    // intervention DAG (type + count). The concrete intervention
1177    // parameters live in the YAML files under
1178    // `crates/datasynth-config/src/templates/scenarios/` and are loaded
1179    // at job-submission time.
1180    let templates = serde_json::json!([
1181        {
1182            "template_id": "tpl_financial_process_17",
1183            "name": "ISA 315 Financial Reporting Process",
1184            "description": "Generic financial reporting process with 17 key risk nodes per ISA 315 (revised 2019)",
1185            "industry": "generic",
1186            "tags": ["audit", "isa_315", "financial_reporting"],
1187            "intervention_count": 0,
1188            "yaml_source": null,
1189            "is_default": true
1190        },
1191        {
1192            "template_id": "tpl_manufacturing_supply_disruption",
1193            "name": "Manufacturing Supply Chain Disruption",
1194            "description": "Critical component shortage cascades through BOMs, production orders, quality inspections, and COGS",
1195            "industry": "manufacturing",
1196            "tags": ["manufacturing", "supply_chain", "disruption"],
1197            "intervention_count": 2,
1198            "yaml_source": "manufacturing_supply_disruption.yaml",
1199            "is_default": false
1200        },
1201        {
1202            "template_id": "tpl_retail_seasonal_revenue",
1203            "name": "Retail Seasonal Revenue Swing",
1204            "description": "Q4 holiday surge + Q1 post-holiday slump drives revenue, inventory, and accrual volatility",
1205            "industry": "retail",
1206            "tags": ["retail", "seasonality", "revenue"],
1207            "intervention_count": 2,
1208            "yaml_source": "retail_seasonal_revenue.yaml",
1209            "is_default": false
1210        },
1211        {
1212            "template_id": "tpl_financial_services_credit_risk",
1213            "name": "Financial Services Credit Risk Shock",
1214            "description": "Macro credit downturn: ECL model reweighting, provision matrix changes, going concern assessment",
1215            "industry": "financial_services",
1216            "tags": ["financial_services", "credit_risk", "ifrs9"],
1217            "intervention_count": 2,
1218            "yaml_source": "financial_services_credit_risk.yaml",
1219            "is_default": false
1220        },
1221        {
1222            "template_id": "tpl_control_failure_cascade",
1223            "name": "Control Failure Cascade",
1224            "description": "Significant control failure in revenue cycle, cascading through audit risk assessment",
1225            "industry": "generic",
1226            "tags": ["audit", "control_failure"],
1227            "intervention_count": 1,
1228            "yaml_source": "control_failure_cascade.yaml",
1229            "is_default": false
1230        },
1231        {
1232            "template_id": "tpl_audit_scope_change",
1233            "name": "Audit Scope Change",
1234            "description": "Regulatory change triggering materiality reduction mid-engagement",
1235            "industry": "generic",
1236            "tags": ["audit", "regulatory"],
1237            "intervention_count": 1,
1238            "yaml_source": "audit_scope_change.yaml",
1239            "is_default": false
1240        },
1241        {
1242            "template_id": "tpl_going_concern_trigger",
1243            "name": "Going Concern Trigger",
1244            "description": "Credit crunch macro shock driving ISA 570 going concern assessment",
1245            "industry": "generic",
1246            "tags": ["audit", "going_concern", "isa_570"],
1247            "intervention_count": 1,
1248            "yaml_source": "going_concern_trigger.yaml",
1249            "is_default": false
1250        }
1251    ]);
1252    Json(serde_json::json!({
1253        "templates": templates,
1254        "total": 7,
1255        "schema_version": "1.0"
1256    }))
1257}
1258
1259// ===========================================================================
1260// Job Queue Handlers
1261// ===========================================================================
1262
1263/// Submit a new async generation job.
1264async fn submit_job(
1265    State(state): State<AppState>,
1266    Json(request): Json<JobRequest>,
1267) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
1268    let queue = state.job_queue.as_ref().ok_or_else(|| {
1269        (
1270            StatusCode::SERVICE_UNAVAILABLE,
1271            Json(serde_json::json!({"error": "Job queue not enabled"})),
1272        )
1273    })?;
1274
1275    let job_id = queue.submit(request).await;
1276    info!("Job submitted: {}", job_id);
1277
1278    Ok((
1279        StatusCode::CREATED,
1280        Json(serde_json::json!({
1281            "id": job_id.to_string(),
1282            "status": "queued"
1283        })),
1284    ))
1285}
1286
1287/// Get status of a specific job.
1288async fn get_job(
1289    State(state): State<AppState>,
1290    axum::extract::Path(id): axum::extract::Path<String>,
1291) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1292    let queue = state.job_queue.as_ref().ok_or_else(|| {
1293        (
1294            StatusCode::SERVICE_UNAVAILABLE,
1295            Json(serde_json::json!({"error": "Job queue not enabled"})),
1296        )
1297    })?;
1298
1299    match queue.get(&id).await {
1300        Some(entry) => Ok(Json(serde_json::json!({
1301            "id": entry.id,
1302            "status": format!("{:?}", entry.status).to_lowercase(),
1303            "submitted_at": entry.submitted_at.to_rfc3339(),
1304            "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1305            "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1306            "result": entry.result,
1307        }))),
1308        None => Err((
1309            StatusCode::NOT_FOUND,
1310            Json(serde_json::json!({"error": "Job not found"})),
1311        )),
1312    }
1313}
1314
1315/// List all jobs.
1316async fn list_jobs(
1317    State(state): State<AppState>,
1318) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1319    let queue = state.job_queue.as_ref().ok_or_else(|| {
1320        (
1321            StatusCode::SERVICE_UNAVAILABLE,
1322            Json(serde_json::json!({"error": "Job queue not enabled"})),
1323        )
1324    })?;
1325
1326    let summaries: Vec<_> = queue
1327        .list()
1328        .await
1329        .into_iter()
1330        .map(|s| {
1331            serde_json::json!({
1332                "id": s.id,
1333                "status": format!("{:?}", s.status).to_lowercase(),
1334                "submitted_at": s.submitted_at.to_rfc3339(),
1335            })
1336        })
1337        .collect();
1338
1339    Ok(Json(serde_json::json!({ "jobs": summaries })))
1340}
1341
1342/// Cancel a queued job.
1343async fn cancel_job(
1344    State(state): State<AppState>,
1345    axum::extract::Path(id): axum::extract::Path<String>,
1346) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1347    let queue = state.job_queue.as_ref().ok_or_else(|| {
1348        (
1349            StatusCode::SERVICE_UNAVAILABLE,
1350            Json(serde_json::json!({"error": "Job queue not enabled"})),
1351        )
1352    })?;
1353
1354    if queue.cancel(&id).await {
1355        Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1356    } else {
1357        Err((
1358            StatusCode::CONFLICT,
1359            Json(
1360                serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1361            ),
1362        ))
1363    }
1364}
1365
1366// ===========================================================================
1367// Config Reload Handler
1368// ===========================================================================
1369
1370/// Reload configuration from the configured source.
1371async fn reload_config(
1372    State(state): State<AppState>,
1373) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1374    let source = state.server_state.config_source.read().await.clone();
1375    match crate::config_loader::load_config(&source).await {
1376        Ok(new_config) => {
1377            let mut config = state.server_state.config.write().await;
1378            *config = new_config;
1379            info!("Configuration reloaded via REST API from {:?}", source);
1380            Ok(Json(serde_json::json!({
1381                "success": true,
1382                "message": "Configuration reloaded"
1383            })))
1384        }
1385        Err(e) => {
1386            error!("Failed to reload configuration: {}", e);
1387            Err((
1388                StatusCode::INTERNAL_SERVER_ERROR,
1389                Json(serde_json::json!({
1390                    "success": false,
1391                    "message": format!("Failed to reload configuration: {}", e)
1392                })),
1393            ))
1394        }
1395    }
1396}
1397
1398#[cfg(test)]
1399#[allow(clippy::unwrap_used)]
1400mod tests {
1401    use super::*;
1402
1403    // ==========================================================================
1404    // Response Serialization Tests
1405    // ==========================================================================
1406
1407    #[test]
1408    fn test_health_response_serialization() {
1409        let response = HealthResponse {
1410            healthy: true,
1411            version: "0.1.0".to_string(),
1412            uptime_seconds: 100,
1413        };
1414        let json = serde_json::to_string(&response).unwrap();
1415        assert!(json.contains("healthy"));
1416        assert!(json.contains("version"));
1417        assert!(json.contains("uptime_seconds"));
1418    }
1419
1420    #[test]
1421    fn test_health_response_deserialization() {
1422        let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1423        let response: HealthResponse = serde_json::from_str(json).unwrap();
1424        assert!(response.healthy);
1425        assert_eq!(response.version, "0.1.0");
1426        assert_eq!(response.uptime_seconds, 100);
1427    }
1428
1429    #[test]
1430    fn test_metrics_response_serialization() {
1431        let response = MetricsResponse {
1432            total_entries_generated: 1000,
1433            total_anomalies_injected: 10,
1434            uptime_seconds: 60,
1435            session_entries: 1000,
1436            session_entries_per_second: 16.67,
1437            active_streams: 1,
1438            total_stream_events: 500,
1439        };
1440        let json = serde_json::to_string(&response).unwrap();
1441        assert!(json.contains("total_entries_generated"));
1442        assert!(json.contains("session_entries_per_second"));
1443    }
1444
1445    #[test]
1446    fn test_metrics_response_deserialization() {
1447        let json = r#"{
1448            "total_entries_generated": 5000,
1449            "total_anomalies_injected": 50,
1450            "uptime_seconds": 300,
1451            "session_entries": 5000,
1452            "session_entries_per_second": 16.67,
1453            "active_streams": 2,
1454            "total_stream_events": 10000
1455        }"#;
1456        let response: MetricsResponse = serde_json::from_str(json).unwrap();
1457        assert_eq!(response.total_entries_generated, 5000);
1458        assert_eq!(response.active_streams, 2);
1459    }
1460
1461    #[test]
1462    fn test_config_response_serialization() {
1463        let response = ConfigResponse {
1464            success: true,
1465            message: "Configuration loaded".to_string(),
1466            config: Some(GenerationConfigDto {
1467                industry: "manufacturing".to_string(),
1468                start_date: "2024-01-01".to_string(),
1469                period_months: 12,
1470                seed: Some(42),
1471                coa_complexity: "medium".to_string(),
1472                companies: vec![],
1473                fraud_enabled: false,
1474                fraud_rate: 0.0,
1475            }),
1476        };
1477        let json = serde_json::to_string(&response).unwrap();
1478        assert!(json.contains("success"));
1479        assert!(json.contains("config"));
1480    }
1481
1482    #[test]
1483    fn test_config_response_without_config() {
1484        let response = ConfigResponse {
1485            success: false,
1486            message: "No configuration available".to_string(),
1487            config: None,
1488        };
1489        let json = serde_json::to_string(&response).unwrap();
1490        assert!(json.contains("null") || json.contains("config\":null"));
1491    }
1492
1493    #[test]
1494    fn test_generation_config_dto_roundtrip() {
1495        let original = GenerationConfigDto {
1496            industry: "retail".to_string(),
1497            start_date: "2024-06-01".to_string(),
1498            period_months: 6,
1499            seed: Some(12345),
1500            coa_complexity: "large".to_string(),
1501            companies: vec![CompanyConfigDto {
1502                code: "1000".to_string(),
1503                name: "Test Corp".to_string(),
1504                currency: "USD".to_string(),
1505                country: "US".to_string(),
1506                annual_transaction_volume: 100000,
1507                volume_weight: 1.0,
1508            }],
1509            fraud_enabled: true,
1510            fraud_rate: 0.05,
1511        };
1512
1513        let json = serde_json::to_string(&original).unwrap();
1514        let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1515
1516        assert_eq!(original.industry, deserialized.industry);
1517        assert_eq!(original.seed, deserialized.seed);
1518        assert_eq!(original.companies.len(), deserialized.companies.len());
1519    }
1520
1521    #[test]
1522    fn test_company_config_dto_serialization() {
1523        let company = CompanyConfigDto {
1524            code: "2000".to_string(),
1525            name: "European Subsidiary".to_string(),
1526            currency: "EUR".to_string(),
1527            country: "DE".to_string(),
1528            annual_transaction_volume: 50000,
1529            volume_weight: 0.5,
1530        };
1531        let json = serde_json::to_string(&company).unwrap();
1532        assert!(json.contains("2000"));
1533        assert!(json.contains("EUR"));
1534        assert!(json.contains("DE"));
1535    }
1536
1537    #[test]
1538    fn test_bulk_generate_request_deserialization() {
1539        let json = r#"{
1540            "entry_count": 5000,
1541            "include_master_data": true,
1542            "inject_anomalies": true
1543        }"#;
1544        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1545        assert_eq!(request.entry_count, Some(5000));
1546        assert_eq!(request.include_master_data, Some(true));
1547        assert_eq!(request.inject_anomalies, Some(true));
1548    }
1549
1550    #[test]
1551    fn test_bulk_generate_request_with_defaults() {
1552        let json = r#"{}"#;
1553        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1554        assert_eq!(request.entry_count, None);
1555        assert_eq!(request.include_master_data, None);
1556        assert_eq!(request.inject_anomalies, None);
1557    }
1558
1559    #[test]
1560    fn test_bulk_generate_response_serialization() {
1561        let response = BulkGenerateResponse {
1562            success: true,
1563            entries_generated: 1000,
1564            duration_ms: 250,
1565            anomaly_count: 20,
1566        };
1567        let json = serde_json::to_string(&response).unwrap();
1568        assert!(json.contains("entries_generated"));
1569        assert!(json.contains("1000"));
1570        assert!(json.contains("duration_ms"));
1571    }
1572
1573    #[test]
1574    fn test_stream_response_serialization() {
1575        let response = StreamResponse {
1576            success: true,
1577            message: "Stream started successfully".to_string(),
1578        };
1579        let json = serde_json::to_string(&response).unwrap();
1580        assert!(json.contains("success"));
1581        assert!(json.contains("Stream started"));
1582    }
1583
1584    #[test]
1585    fn test_stream_response_failure() {
1586        let response = StreamResponse {
1587            success: false,
1588            message: "Stream failed to start".to_string(),
1589        };
1590        let json = serde_json::to_string(&response).unwrap();
1591        assert!(json.contains("false"));
1592        assert!(json.contains("failed"));
1593    }
1594
1595    // ==========================================================================
1596    // CORS Configuration Tests
1597    // ==========================================================================
1598
1599    #[test]
1600    fn test_cors_config_default() {
1601        let config = CorsConfig::default();
1602        assert!(!config.allow_any_origin);
1603        assert!(!config.allowed_origins.is_empty());
1604        assert!(config
1605            .allowed_origins
1606            .contains(&"http://localhost:5173".to_string()));
1607        assert!(config
1608            .allowed_origins
1609            .contains(&"tauri://localhost".to_string()));
1610    }
1611
1612    #[test]
1613    fn test_cors_config_custom_origins() {
1614        let config = CorsConfig {
1615            allowed_origins: vec![
1616                "https://example.com".to_string(),
1617                "https://app.example.com".to_string(),
1618            ],
1619            allow_any_origin: false,
1620        };
1621        assert_eq!(config.allowed_origins.len(), 2);
1622        assert!(config
1623            .allowed_origins
1624            .contains(&"https://example.com".to_string()));
1625    }
1626
1627    #[test]
1628    fn test_cors_config_permissive() {
1629        let config = CorsConfig {
1630            allowed_origins: vec![],
1631            allow_any_origin: true,
1632        };
1633        assert!(config.allow_any_origin);
1634    }
1635
1636    // ==========================================================================
1637    // Request Validation Tests (edge cases)
1638    // ==========================================================================
1639
1640    #[test]
1641    fn test_bulk_generate_request_partial() {
1642        let json = r#"{"entry_count": 100}"#;
1643        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1644        assert_eq!(request.entry_count, Some(100));
1645        assert!(request.include_master_data.is_none());
1646    }
1647
1648    #[test]
1649    fn test_generation_config_no_seed() {
1650        let config = GenerationConfigDto {
1651            industry: "technology".to_string(),
1652            start_date: "2024-01-01".to_string(),
1653            period_months: 3,
1654            seed: None,
1655            coa_complexity: "small".to_string(),
1656            companies: vec![],
1657            fraud_enabled: false,
1658            fraud_rate: 0.0,
1659        };
1660        let json = serde_json::to_string(&config).unwrap();
1661        assert!(json.contains("seed"));
1662    }
1663
1664    #[test]
1665    fn test_generation_config_multiple_companies() {
1666        let config = GenerationConfigDto {
1667            industry: "manufacturing".to_string(),
1668            start_date: "2024-01-01".to_string(),
1669            period_months: 12,
1670            seed: Some(42),
1671            coa_complexity: "large".to_string(),
1672            companies: vec![
1673                CompanyConfigDto {
1674                    code: "1000".to_string(),
1675                    name: "Headquarters".to_string(),
1676                    currency: "USD".to_string(),
1677                    country: "US".to_string(),
1678                    annual_transaction_volume: 100000,
1679                    volume_weight: 1.0,
1680                },
1681                CompanyConfigDto {
1682                    code: "2000".to_string(),
1683                    name: "European Sub".to_string(),
1684                    currency: "EUR".to_string(),
1685                    country: "DE".to_string(),
1686                    annual_transaction_volume: 50000,
1687                    volume_weight: 0.5,
1688                },
1689                CompanyConfigDto {
1690                    code: "3000".to_string(),
1691                    name: "APAC Sub".to_string(),
1692                    currency: "JPY".to_string(),
1693                    country: "JP".to_string(),
1694                    annual_transaction_volume: 30000,
1695                    volume_weight: 0.3,
1696                },
1697            ],
1698            fraud_enabled: true,
1699            fraud_rate: 0.02,
1700        };
1701        assert_eq!(config.companies.len(), 3);
1702    }
1703
1704    // ==========================================================================
1705    // Metrics Calculation Tests
1706    // ==========================================================================
1707
1708    #[test]
1709    fn test_metrics_entries_per_second_calculation() {
1710        // Test that we can represent the expected calculation
1711        let total_entries: u64 = 1000;
1712        let uptime: u64 = 60;
1713        let eps = if uptime > 0 {
1714            total_entries as f64 / uptime as f64
1715        } else {
1716            0.0
1717        };
1718        assert!((eps - 16.67).abs() < 0.1);
1719    }
1720
1721    #[test]
1722    fn test_metrics_entries_per_second_zero_uptime() {
1723        let total_entries: u64 = 1000;
1724        let uptime: u64 = 0;
1725        let eps = if uptime > 0 {
1726            total_entries as f64 / uptime as f64
1727        } else {
1728            0.0
1729        };
1730        assert_eq!(eps, 0.0);
1731    }
1732}