Skip to main content

datasynth_server/rest/
routes.rs

1//! REST API routes.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7    extract::{State, WebSocketUpgrade},
8    http::{header, Method, StatusCode},
9    response::IntoResponse,
10    routing::{get, post},
11    Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24/// Application state shared across handlers.
25#[derive(Clone)]
26pub struct AppState {
27    pub server_state: Arc<ServerState>,
28    pub job_queue: Option<Arc<JobQueue>>,
29}
30
31/// Timeout configuration for the REST API.
32#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34    /// Request timeout in seconds.
35    pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39    fn default() -> Self {
40        Self {
41            // 5 minutes default - bulk generation can take a while
42            request_timeout_secs: 300,
43        }
44    }
45}
46
47impl TimeoutConfig {
48    /// Create a new timeout config.
49    pub fn new(timeout_secs: u64) -> Self {
50        Self {
51            request_timeout_secs: timeout_secs,
52        }
53    }
54}
55
56/// CORS configuration for the REST API.
57#[derive(Clone)]
58pub struct CorsConfig {
59    /// Allowed origins. If empty, only localhost is allowed.
60    pub allowed_origins: Vec<String>,
61    /// Allow any origin (development mode only - NOT recommended for production).
62    pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66    fn default() -> Self {
67        Self {
68            allowed_origins: vec![
69                "http://localhost:5173".to_string(), // Vite dev server
70                "http://localhost:3000".to_string(), // Common dev server
71                "http://127.0.0.1:5173".to_string(),
72                "http://127.0.0.1:3000".to_string(),
73                "tauri://localhost".to_string(), // Tauri app
74            ],
75            allow_any_origin: false,
76        }
77    }
78}
79
80/// Add API version header to responses.
81async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82    let (mut parts, body) = response.into_parts();
83    parts.headers.insert(
84        axum::http::HeaderName::from_static("x-api-version"),
85        axum::http::HeaderValue::from_static("v1"),
86    );
87    axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97/// Create the REST API router with default CORS settings.
98pub fn create_router(service: SynthService) -> Router {
99    create_router_with_cors(service, CorsConfig::default())
100}
101
102/// Create the REST API router with full configuration (CORS, auth, rate limiting, and timeout).
103///
104/// Uses in-memory rate limiting by default. For distributed rate limiting
105/// with Redis, use [`create_router_full_with_backend`] instead.
106pub fn create_router_full(
107    service: SynthService,
108    cors_config: CorsConfig,
109    auth_config: AuthConfig,
110    rate_limit_config: RateLimitConfig,
111    timeout_config: TimeoutConfig,
112) -> Router {
113    let backend = RateLimitBackend::in_memory(rate_limit_config);
114    create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117/// Create the REST API router with full configuration and a specific rate limiting backend.
118///
119/// This allows using either in-memory or Redis-backed rate limiting.
120///
121/// # Example (in-memory)
122/// ```rust,ignore
123/// let backend = RateLimitBackend::in_memory(rate_limit_config);
124/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
125/// ```
126///
127/// # Example (Redis)
128/// ```rust,ignore
129/// let backend = RateLimitBackend::redis("redis://127.0.0.1:6379", rate_limit_config).await?;
130/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
131/// ```
132pub fn create_router_full_with_backend(
133    service: SynthService,
134    cors_config: CorsConfig,
135    auth_config: AuthConfig,
136    rate_limit_backend: RateLimitBackend,
137    timeout_config: TimeoutConfig,
138) -> Router {
139    let server_state = service.state.clone();
140    let state = AppState {
141        server_state,
142        job_queue: None,
143    };
144
145    let cors = if cors_config.allow_any_origin {
146        CorsLayer::permissive()
147    } else {
148        let origins: Vec<_> = cors_config
149            .allowed_origins
150            .iter()
151            .filter_map(|o| o.parse().ok())
152            .collect();
153
154        CorsLayer::new()
155            .allow_origin(AllowOrigin::list(origins))
156            .allow_methods([
157                Method::GET,
158                Method::POST,
159                Method::PUT,
160                Method::DELETE,
161                Method::OPTIONS,
162            ])
163            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164    };
165
166    Router::new()
167        // Health and metrics (exempt from auth and rate limiting by default)
168        .route("/health", get(health_check))
169        .route("/ready", get(readiness_check))
170        .route("/live", get(liveness_check))
171        .route("/api/metrics", get(get_metrics))
172        .route("/metrics", get(prometheus_metrics))
173        // Configuration
174        .route("/api/config", get(get_config))
175        .route("/api/config", post(set_config))
176        .route("/api/config/reload", post(reload_config))
177        // Generation
178        .route("/api/generate/bulk", post(bulk_generate))
179        .route("/api/stream/start", post(start_stream))
180        .route("/api/stream/stop", post(stop_stream))
181        .route("/api/stream/pause", post(pause_stream))
182        .route("/api/stream/resume", post(resume_stream))
183        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184        // Jobs
185        .route("/api/jobs/submit", post(submit_job))
186        .route("/api/jobs", get(list_jobs))
187        .route("/api/jobs/{id}", get(get_job))
188        .route("/api/jobs/{id}/cancel", post(cancel_job))
189        // WebSocket
190        .route("/ws/metrics", get(websocket_metrics))
191        .route("/ws/events", get(websocket_events))
192        // Middleware stack (outermost applied first, innermost last)
193        // Order: Timeout -> RateLimit -> RequestValidation -> Auth -> RequestId -> CORS -> SecurityHeaders -> APIVersion -> Router
194        .layer(axum::middleware::from_fn(security_headers_middleware))
195        .layer(axum::middleware::map_response(api_version_header))
196        .layer(cors)
197        .layer(axum::middleware::from_fn(request_id_middleware))
198        .layer(axum::middleware::from_fn(auth_middleware))
199        .layer(axum::Extension(auth_config))
200        .layer(axum::middleware::from_fn(request_validation_middleware))
201        .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202        .layer(axum::Extension(rate_limit_backend))
203        .layer(TimeoutLayer::with_status_code(
204            StatusCode::REQUEST_TIMEOUT,
205            Duration::from_secs(timeout_config.request_timeout_secs),
206        ))
207        .with_state(state)
208}
209
210/// Create the REST API router with custom CORS and authentication settings.
211pub fn create_router_with_auth(
212    service: SynthService,
213    cors_config: CorsConfig,
214    auth_config: AuthConfig,
215) -> Router {
216    let server_state = service.state.clone();
217    let state = AppState {
218        server_state,
219        job_queue: None,
220    };
221
222    let cors = if cors_config.allow_any_origin {
223        CorsLayer::permissive()
224    } else {
225        let origins: Vec<_> = cors_config
226            .allowed_origins
227            .iter()
228            .filter_map(|o| o.parse().ok())
229            .collect();
230
231        CorsLayer::new()
232            .allow_origin(AllowOrigin::list(origins))
233            .allow_methods([
234                Method::GET,
235                Method::POST,
236                Method::PUT,
237                Method::DELETE,
238                Method::OPTIONS,
239            ])
240            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
241    };
242
243    Router::new()
244        // Health and metrics (exempt from auth by default)
245        .route("/health", get(health_check))
246        .route("/ready", get(readiness_check))
247        .route("/live", get(liveness_check))
248        .route("/api/metrics", get(get_metrics))
249        .route("/metrics", get(prometheus_metrics))
250        // Configuration
251        .route("/api/config", get(get_config))
252        .route("/api/config", post(set_config))
253        .route("/api/config/reload", post(reload_config))
254        // Generation
255        .route("/api/generate/bulk", post(bulk_generate))
256        .route("/api/stream/start", post(start_stream))
257        .route("/api/stream/stop", post(stop_stream))
258        .route("/api/stream/pause", post(pause_stream))
259        .route("/api/stream/resume", post(resume_stream))
260        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
261        // Jobs
262        .route("/api/jobs/submit", post(submit_job))
263        .route("/api/jobs", get(list_jobs))
264        .route("/api/jobs/{id}", get(get_job))
265        .route("/api/jobs/{id}/cancel", post(cancel_job))
266        // WebSocket
267        .route("/ws/metrics", get(websocket_metrics))
268        .route("/ws/events", get(websocket_events))
269        .layer(axum::middleware::from_fn(auth_middleware))
270        .layer(axum::Extension(auth_config))
271        .layer(cors)
272        .with_state(state)
273}
274
275/// Create the REST API router with custom CORS settings.
276pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
277    let server_state = service.state.clone();
278    let state = AppState {
279        server_state,
280        job_queue: None,
281    };
282
283    let cors = if cors_config.allow_any_origin {
284        // Development mode - allow any origin (use with caution)
285        CorsLayer::permissive()
286    } else {
287        // Production mode - restricted origins
288        let origins: Vec<_> = cors_config
289            .allowed_origins
290            .iter()
291            .filter_map(|o| o.parse().ok())
292            .collect();
293
294        CorsLayer::new()
295            .allow_origin(AllowOrigin::list(origins))
296            .allow_methods([
297                Method::GET,
298                Method::POST,
299                Method::PUT,
300                Method::DELETE,
301                Method::OPTIONS,
302            ])
303            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
304    };
305
306    Router::new()
307        // Health and metrics
308        .route("/health", get(health_check))
309        .route("/ready", get(readiness_check))
310        .route("/live", get(liveness_check))
311        .route("/api/metrics", get(get_metrics))
312        .route("/metrics", get(prometheus_metrics))
313        // Configuration
314        .route("/api/config", get(get_config))
315        .route("/api/config", post(set_config))
316        .route("/api/config/reload", post(reload_config))
317        // Generation
318        .route("/api/generate/bulk", post(bulk_generate))
319        .route("/api/stream/start", post(start_stream))
320        .route("/api/stream/stop", post(stop_stream))
321        .route("/api/stream/pause", post(pause_stream))
322        .route("/api/stream/resume", post(resume_stream))
323        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
324        // Jobs
325        .route("/api/jobs/submit", post(submit_job))
326        .route("/api/jobs", get(list_jobs))
327        .route("/api/jobs/{id}", get(get_job))
328        .route("/api/jobs/{id}/cancel", post(cancel_job))
329        // WebSocket
330        .route("/ws/metrics", get(websocket_metrics))
331        .route("/ws/events", get(websocket_events))
332        .layer(cors)
333        .with_state(state)
334}
335
336// ===========================================================================
337// Request/Response types
338// ===========================================================================
339
340#[derive(Debug, Serialize, Deserialize)]
341pub struct HealthResponse {
342    pub healthy: bool,
343    pub version: String,
344    pub uptime_seconds: u64,
345}
346
347/// Readiness check response for Kubernetes.
348#[derive(Debug, Serialize, Deserialize)]
349pub struct ReadinessResponse {
350    pub ready: bool,
351    pub message: String,
352    pub checks: Vec<HealthCheck>,
353}
354
355/// Individual health check result.
356#[derive(Debug, Serialize, Deserialize)]
357pub struct HealthCheck {
358    pub name: String,
359    pub status: String,
360}
361
362/// Liveness check response for Kubernetes.
363#[derive(Debug, Serialize, Deserialize)]
364pub struct LivenessResponse {
365    pub alive: bool,
366    pub timestamp: String,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370pub struct MetricsResponse {
371    pub total_entries_generated: u64,
372    pub total_anomalies_injected: u64,
373    pub uptime_seconds: u64,
374    pub session_entries: u64,
375    pub session_entries_per_second: f64,
376    pub active_streams: u32,
377    pub total_stream_events: u64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ConfigResponse {
382    pub success: bool,
383    pub message: String,
384    pub config: Option<GenerationConfigDto>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct GenerationConfigDto {
389    pub industry: String,
390    pub start_date: String,
391    pub period_months: u32,
392    pub seed: Option<u64>,
393    pub coa_complexity: String,
394    pub companies: Vec<CompanyConfigDto>,
395    pub fraud_enabled: bool,
396    pub fraud_rate: f32,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct CompanyConfigDto {
401    pub code: String,
402    pub name: String,
403    pub currency: String,
404    pub country: String,
405    pub annual_transaction_volume: u64,
406    pub volume_weight: f32,
407}
408
409#[derive(Debug, Deserialize)]
410pub struct BulkGenerateRequest {
411    pub entry_count: Option<u64>,
412    pub include_master_data: Option<bool>,
413    pub inject_anomalies: Option<bool>,
414}
415
416#[derive(Debug, Serialize)]
417pub struct BulkGenerateResponse {
418    pub success: bool,
419    pub entries_generated: u64,
420    pub duration_ms: u64,
421    pub anomaly_count: u64,
422}
423
424#[derive(Debug, Deserialize)]
425#[allow(dead_code)] // Fields deserialized from request, reserved for future use
426pub struct StreamRequest {
427    pub events_per_second: Option<u32>,
428    pub max_events: Option<u64>,
429    pub inject_anomalies: Option<bool>,
430}
431
432#[derive(Debug, Serialize)]
433pub struct StreamResponse {
434    pub success: bool,
435    pub message: String,
436}
437
438// ===========================================================================
439// Handlers
440// ===========================================================================
441
442/// Health check endpoint - returns overall health status.
443async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
444    Json(HealthResponse {
445        healthy: true,
446        version: env!("CARGO_PKG_VERSION").to_string(),
447        uptime_seconds: state.server_state.uptime_seconds(),
448    })
449}
450
451/// Readiness probe - indicates the service is ready to accept traffic.
452/// Use for Kubernetes readiness probes.
453async fn readiness_check(
454    State(state): State<AppState>,
455) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
456    let mut checks = Vec::new();
457    let mut any_fail = false;
458
459    // Check if configuration is loaded and valid
460    let config = state.server_state.config.read().await;
461    let config_valid = !config.companies.is_empty();
462    checks.push(HealthCheck {
463        name: "config".to_string(),
464        status: if config_valid { "ok" } else { "fail" }.to_string(),
465    });
466    if !config_valid {
467        any_fail = true;
468    }
469    drop(config);
470
471    // Check resource guard (memory)
472    let resource_status = state.server_state.resource_status();
473    let memory_status = if resource_status.degradation_level == "Emergency" {
474        any_fail = true;
475        "fail"
476    } else if resource_status.degradation_level != "Normal" {
477        "degraded"
478    } else {
479        "ok"
480    };
481    checks.push(HealthCheck {
482        name: "memory".to_string(),
483        status: memory_status.to_string(),
484    });
485
486    // Check disk (>100MB free)
487    let disk_ok = resource_status.disk_available_mb > 100;
488    checks.push(HealthCheck {
489        name: "disk".to_string(),
490        status: if disk_ok { "ok" } else { "fail" }.to_string(),
491    });
492    if !disk_ok {
493        any_fail = true;
494    }
495
496    let response = ReadinessResponse {
497        ready: !any_fail,
498        message: if any_fail {
499            "Service is not ready".to_string()
500        } else {
501            "Service is ready".to_string()
502        },
503        checks,
504    };
505
506    if any_fail {
507        Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
508    } else {
509        Ok(Json(response))
510    }
511}
512
513/// Liveness probe - indicates the service is alive.
514/// Use for Kubernetes liveness probes.
515async fn liveness_check() -> Json<LivenessResponse> {
516    Json(LivenessResponse {
517        alive: true,
518        timestamp: chrono::Utc::now().to_rfc3339(),
519    })
520}
521
522/// Prometheus-compatible metrics endpoint.
523/// Returns metrics in Prometheus text exposition format.
524async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
525    use std::sync::atomic::Ordering;
526
527    let uptime = state.server_state.uptime_seconds();
528    let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
529    let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
530    let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
531    let total_stream_events = state
532        .server_state
533        .total_stream_events
534        .load(Ordering::Relaxed);
535
536    let entries_per_second = if uptime > 0 {
537        total_entries as f64 / uptime as f64
538    } else {
539        0.0
540    };
541
542    let metrics = format!(
543        r#"# HELP synth_entries_generated_total Total number of journal entries generated
544# TYPE synth_entries_generated_total counter
545synth_entries_generated_total {}
546
547# HELP synth_anomalies_injected_total Total number of anomalies injected
548# TYPE synth_anomalies_injected_total counter
549synth_anomalies_injected_total {}
550
551# HELP synth_uptime_seconds Server uptime in seconds
552# TYPE synth_uptime_seconds gauge
553synth_uptime_seconds {}
554
555# HELP synth_entries_per_second Rate of entry generation
556# TYPE synth_entries_per_second gauge
557synth_entries_per_second {:.2}
558
559# HELP synth_active_streams Number of active streaming connections
560# TYPE synth_active_streams gauge
561synth_active_streams {}
562
563# HELP synth_stream_events_total Total events sent through streams
564# TYPE synth_stream_events_total counter
565synth_stream_events_total {}
566
567# HELP synth_info Server version information
568# TYPE synth_info gauge
569synth_info{{version="{}"}} 1
570"#,
571        total_entries,
572        total_anomalies,
573        uptime,
574        entries_per_second,
575        active_streams,
576        total_stream_events,
577        env!("CARGO_PKG_VERSION")
578    );
579
580    (
581        StatusCode::OK,
582        [(
583            header::CONTENT_TYPE,
584            "text/plain; version=0.0.4; charset=utf-8",
585        )],
586        metrics,
587    )
588}
589
590/// Get server metrics.
591async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
592    let uptime = state.server_state.uptime_seconds();
593    let total_entries = state
594        .server_state
595        .total_entries
596        .load(std::sync::atomic::Ordering::Relaxed);
597
598    let entries_per_second = if uptime > 0 {
599        total_entries as f64 / uptime as f64
600    } else {
601        0.0
602    };
603
604    Json(MetricsResponse {
605        total_entries_generated: total_entries,
606        total_anomalies_injected: state
607            .server_state
608            .total_anomalies
609            .load(std::sync::atomic::Ordering::Relaxed),
610        uptime_seconds: uptime,
611        session_entries: total_entries,
612        session_entries_per_second: entries_per_second,
613        active_streams: state
614            .server_state
615            .active_streams
616            .load(std::sync::atomic::Ordering::Relaxed) as u32,
617        total_stream_events: state
618            .server_state
619            .total_stream_events
620            .load(std::sync::atomic::Ordering::Relaxed),
621    })
622}
623
624/// Get current configuration.
625async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
626    let config = state.server_state.config.read().await;
627
628    Json(ConfigResponse {
629        success: true,
630        message: "Current configuration".to_string(),
631        config: Some(GenerationConfigDto {
632            industry: format!("{:?}", config.global.industry),
633            start_date: config.global.start_date.clone(),
634            period_months: config.global.period_months,
635            seed: config.global.seed,
636            coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
637            companies: config
638                .companies
639                .iter()
640                .map(|c| CompanyConfigDto {
641                    code: c.code.clone(),
642                    name: c.name.clone(),
643                    currency: c.currency.clone(),
644                    country: c.country.clone(),
645                    annual_transaction_volume: c.annual_transaction_volume.count(),
646                    volume_weight: c.volume_weight as f32,
647                })
648                .collect(),
649            fraud_enabled: config.fraud.enabled,
650            fraud_rate: config.fraud.fraud_rate as f32,
651        }),
652    })
653}
654
655/// Set configuration.
656async fn set_config(
657    State(state): State<AppState>,
658    Json(new_config): Json<GenerationConfigDto>,
659) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
660    use datasynth_config::schema::{CompanyConfig, TransactionVolume};
661    use datasynth_core::models::{CoAComplexity, IndustrySector};
662
663    info!(
664        "Configuration update requested: industry={}, period_months={}",
665        new_config.industry, new_config.period_months
666    );
667
668    // Parse industry from string
669    let industry = match new_config.industry.to_lowercase().as_str() {
670        "manufacturing" => IndustrySector::Manufacturing,
671        "retail" => IndustrySector::Retail,
672        "financial_services" | "financialservices" => IndustrySector::FinancialServices,
673        "healthcare" => IndustrySector::Healthcare,
674        "technology" => IndustrySector::Technology,
675        "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
676        "energy" => IndustrySector::Energy,
677        "transportation" => IndustrySector::Transportation,
678        "real_estate" | "realestate" => IndustrySector::RealEstate,
679        "telecommunications" => IndustrySector::Telecommunications,
680        _ => {
681            return Err((
682                StatusCode::BAD_REQUEST,
683                Json(ConfigResponse {
684                    success: false,
685                    message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
686                    config: None,
687                }),
688            ));
689        }
690    };
691
692    // Parse CoA complexity from string
693    let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
694        "small" => CoAComplexity::Small,
695        "medium" => CoAComplexity::Medium,
696        "large" => CoAComplexity::Large,
697        _ => {
698            return Err((
699                StatusCode::BAD_REQUEST,
700                Json(ConfigResponse {
701                    success: false,
702                    message: format!(
703                        "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
704                        new_config.coa_complexity
705                    ),
706                    config: None,
707                }),
708            ));
709        }
710    };
711
712    // Convert CompanyConfigDto to CompanyConfig
713    let companies: Vec<CompanyConfig> = new_config
714        .companies
715        .iter()
716        .map(|c| CompanyConfig {
717            code: c.code.clone(),
718            name: c.name.clone(),
719            currency: c.currency.clone(),
720            functional_currency: None,
721            country: c.country.clone(),
722            fiscal_year_variant: "K4".to_string(),
723            annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
724            volume_weight: c.volume_weight as f64,
725        })
726        .collect();
727
728    // Update the configuration
729    let mut config = state.server_state.config.write().await;
730    config.global.industry = industry;
731    config.global.start_date = new_config.start_date.clone();
732    config.global.period_months = new_config.period_months;
733    config.global.seed = new_config.seed;
734    config.chart_of_accounts.complexity = complexity;
735    config.fraud.enabled = new_config.fraud_enabled;
736    config.fraud.fraud_rate = new_config.fraud_rate as f64;
737
738    // Only update companies if provided
739    if !companies.is_empty() {
740        config.companies = companies;
741    }
742
743    info!("Configuration updated successfully");
744
745    Ok(Json(ConfigResponse {
746        success: true,
747        message: "Configuration updated and applied".to_string(),
748        config: Some(new_config),
749    }))
750}
751
752/// Bulk generation endpoint.
753async fn bulk_generate(
754    State(state): State<AppState>,
755    Json(req): Json<BulkGenerateRequest>,
756) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
757    // Validate entry_count bounds
758    const MAX_ENTRY_COUNT: u64 = 1_000_000;
759    if let Some(count) = req.entry_count {
760        if count > MAX_ENTRY_COUNT {
761            return Err((
762                StatusCode::BAD_REQUEST,
763                format!("entry_count ({count}) exceeds maximum allowed value ({MAX_ENTRY_COUNT})"),
764            ));
765        }
766    }
767
768    let config = state.server_state.config.read().await.clone();
769    let start_time = std::time::Instant::now();
770
771    let phase_config = {
772        let mut pc = PhaseConfig::from_config(&config);
773        pc.generate_master_data = req.include_master_data.unwrap_or(false);
774        pc.generate_document_flows = false;
775        pc.generate_journal_entries = true;
776        pc.inject_anomalies = req.inject_anomalies.unwrap_or(false);
777        pc.show_progress = false;
778        pc
779    };
780
781    let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
782        (
783            StatusCode::INTERNAL_SERVER_ERROR,
784            format!("Failed to create orchestrator: {e}"),
785        )
786    })?;
787
788    let result = orchestrator.generate().map_err(|e| {
789        (
790            StatusCode::INTERNAL_SERVER_ERROR,
791            format!("Generation failed: {e}"),
792        )
793    })?;
794
795    let duration_ms = start_time.elapsed().as_millis() as u64;
796    let entries_count = result.journal_entries.len() as u64;
797    let anomaly_count = result.anomaly_labels.labels.len() as u64;
798
799    // Update metrics
800    state
801        .server_state
802        .total_entries
803        .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
804    state
805        .server_state
806        .total_anomalies
807        .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
808
809    Ok(Json(BulkGenerateResponse {
810        success: true,
811        entries_generated: entries_count,
812        duration_ms,
813        anomaly_count,
814    }))
815}
816
817/// Start streaming.
818async fn start_stream(
819    State(state): State<AppState>,
820    Json(req): Json<StreamRequest>,
821) -> Json<StreamResponse> {
822    // Apply stream request parameters to server state
823    if let Some(eps) = req.events_per_second {
824        info!("Stream configured: events_per_second={}", eps);
825        state
826            .server_state
827            .stream_events_per_second
828            .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
829    }
830    if let Some(max) = req.max_events {
831        info!("Stream configured: max_events={}", max);
832        state
833            .server_state
834            .stream_max_events
835            .store(max, std::sync::atomic::Ordering::Relaxed);
836    }
837    if let Some(inject) = req.inject_anomalies {
838        info!("Stream configured: inject_anomalies={}", inject);
839        state
840            .server_state
841            .stream_inject_anomalies
842            .store(inject, std::sync::atomic::Ordering::Relaxed);
843    }
844
845    state
846        .server_state
847        .stream_stopped
848        .store(false, std::sync::atomic::Ordering::Relaxed);
849    state
850        .server_state
851        .stream_paused
852        .store(false, std::sync::atomic::Ordering::Relaxed);
853
854    Json(StreamResponse {
855        success: true,
856        message: "Stream started".to_string(),
857    })
858}
859
860/// Stop streaming.
861async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
862    state
863        .server_state
864        .stream_stopped
865        .store(true, std::sync::atomic::Ordering::Relaxed);
866
867    Json(StreamResponse {
868        success: true,
869        message: "Stream stopped".to_string(),
870    })
871}
872
873/// Pause streaming.
874async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
875    state
876        .server_state
877        .stream_paused
878        .store(true, std::sync::atomic::Ordering::Relaxed);
879
880    Json(StreamResponse {
881        success: true,
882        message: "Stream paused".to_string(),
883    })
884}
885
886/// Resume streaming.
887async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
888    state
889        .server_state
890        .stream_paused
891        .store(false, std::sync::atomic::Ordering::Relaxed);
892
893    Json(StreamResponse {
894        success: true,
895        message: "Stream resumed".to_string(),
896    })
897}
898
899/// Trigger a specific pattern.
900///
901/// Valid patterns: year_end_spike, period_end_spike, holiday_cluster,
902/// fraud_cluster, error_cluster, uniform, or custom:* patterns.
903async fn trigger_pattern(
904    State(state): State<AppState>,
905    axum::extract::Path(pattern): axum::extract::Path<String>,
906) -> Json<StreamResponse> {
907    info!("Pattern trigger requested: {}", pattern);
908
909    // Validate pattern name
910    let valid_patterns = [
911        "year_end_spike",
912        "period_end_spike",
913        "holiday_cluster",
914        "fraud_cluster",
915        "error_cluster",
916        "uniform",
917    ];
918
919    let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
920
921    if !is_valid {
922        return Json(StreamResponse {
923            success: false,
924            message: format!(
925                "Unknown pattern '{pattern}'. Valid patterns: {valid_patterns:?}, or use 'custom:name' for custom patterns"
926            ),
927        });
928    }
929
930    // Store the pattern for the stream generator to pick up
931    match state.server_state.triggered_pattern.try_write() {
932        Ok(mut triggered) => {
933            *triggered = Some(pattern.clone());
934            Json(StreamResponse {
935                success: true,
936                message: format!("Pattern '{pattern}' will be applied to upcoming entries"),
937            })
938        }
939        Err(_) => Json(StreamResponse {
940            success: false,
941            message: "Failed to acquire lock for pattern trigger".to_string(),
942        }),
943    }
944}
945
946/// WebSocket endpoint for metrics stream.
947async fn websocket_metrics(
948    ws: WebSocketUpgrade,
949    State(state): State<AppState>,
950) -> impl IntoResponse {
951    ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
952}
953
954/// WebSocket endpoint for event stream.
955async fn websocket_events(
956    ws: WebSocketUpgrade,
957    State(state): State<AppState>,
958) -> impl IntoResponse {
959    ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
960}
961
962// ===========================================================================
963// Job Queue Handlers
964// ===========================================================================
965
966/// Submit a new async generation job.
967async fn submit_job(
968    State(state): State<AppState>,
969    Json(request): Json<JobRequest>,
970) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
971    let queue = state.job_queue.as_ref().ok_or_else(|| {
972        (
973            StatusCode::SERVICE_UNAVAILABLE,
974            Json(serde_json::json!({"error": "Job queue not enabled"})),
975        )
976    })?;
977
978    let job_id = queue.submit(request).await;
979    info!("Job submitted: {}", job_id);
980
981    Ok((
982        StatusCode::CREATED,
983        Json(serde_json::json!({
984            "id": job_id.to_string(),
985            "status": "queued"
986        })),
987    ))
988}
989
990/// Get status of a specific job.
991async fn get_job(
992    State(state): State<AppState>,
993    axum::extract::Path(id): axum::extract::Path<String>,
994) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
995    let queue = state.job_queue.as_ref().ok_or_else(|| {
996        (
997            StatusCode::SERVICE_UNAVAILABLE,
998            Json(serde_json::json!({"error": "Job queue not enabled"})),
999        )
1000    })?;
1001
1002    match queue.get(&id).await {
1003        Some(entry) => Ok(Json(serde_json::json!({
1004            "id": entry.id,
1005            "status": format!("{:?}", entry.status).to_lowercase(),
1006            "submitted_at": entry.submitted_at.to_rfc3339(),
1007            "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1008            "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1009            "result": entry.result,
1010        }))),
1011        None => Err((
1012            StatusCode::NOT_FOUND,
1013            Json(serde_json::json!({"error": "Job not found"})),
1014        )),
1015    }
1016}
1017
1018/// List all jobs.
1019async fn list_jobs(
1020    State(state): State<AppState>,
1021) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1022    let queue = state.job_queue.as_ref().ok_or_else(|| {
1023        (
1024            StatusCode::SERVICE_UNAVAILABLE,
1025            Json(serde_json::json!({"error": "Job queue not enabled"})),
1026        )
1027    })?;
1028
1029    let summaries: Vec<_> = queue
1030        .list()
1031        .await
1032        .into_iter()
1033        .map(|s| {
1034            serde_json::json!({
1035                "id": s.id,
1036                "status": format!("{:?}", s.status).to_lowercase(),
1037                "submitted_at": s.submitted_at.to_rfc3339(),
1038            })
1039        })
1040        .collect();
1041
1042    Ok(Json(serde_json::json!({ "jobs": summaries })))
1043}
1044
1045/// Cancel a queued job.
1046async fn cancel_job(
1047    State(state): State<AppState>,
1048    axum::extract::Path(id): axum::extract::Path<String>,
1049) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1050    let queue = state.job_queue.as_ref().ok_or_else(|| {
1051        (
1052            StatusCode::SERVICE_UNAVAILABLE,
1053            Json(serde_json::json!({"error": "Job queue not enabled"})),
1054        )
1055    })?;
1056
1057    if queue.cancel(&id).await {
1058        Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1059    } else {
1060        Err((
1061            StatusCode::CONFLICT,
1062            Json(
1063                serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1064            ),
1065        ))
1066    }
1067}
1068
1069// ===========================================================================
1070// Config Reload Handler
1071// ===========================================================================
1072
1073/// Reload configuration from the configured source.
1074async fn reload_config(
1075    State(state): State<AppState>,
1076) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1077    let source = state.server_state.config_source.read().await.clone();
1078    match crate::config_loader::load_config(&source).await {
1079        Ok(new_config) => {
1080            let mut config = state.server_state.config.write().await;
1081            *config = new_config;
1082            info!("Configuration reloaded via REST API from {:?}", source);
1083            Ok(Json(serde_json::json!({
1084                "success": true,
1085                "message": "Configuration reloaded"
1086            })))
1087        }
1088        Err(e) => {
1089            error!("Failed to reload configuration: {}", e);
1090            Err((
1091                StatusCode::INTERNAL_SERVER_ERROR,
1092                Json(serde_json::json!({
1093                    "success": false,
1094                    "message": format!("Failed to reload configuration: {}", e)
1095                })),
1096            ))
1097        }
1098    }
1099}
1100
1101#[cfg(test)]
1102#[allow(clippy::unwrap_used)]
1103mod tests {
1104    use super::*;
1105
1106    // ==========================================================================
1107    // Response Serialization Tests
1108    // ==========================================================================
1109
1110    #[test]
1111    fn test_health_response_serialization() {
1112        let response = HealthResponse {
1113            healthy: true,
1114            version: "0.1.0".to_string(),
1115            uptime_seconds: 100,
1116        };
1117        let json = serde_json::to_string(&response).unwrap();
1118        assert!(json.contains("healthy"));
1119        assert!(json.contains("version"));
1120        assert!(json.contains("uptime_seconds"));
1121    }
1122
1123    #[test]
1124    fn test_health_response_deserialization() {
1125        let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1126        let response: HealthResponse = serde_json::from_str(json).unwrap();
1127        assert!(response.healthy);
1128        assert_eq!(response.version, "0.1.0");
1129        assert_eq!(response.uptime_seconds, 100);
1130    }
1131
1132    #[test]
1133    fn test_metrics_response_serialization() {
1134        let response = MetricsResponse {
1135            total_entries_generated: 1000,
1136            total_anomalies_injected: 10,
1137            uptime_seconds: 60,
1138            session_entries: 1000,
1139            session_entries_per_second: 16.67,
1140            active_streams: 1,
1141            total_stream_events: 500,
1142        };
1143        let json = serde_json::to_string(&response).unwrap();
1144        assert!(json.contains("total_entries_generated"));
1145        assert!(json.contains("session_entries_per_second"));
1146    }
1147
1148    #[test]
1149    fn test_metrics_response_deserialization() {
1150        let json = r#"{
1151            "total_entries_generated": 5000,
1152            "total_anomalies_injected": 50,
1153            "uptime_seconds": 300,
1154            "session_entries": 5000,
1155            "session_entries_per_second": 16.67,
1156            "active_streams": 2,
1157            "total_stream_events": 10000
1158        }"#;
1159        let response: MetricsResponse = serde_json::from_str(json).unwrap();
1160        assert_eq!(response.total_entries_generated, 5000);
1161        assert_eq!(response.active_streams, 2);
1162    }
1163
1164    #[test]
1165    fn test_config_response_serialization() {
1166        let response = ConfigResponse {
1167            success: true,
1168            message: "Configuration loaded".to_string(),
1169            config: Some(GenerationConfigDto {
1170                industry: "manufacturing".to_string(),
1171                start_date: "2024-01-01".to_string(),
1172                period_months: 12,
1173                seed: Some(42),
1174                coa_complexity: "medium".to_string(),
1175                companies: vec![],
1176                fraud_enabled: false,
1177                fraud_rate: 0.0,
1178            }),
1179        };
1180        let json = serde_json::to_string(&response).unwrap();
1181        assert!(json.contains("success"));
1182        assert!(json.contains("config"));
1183    }
1184
1185    #[test]
1186    fn test_config_response_without_config() {
1187        let response = ConfigResponse {
1188            success: false,
1189            message: "No configuration available".to_string(),
1190            config: None,
1191        };
1192        let json = serde_json::to_string(&response).unwrap();
1193        assert!(json.contains("null") || json.contains("config\":null"));
1194    }
1195
1196    #[test]
1197    fn test_generation_config_dto_roundtrip() {
1198        let original = GenerationConfigDto {
1199            industry: "retail".to_string(),
1200            start_date: "2024-06-01".to_string(),
1201            period_months: 6,
1202            seed: Some(12345),
1203            coa_complexity: "large".to_string(),
1204            companies: vec![CompanyConfigDto {
1205                code: "1000".to_string(),
1206                name: "Test Corp".to_string(),
1207                currency: "USD".to_string(),
1208                country: "US".to_string(),
1209                annual_transaction_volume: 100000,
1210                volume_weight: 1.0,
1211            }],
1212            fraud_enabled: true,
1213            fraud_rate: 0.05,
1214        };
1215
1216        let json = serde_json::to_string(&original).unwrap();
1217        let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1218
1219        assert_eq!(original.industry, deserialized.industry);
1220        assert_eq!(original.seed, deserialized.seed);
1221        assert_eq!(original.companies.len(), deserialized.companies.len());
1222    }
1223
1224    #[test]
1225    fn test_company_config_dto_serialization() {
1226        let company = CompanyConfigDto {
1227            code: "2000".to_string(),
1228            name: "European Subsidiary".to_string(),
1229            currency: "EUR".to_string(),
1230            country: "DE".to_string(),
1231            annual_transaction_volume: 50000,
1232            volume_weight: 0.5,
1233        };
1234        let json = serde_json::to_string(&company).unwrap();
1235        assert!(json.contains("2000"));
1236        assert!(json.contains("EUR"));
1237        assert!(json.contains("DE"));
1238    }
1239
1240    #[test]
1241    fn test_bulk_generate_request_deserialization() {
1242        let json = r#"{
1243            "entry_count": 5000,
1244            "include_master_data": true,
1245            "inject_anomalies": true
1246        }"#;
1247        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1248        assert_eq!(request.entry_count, Some(5000));
1249        assert_eq!(request.include_master_data, Some(true));
1250        assert_eq!(request.inject_anomalies, Some(true));
1251    }
1252
1253    #[test]
1254    fn test_bulk_generate_request_with_defaults() {
1255        let json = r#"{}"#;
1256        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1257        assert_eq!(request.entry_count, None);
1258        assert_eq!(request.include_master_data, None);
1259        assert_eq!(request.inject_anomalies, None);
1260    }
1261
1262    #[test]
1263    fn test_bulk_generate_response_serialization() {
1264        let response = BulkGenerateResponse {
1265            success: true,
1266            entries_generated: 1000,
1267            duration_ms: 250,
1268            anomaly_count: 20,
1269        };
1270        let json = serde_json::to_string(&response).unwrap();
1271        assert!(json.contains("entries_generated"));
1272        assert!(json.contains("1000"));
1273        assert!(json.contains("duration_ms"));
1274    }
1275
1276    #[test]
1277    fn test_stream_response_serialization() {
1278        let response = StreamResponse {
1279            success: true,
1280            message: "Stream started successfully".to_string(),
1281        };
1282        let json = serde_json::to_string(&response).unwrap();
1283        assert!(json.contains("success"));
1284        assert!(json.contains("Stream started"));
1285    }
1286
1287    #[test]
1288    fn test_stream_response_failure() {
1289        let response = StreamResponse {
1290            success: false,
1291            message: "Stream failed to start".to_string(),
1292        };
1293        let json = serde_json::to_string(&response).unwrap();
1294        assert!(json.contains("false"));
1295        assert!(json.contains("failed"));
1296    }
1297
1298    // ==========================================================================
1299    // CORS Configuration Tests
1300    // ==========================================================================
1301
1302    #[test]
1303    fn test_cors_config_default() {
1304        let config = CorsConfig::default();
1305        assert!(!config.allow_any_origin);
1306        assert!(!config.allowed_origins.is_empty());
1307        assert!(config
1308            .allowed_origins
1309            .contains(&"http://localhost:5173".to_string()));
1310        assert!(config
1311            .allowed_origins
1312            .contains(&"tauri://localhost".to_string()));
1313    }
1314
1315    #[test]
1316    fn test_cors_config_custom_origins() {
1317        let config = CorsConfig {
1318            allowed_origins: vec![
1319                "https://example.com".to_string(),
1320                "https://app.example.com".to_string(),
1321            ],
1322            allow_any_origin: false,
1323        };
1324        assert_eq!(config.allowed_origins.len(), 2);
1325        assert!(config
1326            .allowed_origins
1327            .contains(&"https://example.com".to_string()));
1328    }
1329
1330    #[test]
1331    fn test_cors_config_permissive() {
1332        let config = CorsConfig {
1333            allowed_origins: vec![],
1334            allow_any_origin: true,
1335        };
1336        assert!(config.allow_any_origin);
1337    }
1338
1339    // ==========================================================================
1340    // Request Validation Tests (edge cases)
1341    // ==========================================================================
1342
1343    #[test]
1344    fn test_bulk_generate_request_partial() {
1345        let json = r#"{"entry_count": 100}"#;
1346        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1347        assert_eq!(request.entry_count, Some(100));
1348        assert!(request.include_master_data.is_none());
1349    }
1350
1351    #[test]
1352    fn test_generation_config_no_seed() {
1353        let config = GenerationConfigDto {
1354            industry: "technology".to_string(),
1355            start_date: "2024-01-01".to_string(),
1356            period_months: 3,
1357            seed: None,
1358            coa_complexity: "small".to_string(),
1359            companies: vec![],
1360            fraud_enabled: false,
1361            fraud_rate: 0.0,
1362        };
1363        let json = serde_json::to_string(&config).unwrap();
1364        assert!(json.contains("seed"));
1365    }
1366
1367    #[test]
1368    fn test_generation_config_multiple_companies() {
1369        let config = GenerationConfigDto {
1370            industry: "manufacturing".to_string(),
1371            start_date: "2024-01-01".to_string(),
1372            period_months: 12,
1373            seed: Some(42),
1374            coa_complexity: "large".to_string(),
1375            companies: vec![
1376                CompanyConfigDto {
1377                    code: "1000".to_string(),
1378                    name: "Headquarters".to_string(),
1379                    currency: "USD".to_string(),
1380                    country: "US".to_string(),
1381                    annual_transaction_volume: 100000,
1382                    volume_weight: 1.0,
1383                },
1384                CompanyConfigDto {
1385                    code: "2000".to_string(),
1386                    name: "European Sub".to_string(),
1387                    currency: "EUR".to_string(),
1388                    country: "DE".to_string(),
1389                    annual_transaction_volume: 50000,
1390                    volume_weight: 0.5,
1391                },
1392                CompanyConfigDto {
1393                    code: "3000".to_string(),
1394                    name: "APAC Sub".to_string(),
1395                    currency: "JPY".to_string(),
1396                    country: "JP".to_string(),
1397                    annual_transaction_volume: 30000,
1398                    volume_weight: 0.3,
1399                },
1400            ],
1401            fraud_enabled: true,
1402            fraud_rate: 0.02,
1403        };
1404        assert_eq!(config.companies.len(), 3);
1405    }
1406
1407    // ==========================================================================
1408    // Metrics Calculation Tests
1409    // ==========================================================================
1410
1411    #[test]
1412    fn test_metrics_entries_per_second_calculation() {
1413        // Test that we can represent the expected calculation
1414        let total_entries: u64 = 1000;
1415        let uptime: u64 = 60;
1416        let eps = if uptime > 0 {
1417            total_entries as f64 / uptime as f64
1418        } else {
1419            0.0
1420        };
1421        assert!((eps - 16.67).abs() < 0.1);
1422    }
1423
1424    #[test]
1425    fn test_metrics_entries_per_second_zero_uptime() {
1426        let total_entries: u64 = 1000;
1427        let uptime: u64 = 0;
1428        let eps = if uptime > 0 {
1429            total_entries as f64 / uptime as f64
1430        } else {
1431            0.0
1432        };
1433        assert_eq!(eps, 0.0);
1434    }
1435}