Skip to main content

datasynth_server/rest/
routes.rs

1//! REST API routes.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7    extract::{State, WebSocketUpgrade},
8    http::{header, Method, StatusCode},
9    response::IntoResponse,
10    routing::{get, post},
11    Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::info;
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24/// Application state shared across handlers.
25#[derive(Clone)]
26pub struct AppState {
27    pub server_state: Arc<ServerState>,
28    pub job_queue: Option<Arc<JobQueue>>,
29}
30
31/// Timeout configuration for the REST API.
32#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34    /// Request timeout in seconds.
35    pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39    fn default() -> Self {
40        Self {
41            // 5 minutes default - bulk generation can take a while
42            request_timeout_secs: 300,
43        }
44    }
45}
46
47impl TimeoutConfig {
48    /// Create a new timeout config.
49    pub fn new(timeout_secs: u64) -> Self {
50        Self {
51            request_timeout_secs: timeout_secs,
52        }
53    }
54}
55
56/// CORS configuration for the REST API.
57#[derive(Clone)]
58pub struct CorsConfig {
59    /// Allowed origins. If empty, only localhost is allowed.
60    pub allowed_origins: Vec<String>,
61    /// Allow any origin (development mode only - NOT recommended for production).
62    pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66    fn default() -> Self {
67        Self {
68            allowed_origins: vec![
69                "http://localhost:5173".to_string(), // Vite dev server
70                "http://localhost:3000".to_string(), // Common dev server
71                "http://127.0.0.1:5173".to_string(),
72                "http://127.0.0.1:3000".to_string(),
73                "tauri://localhost".to_string(), // Tauri app
74            ],
75            allow_any_origin: false,
76        }
77    }
78}
79
80/// Add API version header to responses.
81async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82    let (mut parts, body) = response.into_parts();
83    parts.headers.insert(
84        axum::http::HeaderName::from_static("x-api-version"),
85        axum::http::HeaderValue::from_static("v1"),
86    );
87    axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97/// Create the REST API router with default CORS settings.
98pub fn create_router(service: SynthService) -> Router {
99    create_router_with_cors(service, CorsConfig::default())
100}
101
102/// Create the REST API router with full configuration (CORS, auth, rate limiting, and timeout).
103///
104/// Uses in-memory rate limiting by default. For distributed rate limiting
105/// with Redis, use [`create_router_full_with_backend`] instead.
106pub fn create_router_full(
107    service: SynthService,
108    cors_config: CorsConfig,
109    auth_config: AuthConfig,
110    rate_limit_config: RateLimitConfig,
111    timeout_config: TimeoutConfig,
112) -> Router {
113    let backend = RateLimitBackend::in_memory(rate_limit_config);
114    create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117/// Create the REST API router with full configuration and a specific rate limiting backend.
118///
119/// This allows using either in-memory or Redis-backed rate limiting.
120///
121/// # Example (in-memory)
122/// ```rust,ignore
123/// let backend = RateLimitBackend::in_memory(rate_limit_config);
124/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
125/// ```
126///
127/// # Example (Redis)
128/// ```rust,ignore
129/// let backend = RateLimitBackend::redis("redis://127.0.0.1:6379", rate_limit_config).await?;
130/// let router = create_router_full_with_backend(service, cors, auth, backend, timeout);
131/// ```
132pub fn create_router_full_with_backend(
133    service: SynthService,
134    cors_config: CorsConfig,
135    auth_config: AuthConfig,
136    rate_limit_backend: RateLimitBackend,
137    timeout_config: TimeoutConfig,
138) -> Router {
139    let server_state = service.state.clone();
140    let state = AppState {
141        server_state,
142        job_queue: None,
143    };
144
145    let cors = if cors_config.allow_any_origin {
146        CorsLayer::permissive()
147    } else {
148        let origins: Vec<_> = cors_config
149            .allowed_origins
150            .iter()
151            .filter_map(|o| o.parse().ok())
152            .collect();
153
154        CorsLayer::new()
155            .allow_origin(AllowOrigin::list(origins))
156            .allow_methods([
157                Method::GET,
158                Method::POST,
159                Method::PUT,
160                Method::DELETE,
161                Method::OPTIONS,
162            ])
163            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164    };
165
166    Router::new()
167        // Health and metrics (exempt from auth and rate limiting by default)
168        .route("/health", get(health_check))
169        .route("/ready", get(readiness_check))
170        .route("/live", get(liveness_check))
171        .route("/api/metrics", get(get_metrics))
172        .route("/metrics", get(prometheus_metrics))
173        // Configuration
174        .route("/api/config", get(get_config))
175        .route("/api/config", post(set_config))
176        .route("/api/config/reload", post(reload_config))
177        // Generation
178        .route("/api/generate/bulk", post(bulk_generate))
179        .route("/api/stream/start", post(start_stream))
180        .route("/api/stream/stop", post(stop_stream))
181        .route("/api/stream/pause", post(pause_stream))
182        .route("/api/stream/resume", post(resume_stream))
183        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184        // Jobs
185        .route("/api/jobs/submit", post(submit_job))
186        .route("/api/jobs", get(list_jobs))
187        .route("/api/jobs/{id}", get(get_job))
188        .route("/api/jobs/{id}/cancel", post(cancel_job))
189        // WebSocket
190        .route("/ws/metrics", get(websocket_metrics))
191        .route("/ws/events", get(websocket_events))
192        // Middleware stack (outermost applied first, innermost last)
193        // Order: Timeout -> RateLimit -> RequestValidation -> Auth -> RequestId -> CORS -> SecurityHeaders -> APIVersion -> Router
194        .layer(axum::middleware::from_fn(security_headers_middleware))
195        .layer(axum::middleware::map_response(api_version_header))
196        .layer(cors)
197        .layer(axum::middleware::from_fn(request_id_middleware))
198        .layer(axum::middleware::from_fn(auth_middleware))
199        .layer(axum::Extension(auth_config))
200        .layer(axum::middleware::from_fn(request_validation_middleware))
201        .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202        .layer(axum::Extension(rate_limit_backend))
203        .layer(TimeoutLayer::new(Duration::from_secs(
204            timeout_config.request_timeout_secs,
205        )))
206        .with_state(state)
207}
208
209/// Create the REST API router with custom CORS and authentication settings.
210pub fn create_router_with_auth(
211    service: SynthService,
212    cors_config: CorsConfig,
213    auth_config: AuthConfig,
214) -> Router {
215    let server_state = service.state.clone();
216    let state = AppState {
217        server_state,
218        job_queue: None,
219    };
220
221    let cors = if cors_config.allow_any_origin {
222        CorsLayer::permissive()
223    } else {
224        let origins: Vec<_> = cors_config
225            .allowed_origins
226            .iter()
227            .filter_map(|o| o.parse().ok())
228            .collect();
229
230        CorsLayer::new()
231            .allow_origin(AllowOrigin::list(origins))
232            .allow_methods([
233                Method::GET,
234                Method::POST,
235                Method::PUT,
236                Method::DELETE,
237                Method::OPTIONS,
238            ])
239            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
240    };
241
242    Router::new()
243        // Health and metrics (exempt from auth by default)
244        .route("/health", get(health_check))
245        .route("/ready", get(readiness_check))
246        .route("/live", get(liveness_check))
247        .route("/api/metrics", get(get_metrics))
248        .route("/metrics", get(prometheus_metrics))
249        // Configuration
250        .route("/api/config", get(get_config))
251        .route("/api/config", post(set_config))
252        .route("/api/config/reload", post(reload_config))
253        // Generation
254        .route("/api/generate/bulk", post(bulk_generate))
255        .route("/api/stream/start", post(start_stream))
256        .route("/api/stream/stop", post(stop_stream))
257        .route("/api/stream/pause", post(pause_stream))
258        .route("/api/stream/resume", post(resume_stream))
259        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
260        // Jobs
261        .route("/api/jobs/submit", post(submit_job))
262        .route("/api/jobs", get(list_jobs))
263        .route("/api/jobs/{id}", get(get_job))
264        .route("/api/jobs/{id}/cancel", post(cancel_job))
265        // WebSocket
266        .route("/ws/metrics", get(websocket_metrics))
267        .route("/ws/events", get(websocket_events))
268        .layer(axum::middleware::from_fn(auth_middleware))
269        .layer(axum::Extension(auth_config))
270        .layer(cors)
271        .with_state(state)
272}
273
274/// Create the REST API router with custom CORS settings.
275pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
276    let server_state = service.state.clone();
277    let state = AppState {
278        server_state,
279        job_queue: None,
280    };
281
282    let cors = if cors_config.allow_any_origin {
283        // Development mode - allow any origin (use with caution)
284        CorsLayer::permissive()
285    } else {
286        // Production mode - restricted origins
287        let origins: Vec<_> = cors_config
288            .allowed_origins
289            .iter()
290            .filter_map(|o| o.parse().ok())
291            .collect();
292
293        CorsLayer::new()
294            .allow_origin(AllowOrigin::list(origins))
295            .allow_methods([
296                Method::GET,
297                Method::POST,
298                Method::PUT,
299                Method::DELETE,
300                Method::OPTIONS,
301            ])
302            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
303    };
304
305    Router::new()
306        // Health and metrics
307        .route("/health", get(health_check))
308        .route("/ready", get(readiness_check))
309        .route("/live", get(liveness_check))
310        .route("/api/metrics", get(get_metrics))
311        .route("/metrics", get(prometheus_metrics))
312        // Configuration
313        .route("/api/config", get(get_config))
314        .route("/api/config", post(set_config))
315        .route("/api/config/reload", post(reload_config))
316        // Generation
317        .route("/api/generate/bulk", post(bulk_generate))
318        .route("/api/stream/start", post(start_stream))
319        .route("/api/stream/stop", post(stop_stream))
320        .route("/api/stream/pause", post(pause_stream))
321        .route("/api/stream/resume", post(resume_stream))
322        .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
323        // Jobs
324        .route("/api/jobs/submit", post(submit_job))
325        .route("/api/jobs", get(list_jobs))
326        .route("/api/jobs/{id}", get(get_job))
327        .route("/api/jobs/{id}/cancel", post(cancel_job))
328        // WebSocket
329        .route("/ws/metrics", get(websocket_metrics))
330        .route("/ws/events", get(websocket_events))
331        .layer(cors)
332        .with_state(state)
333}
334
335// ===========================================================================
336// Request/Response types
337// ===========================================================================
338
339#[derive(Debug, Serialize, Deserialize)]
340pub struct HealthResponse {
341    pub healthy: bool,
342    pub version: String,
343    pub uptime_seconds: u64,
344}
345
346/// Readiness check response for Kubernetes.
347#[derive(Debug, Serialize, Deserialize)]
348pub struct ReadinessResponse {
349    pub ready: bool,
350    pub message: String,
351    pub checks: Vec<HealthCheck>,
352}
353
354/// Individual health check result.
355#[derive(Debug, Serialize, Deserialize)]
356pub struct HealthCheck {
357    pub name: String,
358    pub status: String,
359}
360
361/// Liveness check response for Kubernetes.
362#[derive(Debug, Serialize, Deserialize)]
363pub struct LivenessResponse {
364    pub alive: bool,
365    pub timestamp: String,
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct MetricsResponse {
370    pub total_entries_generated: u64,
371    pub total_anomalies_injected: u64,
372    pub uptime_seconds: u64,
373    pub session_entries: u64,
374    pub session_entries_per_second: f64,
375    pub active_streams: u32,
376    pub total_stream_events: u64,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ConfigResponse {
381    pub success: bool,
382    pub message: String,
383    pub config: Option<GenerationConfigDto>,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct GenerationConfigDto {
388    pub industry: String,
389    pub start_date: String,
390    pub period_months: u32,
391    pub seed: Option<u64>,
392    pub coa_complexity: String,
393    pub companies: Vec<CompanyConfigDto>,
394    pub fraud_enabled: bool,
395    pub fraud_rate: f32,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct CompanyConfigDto {
400    pub code: String,
401    pub name: String,
402    pub currency: String,
403    pub country: String,
404    pub annual_transaction_volume: u64,
405    pub volume_weight: f32,
406}
407
408#[derive(Debug, Deserialize)]
409pub struct BulkGenerateRequest {
410    pub entry_count: Option<u64>,
411    pub include_master_data: Option<bool>,
412    pub inject_anomalies: Option<bool>,
413}
414
415#[derive(Debug, Serialize)]
416pub struct BulkGenerateResponse {
417    pub success: bool,
418    pub entries_generated: u64,
419    pub duration_ms: u64,
420    pub anomaly_count: u64,
421}
422
423#[derive(Debug, Deserialize)]
424#[allow(dead_code)] // Fields deserialized from request, reserved for future use
425pub struct StreamRequest {
426    pub events_per_second: Option<u32>,
427    pub max_events: Option<u64>,
428    pub inject_anomalies: Option<bool>,
429}
430
431#[derive(Debug, Serialize)]
432pub struct StreamResponse {
433    pub success: bool,
434    pub message: String,
435}
436
437// ===========================================================================
438// Handlers
439// ===========================================================================
440
441/// Health check endpoint - returns overall health status.
442async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
443    Json(HealthResponse {
444        healthy: true,
445        version: env!("CARGO_PKG_VERSION").to_string(),
446        uptime_seconds: state.server_state.uptime_seconds(),
447    })
448}
449
450/// Readiness probe - indicates the service is ready to accept traffic.
451/// Use for Kubernetes readiness probes.
452async fn readiness_check(
453    State(state): State<AppState>,
454) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
455    let mut checks = Vec::new();
456    let mut any_fail = false;
457
458    // Check if configuration is loaded and valid
459    let config = state.server_state.config.read().await;
460    let config_valid = !config.companies.is_empty();
461    checks.push(HealthCheck {
462        name: "config".to_string(),
463        status: if config_valid { "ok" } else { "fail" }.to_string(),
464    });
465    if !config_valid {
466        any_fail = true;
467    }
468    drop(config);
469
470    // Check resource guard (memory)
471    let resource_status = state.server_state.resource_status();
472    let memory_status = if resource_status.degradation_level == "Emergency" {
473        any_fail = true;
474        "fail"
475    } else if resource_status.degradation_level != "Normal" {
476        "degraded"
477    } else {
478        "ok"
479    };
480    checks.push(HealthCheck {
481        name: "memory".to_string(),
482        status: memory_status.to_string(),
483    });
484
485    // Check disk (>100MB free)
486    let disk_ok = resource_status.disk_available_mb > 100;
487    checks.push(HealthCheck {
488        name: "disk".to_string(),
489        status: if disk_ok { "ok" } else { "fail" }.to_string(),
490    });
491    if !disk_ok {
492        any_fail = true;
493    }
494
495    let response = ReadinessResponse {
496        ready: !any_fail,
497        message: if any_fail {
498            "Service is not ready".to_string()
499        } else {
500            "Service is ready".to_string()
501        },
502        checks,
503    };
504
505    if any_fail {
506        Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
507    } else {
508        Ok(Json(response))
509    }
510}
511
512/// Liveness probe - indicates the service is alive.
513/// Use for Kubernetes liveness probes.
514async fn liveness_check() -> Json<LivenessResponse> {
515    Json(LivenessResponse {
516        alive: true,
517        timestamp: chrono::Utc::now().to_rfc3339(),
518    })
519}
520
521/// Prometheus-compatible metrics endpoint.
522/// Returns metrics in Prometheus text exposition format.
523async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
524    use std::sync::atomic::Ordering;
525
526    let uptime = state.server_state.uptime_seconds();
527    let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
528    let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
529    let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
530    let total_stream_events = state
531        .server_state
532        .total_stream_events
533        .load(Ordering::Relaxed);
534
535    let entries_per_second = if uptime > 0 {
536        total_entries as f64 / uptime as f64
537    } else {
538        0.0
539    };
540
541    let metrics = format!(
542        r#"# HELP synth_entries_generated_total Total number of journal entries generated
543# TYPE synth_entries_generated_total counter
544synth_entries_generated_total {}
545
546# HELP synth_anomalies_injected_total Total number of anomalies injected
547# TYPE synth_anomalies_injected_total counter
548synth_anomalies_injected_total {}
549
550# HELP synth_uptime_seconds Server uptime in seconds
551# TYPE synth_uptime_seconds gauge
552synth_uptime_seconds {}
553
554# HELP synth_entries_per_second Rate of entry generation
555# TYPE synth_entries_per_second gauge
556synth_entries_per_second {:.2}
557
558# HELP synth_active_streams Number of active streaming connections
559# TYPE synth_active_streams gauge
560synth_active_streams {}
561
562# HELP synth_stream_events_total Total events sent through streams
563# TYPE synth_stream_events_total counter
564synth_stream_events_total {}
565
566# HELP synth_info Server version information
567# TYPE synth_info gauge
568synth_info{{version="{}"}} 1
569"#,
570        total_entries,
571        total_anomalies,
572        uptime,
573        entries_per_second,
574        active_streams,
575        total_stream_events,
576        env!("CARGO_PKG_VERSION")
577    );
578
579    (
580        StatusCode::OK,
581        [(
582            header::CONTENT_TYPE,
583            "text/plain; version=0.0.4; charset=utf-8",
584        )],
585        metrics,
586    )
587}
588
589/// Get server metrics.
590async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
591    let uptime = state.server_state.uptime_seconds();
592    let total_entries = state
593        .server_state
594        .total_entries
595        .load(std::sync::atomic::Ordering::Relaxed);
596
597    let entries_per_second = if uptime > 0 {
598        total_entries as f64 / uptime as f64
599    } else {
600        0.0
601    };
602
603    Json(MetricsResponse {
604        total_entries_generated: total_entries,
605        total_anomalies_injected: state
606            .server_state
607            .total_anomalies
608            .load(std::sync::atomic::Ordering::Relaxed),
609        uptime_seconds: uptime,
610        session_entries: total_entries,
611        session_entries_per_second: entries_per_second,
612        active_streams: state
613            .server_state
614            .active_streams
615            .load(std::sync::atomic::Ordering::Relaxed) as u32,
616        total_stream_events: state
617            .server_state
618            .total_stream_events
619            .load(std::sync::atomic::Ordering::Relaxed),
620    })
621}
622
623/// Get current configuration.
624async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
625    let config = state.server_state.config.read().await;
626
627    Json(ConfigResponse {
628        success: true,
629        message: "Current configuration".to_string(),
630        config: Some(GenerationConfigDto {
631            industry: format!("{:?}", config.global.industry),
632            start_date: config.global.start_date.clone(),
633            period_months: config.global.period_months,
634            seed: config.global.seed,
635            coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
636            companies: config
637                .companies
638                .iter()
639                .map(|c| CompanyConfigDto {
640                    code: c.code.clone(),
641                    name: c.name.clone(),
642                    currency: c.currency.clone(),
643                    country: c.country.clone(),
644                    annual_transaction_volume: c.annual_transaction_volume.count(),
645                    volume_weight: c.volume_weight as f32,
646                })
647                .collect(),
648            fraud_enabled: config.fraud.enabled,
649            fraud_rate: config.fraud.fraud_rate as f32,
650        }),
651    })
652}
653
654/// Set configuration.
655async fn set_config(
656    State(state): State<AppState>,
657    Json(new_config): Json<GenerationConfigDto>,
658) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
659    use datasynth_config::schema::{CompanyConfig, TransactionVolume};
660    use datasynth_core::models::{CoAComplexity, IndustrySector};
661
662    info!(
663        "Configuration update requested: industry={}, period_months={}",
664        new_config.industry, new_config.period_months
665    );
666
667    // Parse industry from string
668    let industry = match new_config.industry.to_lowercase().as_str() {
669        "manufacturing" => IndustrySector::Manufacturing,
670        "retail" => IndustrySector::Retail,
671        "financial_services" | "financialservices" => IndustrySector::FinancialServices,
672        "healthcare" => IndustrySector::Healthcare,
673        "technology" => IndustrySector::Technology,
674        "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
675        "energy" => IndustrySector::Energy,
676        "transportation" => IndustrySector::Transportation,
677        "real_estate" | "realestate" => IndustrySector::RealEstate,
678        "telecommunications" => IndustrySector::Telecommunications,
679        _ => {
680            return Err((
681                StatusCode::BAD_REQUEST,
682                Json(ConfigResponse {
683                    success: false,
684                    message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
685                    config: None,
686                }),
687            ));
688        }
689    };
690
691    // Parse CoA complexity from string
692    let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
693        "small" => CoAComplexity::Small,
694        "medium" => CoAComplexity::Medium,
695        "large" => CoAComplexity::Large,
696        _ => {
697            return Err((
698                StatusCode::BAD_REQUEST,
699                Json(ConfigResponse {
700                    success: false,
701                    message: format!(
702                        "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
703                        new_config.coa_complexity
704                    ),
705                    config: None,
706                }),
707            ));
708        }
709    };
710
711    // Convert CompanyConfigDto to CompanyConfig
712    let companies: Vec<CompanyConfig> = new_config
713        .companies
714        .iter()
715        .map(|c| CompanyConfig {
716            code: c.code.clone(),
717            name: c.name.clone(),
718            currency: c.currency.clone(),
719            country: c.country.clone(),
720            fiscal_year_variant: "K4".to_string(),
721            annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
722            volume_weight: c.volume_weight as f64,
723        })
724        .collect();
725
726    // Update the configuration
727    let mut config = state.server_state.config.write().await;
728    config.global.industry = industry;
729    config.global.start_date = new_config.start_date.clone();
730    config.global.period_months = new_config.period_months;
731    config.global.seed = new_config.seed;
732    config.chart_of_accounts.complexity = complexity;
733    config.fraud.enabled = new_config.fraud_enabled;
734    config.fraud.fraud_rate = new_config.fraud_rate as f64;
735
736    // Only update companies if provided
737    if !companies.is_empty() {
738        config.companies = companies;
739    }
740
741    info!("Configuration updated successfully");
742
743    Ok(Json(ConfigResponse {
744        success: true,
745        message: "Configuration updated and applied".to_string(),
746        config: Some(new_config),
747    }))
748}
749
750/// Bulk generation endpoint.
751async fn bulk_generate(
752    State(state): State<AppState>,
753    Json(req): Json<BulkGenerateRequest>,
754) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
755    // Validate entry_count bounds
756    const MAX_ENTRY_COUNT: u64 = 1_000_000;
757    if let Some(count) = req.entry_count {
758        if count > MAX_ENTRY_COUNT {
759            return Err((
760                StatusCode::BAD_REQUEST,
761                format!(
762                    "entry_count ({}) exceeds maximum allowed value ({})",
763                    count, MAX_ENTRY_COUNT
764                ),
765            ));
766        }
767    }
768
769    let config = state.server_state.config.read().await.clone();
770    let start_time = std::time::Instant::now();
771
772    let phase_config = PhaseConfig {
773        generate_master_data: req.include_master_data.unwrap_or(false),
774        generate_document_flows: false,
775        generate_journal_entries: true,
776        inject_anomalies: req.inject_anomalies.unwrap_or(false),
777        show_progress: false,
778        ..Default::default()
779    };
780
781    let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
782        (
783            StatusCode::INTERNAL_SERVER_ERROR,
784            format!("Failed to create orchestrator: {}", e),
785        )
786    })?;
787
788    let result = orchestrator.generate().map_err(|e| {
789        (
790            StatusCode::INTERNAL_SERVER_ERROR,
791            format!("Generation failed: {}", e),
792        )
793    })?;
794
795    let duration_ms = start_time.elapsed().as_millis() as u64;
796    let entries_count = result.journal_entries.len() as u64;
797    let anomaly_count = result.anomaly_labels.labels.len() as u64;
798
799    // Update metrics
800    state
801        .server_state
802        .total_entries
803        .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
804    state
805        .server_state
806        .total_anomalies
807        .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
808
809    Ok(Json(BulkGenerateResponse {
810        success: true,
811        entries_generated: entries_count,
812        duration_ms,
813        anomaly_count,
814    }))
815}
816
817/// Start streaming.
818async fn start_stream(
819    State(state): State<AppState>,
820    Json(_req): Json<StreamRequest>,
821) -> Json<StreamResponse> {
822    state
823        .server_state
824        .stream_stopped
825        .store(false, std::sync::atomic::Ordering::Relaxed);
826    state
827        .server_state
828        .stream_paused
829        .store(false, std::sync::atomic::Ordering::Relaxed);
830
831    Json(StreamResponse {
832        success: true,
833        message: "Stream started".to_string(),
834    })
835}
836
837/// Stop streaming.
838async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
839    state
840        .server_state
841        .stream_stopped
842        .store(true, std::sync::atomic::Ordering::Relaxed);
843
844    Json(StreamResponse {
845        success: true,
846        message: "Stream stopped".to_string(),
847    })
848}
849
850/// Pause streaming.
851async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
852    state
853        .server_state
854        .stream_paused
855        .store(true, std::sync::atomic::Ordering::Relaxed);
856
857    Json(StreamResponse {
858        success: true,
859        message: "Stream paused".to_string(),
860    })
861}
862
863/// Resume streaming.
864async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
865    state
866        .server_state
867        .stream_paused
868        .store(false, std::sync::atomic::Ordering::Relaxed);
869
870    Json(StreamResponse {
871        success: true,
872        message: "Stream resumed".to_string(),
873    })
874}
875
876/// Trigger a specific pattern.
877///
878/// Valid patterns: year_end_spike, period_end_spike, holiday_cluster,
879/// fraud_cluster, error_cluster, uniform, or custom:* patterns.
880async fn trigger_pattern(
881    State(state): State<AppState>,
882    axum::extract::Path(pattern): axum::extract::Path<String>,
883) -> Json<StreamResponse> {
884    info!("Pattern trigger requested: {}", pattern);
885
886    // Validate pattern name
887    let valid_patterns = [
888        "year_end_spike",
889        "period_end_spike",
890        "holiday_cluster",
891        "fraud_cluster",
892        "error_cluster",
893        "uniform",
894    ];
895
896    let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
897
898    if !is_valid {
899        return Json(StreamResponse {
900            success: false,
901            message: format!(
902                "Unknown pattern '{}'. Valid patterns: {:?}, or use 'custom:name' for custom patterns",
903                pattern, valid_patterns
904            ),
905        });
906    }
907
908    // Store the pattern for the stream generator to pick up
909    match state.server_state.triggered_pattern.try_write() {
910        Ok(mut triggered) => {
911            *triggered = Some(pattern.clone());
912            Json(StreamResponse {
913                success: true,
914                message: format!("Pattern '{}' will be applied to upcoming entries", pattern),
915            })
916        }
917        Err(_) => Json(StreamResponse {
918            success: false,
919            message: "Failed to acquire lock for pattern trigger".to_string(),
920        }),
921    }
922}
923
924/// WebSocket endpoint for metrics stream.
925async fn websocket_metrics(
926    ws: WebSocketUpgrade,
927    State(state): State<AppState>,
928) -> impl IntoResponse {
929    ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
930}
931
932/// WebSocket endpoint for event stream.
933async fn websocket_events(
934    ws: WebSocketUpgrade,
935    State(state): State<AppState>,
936) -> impl IntoResponse {
937    ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
938}
939
940// ===========================================================================
941// Job Queue Handlers
942// ===========================================================================
943
944/// Submit a new async generation job.
945async fn submit_job(
946    State(state): State<AppState>,
947    Json(request): Json<JobRequest>,
948) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
949    let queue = state.job_queue.as_ref().ok_or_else(|| {
950        (
951            StatusCode::SERVICE_UNAVAILABLE,
952            Json(serde_json::json!({"error": "Job queue not enabled"})),
953        )
954    })?;
955
956    let job_id = queue.submit(request).await;
957    info!("Job submitted: {}", job_id);
958
959    Ok((
960        StatusCode::CREATED,
961        Json(serde_json::json!({
962            "id": job_id.to_string(),
963            "status": "queued"
964        })),
965    ))
966}
967
968/// Get status of a specific job.
969async fn get_job(
970    State(state): State<AppState>,
971    axum::extract::Path(id): axum::extract::Path<String>,
972) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
973    let queue = state.job_queue.as_ref().ok_or_else(|| {
974        (
975            StatusCode::SERVICE_UNAVAILABLE,
976            Json(serde_json::json!({"error": "Job queue not enabled"})),
977        )
978    })?;
979
980    match queue.get(&id).await {
981        Some(entry) => Ok(Json(serde_json::json!({
982            "id": entry.id,
983            "status": format!("{:?}", entry.status).to_lowercase(),
984            "submitted_at": entry.submitted_at.to_rfc3339(),
985            "started_at": entry.started_at.map(|t| t.to_rfc3339()),
986            "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
987            "result": entry.result,
988        }))),
989        None => Err((
990            StatusCode::NOT_FOUND,
991            Json(serde_json::json!({"error": "Job not found"})),
992        )),
993    }
994}
995
996/// List all jobs.
997async fn list_jobs(
998    State(state): State<AppState>,
999) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1000    let queue = state.job_queue.as_ref().ok_or_else(|| {
1001        (
1002            StatusCode::SERVICE_UNAVAILABLE,
1003            Json(serde_json::json!({"error": "Job queue not enabled"})),
1004        )
1005    })?;
1006
1007    let summaries: Vec<_> = queue
1008        .list()
1009        .await
1010        .into_iter()
1011        .map(|s| {
1012            serde_json::json!({
1013                "id": s.id,
1014                "status": format!("{:?}", s.status).to_lowercase(),
1015                "submitted_at": s.submitted_at.to_rfc3339(),
1016            })
1017        })
1018        .collect();
1019
1020    Ok(Json(serde_json::json!({ "jobs": summaries })))
1021}
1022
1023/// Cancel a queued job.
1024async fn cancel_job(
1025    State(state): State<AppState>,
1026    axum::extract::Path(id): axum::extract::Path<String>,
1027) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1028    let queue = state.job_queue.as_ref().ok_or_else(|| {
1029        (
1030            StatusCode::SERVICE_UNAVAILABLE,
1031            Json(serde_json::json!({"error": "Job queue not enabled"})),
1032        )
1033    })?;
1034
1035    if queue.cancel(&id).await {
1036        Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1037    } else {
1038        Err((
1039            StatusCode::CONFLICT,
1040            Json(
1041                serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1042            ),
1043        ))
1044    }
1045}
1046
1047// ===========================================================================
1048// Config Reload Handler
1049// ===========================================================================
1050
1051/// Reload configuration from the configured source.
1052async fn reload_config(
1053    State(state): State<AppState>,
1054) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1055    // Reload from default config source
1056    let new_config = crate::grpc::service::default_generator_config();
1057    let mut config = state.server_state.config.write().await;
1058    *config = new_config;
1059    info!("Configuration reloaded via REST API");
1060
1061    Ok(Json(serde_json::json!({
1062        "success": true,
1063        "message": "Configuration reloaded"
1064    })))
1065}
1066
1067#[cfg(test)]
1068#[allow(clippy::unwrap_used)]
1069mod tests {
1070    use super::*;
1071
1072    // ==========================================================================
1073    // Response Serialization Tests
1074    // ==========================================================================
1075
1076    #[test]
1077    fn test_health_response_serialization() {
1078        let response = HealthResponse {
1079            healthy: true,
1080            version: "0.1.0".to_string(),
1081            uptime_seconds: 100,
1082        };
1083        let json = serde_json::to_string(&response).unwrap();
1084        assert!(json.contains("healthy"));
1085        assert!(json.contains("version"));
1086        assert!(json.contains("uptime_seconds"));
1087    }
1088
1089    #[test]
1090    fn test_health_response_deserialization() {
1091        let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1092        let response: HealthResponse = serde_json::from_str(json).unwrap();
1093        assert!(response.healthy);
1094        assert_eq!(response.version, "0.1.0");
1095        assert_eq!(response.uptime_seconds, 100);
1096    }
1097
1098    #[test]
1099    fn test_metrics_response_serialization() {
1100        let response = MetricsResponse {
1101            total_entries_generated: 1000,
1102            total_anomalies_injected: 10,
1103            uptime_seconds: 60,
1104            session_entries: 1000,
1105            session_entries_per_second: 16.67,
1106            active_streams: 1,
1107            total_stream_events: 500,
1108        };
1109        let json = serde_json::to_string(&response).unwrap();
1110        assert!(json.contains("total_entries_generated"));
1111        assert!(json.contains("session_entries_per_second"));
1112    }
1113
1114    #[test]
1115    fn test_metrics_response_deserialization() {
1116        let json = r#"{
1117            "total_entries_generated": 5000,
1118            "total_anomalies_injected": 50,
1119            "uptime_seconds": 300,
1120            "session_entries": 5000,
1121            "session_entries_per_second": 16.67,
1122            "active_streams": 2,
1123            "total_stream_events": 10000
1124        }"#;
1125        let response: MetricsResponse = serde_json::from_str(json).unwrap();
1126        assert_eq!(response.total_entries_generated, 5000);
1127        assert_eq!(response.active_streams, 2);
1128    }
1129
1130    #[test]
1131    fn test_config_response_serialization() {
1132        let response = ConfigResponse {
1133            success: true,
1134            message: "Configuration loaded".to_string(),
1135            config: Some(GenerationConfigDto {
1136                industry: "manufacturing".to_string(),
1137                start_date: "2024-01-01".to_string(),
1138                period_months: 12,
1139                seed: Some(42),
1140                coa_complexity: "medium".to_string(),
1141                companies: vec![],
1142                fraud_enabled: false,
1143                fraud_rate: 0.0,
1144            }),
1145        };
1146        let json = serde_json::to_string(&response).unwrap();
1147        assert!(json.contains("success"));
1148        assert!(json.contains("config"));
1149    }
1150
1151    #[test]
1152    fn test_config_response_without_config() {
1153        let response = ConfigResponse {
1154            success: false,
1155            message: "No configuration available".to_string(),
1156            config: None,
1157        };
1158        let json = serde_json::to_string(&response).unwrap();
1159        assert!(json.contains("null") || json.contains("config\":null"));
1160    }
1161
1162    #[test]
1163    fn test_generation_config_dto_roundtrip() {
1164        let original = GenerationConfigDto {
1165            industry: "retail".to_string(),
1166            start_date: "2024-06-01".to_string(),
1167            period_months: 6,
1168            seed: Some(12345),
1169            coa_complexity: "large".to_string(),
1170            companies: vec![CompanyConfigDto {
1171                code: "1000".to_string(),
1172                name: "Test Corp".to_string(),
1173                currency: "USD".to_string(),
1174                country: "US".to_string(),
1175                annual_transaction_volume: 100000,
1176                volume_weight: 1.0,
1177            }],
1178            fraud_enabled: true,
1179            fraud_rate: 0.05,
1180        };
1181
1182        let json = serde_json::to_string(&original).unwrap();
1183        let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1184
1185        assert_eq!(original.industry, deserialized.industry);
1186        assert_eq!(original.seed, deserialized.seed);
1187        assert_eq!(original.companies.len(), deserialized.companies.len());
1188    }
1189
1190    #[test]
1191    fn test_company_config_dto_serialization() {
1192        let company = CompanyConfigDto {
1193            code: "2000".to_string(),
1194            name: "European Subsidiary".to_string(),
1195            currency: "EUR".to_string(),
1196            country: "DE".to_string(),
1197            annual_transaction_volume: 50000,
1198            volume_weight: 0.5,
1199        };
1200        let json = serde_json::to_string(&company).unwrap();
1201        assert!(json.contains("2000"));
1202        assert!(json.contains("EUR"));
1203        assert!(json.contains("DE"));
1204    }
1205
1206    #[test]
1207    fn test_bulk_generate_request_deserialization() {
1208        let json = r#"{
1209            "entry_count": 5000,
1210            "include_master_data": true,
1211            "inject_anomalies": true
1212        }"#;
1213        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1214        assert_eq!(request.entry_count, Some(5000));
1215        assert_eq!(request.include_master_data, Some(true));
1216        assert_eq!(request.inject_anomalies, Some(true));
1217    }
1218
1219    #[test]
1220    fn test_bulk_generate_request_with_defaults() {
1221        let json = r#"{}"#;
1222        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1223        assert_eq!(request.entry_count, None);
1224        assert_eq!(request.include_master_data, None);
1225        assert_eq!(request.inject_anomalies, None);
1226    }
1227
1228    #[test]
1229    fn test_bulk_generate_response_serialization() {
1230        let response = BulkGenerateResponse {
1231            success: true,
1232            entries_generated: 1000,
1233            duration_ms: 250,
1234            anomaly_count: 20,
1235        };
1236        let json = serde_json::to_string(&response).unwrap();
1237        assert!(json.contains("entries_generated"));
1238        assert!(json.contains("1000"));
1239        assert!(json.contains("duration_ms"));
1240    }
1241
1242    #[test]
1243    fn test_stream_response_serialization() {
1244        let response = StreamResponse {
1245            success: true,
1246            message: "Stream started successfully".to_string(),
1247        };
1248        let json = serde_json::to_string(&response).unwrap();
1249        assert!(json.contains("success"));
1250        assert!(json.contains("Stream started"));
1251    }
1252
1253    #[test]
1254    fn test_stream_response_failure() {
1255        let response = StreamResponse {
1256            success: false,
1257            message: "Stream failed to start".to_string(),
1258        };
1259        let json = serde_json::to_string(&response).unwrap();
1260        assert!(json.contains("false"));
1261        assert!(json.contains("failed"));
1262    }
1263
1264    // ==========================================================================
1265    // CORS Configuration Tests
1266    // ==========================================================================
1267
1268    #[test]
1269    fn test_cors_config_default() {
1270        let config = CorsConfig::default();
1271        assert!(!config.allow_any_origin);
1272        assert!(!config.allowed_origins.is_empty());
1273        assert!(config
1274            .allowed_origins
1275            .contains(&"http://localhost:5173".to_string()));
1276        assert!(config
1277            .allowed_origins
1278            .contains(&"tauri://localhost".to_string()));
1279    }
1280
1281    #[test]
1282    fn test_cors_config_custom_origins() {
1283        let config = CorsConfig {
1284            allowed_origins: vec![
1285                "https://example.com".to_string(),
1286                "https://app.example.com".to_string(),
1287            ],
1288            allow_any_origin: false,
1289        };
1290        assert_eq!(config.allowed_origins.len(), 2);
1291        assert!(config
1292            .allowed_origins
1293            .contains(&"https://example.com".to_string()));
1294    }
1295
1296    #[test]
1297    fn test_cors_config_permissive() {
1298        let config = CorsConfig {
1299            allowed_origins: vec![],
1300            allow_any_origin: true,
1301        };
1302        assert!(config.allow_any_origin);
1303    }
1304
1305    // ==========================================================================
1306    // Request Validation Tests (edge cases)
1307    // ==========================================================================
1308
1309    #[test]
1310    fn test_bulk_generate_request_partial() {
1311        let json = r#"{"entry_count": 100}"#;
1312        let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1313        assert_eq!(request.entry_count, Some(100));
1314        assert!(request.include_master_data.is_none());
1315    }
1316
1317    #[test]
1318    fn test_generation_config_no_seed() {
1319        let config = GenerationConfigDto {
1320            industry: "technology".to_string(),
1321            start_date: "2024-01-01".to_string(),
1322            period_months: 3,
1323            seed: None,
1324            coa_complexity: "small".to_string(),
1325            companies: vec![],
1326            fraud_enabled: false,
1327            fraud_rate: 0.0,
1328        };
1329        let json = serde_json::to_string(&config).unwrap();
1330        assert!(json.contains("seed"));
1331    }
1332
1333    #[test]
1334    fn test_generation_config_multiple_companies() {
1335        let config = GenerationConfigDto {
1336            industry: "manufacturing".to_string(),
1337            start_date: "2024-01-01".to_string(),
1338            period_months: 12,
1339            seed: Some(42),
1340            coa_complexity: "large".to_string(),
1341            companies: vec![
1342                CompanyConfigDto {
1343                    code: "1000".to_string(),
1344                    name: "Headquarters".to_string(),
1345                    currency: "USD".to_string(),
1346                    country: "US".to_string(),
1347                    annual_transaction_volume: 100000,
1348                    volume_weight: 1.0,
1349                },
1350                CompanyConfigDto {
1351                    code: "2000".to_string(),
1352                    name: "European Sub".to_string(),
1353                    currency: "EUR".to_string(),
1354                    country: "DE".to_string(),
1355                    annual_transaction_volume: 50000,
1356                    volume_weight: 0.5,
1357                },
1358                CompanyConfigDto {
1359                    code: "3000".to_string(),
1360                    name: "APAC Sub".to_string(),
1361                    currency: "JPY".to_string(),
1362                    country: "JP".to_string(),
1363                    annual_transaction_volume: 30000,
1364                    volume_weight: 0.3,
1365                },
1366            ],
1367            fraud_enabled: true,
1368            fraud_rate: 0.02,
1369        };
1370        assert_eq!(config.companies.len(), 3);
1371    }
1372
1373    // ==========================================================================
1374    // Metrics Calculation Tests
1375    // ==========================================================================
1376
1377    #[test]
1378    fn test_metrics_entries_per_second_calculation() {
1379        // Test that we can represent the expected calculation
1380        let total_entries: u64 = 1000;
1381        let uptime: u64 = 60;
1382        let eps = if uptime > 0 {
1383            total_entries as f64 / uptime as f64
1384        } else {
1385            0.0
1386        };
1387        assert!((eps - 16.67).abs() < 0.1);
1388    }
1389
1390    #[test]
1391    fn test_metrics_entries_per_second_zero_uptime() {
1392        let total_entries: u64 = 1000;
1393        let uptime: u64 = 0;
1394        let eps = if uptime > 0 {
1395            total_entries as f64 / uptime as f64
1396        } else {
1397            0.0
1398        };
1399        assert_eq!(eps, 0.0);
1400    }
1401}