1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 pub server_state: Arc<ServerState>,
28 pub job_queue: Option<Arc<JobQueue>>,
29}
30
31#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34 pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39 fn default() -> Self {
40 Self {
41 request_timeout_secs: 300,
43 }
44 }
45}
46
47impl TimeoutConfig {
48 pub fn new(timeout_secs: u64) -> Self {
50 Self {
51 request_timeout_secs: timeout_secs,
52 }
53 }
54}
55
56#[derive(Clone)]
58pub struct CorsConfig {
59 pub allowed_origins: Vec<String>,
61 pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66 fn default() -> Self {
67 Self {
68 allowed_origins: vec![
69 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
72 "http://127.0.0.1:3000".to_string(),
73 "tauri://localhost".to_string(), ],
75 allow_any_origin: false,
76 }
77 }
78}
79
80async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82 let (mut parts, body) = response.into_parts();
83 parts.headers.insert(
84 axum::http::HeaderName::from_static("x-api-version"),
85 axum::http::HeaderValue::from_static("v1"),
86 );
87 axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97pub fn create_router(service: SynthService) -> Router {
99 create_router_with_cors(service, CorsConfig::default())
100}
101
102pub fn create_router_full(
107 service: SynthService,
108 cors_config: CorsConfig,
109 auth_config: AuthConfig,
110 rate_limit_config: RateLimitConfig,
111 timeout_config: TimeoutConfig,
112) -> Router {
113 let backend = RateLimitBackend::in_memory(rate_limit_config);
114 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117pub fn create_router_full_with_backend(
133 service: SynthService,
134 cors_config: CorsConfig,
135 auth_config: AuthConfig,
136 rate_limit_backend: RateLimitBackend,
137 timeout_config: TimeoutConfig,
138) -> Router {
139 let server_state = service.state.clone();
140 let state = AppState {
141 server_state,
142 job_queue: None,
143 };
144
145 let cors = if cors_config.allow_any_origin {
146 CorsLayer::permissive()
147 } else {
148 let origins: Vec<_> = cors_config
149 .allowed_origins
150 .iter()
151 .filter_map(|o| o.parse().ok())
152 .collect();
153
154 CorsLayer::new()
155 .allow_origin(AllowOrigin::list(origins))
156 .allow_methods([
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::DELETE,
161 Method::OPTIONS,
162 ])
163 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164 };
165
166 Router::new()
167 .route("/health", get(health_check))
169 .route("/ready", get(readiness_check))
170 .route("/live", get(liveness_check))
171 .route("/api/metrics", get(get_metrics))
172 .route("/metrics", get(prometheus_metrics))
173 .route("/api/config", get(get_config))
175 .route("/api/config", post(set_config))
176 .route("/api/config/reload", post(reload_config))
177 .route("/api/generate/bulk", post(bulk_generate))
179 .route("/api/stream/start", post(start_stream))
180 .route("/api/stream/stop", post(stop_stream))
181 .route("/api/stream/pause", post(pause_stream))
182 .route("/api/stream/resume", post(resume_stream))
183 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184 .route("/api/jobs/submit", post(submit_job))
186 .route("/api/jobs", get(list_jobs))
187 .route("/api/jobs/{id}", get(get_job))
188 .route("/api/jobs/{id}/cancel", post(cancel_job))
189 .route("/ws/metrics", get(websocket_metrics))
191 .route("/ws/events", get(websocket_events))
192 .layer(axum::middleware::from_fn(security_headers_middleware))
195 .layer(axum::middleware::map_response(api_version_header))
196 .layer(cors)
197 .layer(axum::middleware::from_fn(request_id_middleware))
198 .layer(axum::middleware::from_fn(auth_middleware))
199 .layer(axum::Extension(auth_config))
200 .layer(axum::middleware::from_fn(request_validation_middleware))
201 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202 .layer(axum::Extension(rate_limit_backend))
203 .layer(TimeoutLayer::new(Duration::from_secs(
204 timeout_config.request_timeout_secs,
205 )))
206 .with_state(state)
207}
208
209pub fn create_router_with_auth(
211 service: SynthService,
212 cors_config: CorsConfig,
213 auth_config: AuthConfig,
214) -> Router {
215 let server_state = service.state.clone();
216 let state = AppState {
217 server_state,
218 job_queue: None,
219 };
220
221 let cors = if cors_config.allow_any_origin {
222 CorsLayer::permissive()
223 } else {
224 let origins: Vec<_> = cors_config
225 .allowed_origins
226 .iter()
227 .filter_map(|o| o.parse().ok())
228 .collect();
229
230 CorsLayer::new()
231 .allow_origin(AllowOrigin::list(origins))
232 .allow_methods([
233 Method::GET,
234 Method::POST,
235 Method::PUT,
236 Method::DELETE,
237 Method::OPTIONS,
238 ])
239 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
240 };
241
242 Router::new()
243 .route("/health", get(health_check))
245 .route("/ready", get(readiness_check))
246 .route("/live", get(liveness_check))
247 .route("/api/metrics", get(get_metrics))
248 .route("/metrics", get(prometheus_metrics))
249 .route("/api/config", get(get_config))
251 .route("/api/config", post(set_config))
252 .route("/api/config/reload", post(reload_config))
253 .route("/api/generate/bulk", post(bulk_generate))
255 .route("/api/stream/start", post(start_stream))
256 .route("/api/stream/stop", post(stop_stream))
257 .route("/api/stream/pause", post(pause_stream))
258 .route("/api/stream/resume", post(resume_stream))
259 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
260 .route("/api/jobs/submit", post(submit_job))
262 .route("/api/jobs", get(list_jobs))
263 .route("/api/jobs/{id}", get(get_job))
264 .route("/api/jobs/{id}/cancel", post(cancel_job))
265 .route("/ws/metrics", get(websocket_metrics))
267 .route("/ws/events", get(websocket_events))
268 .layer(axum::middleware::from_fn(auth_middleware))
269 .layer(axum::Extension(auth_config))
270 .layer(cors)
271 .with_state(state)
272}
273
274pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
276 let server_state = service.state.clone();
277 let state = AppState {
278 server_state,
279 job_queue: None,
280 };
281
282 let cors = if cors_config.allow_any_origin {
283 CorsLayer::permissive()
285 } else {
286 let origins: Vec<_> = cors_config
288 .allowed_origins
289 .iter()
290 .filter_map(|o| o.parse().ok())
291 .collect();
292
293 CorsLayer::new()
294 .allow_origin(AllowOrigin::list(origins))
295 .allow_methods([
296 Method::GET,
297 Method::POST,
298 Method::PUT,
299 Method::DELETE,
300 Method::OPTIONS,
301 ])
302 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
303 };
304
305 Router::new()
306 .route("/health", get(health_check))
308 .route("/ready", get(readiness_check))
309 .route("/live", get(liveness_check))
310 .route("/api/metrics", get(get_metrics))
311 .route("/metrics", get(prometheus_metrics))
312 .route("/api/config", get(get_config))
314 .route("/api/config", post(set_config))
315 .route("/api/config/reload", post(reload_config))
316 .route("/api/generate/bulk", post(bulk_generate))
318 .route("/api/stream/start", post(start_stream))
319 .route("/api/stream/stop", post(stop_stream))
320 .route("/api/stream/pause", post(pause_stream))
321 .route("/api/stream/resume", post(resume_stream))
322 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
323 .route("/api/jobs/submit", post(submit_job))
325 .route("/api/jobs", get(list_jobs))
326 .route("/api/jobs/{id}", get(get_job))
327 .route("/api/jobs/{id}/cancel", post(cancel_job))
328 .route("/ws/metrics", get(websocket_metrics))
330 .route("/ws/events", get(websocket_events))
331 .layer(cors)
332 .with_state(state)
333}
334
335#[derive(Debug, Serialize, Deserialize)]
340pub struct HealthResponse {
341 pub healthy: bool,
342 pub version: String,
343 pub uptime_seconds: u64,
344}
345
346#[derive(Debug, Serialize, Deserialize)]
348pub struct ReadinessResponse {
349 pub ready: bool,
350 pub message: String,
351 pub checks: Vec<HealthCheck>,
352}
353
354#[derive(Debug, Serialize, Deserialize)]
356pub struct HealthCheck {
357 pub name: String,
358 pub status: String,
359}
360
361#[derive(Debug, Serialize, Deserialize)]
363pub struct LivenessResponse {
364 pub alive: bool,
365 pub timestamp: String,
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct MetricsResponse {
370 pub total_entries_generated: u64,
371 pub total_anomalies_injected: u64,
372 pub uptime_seconds: u64,
373 pub session_entries: u64,
374 pub session_entries_per_second: f64,
375 pub active_streams: u32,
376 pub total_stream_events: u64,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ConfigResponse {
381 pub success: bool,
382 pub message: String,
383 pub config: Option<GenerationConfigDto>,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct GenerationConfigDto {
388 pub industry: String,
389 pub start_date: String,
390 pub period_months: u32,
391 pub seed: Option<u64>,
392 pub coa_complexity: String,
393 pub companies: Vec<CompanyConfigDto>,
394 pub fraud_enabled: bool,
395 pub fraud_rate: f32,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct CompanyConfigDto {
400 pub code: String,
401 pub name: String,
402 pub currency: String,
403 pub country: String,
404 pub annual_transaction_volume: u64,
405 pub volume_weight: f32,
406}
407
408#[derive(Debug, Deserialize)]
409pub struct BulkGenerateRequest {
410 pub entry_count: Option<u64>,
411 pub include_master_data: Option<bool>,
412 pub inject_anomalies: Option<bool>,
413}
414
415#[derive(Debug, Serialize)]
416pub struct BulkGenerateResponse {
417 pub success: bool,
418 pub entries_generated: u64,
419 pub duration_ms: u64,
420 pub anomaly_count: u64,
421}
422
423#[derive(Debug, Deserialize)]
424#[allow(dead_code)] pub struct StreamRequest {
426 pub events_per_second: Option<u32>,
427 pub max_events: Option<u64>,
428 pub inject_anomalies: Option<bool>,
429}
430
431#[derive(Debug, Serialize)]
432pub struct StreamResponse {
433 pub success: bool,
434 pub message: String,
435}
436
437async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
443 Json(HealthResponse {
444 healthy: true,
445 version: env!("CARGO_PKG_VERSION").to_string(),
446 uptime_seconds: state.server_state.uptime_seconds(),
447 })
448}
449
450async fn readiness_check(
453 State(state): State<AppState>,
454) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
455 let mut checks = Vec::new();
456 let mut any_fail = false;
457
458 let config = state.server_state.config.read().await;
460 let config_valid = !config.companies.is_empty();
461 checks.push(HealthCheck {
462 name: "config".to_string(),
463 status: if config_valid { "ok" } else { "fail" }.to_string(),
464 });
465 if !config_valid {
466 any_fail = true;
467 }
468 drop(config);
469
470 let resource_status = state.server_state.resource_status();
472 let memory_status = if resource_status.degradation_level == "Emergency" {
473 any_fail = true;
474 "fail"
475 } else if resource_status.degradation_level != "Normal" {
476 "degraded"
477 } else {
478 "ok"
479 };
480 checks.push(HealthCheck {
481 name: "memory".to_string(),
482 status: memory_status.to_string(),
483 });
484
485 let disk_ok = resource_status.disk_available_mb > 100;
487 checks.push(HealthCheck {
488 name: "disk".to_string(),
489 status: if disk_ok { "ok" } else { "fail" }.to_string(),
490 });
491 if !disk_ok {
492 any_fail = true;
493 }
494
495 let response = ReadinessResponse {
496 ready: !any_fail,
497 message: if any_fail {
498 "Service is not ready".to_string()
499 } else {
500 "Service is ready".to_string()
501 },
502 checks,
503 };
504
505 if any_fail {
506 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
507 } else {
508 Ok(Json(response))
509 }
510}
511
512async fn liveness_check() -> Json<LivenessResponse> {
515 Json(LivenessResponse {
516 alive: true,
517 timestamp: chrono::Utc::now().to_rfc3339(),
518 })
519}
520
521async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
524 use std::sync::atomic::Ordering;
525
526 let uptime = state.server_state.uptime_seconds();
527 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
528 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
529 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
530 let total_stream_events = state
531 .server_state
532 .total_stream_events
533 .load(Ordering::Relaxed);
534
535 let entries_per_second = if uptime > 0 {
536 total_entries as f64 / uptime as f64
537 } else {
538 0.0
539 };
540
541 let metrics = format!(
542 r#"# HELP synth_entries_generated_total Total number of journal entries generated
543# TYPE synth_entries_generated_total counter
544synth_entries_generated_total {}
545
546# HELP synth_anomalies_injected_total Total number of anomalies injected
547# TYPE synth_anomalies_injected_total counter
548synth_anomalies_injected_total {}
549
550# HELP synth_uptime_seconds Server uptime in seconds
551# TYPE synth_uptime_seconds gauge
552synth_uptime_seconds {}
553
554# HELP synth_entries_per_second Rate of entry generation
555# TYPE synth_entries_per_second gauge
556synth_entries_per_second {:.2}
557
558# HELP synth_active_streams Number of active streaming connections
559# TYPE synth_active_streams gauge
560synth_active_streams {}
561
562# HELP synth_stream_events_total Total events sent through streams
563# TYPE synth_stream_events_total counter
564synth_stream_events_total {}
565
566# HELP synth_info Server version information
567# TYPE synth_info gauge
568synth_info{{version="{}"}} 1
569"#,
570 total_entries,
571 total_anomalies,
572 uptime,
573 entries_per_second,
574 active_streams,
575 total_stream_events,
576 env!("CARGO_PKG_VERSION")
577 );
578
579 (
580 StatusCode::OK,
581 [(
582 header::CONTENT_TYPE,
583 "text/plain; version=0.0.4; charset=utf-8",
584 )],
585 metrics,
586 )
587}
588
589async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
591 let uptime = state.server_state.uptime_seconds();
592 let total_entries = state
593 .server_state
594 .total_entries
595 .load(std::sync::atomic::Ordering::Relaxed);
596
597 let entries_per_second = if uptime > 0 {
598 total_entries as f64 / uptime as f64
599 } else {
600 0.0
601 };
602
603 Json(MetricsResponse {
604 total_entries_generated: total_entries,
605 total_anomalies_injected: state
606 .server_state
607 .total_anomalies
608 .load(std::sync::atomic::Ordering::Relaxed),
609 uptime_seconds: uptime,
610 session_entries: total_entries,
611 session_entries_per_second: entries_per_second,
612 active_streams: state
613 .server_state
614 .active_streams
615 .load(std::sync::atomic::Ordering::Relaxed) as u32,
616 total_stream_events: state
617 .server_state
618 .total_stream_events
619 .load(std::sync::atomic::Ordering::Relaxed),
620 })
621}
622
623async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
625 let config = state.server_state.config.read().await;
626
627 Json(ConfigResponse {
628 success: true,
629 message: "Current configuration".to_string(),
630 config: Some(GenerationConfigDto {
631 industry: format!("{:?}", config.global.industry),
632 start_date: config.global.start_date.clone(),
633 period_months: config.global.period_months,
634 seed: config.global.seed,
635 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
636 companies: config
637 .companies
638 .iter()
639 .map(|c| CompanyConfigDto {
640 code: c.code.clone(),
641 name: c.name.clone(),
642 currency: c.currency.clone(),
643 country: c.country.clone(),
644 annual_transaction_volume: c.annual_transaction_volume.count(),
645 volume_weight: c.volume_weight as f32,
646 })
647 .collect(),
648 fraud_enabled: config.fraud.enabled,
649 fraud_rate: config.fraud.fraud_rate as f32,
650 }),
651 })
652}
653
654async fn set_config(
656 State(state): State<AppState>,
657 Json(new_config): Json<GenerationConfigDto>,
658) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
659 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
660 use datasynth_core::models::{CoAComplexity, IndustrySector};
661
662 info!(
663 "Configuration update requested: industry={}, period_months={}",
664 new_config.industry, new_config.period_months
665 );
666
667 let industry = match new_config.industry.to_lowercase().as_str() {
669 "manufacturing" => IndustrySector::Manufacturing,
670 "retail" => IndustrySector::Retail,
671 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
672 "healthcare" => IndustrySector::Healthcare,
673 "technology" => IndustrySector::Technology,
674 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
675 "energy" => IndustrySector::Energy,
676 "transportation" => IndustrySector::Transportation,
677 "real_estate" | "realestate" => IndustrySector::RealEstate,
678 "telecommunications" => IndustrySector::Telecommunications,
679 _ => {
680 return Err((
681 StatusCode::BAD_REQUEST,
682 Json(ConfigResponse {
683 success: false,
684 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
685 config: None,
686 }),
687 ));
688 }
689 };
690
691 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
693 "small" => CoAComplexity::Small,
694 "medium" => CoAComplexity::Medium,
695 "large" => CoAComplexity::Large,
696 _ => {
697 return Err((
698 StatusCode::BAD_REQUEST,
699 Json(ConfigResponse {
700 success: false,
701 message: format!(
702 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
703 new_config.coa_complexity
704 ),
705 config: None,
706 }),
707 ));
708 }
709 };
710
711 let companies: Vec<CompanyConfig> = new_config
713 .companies
714 .iter()
715 .map(|c| CompanyConfig {
716 code: c.code.clone(),
717 name: c.name.clone(),
718 currency: c.currency.clone(),
719 country: c.country.clone(),
720 fiscal_year_variant: "K4".to_string(),
721 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
722 volume_weight: c.volume_weight as f64,
723 })
724 .collect();
725
726 let mut config = state.server_state.config.write().await;
728 config.global.industry = industry;
729 config.global.start_date = new_config.start_date.clone();
730 config.global.period_months = new_config.period_months;
731 config.global.seed = new_config.seed;
732 config.chart_of_accounts.complexity = complexity;
733 config.fraud.enabled = new_config.fraud_enabled;
734 config.fraud.fraud_rate = new_config.fraud_rate as f64;
735
736 if !companies.is_empty() {
738 config.companies = companies;
739 }
740
741 info!("Configuration updated successfully");
742
743 Ok(Json(ConfigResponse {
744 success: true,
745 message: "Configuration updated and applied".to_string(),
746 config: Some(new_config),
747 }))
748}
749
750async fn bulk_generate(
752 State(state): State<AppState>,
753 Json(req): Json<BulkGenerateRequest>,
754) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
755 const MAX_ENTRY_COUNT: u64 = 1_000_000;
757 if let Some(count) = req.entry_count {
758 if count > MAX_ENTRY_COUNT {
759 return Err((
760 StatusCode::BAD_REQUEST,
761 format!(
762 "entry_count ({}) exceeds maximum allowed value ({})",
763 count, MAX_ENTRY_COUNT
764 ),
765 ));
766 }
767 }
768
769 let config = state.server_state.config.read().await.clone();
770 let start_time = std::time::Instant::now();
771
772 let phase_config = PhaseConfig {
773 generate_master_data: req.include_master_data.unwrap_or(false),
774 generate_document_flows: false,
775 generate_journal_entries: true,
776 inject_anomalies: req.inject_anomalies.unwrap_or(false),
777 show_progress: false,
778 ..Default::default()
779 };
780
781 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
782 (
783 StatusCode::INTERNAL_SERVER_ERROR,
784 format!("Failed to create orchestrator: {}", e),
785 )
786 })?;
787
788 let result = orchestrator.generate().map_err(|e| {
789 (
790 StatusCode::INTERNAL_SERVER_ERROR,
791 format!("Generation failed: {}", e),
792 )
793 })?;
794
795 let duration_ms = start_time.elapsed().as_millis() as u64;
796 let entries_count = result.journal_entries.len() as u64;
797 let anomaly_count = result.anomaly_labels.labels.len() as u64;
798
799 state
801 .server_state
802 .total_entries
803 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
804 state
805 .server_state
806 .total_anomalies
807 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
808
809 Ok(Json(BulkGenerateResponse {
810 success: true,
811 entries_generated: entries_count,
812 duration_ms,
813 anomaly_count,
814 }))
815}
816
817async fn start_stream(
819 State(state): State<AppState>,
820 Json(req): Json<StreamRequest>,
821) -> Json<StreamResponse> {
822 if let Some(eps) = req.events_per_second {
824 info!("Stream configured: events_per_second={}", eps);
825 state
826 .server_state
827 .stream_events_per_second
828 .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
829 }
830 if let Some(max) = req.max_events {
831 info!("Stream configured: max_events={}", max);
832 state
833 .server_state
834 .stream_max_events
835 .store(max, std::sync::atomic::Ordering::Relaxed);
836 }
837 if let Some(inject) = req.inject_anomalies {
838 info!("Stream configured: inject_anomalies={}", inject);
839 state
840 .server_state
841 .stream_inject_anomalies
842 .store(inject, std::sync::atomic::Ordering::Relaxed);
843 }
844
845 state
846 .server_state
847 .stream_stopped
848 .store(false, std::sync::atomic::Ordering::Relaxed);
849 state
850 .server_state
851 .stream_paused
852 .store(false, std::sync::atomic::Ordering::Relaxed);
853
854 Json(StreamResponse {
855 success: true,
856 message: "Stream started".to_string(),
857 })
858}
859
860async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
862 state
863 .server_state
864 .stream_stopped
865 .store(true, std::sync::atomic::Ordering::Relaxed);
866
867 Json(StreamResponse {
868 success: true,
869 message: "Stream stopped".to_string(),
870 })
871}
872
873async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
875 state
876 .server_state
877 .stream_paused
878 .store(true, std::sync::atomic::Ordering::Relaxed);
879
880 Json(StreamResponse {
881 success: true,
882 message: "Stream paused".to_string(),
883 })
884}
885
886async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
888 state
889 .server_state
890 .stream_paused
891 .store(false, std::sync::atomic::Ordering::Relaxed);
892
893 Json(StreamResponse {
894 success: true,
895 message: "Stream resumed".to_string(),
896 })
897}
898
899async fn trigger_pattern(
904 State(state): State<AppState>,
905 axum::extract::Path(pattern): axum::extract::Path<String>,
906) -> Json<StreamResponse> {
907 info!("Pattern trigger requested: {}", pattern);
908
909 let valid_patterns = [
911 "year_end_spike",
912 "period_end_spike",
913 "holiday_cluster",
914 "fraud_cluster",
915 "error_cluster",
916 "uniform",
917 ];
918
919 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
920
921 if !is_valid {
922 return Json(StreamResponse {
923 success: false,
924 message: format!(
925 "Unknown pattern '{}'. Valid patterns: {:?}, or use 'custom:name' for custom patterns",
926 pattern, valid_patterns
927 ),
928 });
929 }
930
931 match state.server_state.triggered_pattern.try_write() {
933 Ok(mut triggered) => {
934 *triggered = Some(pattern.clone());
935 Json(StreamResponse {
936 success: true,
937 message: format!("Pattern '{}' will be applied to upcoming entries", pattern),
938 })
939 }
940 Err(_) => Json(StreamResponse {
941 success: false,
942 message: "Failed to acquire lock for pattern trigger".to_string(),
943 }),
944 }
945}
946
947async fn websocket_metrics(
949 ws: WebSocketUpgrade,
950 State(state): State<AppState>,
951) -> impl IntoResponse {
952 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
953}
954
955async fn websocket_events(
957 ws: WebSocketUpgrade,
958 State(state): State<AppState>,
959) -> impl IntoResponse {
960 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
961}
962
963async fn submit_job(
969 State(state): State<AppState>,
970 Json(request): Json<JobRequest>,
971) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
972 let queue = state.job_queue.as_ref().ok_or_else(|| {
973 (
974 StatusCode::SERVICE_UNAVAILABLE,
975 Json(serde_json::json!({"error": "Job queue not enabled"})),
976 )
977 })?;
978
979 let job_id = queue.submit(request).await;
980 info!("Job submitted: {}", job_id);
981
982 Ok((
983 StatusCode::CREATED,
984 Json(serde_json::json!({
985 "id": job_id.to_string(),
986 "status": "queued"
987 })),
988 ))
989}
990
991async fn get_job(
993 State(state): State<AppState>,
994 axum::extract::Path(id): axum::extract::Path<String>,
995) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
996 let queue = state.job_queue.as_ref().ok_or_else(|| {
997 (
998 StatusCode::SERVICE_UNAVAILABLE,
999 Json(serde_json::json!({"error": "Job queue not enabled"})),
1000 )
1001 })?;
1002
1003 match queue.get(&id).await {
1004 Some(entry) => Ok(Json(serde_json::json!({
1005 "id": entry.id,
1006 "status": format!("{:?}", entry.status).to_lowercase(),
1007 "submitted_at": entry.submitted_at.to_rfc3339(),
1008 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1009 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1010 "result": entry.result,
1011 }))),
1012 None => Err((
1013 StatusCode::NOT_FOUND,
1014 Json(serde_json::json!({"error": "Job not found"})),
1015 )),
1016 }
1017}
1018
1019async fn list_jobs(
1021 State(state): State<AppState>,
1022) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1023 let queue = state.job_queue.as_ref().ok_or_else(|| {
1024 (
1025 StatusCode::SERVICE_UNAVAILABLE,
1026 Json(serde_json::json!({"error": "Job queue not enabled"})),
1027 )
1028 })?;
1029
1030 let summaries: Vec<_> = queue
1031 .list()
1032 .await
1033 .into_iter()
1034 .map(|s| {
1035 serde_json::json!({
1036 "id": s.id,
1037 "status": format!("{:?}", s.status).to_lowercase(),
1038 "submitted_at": s.submitted_at.to_rfc3339(),
1039 })
1040 })
1041 .collect();
1042
1043 Ok(Json(serde_json::json!({ "jobs": summaries })))
1044}
1045
1046async fn cancel_job(
1048 State(state): State<AppState>,
1049 axum::extract::Path(id): axum::extract::Path<String>,
1050) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1051 let queue = state.job_queue.as_ref().ok_or_else(|| {
1052 (
1053 StatusCode::SERVICE_UNAVAILABLE,
1054 Json(serde_json::json!({"error": "Job queue not enabled"})),
1055 )
1056 })?;
1057
1058 if queue.cancel(&id).await {
1059 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1060 } else {
1061 Err((
1062 StatusCode::CONFLICT,
1063 Json(
1064 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1065 ),
1066 ))
1067 }
1068}
1069
1070async fn reload_config(
1076 State(state): State<AppState>,
1077) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1078 let source = state.server_state.config_source.read().await.clone();
1079 match crate::config_loader::load_config(&source).await {
1080 Ok(new_config) => {
1081 let mut config = state.server_state.config.write().await;
1082 *config = new_config;
1083 info!("Configuration reloaded via REST API from {:?}", source);
1084 Ok(Json(serde_json::json!({
1085 "success": true,
1086 "message": "Configuration reloaded"
1087 })))
1088 }
1089 Err(e) => {
1090 error!("Failed to reload configuration: {}", e);
1091 Err((
1092 StatusCode::INTERNAL_SERVER_ERROR,
1093 Json(serde_json::json!({
1094 "success": false,
1095 "message": format!("Failed to reload configuration: {}", e)
1096 })),
1097 ))
1098 }
1099 }
1100}
1101
1102#[cfg(test)]
1103#[allow(clippy::unwrap_used)]
1104mod tests {
1105 use super::*;
1106
1107 #[test]
1112 fn test_health_response_serialization() {
1113 let response = HealthResponse {
1114 healthy: true,
1115 version: "0.1.0".to_string(),
1116 uptime_seconds: 100,
1117 };
1118 let json = serde_json::to_string(&response).unwrap();
1119 assert!(json.contains("healthy"));
1120 assert!(json.contains("version"));
1121 assert!(json.contains("uptime_seconds"));
1122 }
1123
1124 #[test]
1125 fn test_health_response_deserialization() {
1126 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1127 let response: HealthResponse = serde_json::from_str(json).unwrap();
1128 assert!(response.healthy);
1129 assert_eq!(response.version, "0.1.0");
1130 assert_eq!(response.uptime_seconds, 100);
1131 }
1132
1133 #[test]
1134 fn test_metrics_response_serialization() {
1135 let response = MetricsResponse {
1136 total_entries_generated: 1000,
1137 total_anomalies_injected: 10,
1138 uptime_seconds: 60,
1139 session_entries: 1000,
1140 session_entries_per_second: 16.67,
1141 active_streams: 1,
1142 total_stream_events: 500,
1143 };
1144 let json = serde_json::to_string(&response).unwrap();
1145 assert!(json.contains("total_entries_generated"));
1146 assert!(json.contains("session_entries_per_second"));
1147 }
1148
1149 #[test]
1150 fn test_metrics_response_deserialization() {
1151 let json = r#"{
1152 "total_entries_generated": 5000,
1153 "total_anomalies_injected": 50,
1154 "uptime_seconds": 300,
1155 "session_entries": 5000,
1156 "session_entries_per_second": 16.67,
1157 "active_streams": 2,
1158 "total_stream_events": 10000
1159 }"#;
1160 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1161 assert_eq!(response.total_entries_generated, 5000);
1162 assert_eq!(response.active_streams, 2);
1163 }
1164
1165 #[test]
1166 fn test_config_response_serialization() {
1167 let response = ConfigResponse {
1168 success: true,
1169 message: "Configuration loaded".to_string(),
1170 config: Some(GenerationConfigDto {
1171 industry: "manufacturing".to_string(),
1172 start_date: "2024-01-01".to_string(),
1173 period_months: 12,
1174 seed: Some(42),
1175 coa_complexity: "medium".to_string(),
1176 companies: vec![],
1177 fraud_enabled: false,
1178 fraud_rate: 0.0,
1179 }),
1180 };
1181 let json = serde_json::to_string(&response).unwrap();
1182 assert!(json.contains("success"));
1183 assert!(json.contains("config"));
1184 }
1185
1186 #[test]
1187 fn test_config_response_without_config() {
1188 let response = ConfigResponse {
1189 success: false,
1190 message: "No configuration available".to_string(),
1191 config: None,
1192 };
1193 let json = serde_json::to_string(&response).unwrap();
1194 assert!(json.contains("null") || json.contains("config\":null"));
1195 }
1196
1197 #[test]
1198 fn test_generation_config_dto_roundtrip() {
1199 let original = GenerationConfigDto {
1200 industry: "retail".to_string(),
1201 start_date: "2024-06-01".to_string(),
1202 period_months: 6,
1203 seed: Some(12345),
1204 coa_complexity: "large".to_string(),
1205 companies: vec![CompanyConfigDto {
1206 code: "1000".to_string(),
1207 name: "Test Corp".to_string(),
1208 currency: "USD".to_string(),
1209 country: "US".to_string(),
1210 annual_transaction_volume: 100000,
1211 volume_weight: 1.0,
1212 }],
1213 fraud_enabled: true,
1214 fraud_rate: 0.05,
1215 };
1216
1217 let json = serde_json::to_string(&original).unwrap();
1218 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1219
1220 assert_eq!(original.industry, deserialized.industry);
1221 assert_eq!(original.seed, deserialized.seed);
1222 assert_eq!(original.companies.len(), deserialized.companies.len());
1223 }
1224
1225 #[test]
1226 fn test_company_config_dto_serialization() {
1227 let company = CompanyConfigDto {
1228 code: "2000".to_string(),
1229 name: "European Subsidiary".to_string(),
1230 currency: "EUR".to_string(),
1231 country: "DE".to_string(),
1232 annual_transaction_volume: 50000,
1233 volume_weight: 0.5,
1234 };
1235 let json = serde_json::to_string(&company).unwrap();
1236 assert!(json.contains("2000"));
1237 assert!(json.contains("EUR"));
1238 assert!(json.contains("DE"));
1239 }
1240
1241 #[test]
1242 fn test_bulk_generate_request_deserialization() {
1243 let json = r#"{
1244 "entry_count": 5000,
1245 "include_master_data": true,
1246 "inject_anomalies": true
1247 }"#;
1248 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1249 assert_eq!(request.entry_count, Some(5000));
1250 assert_eq!(request.include_master_data, Some(true));
1251 assert_eq!(request.inject_anomalies, Some(true));
1252 }
1253
1254 #[test]
1255 fn test_bulk_generate_request_with_defaults() {
1256 let json = r#"{}"#;
1257 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1258 assert_eq!(request.entry_count, None);
1259 assert_eq!(request.include_master_data, None);
1260 assert_eq!(request.inject_anomalies, None);
1261 }
1262
1263 #[test]
1264 fn test_bulk_generate_response_serialization() {
1265 let response = BulkGenerateResponse {
1266 success: true,
1267 entries_generated: 1000,
1268 duration_ms: 250,
1269 anomaly_count: 20,
1270 };
1271 let json = serde_json::to_string(&response).unwrap();
1272 assert!(json.contains("entries_generated"));
1273 assert!(json.contains("1000"));
1274 assert!(json.contains("duration_ms"));
1275 }
1276
1277 #[test]
1278 fn test_stream_response_serialization() {
1279 let response = StreamResponse {
1280 success: true,
1281 message: "Stream started successfully".to_string(),
1282 };
1283 let json = serde_json::to_string(&response).unwrap();
1284 assert!(json.contains("success"));
1285 assert!(json.contains("Stream started"));
1286 }
1287
1288 #[test]
1289 fn test_stream_response_failure() {
1290 let response = StreamResponse {
1291 success: false,
1292 message: "Stream failed to start".to_string(),
1293 };
1294 let json = serde_json::to_string(&response).unwrap();
1295 assert!(json.contains("false"));
1296 assert!(json.contains("failed"));
1297 }
1298
1299 #[test]
1304 fn test_cors_config_default() {
1305 let config = CorsConfig::default();
1306 assert!(!config.allow_any_origin);
1307 assert!(!config.allowed_origins.is_empty());
1308 assert!(config
1309 .allowed_origins
1310 .contains(&"http://localhost:5173".to_string()));
1311 assert!(config
1312 .allowed_origins
1313 .contains(&"tauri://localhost".to_string()));
1314 }
1315
1316 #[test]
1317 fn test_cors_config_custom_origins() {
1318 let config = CorsConfig {
1319 allowed_origins: vec![
1320 "https://example.com".to_string(),
1321 "https://app.example.com".to_string(),
1322 ],
1323 allow_any_origin: false,
1324 };
1325 assert_eq!(config.allowed_origins.len(), 2);
1326 assert!(config
1327 .allowed_origins
1328 .contains(&"https://example.com".to_string()));
1329 }
1330
1331 #[test]
1332 fn test_cors_config_permissive() {
1333 let config = CorsConfig {
1334 allowed_origins: vec![],
1335 allow_any_origin: true,
1336 };
1337 assert!(config.allow_any_origin);
1338 }
1339
1340 #[test]
1345 fn test_bulk_generate_request_partial() {
1346 let json = r#"{"entry_count": 100}"#;
1347 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1348 assert_eq!(request.entry_count, Some(100));
1349 assert!(request.include_master_data.is_none());
1350 }
1351
1352 #[test]
1353 fn test_generation_config_no_seed() {
1354 let config = GenerationConfigDto {
1355 industry: "technology".to_string(),
1356 start_date: "2024-01-01".to_string(),
1357 period_months: 3,
1358 seed: None,
1359 coa_complexity: "small".to_string(),
1360 companies: vec![],
1361 fraud_enabled: false,
1362 fraud_rate: 0.0,
1363 };
1364 let json = serde_json::to_string(&config).unwrap();
1365 assert!(json.contains("seed"));
1366 }
1367
1368 #[test]
1369 fn test_generation_config_multiple_companies() {
1370 let config = GenerationConfigDto {
1371 industry: "manufacturing".to_string(),
1372 start_date: "2024-01-01".to_string(),
1373 period_months: 12,
1374 seed: Some(42),
1375 coa_complexity: "large".to_string(),
1376 companies: vec![
1377 CompanyConfigDto {
1378 code: "1000".to_string(),
1379 name: "Headquarters".to_string(),
1380 currency: "USD".to_string(),
1381 country: "US".to_string(),
1382 annual_transaction_volume: 100000,
1383 volume_weight: 1.0,
1384 },
1385 CompanyConfigDto {
1386 code: "2000".to_string(),
1387 name: "European Sub".to_string(),
1388 currency: "EUR".to_string(),
1389 country: "DE".to_string(),
1390 annual_transaction_volume: 50000,
1391 volume_weight: 0.5,
1392 },
1393 CompanyConfigDto {
1394 code: "3000".to_string(),
1395 name: "APAC Sub".to_string(),
1396 currency: "JPY".to_string(),
1397 country: "JP".to_string(),
1398 annual_transaction_volume: 30000,
1399 volume_weight: 0.3,
1400 },
1401 ],
1402 fraud_enabled: true,
1403 fraud_rate: 0.02,
1404 };
1405 assert_eq!(config.companies.len(), 3);
1406 }
1407
1408 #[test]
1413 fn test_metrics_entries_per_second_calculation() {
1414 let total_entries: u64 = 1000;
1416 let uptime: u64 = 60;
1417 let eps = if uptime > 0 {
1418 total_entries as f64 / uptime as f64
1419 } else {
1420 0.0
1421 };
1422 assert!((eps - 16.67).abs() < 0.1);
1423 }
1424
1425 #[test]
1426 fn test_metrics_entries_per_second_zero_uptime() {
1427 let total_entries: u64 = 1000;
1428 let uptime: u64 = 0;
1429 let eps = if uptime > 0 {
1430 total_entries as f64 / uptime as f64
1431 } else {
1432 0.0
1433 };
1434 assert_eq!(eps, 0.0);
1435 }
1436}