Skip to main content

datasynth_server/rest/
routes.rs

1//! REST API routes.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7    extract::{State, WebSocketUpgrade},
8    http::{header, Method, StatusCode},
9    response::IntoResponse,
10    routing::{get, post},
11    Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24/// Application state shared across handlers.
25#[derive(Clone)]
26pub struct AppState {
27    pub server_state: Arc<ServerState>,
28    pub job_queue: Option<Arc<JobQueue>>,
29}
30
31/// Timeout configuration for the REST API.
32#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34    /// Request timeout in seconds.
35    pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39    fn default() -> Self {
40        Self {
41            // 5 minutes default - bulk generation can take a while
42            request_timeout_secs: 300,
43        }
44    }
45}
46
47impl TimeoutConfig {
48    /// Create a new timeout config.
49    pub fn new(timeout_secs: u64) -> Self {
50        Self {
51            request_timeout_secs: timeout_secs,
52        }
53    }
54}
55
56/// CORS configuration for the REST API.
57#[derive(Clone)]
58pub struct CorsConfig {
59    /// Allowed origins. If empty, only localhost is allowed.
60    pub allowed_origins: Vec<String>,
61    /// Allow any origin (development mode only - NOT recommended for production).
62    pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66    fn default() -> Self {
67        Self {
68            allowed_origins: vec![
69                "http://localhost:5173".to_string(), // Vite dev server
70                "http://localhost:3000".to_string(), // Common dev server
71                "http://127.0.0.1:5173".to_string(),
72                "http://127.0.0.1:3000".to_string(),
73                "tauri://localhost".to_string(), // Tauri app
74            ],
75            allow_any_origin: false,
76        }
77    }
78}
79
80/// Add API version header to responses.
81async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82    let (mut parts, body) = response.into_parts();
83    parts.headers.insert(
84        axum::http::HeaderName::from_static("x-api-version"),
85        axum::http::HeaderValue::from_static("v1"),
86    );
87    axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97/// Create the REST API router with default CORS settings.
98pub fn create_router(service: SynthService) -> Router {
99    create_router_with_cors(service, CorsConfig::default())
100}
101
102/// Create the REST API router with full configuration (CORS, auth, rate limiting, and timeout).
103///
104/// Uses in-memory rate limiting by default. For distributed rate limiting
105/// with Redis, use [`create_router_full_with_backend`] instead.
106pub fn create_router_full(
107    service: SynthService,
108    cors_config: CorsConfig,
109    auth_config: AuthConfig,
110    rate_limit_config: RateLimitConfig,
111    timeout_config: TimeoutConfig,
112) -> Router {
113    let backend = RateLimitBackend::in_memory(rate_limit_config);
114    create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117/// Create the REST API router with full configuration and a specific rate limiting backend.
118///
119/// This allows using either in-memory or Redis-backed rate limiting.
120///
121/// # Example (in-memory)
122/// ```rust,ignore
123/// let backend = RateLimitBackend::in_memory(rate_limit_config);
124/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
125/// ```
126///
127/// # Example (Redis)
128/// ```rust,ignore
129/// let backend = RateLimitBackend::redis("redis://127.0.0.1:6379", rate_limit_config).await?;
130/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
131/// ```
132pub fn create_router_full_with_backend(
133    service: SynthService,
134    cors_config: CorsConfig,
135    auth_config: AuthConfig,
136    rate_limit_backend: RateLimitBackend,
137    timeout_config: TimeoutConfig,
138) -> Router {
139    let server_state = service.state.clone();
140    let state = AppState {
141        server_state,
142        job_queue: None,
143    };
144
145    let cors = if cors_config.allow_any_origin {
146        CorsLayer::permissive()
147    } else {
148        let origins: Vec<_> = cors_config
149            .allowed_origins
150            .iter()
151            .filter_map(|o| o.parse().ok())
152            .collect();
153
154        CorsLayer::new()
155            .allow_origin(AllowOrigin::list(origins))
156            .allow_methods([
157                Method::GET,
158                Method::POST,
159                Method::PUT,
160                Method::DELETE,
161                Method::OPTIONS,
162            ])
163            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164    };
165
166    Router::new()
167        // Health and metrics (exempt from auth and rate limiting by default)
168        .route("/health", get(health_check))
169        .route("/ready", get(readiness_check))
170        .route("/live", get(liveness_check))
171        .route("/api/metrics", get(get_metrics))
172        .route("/metrics", get(prometheus_metrics))
173        // Configuration
174        .route("/api/config", get(get_config))
175        .route("/api/config", post(set_config))
176        .route("/api/config/reload", post(reload_config))
177        // Generation
178        .route("/api/generate/bulk", post(bulk_generate))
179        .route("/api/stream/start", post(start_stream))
180        .route("/api/stream/stop", post(stop_stream))
181        .route("/api/stream/pause", post(pause_stream))
182        .route("/api/stream/resume", post(resume_stream))
183        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184        // Jobs
185        .route("/api/jobs/submit", post(submit_job))
186        .route("/api/jobs", get(list_jobs))
187        .route("/api/jobs/{id}", get(get_job))
188        .route("/api/jobs/{id}/cancel", post(cancel_job))
189        // WebSocket
190        .route("/ws/metrics", get(websocket_metrics))
191        .route("/ws/events", get(websocket_events))
192        // Middleware stack (outermost applied first, innermost last)
193        // Order: Timeout -> RateLimit -> RequestValidation -> Auth -> RequestId -> CORS -> SecurityHeaders -> APIVersion -> Router
194        .layer(axum::middleware::from_fn(security_headers_middleware))
195        .layer(axum::middleware::map_response(api_version_header))
196        .layer(cors)
197        .layer(axum::middleware::from_fn(request_id_middleware))
198        .layer(axum::middleware::from_fn(auth_middleware))
199        .layer(axum::Extension(auth_config))
200        .layer(axum::middleware::from_fn(request_validation_middleware))
201        .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202        .layer(axum::Extension(rate_limit_backend))
203        .layer(TimeoutLayer::new(Duration::from_secs(
204            timeout_config.request_timeout_secs,
205        )))
206        .with_state(state)
207}
208
209/// Create the REST API router with custom CORS and authentication settings.
210pub fn create_router_with_auth(
211    service: SynthService,
212    cors_config: CorsConfig,
213    auth_config: AuthConfig,
214) -> Router {
215    let server_state = service.state.clone();
216    let state = AppState {
217        server_state,
218        job_queue: None,
219    };
220
221    let cors = if cors_config.allow_any_origin {
222        CorsLayer::permissive()
223    } else {
224        let origins: Vec<_> = cors_config
225            .allowed_origins
226            .iter()
227            .filter_map(|o| o.parse().ok())
228            .collect();
229
230        CorsLayer::new()
231            .allow_origin(AllowOrigin::list(origins))
232            .allow_methods([
233                Method::GET,
234                Method::POST,
235                Method::PUT,
236                Method::DELETE,
237                Method::OPTIONS,
238            ])
239            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
240    };
241
242    Router::new()
243        // Health and metrics (exempt from auth by default)
244        .route("/health", get(health_check))
245        .route("/ready", get(readiness_check))
246        .route("/live", get(liveness_check))
247        .route("/api/metrics", get(get_metrics))
248        .route("/metrics", get(prometheus_metrics))
249        // Configuration
250        .route("/api/config", get(get_config))
251        .route("/api/config", post(set_config))
252        .route("/api/config/reload", post(reload_config))
253        // Generation
254        .route("/api/generate/bulk", post(bulk_generate))
255        .route("/api/stream/start", post(start_stream))
256        .route("/api/stream/stop", post(stop_stream))
257        .route("/api/stream/pause", post(pause_stream))
258        .route("/api/stream/resume", post(resume_stream))
259        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
260        // Jobs
261        .route("/api/jobs/submit", post(submit_job))
262        .route("/api/jobs", get(list_jobs))
263        .route("/api/jobs/{id}", get(get_job))
264        .route("/api/jobs/{id}/cancel", post(cancel_job))
265        // WebSocket
266        .route("/ws/metrics", get(websocket_metrics))
267        .route("/ws/events", get(websocket_events))
268        .layer(axum::middleware::from_fn(auth_middleware))
269        .layer(axum::Extension(auth_config))
270        .layer(cors)
271        .with_state(state)
272}
273
274/// Create the REST API router with custom CORS settings.
275pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
276    let server_state = service.state.clone();
277    let state = AppState {
278        server_state,
279        job_queue: None,
280    };
281
282    let cors = if cors_config.allow_any_origin {
283        // Development mode - allow any origin (use with caution)
284        CorsLayer::permissive()
285    } else {
286        // Production mode - restricted origins
287        let origins: Vec<_> = cors_config
288            .allowed_origins
289            .iter()
290            .filter_map(|o| o.parse().ok())
291            .collect();
292
293        CorsLayer::new()
294            .allow_origin(AllowOrigin::list(origins))
295            .allow_methods([
296                Method::GET,
297                Method::POST,
298                Method::PUT,
299                Method::DELETE,
300                Method::OPTIONS,
301            ])
302            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
303    };
304
305    Router::new()
306        // Health and metrics
307        .route("/health", get(health_check))
308        .route("/ready", get(readiness_check))
309        .route("/live", get(liveness_check))
310        .route("/api/metrics", get(get_metrics))
311        .route("/metrics", get(prometheus_metrics))
312        // Configuration
313        .route("/api/config", get(get_config))
314        .route("/api/config", post(set_config))
315        .route("/api/config/reload", post(reload_config))
316        // Generation
317        .route("/api/generate/bulk", post(bulk_generate))
318        .route("/api/stream/start", post(start_stream))
319        .route("/api/stream/stop", post(stop_stream))
320        .route("/api/stream/pause", post(pause_stream))
321        .route("/api/stream/resume", post(resume_stream))
322        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
323        // Jobs
324        .route("/api/jobs/submit", post(submit_job))
325        .route("/api/jobs", get(list_jobs))
326        .route("/api/jobs/{id}", get(get_job))
327        .route("/api/jobs/{id}/cancel", post(cancel_job))
328        // WebSocket
329        .route("/ws/metrics", get(websocket_metrics))
330        .route("/ws/events", get(websocket_events))
331        .layer(cors)
332        .with_state(state)
333}
334
335// ===========================================================================
336// Request/Response types
337// ===========================================================================
338
339#[derive(Debug, Serialize, Deserialize)]
340pub struct HealthResponse {
341    pub healthy: bool,
342    pub version: String,
343    pub uptime_seconds: u64,
344}
345
346/// Readiness check response for Kubernetes.
347#[derive(Debug, Serialize, Deserialize)]
348pub struct ReadinessResponse {
349    pub ready: bool,
350    pub message: String,
351    pub checks: Vec<HealthCheck>,
352}
353
354/// Individual health check result.
355#[derive(Debug, Serialize, Deserialize)]
356pub struct HealthCheck {
357    pub name: String,
358    pub status: String,
359}
360
361/// Liveness check response for Kubernetes.
362#[derive(Debug, Serialize, Deserialize)]
363pub struct LivenessResponse {
364    pub alive: bool,
365    pub timestamp: String,
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct MetricsResponse {
370    pub total_entries_generated: u64,
371    pub total_anomalies_injected: u64,
372    pub uptime_seconds: u64,
373    pub session_entries: u64,
374    pub session_entries_per_second: f64,
375    pub active_streams: u32,
376    pub total_stream_events: u64,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ConfigResponse {
381    pub success: bool,
382    pub message: String,
383    pub config: Option<GenerationConfigDto>,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct GenerationConfigDto {
388    pub industry: String,
389    pub start_date: String,
390    pub period_months: u32,
391    pub seed: Option<u64>,
392    pub coa_complexity: String,
393    pub companies: Vec<CompanyConfigDto>,
394    pub fraud_enabled: bool,
395    pub fraud_rate: f32,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct CompanyConfigDto {
400    pub code: String,
401    pub name: String,
402    pub currency: String,
403    pub country: String,
404    pub annual_transaction_volume: u64,
405    pub volume_weight: f32,
406}
407
408#[derive(Debug, Deserialize)]
409pub struct BulkGenerateRequest {
410    pub entry_count: Option<u64>,
411    pub include_master_data: Option<bool>,
412    pub inject_anomalies: Option<bool>,
413}
414
415#[derive(Debug, Serialize)]
416pub struct BulkGenerateResponse {
417    pub success: bool,
418    pub entries_generated: u64,
419    pub duration_ms: u64,
420    pub anomaly_count: u64,
421}
422
423#[derive(Debug, Deserialize)]
424#[allow(dead_code)] // Fields deserialized from request, reserved for future use
425pub struct StreamRequest {
426    pub events_per_second: Option<u32>,
427    pub max_events: Option<u64>,
428    pub inject_anomalies: Option<bool>,
429}
430
431#[derive(Debug, Serialize)]
432pub struct StreamResponse {
433    pub success: bool,
434    pub message: String,
435}
436
437// ===========================================================================
438// Handlers
439// ===========================================================================
440
441/// Health check endpoint - returns overall health status.
442async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
443    Json(HealthResponse {
444        healthy: true,
445        version: env!("CARGO_PKG_VERSION").to_string(),
446        uptime_seconds: state.server_state.uptime_seconds(),
447    })
448}
449
450/// Readiness probe - indicates the service is ready to accept traffic.
451/// Use for Kubernetes readiness probes.
452async fn readiness_check(
453    State(state): State<AppState>,
454) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
455    let mut checks = Vec::new();
456    let mut any_fail = false;
457
458    // Check if configuration is loaded and valid
459    let config = state.server_state.config.read().await;
460    let config_valid = !config.companies.is_empty();
461    checks.push(HealthCheck {
462        name: "config".to_string(),
463        status: if config_valid { "ok" } else { "fail" }.to_string(),
464    });
465    if !config_valid {
466        any_fail = true;
467    }
468    drop(config);
469
470    // Check resource guard (memory)
471    let resource_status = state.server_state.resource_status();
472    let memory_status = if resource_status.degradation_level == "Emergency" {
473        any_fail = true;
474        "fail"
475    } else if resource_status.degradation_level != "Normal" {
476        "degraded"
477    } else {
478        "ok"
479    };
480    checks.push(HealthCheck {
481        name: "memory".to_string(),
482        status: memory_status.to_string(),
483    });
484
485    // Check disk (>100MB free)
486    let disk_ok = resource_status.disk_available_mb > 100;
487    checks.push(HealthCheck {
488        name: "disk".to_string(),
489        status: if disk_ok { "ok" } else { "fail" }.to_string(),
490    });
491    if !disk_ok {
492        any_fail = true;
493    }
494
495    let response = ReadinessResponse {
496        ready: !any_fail,
497        message: if any_fail {
498            "Service is not ready".to_string()
499        } else {
500            "Service is ready".to_string()
501        },
502        checks,
503    };
504
505    if any_fail {
506        Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
507    } else {
508        Ok(Json(response))
509    }
510}
511
512/// Liveness probe - indicates the service is alive.
513/// Use for Kubernetes liveness probes.
514async fn liveness_check() -> Json<LivenessResponse> {
515    Json(LivenessResponse {
516        alive: true,
517        timestamp: chrono::Utc::now().to_rfc3339(),
518    })
519}
520
521/// Prometheus-compatible metrics endpoint.
522/// Returns metrics in Prometheus text exposition format.
523async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
524    use std::sync::atomic::Ordering;
525
526    let uptime = state.server_state.uptime_seconds();
527    let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
528    let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
529    let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
530    let total_stream_events = state
531        .server_state
532        .total_stream_events
533        .load(Ordering::Relaxed);
534
535    let entries_per_second = if uptime > 0 {
536        total_entries as f64 / uptime as f64
537    } else {
538        0.0
539    };
540
541    let metrics = format!(
542        r#"# HELP synth_entries_generated_total Total number of journal entries generated
543# TYPE synth_entries_generated_total counter
544synth_entries_generated_total {}
545
546# HELP synth_anomalies_injected_total Total number of anomalies injected
547# TYPE synth_anomalies_injected_total counter
548synth_anomalies_injected_total {}
549
550# HELP synth_uptime_seconds Server uptime in seconds
551# TYPE synth_uptime_seconds gauge
552synth_uptime_seconds {}
553
554# HELP synth_entries_per_second Rate of entry generation
555# TYPE synth_entries_per_second gauge
556synth_entries_per_second {:.2}
557
558# HELP synth_active_streams Number of active streaming connections
559# TYPE synth_active_streams gauge
560synth_active_streams {}
561
562# HELP synth_stream_events_total Total events sent through streams
563# TYPE synth_stream_events_total counter
564synth_stream_events_total {}
565
566# HELP synth_info Server version information
567# TYPE synth_info gauge
568synth_info{{version="{}"}} 1
569"#,
570        total_entries,
571        total_anomalies,
572        uptime,
573        entries_per_second,
574        active_streams,
575        total_stream_events,
576        env!("CARGO_PKG_VERSION")
577    );
578
579    (
580        StatusCode::OK,
581        [(
582            header::CONTENT_TYPE,
583            "text/plain; version=0.0.4; charset=utf-8",
584        )],
585        metrics,
586    )
587}
588
589/// Get server metrics.
590async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
591    let uptime = state.server_state.uptime_seconds();
592    let total_entries = state
593        .server_state
594        .total_entries
595        .load(std::sync::atomic::Ordering::Relaxed);
596
597    let entries_per_second = if uptime > 0 {
598        total_entries as f64 / uptime as f64
599    } else {
600        0.0
601    };
602
603    Json(MetricsResponse {
604        total_entries_generated: total_entries,
605        total_anomalies_injected: state
606            .server_state
607            .total_anomalies
608            .load(std::sync::atomic::Ordering::Relaxed),
609        uptime_seconds: uptime,
610        session_entries: total_entries,
611        session_entries_per_second: entries_per_second,
612        active_streams: state
613            .server_state
614            .active_streams
615            .load(std::sync::atomic::Ordering::Relaxed) as u32,
616        total_stream_events: state
617            .server_state
618            .total_stream_events
619            .load(std::sync::atomic::Ordering::Relaxed),
620    })
621}
622
623/// Get current configuration.
624async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
625    let config = state.server_state.config.read().await;
626
627    Json(ConfigResponse {
628        success: true,
629        message: "Current configuration".to_string(),
630        config: Some(GenerationConfigDto {
631            industry: format!("{:?}", config.global.industry),
632            start_date: config.global.start_date.clone(),
633            period_months: config.global.period_months,
634            seed: config.global.seed,
635            coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
636            companies: config
637                .companies
638                .iter()
639                .map(|c| CompanyConfigDto {
640                    code: c.code.clone(),
641                    name: c.name.clone(),
642                    currency: c.currency.clone(),
643                    country: c.country.clone(),
644                    annual_transaction_volume: c.annual_transaction_volume.count(),
645                    volume_weight: c.volume_weight as f32,
646                })
647                .collect(),
648            fraud_enabled: config.fraud.enabled,
649            fraud_rate: config.fraud.fraud_rate as f32,
650        }),
651    })
652}
653
654/// Set configuration.
655async fn set_config(
656    State(state): State<AppState>,
657    Json(new_config): Json<GenerationConfigDto>,
658) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
659    use datasynth_config::schema::{CompanyConfig, TransactionVolume};
660    use datasynth_core::models::{CoAComplexity, IndustrySector};
661
662    info!(
663        "Configuration update requested: industry={}, period_months={}",
664        new_config.industry, new_config.period_months
665    );
666
667    // Parse industry from string
668    let industry = match new_config.industry.to_lowercase().as_str() {
669        "manufacturing" => IndustrySector::Manufacturing,
670        "retail" => IndustrySector::Retail,
671        "financial_services" | "financialservices" => IndustrySector::FinancialServices,
672        "healthcare" => IndustrySector::Healthcare,
673        "technology" => IndustrySector::Technology,
674        "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
675        "energy" => IndustrySector::Energy,
676        "transportation" => IndustrySector::Transportation,
677        "real_estate" | "realestate" => IndustrySector::RealEstate,
678        "telecommunications" => IndustrySector::Telecommunications,
679        _ => {
680            return Err((
681                StatusCode::BAD_REQUEST,
682                Json(ConfigResponse {
683                    success: false,
684                    message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
685                    config: None,
686                }),
687            ));
688        }
689    };
690
691    // Parse CoA complexity from string
692    let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
693        "small" => CoAComplexity::Small,
694        "medium" => CoAComplexity::Medium,
695        "large" => CoAComplexity::Large,
696        _ => {
697            return Err((
698                StatusCode::BAD_REQUEST,
699                Json(ConfigResponse {
700                    success: false,
701                    message: format!(
702                        "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
703                        new_config.coa_complexity
704                    ),
705                    config: None,
706                }),
707            ));
708        }
709    };
710
711    // Convert CompanyConfigDto to CompanyConfig
712    let companies: Vec<CompanyConfig> = new_config
713        .companies
714        .iter()
715        .map(|c| CompanyConfig {
716            code: c.code.clone(),
717            name: c.name.clone(),
718            currency: c.currency.clone(),
719            country: c.country.clone(),
720            fiscal_year_variant: "K4".to_string(),
721            annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
722            volume_weight: c.volume_weight as f64,
723        })
724        .collect();
725
726    // Update the configuration
727    let mut config = state.server_state.config.write().await;
728    config.global.industry = industry;
729    config.global.start_date = new_config.start_date.clone();
730    config.global.period_months = new_config.period_months;
731    config.global.seed = new_config.seed;
732    config.chart_of_accounts.complexity = complexity;
733    config.fraud.enabled = new_config.fraud_enabled;
734    config.fraud.fraud_rate = new_config.fraud_rate as f64;
735
736    // Only update companies if provided
737    if !companies.is_empty() {
738        config.companies = companies;
739    }
740
741    info!("Configuration updated successfully");
742
743    Ok(Json(ConfigResponse {
744        success: true,
745        message: "Configuration updated and applied".to_string(),
746        config: Some(new_config),
747    }))
748}
749
750/// Bulk generation endpoint.
751async fn bulk_generate(
752    State(state): State<AppState>,
753    Json(req): Json<BulkGenerateRequest>,
754) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
755    // Validate entry_count bounds
756    const MAX_ENTRY_COUNT: u64 = 1_000_000;
757    if let Some(count) = req.entry_count {
758        if count > MAX_ENTRY_COUNT {
759            return Err((
760                StatusCode::BAD_REQUEST,
761                format!(
762                    "entry_count ({}) exceeds maximum allowed value ({})",
763                    count, MAX_ENTRY_COUNT
764                ),
765            ));
766        }
767    }
768
769    let config = state.server_state.config.read().await.clone();
770    let start_time = std::time::Instant::now();
771
772    let phase_config = PhaseConfig {
773        generate_master_data: req.include_master_data.unwrap_or(false),
774        generate_document_flows: false,
775        generate_journal_entries: true,
776        inject_anomalies: req.inject_anomalies.unwrap_or(false),
777        show_progress: false,
778        ..Default::default()
779    };
780
781    let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
782        (
783            StatusCode::INTERNAL_SERVER_ERROR,
784            format!("Failed to create orchestrator: {}", e),
785        )
786    })?;
787
788    let result = orchestrator.generate().map_err(|e| {
789        (
790            StatusCode::INTERNAL_SERVER_ERROR,
791            format!("Generation failed: {}", e),
792        )
793    })?;
794
795    let duration_ms = start_time.elapsed().as_millis() as u64;
796    let entries_count = result.journal_entries.len() as u64;
797    let anomaly_count = result.anomaly_labels.labels.len() as u64;
798
799    // Update metrics
800    state
801        .server_state
802        .total_entries
803        .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
804    state
805        .server_state
806        .total_anomalies
807        .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
808
809    Ok(Json(BulkGenerateResponse {
810        success: true,
811        entries_generated: entries_count,
812        duration_ms,
813        anomaly_count,
814    }))
815}
816
817/// Start streaming.
818async fn start_stream(
819    State(state): State<AppState>,
820    Json(req): Json<StreamRequest>,
821) -> Json<StreamResponse> {
822    // Apply stream request parameters to server state
823    if let Some(eps) = req.events_per_second {
824        info!("Stream configured: events_per_second={}", eps);
825        state
826            .server_state
827            .stream_events_per_second
828            .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
829    }
830    if let Some(max) = req.max_events {
831        info!("Stream configured: max_events={}", max);
832        state
833            .server_state
834            .stream_max_events
835            .store(max, std::sync::atomic::Ordering::Relaxed);
836    }
837    if let Some(inject) = req.inject_anomalies {
838        info!("Stream configured: inject_anomalies={}", inject);
839        state
840            .server_state
841            .stream_inject_anomalies
842            .store(inject, std::sync::atomic::Ordering::Relaxed);
843    }
844
845    state
846        .server_state
847        .stream_stopped
848        .store(false, std::sync::atomic::Ordering::Relaxed);
849    state
850        .server_state
851        .stream_paused
852        .store(false, std::sync::atomic::Ordering::Relaxed);
853
854    Json(StreamResponse {
855        success: true,
856        message: "Stream started".to_string(),
857    })
858}
859
860/// Stop streaming.
861async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
862    state
863        .server_state
864        .stream_stopped
865        .store(true, std::sync::atomic::Ordering::Relaxed);
866
867    Json(StreamResponse {
868        success: true,
869        message: "Stream stopped".to_string(),
870    })
871}
872
873/// Pause streaming.
874async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
875    state
876        .server_state
877        .stream_paused
878        .store(true, std::sync::atomic::Ordering::Relaxed);
879
880    Json(StreamResponse {
881        success: true,
882        message: "Stream paused".to_string(),
883    })
884}
885
886/// Resume streaming.
887async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
888    state
889        .server_state
890        .stream_paused
891        .store(false, std::sync::atomic::Ordering::Relaxed);
892
893    Json(StreamResponse {
894        success: true,
895        message: "Stream resumed".to_string(),
896    })
897}
898
899/// Trigger a specific pattern.
900///
901/// Valid patterns: year_end_spike, period_end_spike, holiday_cluster,
902/// fraud_cluster, error_cluster, uniform, or custom:* patterns.
903async fn trigger_pattern(
904    State(state): State<AppState>,
905    axum::extract::Path(pattern): axum::extract::Path<String>,
906) -> Json<StreamResponse> {
907    info!("Pattern trigger requested: {}", pattern);
908
909    // Validate pattern name
910    let valid_patterns = [
911        "year_end_spike",
912        "period_end_spike",
913        "holiday_cluster",
914        "fraud_cluster",
915        "error_cluster",
916        "uniform",
917    ];
918
919    let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
920
921    if !is_valid {
922        return Json(StreamResponse {
923            success: false,
924            message: format!(
925                "Unknown pattern '{}'. Valid patterns: {:?}, or use 'custom:name' for custom patterns",
926                pattern, valid_patterns
927            ),
928        });
929    }
930
931    // Store the pattern for the stream generator to pick up
932    match state.server_state.triggered_pattern.try_write() {
933        Ok(mut triggered) => {
934            *triggered = Some(pattern.clone());
935            Json(StreamResponse {
936                success: true,
937                message: format!("Pattern '{}' will be applied to upcoming entries", pattern),
938            })
939        }
940        Err(_) => Json(StreamResponse {
941            success: false,
942            message: "Failed to acquire lock for pattern trigger".to_string(),
943        }),
944    }
945}
946
947/// WebSocket endpoint for metrics stream.
948async fn websocket_metrics(
949    ws: WebSocketUpgrade,
950    State(state): State<AppState>,
951) -> impl IntoResponse {
952    ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
953}
954
955/// WebSocket endpoint for event stream.
956async fn websocket_events(
957    ws: WebSocketUpgrade,
958    State(state): State<AppState>,
959) -> impl IntoResponse {
960    ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
961}
962
963// ===========================================================================
964// Job Queue Handlers
965// ===========================================================================
966
967/// Submit a new async generation job.
968async fn submit_job(
969    State(state): State<AppState>,
970    Json(request): Json<JobRequest>,
971) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
972    let queue = state.job_queue.as_ref().ok_or_else(|| {
973        (
974            StatusCode::SERVICE_UNAVAILABLE,
975            Json(serde_json::json!({"error": "Job queue not enabled"})),
976        )
977    })?;
978
979    let job_id = queue.submit(request).await;
980    info!("Job submitted: {}", job_id);
981
982    Ok((
983        StatusCode::CREATED,
984        Json(serde_json::json!({
985            "id": job_id.to_string(),
986            "status": "queued"
987        })),
988    ))
989}
990
991/// Get status of a specific job.
992async fn get_job(
993    State(state): State<AppState>,
994    axum::extract::Path(id): axum::extract::Path<String>,
995) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
996    let queue = state.job_queue.as_ref().ok_or_else(|| {
997        (
998            StatusCode::SERVICE_UNAVAILABLE,
999            Json(serde_json::json!({"error": "Job queue not enabled"})),
1000        )
1001    })?;
1002
1003    match queue.get(&id).await {
1004        Some(entry) => Ok(Json(serde_json::json!({
1005            "id": entry.id,
1006            "status": format!("{:?}", entry.status).to_lowercase(),
1007            "submitted_at": entry.submitted_at.to_rfc3339(),
1008            "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1009            "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1010            "result": entry.result,
1011        }))),
1012        None => Err((
1013            StatusCode::NOT_FOUND,
1014            Json(serde_json::json!({"error": "Job not found"})),
1015        )),
1016    }
1017}
1018
1019/// List all jobs.
1020async fn list_jobs(
1021    State(state): State<AppState>,
1022) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1023    let queue = state.job_queue.as_ref().ok_or_else(|| {
1024        (
1025            StatusCode::SERVICE_UNAVAILABLE,
1026            Json(serde_json::json!({"error": "Job queue not enabled"})),
1027        )
1028    })?;
1029
1030    let summaries: Vec<_> = queue
1031        .list()
1032        .await
1033        .into_iter()
1034        .map(|s| {
1035            serde_json::json!({
1036                "id": s.id,
1037                "status": format!("{:?}", s.status).to_lowercase(),
1038                "submitted_at": s.submitted_at.to_rfc3339(),
1039            })
1040        })
1041        .collect();
1042
1043    Ok(Json(serde_json::json!({ "jobs": summaries })))
1044}
1045
1046/// Cancel a queued job.
1047async fn cancel_job(
1048    State(state): State<AppState>,
1049    axum::extract::Path(id): axum::extract::Path<String>,
1050) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1051    let queue = state.job_queue.as_ref().ok_or_else(|| {
1052        (
1053            StatusCode::SERVICE_UNAVAILABLE,
1054            Json(serde_json::json!({"error": "Job queue not enabled"})),
1055        )
1056    })?;
1057
1058    if queue.cancel(&id).await {
1059        Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1060    } else {
1061        Err((
1062            StatusCode::CONFLICT,
1063            Json(
1064                serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1065            ),
1066        ))
1067    }
1068}
1069
1070// ===========================================================================
1071// Config Reload Handler
1072// ===========================================================================
1073
1074/// Reload configuration from the configured source.
1075async fn reload_config(
1076    State(state): State<AppState>,
1077) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1078    let source = state.server_state.config_source.read().await.clone();
1079    match crate::config_loader::load_config(&source).await {
1080        Ok(new_config) => {
1081            let mut config = state.server_state.config.write().await;
1082            *config = new_config;
1083            info!("Configuration reloaded via REST API from {:?}", source);
1084            Ok(Json(serde_json::json!({
1085                "success": true,
1086                "message": "Configuration reloaded"
1087            })))
1088        }
1089        Err(e) => {
1090            error!("Failed to reload configuration: {}", e);
1091            Err((
1092                StatusCode::INTERNAL_SERVER_ERROR,
1093                Json(serde_json::json!({
1094                    "success": false,
1095                    "message": format!("Failed to reload configuration: {}", e)
1096                })),
1097            ))
1098        }
1099    }
1100}
1101
1102#[cfg(test)]
1103#[allow(clippy::unwrap_used)]
1104mod tests {
1105    use super::*;
1106
1107    // ==========================================================================
1108    // Response Serialization Tests
1109    // ==========================================================================
1110
1111    #[test]
1112    fn test_health_response_serialization() {
1113        let response = HealthResponse {
1114            healthy: true,
1115            version: "0.1.0".to_string(),
1116            uptime_seconds: 100,
1117        };
1118        let json = serde_json::to_string(&response).unwrap();
1119        assert!(json.contains("healthy"));
1120        assert!(json.contains("version"));
1121        assert!(json.contains("uptime_seconds"));
1122    }
1123
1124    #[test]
1125    fn test_health_response_deserialization() {
1126        let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1127        let response: HealthResponse = serde_json::from_str(json).unwrap();
1128        assert!(response.healthy);
1129        assert_eq!(response.version, "0.1.0");
1130        assert_eq!(response.uptime_seconds, 100);
1131    }
1132
1133    #[test]
1134    fn test_metrics_response_serialization() {
1135        let response = MetricsResponse {
1136            total_entries_generated: 1000,
1137            total_anomalies_injected: 10,
1138            uptime_seconds: 60,
1139            session_entries: 1000,
1140            session_entries_per_second: 16.67,
1141            active_streams: 1,
1142            total_stream_events: 500,
1143        };
1144        let json = serde_json::to_string(&response).unwrap();
1145        assert!(json.contains("total_entries_generated"));
1146        assert!(json.contains("session_entries_per_second"));
1147    }
1148
1149    #[test]
1150    fn test_metrics_response_deserialization() {
1151        let json = r#"{
1152            "total_entries_generated": 5000,
1153            "total_anomalies_injected": 50,
1154            "uptime_seconds": 300,
1155            "session_entries": 5000,
1156            "session_entries_per_second": 16.67,
1157            "active_streams": 2,
1158            "total_stream_events": 10000
1159        }"#;
1160        let response: MetricsResponse = serde_json::from_str(json).unwrap();
1161        assert_eq!(response.total_entries_generated, 5000);
1162        assert_eq!(response.active_streams, 2);
1163    }
1164
1165    #[test]
1166    fn test_config_response_serialization() {
1167        let response = ConfigResponse {
1168            success: true,
1169            message: "Configuration loaded".to_string(),
1170            config: Some(GenerationConfigDto {
1171                industry: "manufacturing".to_string(),
1172                start_date: "2024-01-01".to_string(),
1173                period_months: 12,
1174                seed: Some(42),
1175                coa_complexity: "medium".to_string(),
1176                companies: vec![],
1177                fraud_enabled: false,
1178                fraud_rate: 0.0,
1179            }),
1180        };
1181        let json = serde_json::to_string(&response).unwrap();
1182        assert!(json.contains("success"));
1183        assert!(json.contains("config"));
1184    }
1185
1186    #[test]
1187    fn test_config_response_without_config() {
1188        let response = ConfigResponse {
1189            success: false,
1190            message: "No configuration available".to_string(),
1191            config: None,
1192        };
1193        let json = serde_json::to_string(&response).unwrap();
1194        assert!(json.contains("null") || json.contains("config\":null"));
1195    }
1196
1197    #[test]
1198    fn test_generation_config_dto_roundtrip() {
1199        let original = GenerationConfigDto {
1200            industry: "retail".to_string(),
1201            start_date: "2024-06-01".to_string(),
1202            period_months: 6,
1203            seed: Some(12345),
1204            coa_complexity: "large".to_string(),
1205            companies: vec![CompanyConfigDto {
1206                code: "1000".to_string(),
1207                name: "Test Corp".to_string(),
1208                currency: "USD".to_string(),
1209                country: "US".to_string(),
1210                annual_transaction_volume: 100000,
1211                volume_weight: 1.0,
1212            }],
1213            fraud_enabled: true,
1214            fraud_rate: 0.05,
1215        };
1216
1217        let json = serde_json::to_string(&original).unwrap();
1218        let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1219
1220        assert_eq!(original.industry, deserialized.industry);
1221        assert_eq!(original.seed, deserialized.seed);
1222        assert_eq!(original.companies.len(), deserialized.companies.len());
1223    }
1224
1225    #[test]
1226    fn test_company_config_dto_serialization() {
1227        let company = CompanyConfigDto {
1228            code: "2000".to_string(),
1229            name: "European Subsidiary".to_string(),
1230            currency: "EUR".to_string(),
1231            country: "DE".to_string(),
1232            annual_transaction_volume: 50000,
1233            volume_weight: 0.5,
1234        };
1235        let json = serde_json::to_string(&company).unwrap();
1236        assert!(json.contains("2000"));
1237        assert!(json.contains("EUR"));
1238        assert!(json.contains("DE"));
1239    }
1240
1241    #[test]
1242    fn test_bulk_generate_request_deserialization() {
1243        let json = r#"{
1244            "entry_count": 5000,
1245            "include_master_data": true,
1246            "inject_anomalies": true
1247        }"#;
1248        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1249        assert_eq!(request.entry_count, Some(5000));
1250        assert_eq!(request.include_master_data, Some(true));
1251        assert_eq!(request.inject_anomalies, Some(true));
1252    }
1253
1254    #[test]
1255    fn test_bulk_generate_request_with_defaults() {
1256        let json = r#"{}"#;
1257        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1258        assert_eq!(request.entry_count, None);
1259        assert_eq!(request.include_master_data, None);
1260        assert_eq!(request.inject_anomalies, None);
1261    }
1262
1263    #[test]
1264    fn test_bulk_generate_response_serialization() {
1265        let response = BulkGenerateResponse {
1266            success: true,
1267            entries_generated: 1000,
1268            duration_ms: 250,
1269            anomaly_count: 20,
1270        };
1271        let json = serde_json::to_string(&response).unwrap();
1272        assert!(json.contains("entries_generated"));
1273        assert!(json.contains("1000"));
1274        assert!(json.contains("duration_ms"));
1275    }
1276
1277    #[test]
1278    fn test_stream_response_serialization() {
1279        let response = StreamResponse {
1280            success: true,
1281            message: "Stream started successfully".to_string(),
1282        };
1283        let json = serde_json::to_string(&response).unwrap();
1284        assert!(json.contains("success"));
1285        assert!(json.contains("Stream started"));
1286    }
1287
1288    #[test]
1289    fn test_stream_response_failure() {
1290        let response = StreamResponse {
1291            success: false,
1292            message: "Stream failed to start".to_string(),
1293        };
1294        let json = serde_json::to_string(&response).unwrap();
1295        assert!(json.contains("false"));
1296        assert!(json.contains("failed"));
1297    }
1298
1299    // ==========================================================================
1300    // CORS Configuration Tests
1301    // ==========================================================================
1302
1303    #[test]
1304    fn test_cors_config_default() {
1305        let config = CorsConfig::default();
1306        assert!(!config.allow_any_origin);
1307        assert!(!config.allowed_origins.is_empty());
1308        assert!(config
1309            .allowed_origins
1310            .contains(&"http://localhost:5173".to_string()));
1311        assert!(config
1312            .allowed_origins
1313            .contains(&"tauri://localhost".to_string()));
1314    }
1315
1316    #[test]
1317    fn test_cors_config_custom_origins() {
1318        let config = CorsConfig {
1319            allowed_origins: vec![
1320                "https://example.com".to_string(),
1321                "https://app.example.com".to_string(),
1322            ],
1323            allow_any_origin: false,
1324        };
1325        assert_eq!(config.allowed_origins.len(), 2);
1326        assert!(config
1327            .allowed_origins
1328            .contains(&"https://example.com".to_string()));
1329    }
1330
1331    #[test]
1332    fn test_cors_config_permissive() {
1333        let config = CorsConfig {
1334            allowed_origins: vec![],
1335            allow_any_origin: true,
1336        };
1337        assert!(config.allow_any_origin);
1338    }
1339
1340    // ==========================================================================
1341    // Request Validation Tests (edge cases)
1342    // ==========================================================================
1343
1344    #[test]
1345    fn test_bulk_generate_request_partial() {
1346        let json = r#"{"entry_count": 100}"#;
1347        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1348        assert_eq!(request.entry_count, Some(100));
1349        assert!(request.include_master_data.is_none());
1350    }
1351
1352    #[test]
1353    fn test_generation_config_no_seed() {
1354        let config = GenerationConfigDto {
1355            industry: "technology".to_string(),
1356            start_date: "2024-01-01".to_string(),
1357            period_months: 3,
1358            seed: None,
1359            coa_complexity: "small".to_string(),
1360            companies: vec![],
1361            fraud_enabled: false,
1362            fraud_rate: 0.0,
1363        };
1364        let json = serde_json::to_string(&config).unwrap();
1365        assert!(json.contains("seed"));
1366    }
1367
1368    #[test]
1369    fn test_generation_config_multiple_companies() {
1370        let config = GenerationConfigDto {
1371            industry: "manufacturing".to_string(),
1372            start_date: "2024-01-01".to_string(),
1373            period_months: 12,
1374            seed: Some(42),
1375            coa_complexity: "large".to_string(),
1376            companies: vec![
1377                CompanyConfigDto {
1378                    code: "1000".to_string(),
1379                    name: "Headquarters".to_string(),
1380                    currency: "USD".to_string(),
1381                    country: "US".to_string(),
1382                    annual_transaction_volume: 100000,
1383                    volume_weight: 1.0,
1384                },
1385                CompanyConfigDto {
1386                    code: "2000".to_string(),
1387                    name: "European Sub".to_string(),
1388                    currency: "EUR".to_string(),
1389                    country: "DE".to_string(),
1390                    annual_transaction_volume: 50000,
1391                    volume_weight: 0.5,
1392                },
1393                CompanyConfigDto {
1394                    code: "3000".to_string(),
1395                    name: "APAC Sub".to_string(),
1396                    currency: "JPY".to_string(),
1397                    country: "JP".to_string(),
1398                    annual_transaction_volume: 30000,
1399                    volume_weight: 0.3,
1400                },
1401            ],
1402            fraud_enabled: true,
1403            fraud_rate: 0.02,
1404        };
1405        assert_eq!(config.companies.len(), 3);
1406    }
1407
1408    // ==========================================================================
1409    // Metrics Calculation Tests
1410    // ==========================================================================
1411
1412    #[test]
1413    fn test_metrics_entries_per_second_calculation() {
1414        // Test that we can represent the expected calculation
1415        let total_entries: u64 = 1000;
1416        let uptime: u64 = 60;
1417        let eps = if uptime > 0 {
1418            total_entries as f64 / uptime as f64
1419        } else {
1420            0.0
1421        };
1422        assert!((eps - 16.67).abs() < 0.1);
1423    }
1424
1425    #[test]
1426    fn test_metrics_entries_per_second_zero_uptime() {
1427        let total_entries: u64 = 1000;
1428        let uptime: u64 = 0;
1429        let eps = if uptime > 0 {
1430            total_entries as f64 / uptime as f64
1431        } else {
1432            0.0
1433        };
1434        assert_eq!(eps, 0.0);
1435    }
1436}