Skip to main content

datasynth_server/rest/
routes.rs

1//! REST API routes.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7    extract::{State, WebSocketUpgrade},
8    http::{header, Method, StatusCode},
9    response::IntoResponse,
10    routing::{get, post},
11    Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::info;
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24/// Application state shared across handlers.
25#[derive(Clone)]
26pub struct AppState {
27    #[allow(dead_code)] // Reserved for future use
28    pub service: Arc<SynthService>,
29    pub server_state: Arc<ServerState>,
30    pub job_queue: Option<Arc<JobQueue>>,
31}
32
33/// Timeout configuration for the REST API.
34#[derive(Clone, Debug)]
35pub struct TimeoutConfig {
36    /// Request timeout in seconds.
37    pub request_timeout_secs: u64,
38}
39
40impl Default for TimeoutConfig {
41    fn default() -> Self {
42        Self {
43            // 5 minutes default - bulk generation can take a while
44            request_timeout_secs: 300,
45        }
46    }
47}
48
49impl TimeoutConfig {
50    /// Create a new timeout config.
51    pub fn new(timeout_secs: u64) -> Self {
52        Self {
53            request_timeout_secs: timeout_secs,
54        }
55    }
56}
57
58/// CORS configuration for the REST API.
59#[derive(Clone)]
60pub struct CorsConfig {
61    /// Allowed origins. If empty, only localhost is allowed.
62    pub allowed_origins: Vec<String>,
63    /// Allow any origin (development mode only - NOT recommended for production).
64    pub allow_any_origin: bool,
65}
66
67impl Default for CorsConfig {
68    fn default() -> Self {
69        Self {
70            allowed_origins: vec![
71                "http://localhost:5173".to_string(), // Vite dev server
72                "http://localhost:3000".to_string(), // Common dev server
73                "http://127.0.0.1:5173".to_string(),
74                "http://127.0.0.1:3000".to_string(),
75                "tauri://localhost".to_string(), // Tauri app
76            ],
77            allow_any_origin: false,
78        }
79    }
80}
81
82/// Add API version header to responses.
83async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
84    let (mut parts, body) = response.into_parts();
85    parts.headers.insert(
86        axum::http::HeaderName::from_static("x-api-version"),
87        axum::http::HeaderValue::from_static("v1"),
88    );
89    axum::response::Response::from_parts(parts, body)
90}
91
92use super::auth::{auth_middleware, AuthConfig};
93use super::rate_limit::RateLimitConfig;
94use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
95use super::request_id::request_id_middleware;
96use super::request_validation::request_validation_middleware;
97use super::security_headers::security_headers_middleware;
98
99/// Create the REST API router with default CORS settings.
100pub fn create_router(service: SynthService) -> Router {
101    create_router_with_cors(service, CorsConfig::default())
102}
103
104/// Create the REST API router with full configuration (CORS, auth, rate limiting, and timeout).
105///
106/// Uses in-memory rate limiting by default. For distributed rate limiting
107/// with Redis, use [`create_router_full_with_backend`] instead.
108pub fn create_router_full(
109    service: SynthService,
110    cors_config: CorsConfig,
111    auth_config: AuthConfig,
112    rate_limit_config: RateLimitConfig,
113    timeout_config: TimeoutConfig,
114) -> Router {
115    let backend = RateLimitBackend::in_memory(rate_limit_config);
116    create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
117}
118
119/// Create the REST API router with full configuration and a specific rate limiting backend.
120///
121/// This allows using either in-memory or Redis-backed rate limiting.
122///
123/// # Example (in-memory)
124/// ```rust,ignore
125/// let backend = RateLimitBackend::in_memory(rate_limit_config);
126/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
127/// ```
128///
129/// # Example (Redis)
130/// ```rust,ignore
131/// let backend = RateLimitBackend::redis("redis://127.0.0.1:6379", rate_limit_config).await?;
132/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
133/// ```
134pub fn create_router_full_with_backend(
135    service: SynthService,
136    cors_config: CorsConfig,
137    auth_config: AuthConfig,
138    rate_limit_backend: RateLimitBackend,
139    timeout_config: TimeoutConfig,
140) -> Router {
141    let server_state = service.state.clone();
142    let state = AppState {
143        service: Arc::new(service),
144        server_state,
145        job_queue: None,
146    };
147
148    let cors = if cors_config.allow_any_origin {
149        CorsLayer::permissive()
150    } else {
151        let origins: Vec<_> = cors_config
152            .allowed_origins
153            .iter()
154            .filter_map(|o| o.parse().ok())
155            .collect();
156
157        CorsLayer::new()
158            .allow_origin(AllowOrigin::list(origins))
159            .allow_methods([
160                Method::GET,
161                Method::POST,
162                Method::PUT,
163                Method::DELETE,
164                Method::OPTIONS,
165            ])
166            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
167    };
168
169    Router::new()
170        // Health and metrics (exempt from auth and rate limiting by default)
171        .route("/health", get(health_check))
172        .route("/ready", get(readiness_check))
173        .route("/live", get(liveness_check))
174        .route("/api/metrics", get(get_metrics))
175        .route("/metrics", get(prometheus_metrics))
176        // Configuration
177        .route("/api/config", get(get_config))
178        .route("/api/config", post(set_config))
179        .route("/api/config/reload", post(reload_config))
180        // Generation
181        .route("/api/generate/bulk", post(bulk_generate))
182        .route("/api/stream/start", post(start_stream))
183        .route("/api/stream/stop", post(stop_stream))
184        .route("/api/stream/pause", post(pause_stream))
185        .route("/api/stream/resume", post(resume_stream))
186        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
187        // Jobs
188        .route("/api/jobs/submit", post(submit_job))
189        .route("/api/jobs", get(list_jobs))
190        .route("/api/jobs/{id}", get(get_job))
191        .route("/api/jobs/{id}/cancel", post(cancel_job))
192        // WebSocket
193        .route("/ws/metrics", get(websocket_metrics))
194        .route("/ws/events", get(websocket_events))
195        // Middleware stack (outermost applied first, innermost last)
196        // Order: Timeout -> RateLimit -> RequestValidation -> Auth -> RequestId -> CORS -> SecurityHeaders -> APIVersion -> Router
197        .layer(axum::middleware::from_fn(security_headers_middleware))
198        .layer(axum::middleware::map_response(api_version_header))
199        .layer(cors)
200        .layer(axum::middleware::from_fn(request_id_middleware))
201        .layer(axum::middleware::from_fn(auth_middleware))
202        .layer(axum::Extension(auth_config))
203        .layer(axum::middleware::from_fn(request_validation_middleware))
204        .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
205        .layer(axum::Extension(rate_limit_backend))
206        .layer(TimeoutLayer::new(Duration::from_secs(
207            timeout_config.request_timeout_secs,
208        )))
209        .with_state(state)
210}
211
212/// Create the REST API router with custom CORS and authentication settings.
213pub fn create_router_with_auth(
214    service: SynthService,
215    cors_config: CorsConfig,
216    auth_config: AuthConfig,
217) -> Router {
218    let server_state = service.state.clone();
219    let state = AppState {
220        service: Arc::new(service),
221        server_state,
222        job_queue: None,
223    };
224
225    let cors = if cors_config.allow_any_origin {
226        CorsLayer::permissive()
227    } else {
228        let origins: Vec<_> = cors_config
229            .allowed_origins
230            .iter()
231            .filter_map(|o| o.parse().ok())
232            .collect();
233
234        CorsLayer::new()
235            .allow_origin(AllowOrigin::list(origins))
236            .allow_methods([
237                Method::GET,
238                Method::POST,
239                Method::PUT,
240                Method::DELETE,
241                Method::OPTIONS,
242            ])
243            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
244    };
245
246    Router::new()
247        // Health and metrics (exempt from auth by default)
248        .route("/health", get(health_check))
249        .route("/ready", get(readiness_check))
250        .route("/live", get(liveness_check))
251        .route("/api/metrics", get(get_metrics))
252        .route("/metrics", get(prometheus_metrics))
253        // Configuration
254        .route("/api/config", get(get_config))
255        .route("/api/config", post(set_config))
256        .route("/api/config/reload", post(reload_config))
257        // Generation
258        .route("/api/generate/bulk", post(bulk_generate))
259        .route("/api/stream/start", post(start_stream))
260        .route("/api/stream/stop", post(stop_stream))
261        .route("/api/stream/pause", post(pause_stream))
262        .route("/api/stream/resume", post(resume_stream))
263        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
264        // Jobs
265        .route("/api/jobs/submit", post(submit_job))
266        .route("/api/jobs", get(list_jobs))
267        .route("/api/jobs/{id}", get(get_job))
268        .route("/api/jobs/{id}/cancel", post(cancel_job))
269        // WebSocket
270        .route("/ws/metrics", get(websocket_metrics))
271        .route("/ws/events", get(websocket_events))
272        .layer(axum::middleware::from_fn(auth_middleware))
273        .layer(axum::Extension(auth_config))
274        .layer(cors)
275        .with_state(state)
276}
277
278/// Create the REST API router with custom CORS settings.
279pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
280    let server_state = service.state.clone();
281    let state = AppState {
282        service: Arc::new(service),
283        server_state,
284        job_queue: None,
285    };
286
287    let cors = if cors_config.allow_any_origin {
288        // Development mode - allow any origin (use with caution)
289        CorsLayer::permissive()
290    } else {
291        // Production mode - restricted origins
292        let origins: Vec<_> = cors_config
293            .allowed_origins
294            .iter()
295            .filter_map(|o| o.parse().ok())
296            .collect();
297
298        CorsLayer::new()
299            .allow_origin(AllowOrigin::list(origins))
300            .allow_methods([
301                Method::GET,
302                Method::POST,
303                Method::PUT,
304                Method::DELETE,
305                Method::OPTIONS,
306            ])
307            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
308    };
309
310    Router::new()
311        // Health and metrics
312        .route("/health", get(health_check))
313        .route("/ready", get(readiness_check))
314        .route("/live", get(liveness_check))
315        .route("/api/metrics", get(get_metrics))
316        .route("/metrics", get(prometheus_metrics))
317        // Configuration
318        .route("/api/config", get(get_config))
319        .route("/api/config", post(set_config))
320        .route("/api/config/reload", post(reload_config))
321        // Generation
322        .route("/api/generate/bulk", post(bulk_generate))
323        .route("/api/stream/start", post(start_stream))
324        .route("/api/stream/stop", post(stop_stream))
325        .route("/api/stream/pause", post(pause_stream))
326        .route("/api/stream/resume", post(resume_stream))
327        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
328        // Jobs
329        .route("/api/jobs/submit", post(submit_job))
330        .route("/api/jobs", get(list_jobs))
331        .route("/api/jobs/{id}", get(get_job))
332        .route("/api/jobs/{id}/cancel", post(cancel_job))
333        // WebSocket
334        .route("/ws/metrics", get(websocket_metrics))
335        .route("/ws/events", get(websocket_events))
336        .layer(cors)
337        .with_state(state)
338}
339
340// ===========================================================================
341// Request/Response types
342// ===========================================================================
343
344#[derive(Debug, Serialize, Deserialize)]
345pub struct HealthResponse {
346    pub healthy: bool,
347    pub version: String,
348    pub uptime_seconds: u64,
349}
350
351/// Readiness check response for Kubernetes.
352#[derive(Debug, Serialize, Deserialize)]
353pub struct ReadinessResponse {
354    pub ready: bool,
355    pub message: String,
356    pub checks: Vec<HealthCheck>,
357}
358
359/// Individual health check result.
360#[derive(Debug, Serialize, Deserialize)]
361pub struct HealthCheck {
362    pub name: String,
363    pub status: String,
364}
365
366/// Liveness check response for Kubernetes.
367#[derive(Debug, Serialize, Deserialize)]
368pub struct LivenessResponse {
369    pub alive: bool,
370    pub timestamp: String,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374pub struct MetricsResponse {
375    pub total_entries_generated: u64,
376    pub total_anomalies_injected: u64,
377    pub uptime_seconds: u64,
378    pub session_entries: u64,
379    pub session_entries_per_second: f64,
380    pub active_streams: u32,
381    pub total_stream_events: u64,
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct ConfigResponse {
386    pub success: bool,
387    pub message: String,
388    pub config: Option<GenerationConfigDto>,
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct GenerationConfigDto {
393    pub industry: String,
394    pub start_date: String,
395    pub period_months: u32,
396    pub seed: Option<u64>,
397    pub coa_complexity: String,
398    pub companies: Vec<CompanyConfigDto>,
399    pub fraud_enabled: bool,
400    pub fraud_rate: f32,
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct CompanyConfigDto {
405    pub code: String,
406    pub name: String,
407    pub currency: String,
408    pub country: String,
409    pub annual_transaction_volume: u64,
410    pub volume_weight: f32,
411}
412
413#[derive(Debug, Deserialize)]
414pub struct BulkGenerateRequest {
415    pub entry_count: Option<u64>,
416    pub include_master_data: Option<bool>,
417    pub inject_anomalies: Option<bool>,
418}
419
420#[derive(Debug, Serialize)]
421pub struct BulkGenerateResponse {
422    pub success: bool,
423    pub entries_generated: u64,
424    pub duration_ms: u64,
425    pub anomaly_count: u64,
426}
427
428#[derive(Debug, Deserialize)]
429#[allow(dead_code)] // Fields deserialized from request, reserved for future use
430pub struct StreamRequest {
431    pub events_per_second: Option<u32>,
432    pub max_events: Option<u64>,
433    pub inject_anomalies: Option<bool>,
434}
435
436#[derive(Debug, Serialize)]
437pub struct StreamResponse {
438    pub success: bool,
439    pub message: String,
440}
441
442// ===========================================================================
443// Handlers
444// ===========================================================================
445
446/// Health check endpoint - returns overall health status.
447async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
448    Json(HealthResponse {
449        healthy: true,
450        version: env!("CARGO_PKG_VERSION").to_string(),
451        uptime_seconds: state.server_state.uptime_seconds(),
452    })
453}
454
455/// Readiness probe - indicates the service is ready to accept traffic.
456/// Use for Kubernetes readiness probes.
457async fn readiness_check(
458    State(state): State<AppState>,
459) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
460    let mut checks = Vec::new();
461    let mut any_fail = false;
462
463    // Check if configuration is loaded and valid
464    let config = state.server_state.config.read().await;
465    let config_valid = !config.companies.is_empty();
466    checks.push(HealthCheck {
467        name: "config".to_string(),
468        status: if config_valid { "ok" } else { "fail" }.to_string(),
469    });
470    if !config_valid {
471        any_fail = true;
472    }
473    drop(config);
474
475    // Check resource guard (memory)
476    let resource_status = state.server_state.resource_status();
477    let memory_status = if resource_status.degradation_level == "Emergency" {
478        any_fail = true;
479        "fail"
480    } else if resource_status.degradation_level != "Normal" {
481        "degraded"
482    } else {
483        "ok"
484    };
485    checks.push(HealthCheck {
486        name: "memory".to_string(),
487        status: memory_status.to_string(),
488    });
489
490    // Check disk (>100MB free)
491    let disk_ok = resource_status.disk_available_mb > 100;
492    checks.push(HealthCheck {
493        name: "disk".to_string(),
494        status: if disk_ok { "ok" } else { "fail" }.to_string(),
495    });
496    if !disk_ok {
497        any_fail = true;
498    }
499
500    let response = ReadinessResponse {
501        ready: !any_fail,
502        message: if any_fail {
503            "Service is not ready".to_string()
504        } else {
505            "Service is ready".to_string()
506        },
507        checks,
508    };
509
510    if any_fail {
511        Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
512    } else {
513        Ok(Json(response))
514    }
515}
516
517/// Liveness probe - indicates the service is alive.
518/// Use for Kubernetes liveness probes.
519async fn liveness_check() -> Json<LivenessResponse> {
520    Json(LivenessResponse {
521        alive: true,
522        timestamp: chrono::Utc::now().to_rfc3339(),
523    })
524}
525
526/// Prometheus-compatible metrics endpoint.
527/// Returns metrics in Prometheus text exposition format.
528async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
529    use std::sync::atomic::Ordering;
530
531    let uptime = state.server_state.uptime_seconds();
532    let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
533    let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
534    let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
535    let total_stream_events = state
536        .server_state
537        .total_stream_events
538        .load(Ordering::Relaxed);
539
540    let entries_per_second = if uptime > 0 {
541        total_entries as f64 / uptime as f64
542    } else {
543        0.0
544    };
545
546    let metrics = format!(
547        r#"# HELP synth_entries_generated_total Total number of journal entries generated
548# TYPE synth_entries_generated_total counter
549synth_entries_generated_total {}
550
551# HELP synth_anomalies_injected_total Total number of anomalies injected
552# TYPE synth_anomalies_injected_total counter
553synth_anomalies_injected_total {}
554
555# HELP synth_uptime_seconds Server uptime in seconds
556# TYPE synth_uptime_seconds gauge
557synth_uptime_seconds {}
558
559# HELP synth_entries_per_second Rate of entry generation
560# TYPE synth_entries_per_second gauge
561synth_entries_per_second {:.2}
562
563# HELP synth_active_streams Number of active streaming connections
564# TYPE synth_active_streams gauge
565synth_active_streams {}
566
567# HELP synth_stream_events_total Total events sent through streams
568# TYPE synth_stream_events_total counter
569synth_stream_events_total {}
570
571# HELP synth_info Server version information
572# TYPE synth_info gauge
573synth_info{{version="{}"}} 1
574"#,
575        total_entries,
576        total_anomalies,
577        uptime,
578        entries_per_second,
579        active_streams,
580        total_stream_events,
581        env!("CARGO_PKG_VERSION")
582    );
583
584    (
585        StatusCode::OK,
586        [(
587            header::CONTENT_TYPE,
588            "text/plain; version=0.0.4; charset=utf-8",
589        )],
590        metrics,
591    )
592}
593
594/// Get server metrics.
595async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
596    let uptime = state.server_state.uptime_seconds();
597    let total_entries = state
598        .server_state
599        .total_entries
600        .load(std::sync::atomic::Ordering::Relaxed);
601
602    let entries_per_second = if uptime > 0 {
603        total_entries as f64 / uptime as f64
604    } else {
605        0.0
606    };
607
608    Json(MetricsResponse {
609        total_entries_generated: total_entries,
610        total_anomalies_injected: state
611            .server_state
612            .total_anomalies
613            .load(std::sync::atomic::Ordering::Relaxed),
614        uptime_seconds: uptime,
615        session_entries: total_entries,
616        session_entries_per_second: entries_per_second,
617        active_streams: state
618            .server_state
619            .active_streams
620            .load(std::sync::atomic::Ordering::Relaxed) as u32,
621        total_stream_events: state
622            .server_state
623            .total_stream_events
624            .load(std::sync::atomic::Ordering::Relaxed),
625    })
626}
627
628/// Get current configuration.
629async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
630    let config = state.server_state.config.read().await;
631
632    Json(ConfigResponse {
633        success: true,
634        message: "Current configuration".to_string(),
635        config: Some(GenerationConfigDto {
636            industry: format!("{:?}", config.global.industry),
637            start_date: config.global.start_date.clone(),
638            period_months: config.global.period_months,
639            seed: config.global.seed,
640            coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
641            companies: config
642                .companies
643                .iter()
644                .map(|c| CompanyConfigDto {
645                    code: c.code.clone(),
646                    name: c.name.clone(),
647                    currency: c.currency.clone(),
648                    country: c.country.clone(),
649                    annual_transaction_volume: c.annual_transaction_volume.count(),
650                    volume_weight: c.volume_weight as f32,
651                })
652                .collect(),
653            fraud_enabled: config.fraud.enabled,
654            fraud_rate: config.fraud.fraud_rate as f32,
655        }),
656    })
657}
658
659/// Set configuration.
660async fn set_config(
661    State(state): State<AppState>,
662    Json(new_config): Json<GenerationConfigDto>,
663) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
664    use datasynth_config::schema::{CompanyConfig, TransactionVolume};
665    use datasynth_core::models::{CoAComplexity, IndustrySector};
666
667    info!(
668        "Configuration update requested: industry={}, period_months={}",
669        new_config.industry, new_config.period_months
670    );
671
672    // Parse industry from string
673    let industry = match new_config.industry.to_lowercase().as_str() {
674        "manufacturing" => IndustrySector::Manufacturing,
675        "retail" => IndustrySector::Retail,
676        "financial_services" | "financialservices" => IndustrySector::FinancialServices,
677        "healthcare" => IndustrySector::Healthcare,
678        "technology" => IndustrySector::Technology,
679        "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
680        "energy" => IndustrySector::Energy,
681        "transportation" => IndustrySector::Transportation,
682        "real_estate" | "realestate" => IndustrySector::RealEstate,
683        "telecommunications" => IndustrySector::Telecommunications,
684        _ => {
685            return Err((
686                StatusCode::BAD_REQUEST,
687                Json(ConfigResponse {
688                    success: false,
689                    message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
690                    config: None,
691                }),
692            ));
693        }
694    };
695
696    // Parse CoA complexity from string
697    let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
698        "small" => CoAComplexity::Small,
699        "medium" => CoAComplexity::Medium,
700        "large" => CoAComplexity::Large,
701        _ => {
702            return Err((
703                StatusCode::BAD_REQUEST,
704                Json(ConfigResponse {
705                    success: false,
706                    message: format!(
707                        "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
708                        new_config.coa_complexity
709                    ),
710                    config: None,
711                }),
712            ));
713        }
714    };
715
716    // Convert CompanyConfigDto to CompanyConfig
717    let companies: Vec<CompanyConfig> = new_config
718        .companies
719        .iter()
720        .map(|c| CompanyConfig {
721            code: c.code.clone(),
722            name: c.name.clone(),
723            currency: c.currency.clone(),
724            country: c.country.clone(),
725            fiscal_year_variant: "K4".to_string(),
726            annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
727            volume_weight: c.volume_weight as f64,
728        })
729        .collect();
730
731    // Update the configuration
732    let mut config = state.server_state.config.write().await;
733    config.global.industry = industry;
734    config.global.start_date = new_config.start_date.clone();
735    config.global.period_months = new_config.period_months;
736    config.global.seed = new_config.seed;
737    config.chart_of_accounts.complexity = complexity;
738    config.fraud.enabled = new_config.fraud_enabled;
739    config.fraud.fraud_rate = new_config.fraud_rate as f64;
740
741    // Only update companies if provided
742    if !companies.is_empty() {
743        config.companies = companies;
744    }
745
746    info!("Configuration updated successfully");
747
748    Ok(Json(ConfigResponse {
749        success: true,
750        message: "Configuration updated and applied".to_string(),
751        config: Some(new_config),
752    }))
753}
754
755/// Bulk generation endpoint.
756async fn bulk_generate(
757    State(state): State<AppState>,
758    Json(req): Json<BulkGenerateRequest>,
759) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
760    // Validate entry_count bounds
761    const MAX_ENTRY_COUNT: u64 = 1_000_000;
762    if let Some(count) = req.entry_count {
763        if count > MAX_ENTRY_COUNT {
764            return Err((
765                StatusCode::BAD_REQUEST,
766                format!(
767                    "entry_count ({}) exceeds maximum allowed value ({})",
768                    count, MAX_ENTRY_COUNT
769                ),
770            ));
771        }
772    }
773
774    let config = state.server_state.config.read().await.clone();
775    let start_time = std::time::Instant::now();
776
777    let phase_config = PhaseConfig {
778        generate_master_data: req.include_master_data.unwrap_or(false),
779        generate_document_flows: false,
780        generate_journal_entries: true,
781        inject_anomalies: req.inject_anomalies.unwrap_or(false),
782        show_progress: false,
783        ..Default::default()
784    };
785
786    let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
787        (
788            StatusCode::INTERNAL_SERVER_ERROR,
789            format!("Failed to create orchestrator: {}", e),
790        )
791    })?;
792
793    let result = orchestrator.generate().map_err(|e| {
794        (
795            StatusCode::INTERNAL_SERVER_ERROR,
796            format!("Generation failed: {}", e),
797        )
798    })?;
799
800    let duration_ms = start_time.elapsed().as_millis() as u64;
801    let entries_count = result.journal_entries.len() as u64;
802    let anomaly_count = result.anomaly_labels.labels.len() as u64;
803
804    // Update metrics
805    state
806        .server_state
807        .total_entries
808        .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
809    state
810        .server_state
811        .total_anomalies
812        .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
813
814    Ok(Json(BulkGenerateResponse {
815        success: true,
816        entries_generated: entries_count,
817        duration_ms,
818        anomaly_count,
819    }))
820}
821
822/// Start streaming.
823async fn start_stream(
824    State(state): State<AppState>,
825    Json(_req): Json<StreamRequest>,
826) -> Json<StreamResponse> {
827    state
828        .server_state
829        .stream_stopped
830        .store(false, std::sync::atomic::Ordering::Relaxed);
831    state
832        .server_state
833        .stream_paused
834        .store(false, std::sync::atomic::Ordering::Relaxed);
835
836    Json(StreamResponse {
837        success: true,
838        message: "Stream started".to_string(),
839    })
840}
841
842/// Stop streaming.
843async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
844    state
845        .server_state
846        .stream_stopped
847        .store(true, std::sync::atomic::Ordering::Relaxed);
848
849    Json(StreamResponse {
850        success: true,
851        message: "Stream stopped".to_string(),
852    })
853}
854
855/// Pause streaming.
856async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
857    state
858        .server_state
859        .stream_paused
860        .store(true, std::sync::atomic::Ordering::Relaxed);
861
862    Json(StreamResponse {
863        success: true,
864        message: "Stream paused".to_string(),
865    })
866}
867
868/// Resume streaming.
869async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
870    state
871        .server_state
872        .stream_paused
873        .store(false, std::sync::atomic::Ordering::Relaxed);
874
875    Json(StreamResponse {
876        success: true,
877        message: "Stream resumed".to_string(),
878    })
879}
880
881/// Trigger a specific pattern.
882///
883/// Valid patterns: year_end_spike, period_end_spike, holiday_cluster,
884/// fraud_cluster, error_cluster, uniform, or custom:* patterns.
885async fn trigger_pattern(
886    State(state): State<AppState>,
887    axum::extract::Path(pattern): axum::extract::Path<String>,
888) -> Json<StreamResponse> {
889    info!("Pattern trigger requested: {}", pattern);
890
891    // Validate pattern name
892    let valid_patterns = [
893        "year_end_spike",
894        "period_end_spike",
895        "holiday_cluster",
896        "fraud_cluster",
897        "error_cluster",
898        "uniform",
899    ];
900
901    let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
902
903    if !is_valid {
904        return Json(StreamResponse {
905            success: false,
906            message: format!(
907                "Unknown pattern '{}'. Valid patterns: {:?}, or use 'custom:name' for custom patterns",
908                pattern, valid_patterns
909            ),
910        });
911    }
912
913    // Store the pattern for the stream generator to pick up
914    match state.server_state.triggered_pattern.try_write() {
915        Ok(mut triggered) => {
916            *triggered = Some(pattern.clone());
917            Json(StreamResponse {
918                success: true,
919                message: format!("Pattern '{}' will be applied to upcoming entries", pattern),
920            })
921        }
922        Err(_) => Json(StreamResponse {
923            success: false,
924            message: "Failed to acquire lock for pattern trigger".to_string(),
925        }),
926    }
927}
928
929/// WebSocket endpoint for metrics stream.
930async fn websocket_metrics(
931    ws: WebSocketUpgrade,
932    State(state): State<AppState>,
933) -> impl IntoResponse {
934    ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
935}
936
937/// WebSocket endpoint for event stream.
938async fn websocket_events(
939    ws: WebSocketUpgrade,
940    State(state): State<AppState>,
941) -> impl IntoResponse {
942    ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
943}
944
945// ===========================================================================
946// Job Queue Handlers
947// ===========================================================================
948
949/// Submit a new async generation job.
950async fn submit_job(
951    State(state): State<AppState>,
952    Json(request): Json<JobRequest>,
953) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
954    let queue = state.job_queue.as_ref().ok_or_else(|| {
955        (
956            StatusCode::SERVICE_UNAVAILABLE,
957            Json(serde_json::json!({"error": "Job queue not enabled"})),
958        )
959    })?;
960
961    let job_id = queue.submit(request).await;
962    info!("Job submitted: {}", job_id);
963
964    Ok((
965        StatusCode::CREATED,
966        Json(serde_json::json!({
967            "id": job_id.to_string(),
968            "status": "queued"
969        })),
970    ))
971}
972
973/// Get status of a specific job.
974async fn get_job(
975    State(state): State<AppState>,
976    axum::extract::Path(id): axum::extract::Path<String>,
977) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
978    let queue = state.job_queue.as_ref().ok_or_else(|| {
979        (
980            StatusCode::SERVICE_UNAVAILABLE,
981            Json(serde_json::json!({"error": "Job queue not enabled"})),
982        )
983    })?;
984
985    match queue.get(&id).await {
986        Some(entry) => Ok(Json(serde_json::json!({
987            "id": entry.id,
988            "status": format!("{:?}", entry.status).to_lowercase(),
989            "submitted_at": entry.submitted_at.to_rfc3339(),
990            "started_at": entry.started_at.map(|t| t.to_rfc3339()),
991            "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
992            "result": entry.result,
993        }))),
994        None => Err((
995            StatusCode::NOT_FOUND,
996            Json(serde_json::json!({"error": "Job not found"})),
997        )),
998    }
999}
1000
1001/// List all jobs.
1002async fn list_jobs(
1003    State(state): State<AppState>,
1004) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1005    let queue = state.job_queue.as_ref().ok_or_else(|| {
1006        (
1007            StatusCode::SERVICE_UNAVAILABLE,
1008            Json(serde_json::json!({"error": "Job queue not enabled"})),
1009        )
1010    })?;
1011
1012    let summaries: Vec<_> = queue
1013        .list()
1014        .await
1015        .into_iter()
1016        .map(|s| {
1017            serde_json::json!({
1018                "id": s.id,
1019                "status": format!("{:?}", s.status).to_lowercase(),
1020                "submitted_at": s.submitted_at.to_rfc3339(),
1021            })
1022        })
1023        .collect();
1024
1025    Ok(Json(serde_json::json!({ "jobs": summaries })))
1026}
1027
1028/// Cancel a queued job.
1029async fn cancel_job(
1030    State(state): State<AppState>,
1031    axum::extract::Path(id): axum::extract::Path<String>,
1032) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1033    let queue = state.job_queue.as_ref().ok_or_else(|| {
1034        (
1035            StatusCode::SERVICE_UNAVAILABLE,
1036            Json(serde_json::json!({"error": "Job queue not enabled"})),
1037        )
1038    })?;
1039
1040    if queue.cancel(&id).await {
1041        Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1042    } else {
1043        Err((
1044            StatusCode::CONFLICT,
1045            Json(
1046                serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1047            ),
1048        ))
1049    }
1050}
1051
1052// ===========================================================================
1053// Config Reload Handler
1054// ===========================================================================
1055
1056/// Reload configuration from the configured source.
1057async fn reload_config(
1058    State(state): State<AppState>,
1059) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1060    // Reload from default config source
1061    let new_config = crate::grpc::service::default_generator_config();
1062    let mut config = state.server_state.config.write().await;
1063    *config = new_config;
1064    info!("Configuration reloaded via REST API");
1065
1066    Ok(Json(serde_json::json!({
1067        "success": true,
1068        "message": "Configuration reloaded"
1069    })))
1070}
1071
1072#[cfg(test)]
1073#[allow(clippy::unwrap_used)]
1074mod tests {
1075    use super::*;
1076
1077    // ==========================================================================
1078    // Response Serialization Tests
1079    // ==========================================================================
1080
1081    #[test]
1082    fn test_health_response_serialization() {
1083        let response = HealthResponse {
1084            healthy: true,
1085            version: "0.1.0".to_string(),
1086            uptime_seconds: 100,
1087        };
1088        let json = serde_json::to_string(&response).unwrap();
1089        assert!(json.contains("healthy"));
1090        assert!(json.contains("version"));
1091        assert!(json.contains("uptime_seconds"));
1092    }
1093
1094    #[test]
1095    fn test_health_response_deserialization() {
1096        let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1097        let response: HealthResponse = serde_json::from_str(json).unwrap();
1098        assert!(response.healthy);
1099        assert_eq!(response.version, "0.1.0");
1100        assert_eq!(response.uptime_seconds, 100);
1101    }
1102
1103    #[test]
1104    fn test_metrics_response_serialization() {
1105        let response = MetricsResponse {
1106            total_entries_generated: 1000,
1107            total_anomalies_injected: 10,
1108            uptime_seconds: 60,
1109            session_entries: 1000,
1110            session_entries_per_second: 16.67,
1111            active_streams: 1,
1112            total_stream_events: 500,
1113        };
1114        let json = serde_json::to_string(&response).unwrap();
1115        assert!(json.contains("total_entries_generated"));
1116        assert!(json.contains("session_entries_per_second"));
1117    }
1118
1119    #[test]
1120    fn test_metrics_response_deserialization() {
1121        let json = r#"{
1122            "total_entries_generated": 5000,
1123            "total_anomalies_injected": 50,
1124            "uptime_seconds": 300,
1125            "session_entries": 5000,
1126            "session_entries_per_second": 16.67,
1127            "active_streams": 2,
1128            "total_stream_events": 10000
1129        }"#;
1130        let response: MetricsResponse = serde_json::from_str(json).unwrap();
1131        assert_eq!(response.total_entries_generated, 5000);
1132        assert_eq!(response.active_streams, 2);
1133    }
1134
1135    #[test]
1136    fn test_config_response_serialization() {
1137        let response = ConfigResponse {
1138            success: true,
1139            message: "Configuration loaded".to_string(),
1140            config: Some(GenerationConfigDto {
1141                industry: "manufacturing".to_string(),
1142                start_date: "2024-01-01".to_string(),
1143                period_months: 12,
1144                seed: Some(42),
1145                coa_complexity: "medium".to_string(),
1146                companies: vec![],
1147                fraud_enabled: false,
1148                fraud_rate: 0.0,
1149            }),
1150        };
1151        let json = serde_json::to_string(&response).unwrap();
1152        assert!(json.contains("success"));
1153        assert!(json.contains("config"));
1154    }
1155
1156    #[test]
1157    fn test_config_response_without_config() {
1158        let response = ConfigResponse {
1159            success: false,
1160            message: "No configuration available".to_string(),
1161            config: None,
1162        };
1163        let json = serde_json::to_string(&response).unwrap();
1164        assert!(json.contains("null") || json.contains("config\":null"));
1165    }
1166
1167    #[test]
1168    fn test_generation_config_dto_roundtrip() {
1169        let original = GenerationConfigDto {
1170            industry: "retail".to_string(),
1171            start_date: "2024-06-01".to_string(),
1172            period_months: 6,
1173            seed: Some(12345),
1174            coa_complexity: "large".to_string(),
1175            companies: vec![CompanyConfigDto {
1176                code: "1000".to_string(),
1177                name: "Test Corp".to_string(),
1178                currency: "USD".to_string(),
1179                country: "US".to_string(),
1180                annual_transaction_volume: 100000,
1181                volume_weight: 1.0,
1182            }],
1183            fraud_enabled: true,
1184            fraud_rate: 0.05,
1185        };
1186
1187        let json = serde_json::to_string(&original).unwrap();
1188        let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1189
1190        assert_eq!(original.industry, deserialized.industry);
1191        assert_eq!(original.seed, deserialized.seed);
1192        assert_eq!(original.companies.len(), deserialized.companies.len());
1193    }
1194
1195    #[test]
1196    fn test_company_config_dto_serialization() {
1197        let company = CompanyConfigDto {
1198            code: "2000".to_string(),
1199            name: "European Subsidiary".to_string(),
1200            currency: "EUR".to_string(),
1201            country: "DE".to_string(),
1202            annual_transaction_volume: 50000,
1203            volume_weight: 0.5,
1204        };
1205        let json = serde_json::to_string(&company).unwrap();
1206        assert!(json.contains("2000"));
1207        assert!(json.contains("EUR"));
1208        assert!(json.contains("DE"));
1209    }
1210
1211    #[test]
1212    fn test_bulk_generate_request_deserialization() {
1213        let json = r#"{
1214            "entry_count": 5000,
1215            "include_master_data": true,
1216            "inject_anomalies": true
1217        }"#;
1218        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1219        assert_eq!(request.entry_count, Some(5000));
1220        assert_eq!(request.include_master_data, Some(true));
1221        assert_eq!(request.inject_anomalies, Some(true));
1222    }
1223
1224    #[test]
1225    fn test_bulk_generate_request_with_defaults() {
1226        let json = r#"{}"#;
1227        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1228        assert_eq!(request.entry_count, None);
1229        assert_eq!(request.include_master_data, None);
1230        assert_eq!(request.inject_anomalies, None);
1231    }
1232
1233    #[test]
1234    fn test_bulk_generate_response_serialization() {
1235        let response = BulkGenerateResponse {
1236            success: true,
1237            entries_generated: 1000,
1238            duration_ms: 250,
1239            anomaly_count: 20,
1240        };
1241        let json = serde_json::to_string(&response).unwrap();
1242        assert!(json.contains("entries_generated"));
1243        assert!(json.contains("1000"));
1244        assert!(json.contains("duration_ms"));
1245    }
1246
1247    #[test]
1248    fn test_stream_response_serialization() {
1249        let response = StreamResponse {
1250            success: true,
1251            message: "Stream started successfully".to_string(),
1252        };
1253        let json = serde_json::to_string(&response).unwrap();
1254        assert!(json.contains("success"));
1255        assert!(json.contains("Stream started"));
1256    }
1257
1258    #[test]
1259    fn test_stream_response_failure() {
1260        let response = StreamResponse {
1261            success: false,
1262            message: "Stream failed to start".to_string(),
1263        };
1264        let json = serde_json::to_string(&response).unwrap();
1265        assert!(json.contains("false"));
1266        assert!(json.contains("failed"));
1267    }
1268
1269    // ==========================================================================
1270    // CORS Configuration Tests
1271    // ==========================================================================
1272
1273    #[test]
1274    fn test_cors_config_default() {
1275        let config = CorsConfig::default();
1276        assert!(!config.allow_any_origin);
1277        assert!(!config.allowed_origins.is_empty());
1278        assert!(config
1279            .allowed_origins
1280            .contains(&"http://localhost:5173".to_string()));
1281        assert!(config
1282            .allowed_origins
1283            .contains(&"tauri://localhost".to_string()));
1284    }
1285
1286    #[test]
1287    fn test_cors_config_custom_origins() {
1288        let config = CorsConfig {
1289            allowed_origins: vec![
1290                "https://example.com".to_string(),
1291                "https://app.example.com".to_string(),
1292            ],
1293            allow_any_origin: false,
1294        };
1295        assert_eq!(config.allowed_origins.len(), 2);
1296        assert!(config
1297            .allowed_origins
1298            .contains(&"https://example.com".to_string()));
1299    }
1300
1301    #[test]
1302    fn test_cors_config_permissive() {
1303        let config = CorsConfig {
1304            allowed_origins: vec![],
1305            allow_any_origin: true,
1306        };
1307        assert!(config.allow_any_origin);
1308    }
1309
1310    // ==========================================================================
1311    // Request Validation Tests (edge cases)
1312    // ==========================================================================
1313
1314    #[test]
1315    fn test_bulk_generate_request_partial() {
1316        let json = r#"{"entry_count": 100}"#;
1317        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1318        assert_eq!(request.entry_count, Some(100));
1319        assert!(request.include_master_data.is_none());
1320    }
1321
1322    #[test]
1323    fn test_generation_config_no_seed() {
1324        let config = GenerationConfigDto {
1325            industry: "technology".to_string(),
1326            start_date: "2024-01-01".to_string(),
1327            period_months: 3,
1328            seed: None,
1329            coa_complexity: "small".to_string(),
1330            companies: vec![],
1331            fraud_enabled: false,
1332            fraud_rate: 0.0,
1333        };
1334        let json = serde_json::to_string(&config).unwrap();
1335        assert!(json.contains("seed"));
1336    }
1337
1338    #[test]
1339    fn test_generation_config_multiple_companies() {
1340        let config = GenerationConfigDto {
1341            industry: "manufacturing".to_string(),
1342            start_date: "2024-01-01".to_string(),
1343            period_months: 12,
1344            seed: Some(42),
1345            coa_complexity: "large".to_string(),
1346            companies: vec![
1347                CompanyConfigDto {
1348                    code: "1000".to_string(),
1349                    name: "Headquarters".to_string(),
1350                    currency: "USD".to_string(),
1351                    country: "US".to_string(),
1352                    annual_transaction_volume: 100000,
1353                    volume_weight: 1.0,
1354                },
1355                CompanyConfigDto {
1356                    code: "2000".to_string(),
1357                    name: "European Sub".to_string(),
1358                    currency: "EUR".to_string(),
1359                    country: "DE".to_string(),
1360                    annual_transaction_volume: 50000,
1361                    volume_weight: 0.5,
1362                },
1363                CompanyConfigDto {
1364                    code: "3000".to_string(),
1365                    name: "APAC Sub".to_string(),
1366                    currency: "JPY".to_string(),
1367                    country: "JP".to_string(),
1368                    annual_transaction_volume: 30000,
1369                    volume_weight: 0.3,
1370                },
1371            ],
1372            fraud_enabled: true,
1373            fraud_rate: 0.02,
1374        };
1375        assert_eq!(config.companies.len(), 3);
1376    }
1377
1378    // ==========================================================================
1379    // Metrics Calculation Tests
1380    // ==========================================================================
1381
1382    #[test]
1383    fn test_metrics_entries_per_second_calculation() {
1384        // Test that we can represent the expected calculation
1385        let total_entries: u64 = 1000;
1386        let uptime: u64 = 60;
1387        let eps = if uptime > 0 {
1388            total_entries as f64 / uptime as f64
1389        } else {
1390            0.0
1391        };
1392        assert!((eps - 16.67).abs() < 0.1);
1393    }
1394
1395    #[test]
1396    fn test_metrics_entries_per_second_zero_uptime() {
1397        let total_entries: u64 = 1000;
1398        let uptime: u64 = 0;
1399        let eps = if uptime > 0 {
1400            total_entries as f64 / uptime as f64
1401        } else {
1402            0.0
1403        };
1404        assert_eq!(eps, 0.0);
1405    }
1406}