1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::info;
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 pub server_state: Arc<ServerState>,
28 pub job_queue: Option<Arc<JobQueue>>,
29}
30
31#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34 pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39 fn default() -> Self {
40 Self {
41 request_timeout_secs: 300,
43 }
44 }
45}
46
47impl TimeoutConfig {
48 pub fn new(timeout_secs: u64) -> Self {
50 Self {
51 request_timeout_secs: timeout_secs,
52 }
53 }
54}
55
56#[derive(Clone)]
58pub struct CorsConfig {
59 pub allowed_origins: Vec<String>,
61 pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66 fn default() -> Self {
67 Self {
68 allowed_origins: vec![
69 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
72 "http://127.0.0.1:3000".to_string(),
73 "tauri://localhost".to_string(), ],
75 allow_any_origin: false,
76 }
77 }
78}
79
80async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82 let (mut parts, body) = response.into_parts();
83 parts.headers.insert(
84 axum::http::HeaderName::from_static("x-api-version"),
85 axum::http::HeaderValue::from_static("v1"),
86 );
87 axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97pub fn create_router(service: SynthService) -> Router {
99 create_router_with_cors(service, CorsConfig::default())
100}
101
102pub fn create_router_full(
107 service: SynthService,
108 cors_config: CorsConfig,
109 auth_config: AuthConfig,
110 rate_limit_config: RateLimitConfig,
111 timeout_config: TimeoutConfig,
112) -> Router {
113 let backend = RateLimitBackend::in_memory(rate_limit_config);
114 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117pub fn create_router_full_with_backend(
133 service: SynthService,
134 cors_config: CorsConfig,
135 auth_config: AuthConfig,
136 rate_limit_backend: RateLimitBackend,
137 timeout_config: TimeoutConfig,
138) -> Router {
139 let server_state = service.state.clone();
140 let state = AppState {
141 server_state,
142 job_queue: None,
143 };
144
145 let cors = if cors_config.allow_any_origin {
146 CorsLayer::permissive()
147 } else {
148 let origins: Vec<_> = cors_config
149 .allowed_origins
150 .iter()
151 .filter_map(|o| o.parse().ok())
152 .collect();
153
154 CorsLayer::new()
155 .allow_origin(AllowOrigin::list(origins))
156 .allow_methods([
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::DELETE,
161 Method::OPTIONS,
162 ])
163 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164 };
165
166 Router::new()
167 .route("/health", get(health_check))
169 .route("/ready", get(readiness_check))
170 .route("/live", get(liveness_check))
171 .route("/api/metrics", get(get_metrics))
172 .route("/metrics", get(prometheus_metrics))
173 .route("/api/config", get(get_config))
175 .route("/api/config", post(set_config))
176 .route("/api/config/reload", post(reload_config))
177 .route("/api/generate/bulk", post(bulk_generate))
179 .route("/api/stream/start", post(start_stream))
180 .route("/api/stream/stop", post(stop_stream))
181 .route("/api/stream/pause", post(pause_stream))
182 .route("/api/stream/resume", post(resume_stream))
183 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184 .route("/api/jobs/submit", post(submit_job))
186 .route("/api/jobs", get(list_jobs))
187 .route("/api/jobs/{id}", get(get_job))
188 .route("/api/jobs/{id}/cancel", post(cancel_job))
189 .route("/ws/metrics", get(websocket_metrics))
191 .route("/ws/events", get(websocket_events))
192 .layer(axum::middleware::from_fn(security_headers_middleware))
195 .layer(axum::middleware::map_response(api_version_header))
196 .layer(cors)
197 .layer(axum::middleware::from_fn(request_id_middleware))
198 .layer(axum::middleware::from_fn(auth_middleware))
199 .layer(axum::Extension(auth_config))
200 .layer(axum::middleware::from_fn(request_validation_middleware))
201 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202 .layer(axum::Extension(rate_limit_backend))
203 .layer(TimeoutLayer::new(Duration::from_secs(
204 timeout_config.request_timeout_secs,
205 )))
206 .with_state(state)
207}
208
209pub fn create_router_with_auth(
211 service: SynthService,
212 cors_config: CorsConfig,
213 auth_config: AuthConfig,
214) -> Router {
215 let server_state = service.state.clone();
216 let state = AppState {
217 server_state,
218 job_queue: None,
219 };
220
221 let cors = if cors_config.allow_any_origin {
222 CorsLayer::permissive()
223 } else {
224 let origins: Vec<_> = cors_config
225 .allowed_origins
226 .iter()
227 .filter_map(|o| o.parse().ok())
228 .collect();
229
230 CorsLayer::new()
231 .allow_origin(AllowOrigin::list(origins))
232 .allow_methods([
233 Method::GET,
234 Method::POST,
235 Method::PUT,
236 Method::DELETE,
237 Method::OPTIONS,
238 ])
239 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
240 };
241
242 Router::new()
243 .route("/health", get(health_check))
245 .route("/ready", get(readiness_check))
246 .route("/live", get(liveness_check))
247 .route("/api/metrics", get(get_metrics))
248 .route("/metrics", get(prometheus_metrics))
249 .route("/api/config", get(get_config))
251 .route("/api/config", post(set_config))
252 .route("/api/config/reload", post(reload_config))
253 .route("/api/generate/bulk", post(bulk_generate))
255 .route("/api/stream/start", post(start_stream))
256 .route("/api/stream/stop", post(stop_stream))
257 .route("/api/stream/pause", post(pause_stream))
258 .route("/api/stream/resume", post(resume_stream))
259 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
260 .route("/api/jobs/submit", post(submit_job))
262 .route("/api/jobs", get(list_jobs))
263 .route("/api/jobs/{id}", get(get_job))
264 .route("/api/jobs/{id}/cancel", post(cancel_job))
265 .route("/ws/metrics", get(websocket_metrics))
267 .route("/ws/events", get(websocket_events))
268 .layer(axum::middleware::from_fn(auth_middleware))
269 .layer(axum::Extension(auth_config))
270 .layer(cors)
271 .with_state(state)
272}
273
274pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
276 let server_state = service.state.clone();
277 let state = AppState {
278 server_state,
279 job_queue: None,
280 };
281
282 let cors = if cors_config.allow_any_origin {
283 CorsLayer::permissive()
285 } else {
286 let origins: Vec<_> = cors_config
288 .allowed_origins
289 .iter()
290 .filter_map(|o| o.parse().ok())
291 .collect();
292
293 CorsLayer::new()
294 .allow_origin(AllowOrigin::list(origins))
295 .allow_methods([
296 Method::GET,
297 Method::POST,
298 Method::PUT,
299 Method::DELETE,
300 Method::OPTIONS,
301 ])
302 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
303 };
304
305 Router::new()
306 .route("/health", get(health_check))
308 .route("/ready", get(readiness_check))
309 .route("/live", get(liveness_check))
310 .route("/api/metrics", get(get_metrics))
311 .route("/metrics", get(prometheus_metrics))
312 .route("/api/config", get(get_config))
314 .route("/api/config", post(set_config))
315 .route("/api/config/reload", post(reload_config))
316 .route("/api/generate/bulk", post(bulk_generate))
318 .route("/api/stream/start", post(start_stream))
319 .route("/api/stream/stop", post(stop_stream))
320 .route("/api/stream/pause", post(pause_stream))
321 .route("/api/stream/resume", post(resume_stream))
322 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
323 .route("/api/jobs/submit", post(submit_job))
325 .route("/api/jobs", get(list_jobs))
326 .route("/api/jobs/{id}", get(get_job))
327 .route("/api/jobs/{id}/cancel", post(cancel_job))
328 .route("/ws/metrics", get(websocket_metrics))
330 .route("/ws/events", get(websocket_events))
331 .layer(cors)
332 .with_state(state)
333}
334
335#[derive(Debug, Serialize, Deserialize)]
340pub struct HealthResponse {
341 pub healthy: bool,
342 pub version: String,
343 pub uptime_seconds: u64,
344}
345
346#[derive(Debug, Serialize, Deserialize)]
348pub struct ReadinessResponse {
349 pub ready: bool,
350 pub message: String,
351 pub checks: Vec<HealthCheck>,
352}
353
354#[derive(Debug, Serialize, Deserialize)]
356pub struct HealthCheck {
357 pub name: String,
358 pub status: String,
359}
360
361#[derive(Debug, Serialize, Deserialize)]
363pub struct LivenessResponse {
364 pub alive: bool,
365 pub timestamp: String,
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct MetricsResponse {
370 pub total_entries_generated: u64,
371 pub total_anomalies_injected: u64,
372 pub uptime_seconds: u64,
373 pub session_entries: u64,
374 pub session_entries_per_second: f64,
375 pub active_streams: u32,
376 pub total_stream_events: u64,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ConfigResponse {
381 pub success: bool,
382 pub message: String,
383 pub config: Option<GenerationConfigDto>,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct GenerationConfigDto {
388 pub industry: String,
389 pub start_date: String,
390 pub period_months: u32,
391 pub seed: Option<u64>,
392 pub coa_complexity: String,
393 pub companies: Vec<CompanyConfigDto>,
394 pub fraud_enabled: bool,
395 pub fraud_rate: f32,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct CompanyConfigDto {
400 pub code: String,
401 pub name: String,
402 pub currency: String,
403 pub country: String,
404 pub annual_transaction_volume: u64,
405 pub volume_weight: f32,
406}
407
408#[derive(Debug, Deserialize)]
409pub struct BulkGenerateRequest {
410 pub entry_count: Option<u64>,
411 pub include_master_data: Option<bool>,
412 pub inject_anomalies: Option<bool>,
413}
414
415#[derive(Debug, Serialize)]
416pub struct BulkGenerateResponse {
417 pub success: bool,
418 pub entries_generated: u64,
419 pub duration_ms: u64,
420 pub anomaly_count: u64,
421}
422
423#[derive(Debug, Deserialize)]
424#[allow(dead_code)] pub struct StreamRequest {
426 pub events_per_second: Option<u32>,
427 pub max_events: Option<u64>,
428 pub inject_anomalies: Option<bool>,
429}
430
431#[derive(Debug, Serialize)]
432pub struct StreamResponse {
433 pub success: bool,
434 pub message: String,
435}
436
437async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
443 Json(HealthResponse {
444 healthy: true,
445 version: env!("CARGO_PKG_VERSION").to_string(),
446 uptime_seconds: state.server_state.uptime_seconds(),
447 })
448}
449
450async fn readiness_check(
453 State(state): State<AppState>,
454) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
455 let mut checks = Vec::new();
456 let mut any_fail = false;
457
458 let config = state.server_state.config.read().await;
460 let config_valid = !config.companies.is_empty();
461 checks.push(HealthCheck {
462 name: "config".to_string(),
463 status: if config_valid { "ok" } else { "fail" }.to_string(),
464 });
465 if !config_valid {
466 any_fail = true;
467 }
468 drop(config);
469
470 let resource_status = state.server_state.resource_status();
472 let memory_status = if resource_status.degradation_level == "Emergency" {
473 any_fail = true;
474 "fail"
475 } else if resource_status.degradation_level != "Normal" {
476 "degraded"
477 } else {
478 "ok"
479 };
480 checks.push(HealthCheck {
481 name: "memory".to_string(),
482 status: memory_status.to_string(),
483 });
484
485 let disk_ok = resource_status.disk_available_mb > 100;
487 checks.push(HealthCheck {
488 name: "disk".to_string(),
489 status: if disk_ok { "ok" } else { "fail" }.to_string(),
490 });
491 if !disk_ok {
492 any_fail = true;
493 }
494
495 let response = ReadinessResponse {
496 ready: !any_fail,
497 message: if any_fail {
498 "Service is not ready".to_string()
499 } else {
500 "Service is ready".to_string()
501 },
502 checks,
503 };
504
505 if any_fail {
506 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
507 } else {
508 Ok(Json(response))
509 }
510}
511
512async fn liveness_check() -> Json<LivenessResponse> {
515 Json(LivenessResponse {
516 alive: true,
517 timestamp: chrono::Utc::now().to_rfc3339(),
518 })
519}
520
521async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
524 use std::sync::atomic::Ordering;
525
526 let uptime = state.server_state.uptime_seconds();
527 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
528 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
529 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
530 let total_stream_events = state
531 .server_state
532 .total_stream_events
533 .load(Ordering::Relaxed);
534
535 let entries_per_second = if uptime > 0 {
536 total_entries as f64 / uptime as f64
537 } else {
538 0.0
539 };
540
541 let metrics = format!(
542 r#"# HELP synth_entries_generated_total Total number of journal entries generated
543# TYPE synth_entries_generated_total counter
544synth_entries_generated_total {}
545
546# HELP synth_anomalies_injected_total Total number of anomalies injected
547# TYPE synth_anomalies_injected_total counter
548synth_anomalies_injected_total {}
549
550# HELP synth_uptime_seconds Server uptime in seconds
551# TYPE synth_uptime_seconds gauge
552synth_uptime_seconds {}
553
554# HELP synth_entries_per_second Rate of entry generation
555# TYPE synth_entries_per_second gauge
556synth_entries_per_second {:.2}
557
558# HELP synth_active_streams Number of active streaming connections
559# TYPE synth_active_streams gauge
560synth_active_streams {}
561
562# HELP synth_stream_events_total Total events sent through streams
563# TYPE synth_stream_events_total counter
564synth_stream_events_total {}
565
566# HELP synth_info Server version information
567# TYPE synth_info gauge
568synth_info{{version="{}"}} 1
569"#,
570 total_entries,
571 total_anomalies,
572 uptime,
573 entries_per_second,
574 active_streams,
575 total_stream_events,
576 env!("CARGO_PKG_VERSION")
577 );
578
579 (
580 StatusCode::OK,
581 [(
582 header::CONTENT_TYPE,
583 "text/plain; version=0.0.4; charset=utf-8",
584 )],
585 metrics,
586 )
587}
588
589async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
591 let uptime = state.server_state.uptime_seconds();
592 let total_entries = state
593 .server_state
594 .total_entries
595 .load(std::sync::atomic::Ordering::Relaxed);
596
597 let entries_per_second = if uptime > 0 {
598 total_entries as f64 / uptime as f64
599 } else {
600 0.0
601 };
602
603 Json(MetricsResponse {
604 total_entries_generated: total_entries,
605 total_anomalies_injected: state
606 .server_state
607 .total_anomalies
608 .load(std::sync::atomic::Ordering::Relaxed),
609 uptime_seconds: uptime,
610 session_entries: total_entries,
611 session_entries_per_second: entries_per_second,
612 active_streams: state
613 .server_state
614 .active_streams
615 .load(std::sync::atomic::Ordering::Relaxed) as u32,
616 total_stream_events: state
617 .server_state
618 .total_stream_events
619 .load(std::sync::atomic::Ordering::Relaxed),
620 })
621}
622
623async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
625 let config = state.server_state.config.read().await;
626
627 Json(ConfigResponse {
628 success: true,
629 message: "Current configuration".to_string(),
630 config: Some(GenerationConfigDto {
631 industry: format!("{:?}", config.global.industry),
632 start_date: config.global.start_date.clone(),
633 period_months: config.global.period_months,
634 seed: config.global.seed,
635 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
636 companies: config
637 .companies
638 .iter()
639 .map(|c| CompanyConfigDto {
640 code: c.code.clone(),
641 name: c.name.clone(),
642 currency: c.currency.clone(),
643 country: c.country.clone(),
644 annual_transaction_volume: c.annual_transaction_volume.count(),
645 volume_weight: c.volume_weight as f32,
646 })
647 .collect(),
648 fraud_enabled: config.fraud.enabled,
649 fraud_rate: config.fraud.fraud_rate as f32,
650 }),
651 })
652}
653
654async fn set_config(
656 State(state): State<AppState>,
657 Json(new_config): Json<GenerationConfigDto>,
658) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
659 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
660 use datasynth_core::models::{CoAComplexity, IndustrySector};
661
662 info!(
663 "Configuration update requested: industry={}, period_months={}",
664 new_config.industry, new_config.period_months
665 );
666
667 let industry = match new_config.industry.to_lowercase().as_str() {
669 "manufacturing" => IndustrySector::Manufacturing,
670 "retail" => IndustrySector::Retail,
671 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
672 "healthcare" => IndustrySector::Healthcare,
673 "technology" => IndustrySector::Technology,
674 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
675 "energy" => IndustrySector::Energy,
676 "transportation" => IndustrySector::Transportation,
677 "real_estate" | "realestate" => IndustrySector::RealEstate,
678 "telecommunications" => IndustrySector::Telecommunications,
679 _ => {
680 return Err((
681 StatusCode::BAD_REQUEST,
682 Json(ConfigResponse {
683 success: false,
684 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
685 config: None,
686 }),
687 ));
688 }
689 };
690
691 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
693 "small" => CoAComplexity::Small,
694 "medium" => CoAComplexity::Medium,
695 "large" => CoAComplexity::Large,
696 _ => {
697 return Err((
698 StatusCode::BAD_REQUEST,
699 Json(ConfigResponse {
700 success: false,
701 message: format!(
702 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
703 new_config.coa_complexity
704 ),
705 config: None,
706 }),
707 ));
708 }
709 };
710
711 let companies: Vec<CompanyConfig> = new_config
713 .companies
714 .iter()
715 .map(|c| CompanyConfig {
716 code: c.code.clone(),
717 name: c.name.clone(),
718 currency: c.currency.clone(),
719 country: c.country.clone(),
720 fiscal_year_variant: "K4".to_string(),
721 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
722 volume_weight: c.volume_weight as f64,
723 })
724 .collect();
725
726 let mut config = state.server_state.config.write().await;
728 config.global.industry = industry;
729 config.global.start_date = new_config.start_date.clone();
730 config.global.period_months = new_config.period_months;
731 config.global.seed = new_config.seed;
732 config.chart_of_accounts.complexity = complexity;
733 config.fraud.enabled = new_config.fraud_enabled;
734 config.fraud.fraud_rate = new_config.fraud_rate as f64;
735
736 if !companies.is_empty() {
738 config.companies = companies;
739 }
740
741 info!("Configuration updated successfully");
742
743 Ok(Json(ConfigResponse {
744 success: true,
745 message: "Configuration updated and applied".to_string(),
746 config: Some(new_config),
747 }))
748}
749
750async fn bulk_generate(
752 State(state): State<AppState>,
753 Json(req): Json<BulkGenerateRequest>,
754) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
755 const MAX_ENTRY_COUNT: u64 = 1_000_000;
757 if let Some(count) = req.entry_count {
758 if count > MAX_ENTRY_COUNT {
759 return Err((
760 StatusCode::BAD_REQUEST,
761 format!(
762 "entry_count ({}) exceeds maximum allowed value ({})",
763 count, MAX_ENTRY_COUNT
764 ),
765 ));
766 }
767 }
768
769 let config = state.server_state.config.read().await.clone();
770 let start_time = std::time::Instant::now();
771
772 let phase_config = PhaseConfig {
773 generate_master_data: req.include_master_data.unwrap_or(false),
774 generate_document_flows: false,
775 generate_journal_entries: true,
776 inject_anomalies: req.inject_anomalies.unwrap_or(false),
777 show_progress: false,
778 ..Default::default()
779 };
780
781 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
782 (
783 StatusCode::INTERNAL_SERVER_ERROR,
784 format!("Failed to create orchestrator: {}", e),
785 )
786 })?;
787
788 let result = orchestrator.generate().map_err(|e| {
789 (
790 StatusCode::INTERNAL_SERVER_ERROR,
791 format!("Generation failed: {}", e),
792 )
793 })?;
794
795 let duration_ms = start_time.elapsed().as_millis() as u64;
796 let entries_count = result.journal_entries.len() as u64;
797 let anomaly_count = result.anomaly_labels.labels.len() as u64;
798
799 state
801 .server_state
802 .total_entries
803 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
804 state
805 .server_state
806 .total_anomalies
807 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
808
809 Ok(Json(BulkGenerateResponse {
810 success: true,
811 entries_generated: entries_count,
812 duration_ms,
813 anomaly_count,
814 }))
815}
816
817async fn start_stream(
819 State(state): State<AppState>,
820 Json(_req): Json<StreamRequest>,
821) -> Json<StreamResponse> {
822 state
823 .server_state
824 .stream_stopped
825 .store(false, std::sync::atomic::Ordering::Relaxed);
826 state
827 .server_state
828 .stream_paused
829 .store(false, std::sync::atomic::Ordering::Relaxed);
830
831 Json(StreamResponse {
832 success: true,
833 message: "Stream started".to_string(),
834 })
835}
836
837async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
839 state
840 .server_state
841 .stream_stopped
842 .store(true, std::sync::atomic::Ordering::Relaxed);
843
844 Json(StreamResponse {
845 success: true,
846 message: "Stream stopped".to_string(),
847 })
848}
849
850async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
852 state
853 .server_state
854 .stream_paused
855 .store(true, std::sync::atomic::Ordering::Relaxed);
856
857 Json(StreamResponse {
858 success: true,
859 message: "Stream paused".to_string(),
860 })
861}
862
863async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
865 state
866 .server_state
867 .stream_paused
868 .store(false, std::sync::atomic::Ordering::Relaxed);
869
870 Json(StreamResponse {
871 success: true,
872 message: "Stream resumed".to_string(),
873 })
874}
875
876async fn trigger_pattern(
881 State(state): State<AppState>,
882 axum::extract::Path(pattern): axum::extract::Path<String>,
883) -> Json<StreamResponse> {
884 info!("Pattern trigger requested: {}", pattern);
885
886 let valid_patterns = [
888 "year_end_spike",
889 "period_end_spike",
890 "holiday_cluster",
891 "fraud_cluster",
892 "error_cluster",
893 "uniform",
894 ];
895
896 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
897
898 if !is_valid {
899 return Json(StreamResponse {
900 success: false,
901 message: format!(
902 "Unknown pattern '{}'. Valid patterns: {:?}, or use 'custom:name' for custom patterns",
903 pattern, valid_patterns
904 ),
905 });
906 }
907
908 match state.server_state.triggered_pattern.try_write() {
910 Ok(mut triggered) => {
911 *triggered = Some(pattern.clone());
912 Json(StreamResponse {
913 success: true,
914 message: format!("Pattern '{}' will be applied to upcoming entries", pattern),
915 })
916 }
917 Err(_) => Json(StreamResponse {
918 success: false,
919 message: "Failed to acquire lock for pattern trigger".to_string(),
920 }),
921 }
922}
923
924async fn websocket_metrics(
926 ws: WebSocketUpgrade,
927 State(state): State<AppState>,
928) -> impl IntoResponse {
929 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
930}
931
932async fn websocket_events(
934 ws: WebSocketUpgrade,
935 State(state): State<AppState>,
936) -> impl IntoResponse {
937 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
938}
939
940async fn submit_job(
946 State(state): State<AppState>,
947 Json(request): Json<JobRequest>,
948) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
949 let queue = state.job_queue.as_ref().ok_or_else(|| {
950 (
951 StatusCode::SERVICE_UNAVAILABLE,
952 Json(serde_json::json!({"error": "Job queue not enabled"})),
953 )
954 })?;
955
956 let job_id = queue.submit(request).await;
957 info!("Job submitted: {}", job_id);
958
959 Ok((
960 StatusCode::CREATED,
961 Json(serde_json::json!({
962 "id": job_id.to_string(),
963 "status": "queued"
964 })),
965 ))
966}
967
968async fn get_job(
970 State(state): State<AppState>,
971 axum::extract::Path(id): axum::extract::Path<String>,
972) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
973 let queue = state.job_queue.as_ref().ok_or_else(|| {
974 (
975 StatusCode::SERVICE_UNAVAILABLE,
976 Json(serde_json::json!({"error": "Job queue not enabled"})),
977 )
978 })?;
979
980 match queue.get(&id).await {
981 Some(entry) => Ok(Json(serde_json::json!({
982 "id": entry.id,
983 "status": format!("{:?}", entry.status).to_lowercase(),
984 "submitted_at": entry.submitted_at.to_rfc3339(),
985 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
986 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
987 "result": entry.result,
988 }))),
989 None => Err((
990 StatusCode::NOT_FOUND,
991 Json(serde_json::json!({"error": "Job not found"})),
992 )),
993 }
994}
995
996async fn list_jobs(
998 State(state): State<AppState>,
999) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1000 let queue = state.job_queue.as_ref().ok_or_else(|| {
1001 (
1002 StatusCode::SERVICE_UNAVAILABLE,
1003 Json(serde_json::json!({"error": "Job queue not enabled"})),
1004 )
1005 })?;
1006
1007 let summaries: Vec<_> = queue
1008 .list()
1009 .await
1010 .into_iter()
1011 .map(|s| {
1012 serde_json::json!({
1013 "id": s.id,
1014 "status": format!("{:?}", s.status).to_lowercase(),
1015 "submitted_at": s.submitted_at.to_rfc3339(),
1016 })
1017 })
1018 .collect();
1019
1020 Ok(Json(serde_json::json!({ "jobs": summaries })))
1021}
1022
1023async fn cancel_job(
1025 State(state): State<AppState>,
1026 axum::extract::Path(id): axum::extract::Path<String>,
1027) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1028 let queue = state.job_queue.as_ref().ok_or_else(|| {
1029 (
1030 StatusCode::SERVICE_UNAVAILABLE,
1031 Json(serde_json::json!({"error": "Job queue not enabled"})),
1032 )
1033 })?;
1034
1035 if queue.cancel(&id).await {
1036 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1037 } else {
1038 Err((
1039 StatusCode::CONFLICT,
1040 Json(
1041 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1042 ),
1043 ))
1044 }
1045}
1046
1047async fn reload_config(
1053 State(state): State<AppState>,
1054) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1055 let new_config = crate::grpc::service::default_generator_config();
1057 let mut config = state.server_state.config.write().await;
1058 *config = new_config;
1059 info!("Configuration reloaded via REST API");
1060
1061 Ok(Json(serde_json::json!({
1062 "success": true,
1063 "message": "Configuration reloaded"
1064 })))
1065}
1066
1067#[cfg(test)]
1068#[allow(clippy::unwrap_used)]
1069mod tests {
1070 use super::*;
1071
1072 #[test]
1077 fn test_health_response_serialization() {
1078 let response = HealthResponse {
1079 healthy: true,
1080 version: "0.1.0".to_string(),
1081 uptime_seconds: 100,
1082 };
1083 let json = serde_json::to_string(&response).unwrap();
1084 assert!(json.contains("healthy"));
1085 assert!(json.contains("version"));
1086 assert!(json.contains("uptime_seconds"));
1087 }
1088
1089 #[test]
1090 fn test_health_response_deserialization() {
1091 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1092 let response: HealthResponse = serde_json::from_str(json).unwrap();
1093 assert!(response.healthy);
1094 assert_eq!(response.version, "0.1.0");
1095 assert_eq!(response.uptime_seconds, 100);
1096 }
1097
1098 #[test]
1099 fn test_metrics_response_serialization() {
1100 let response = MetricsResponse {
1101 total_entries_generated: 1000,
1102 total_anomalies_injected: 10,
1103 uptime_seconds: 60,
1104 session_entries: 1000,
1105 session_entries_per_second: 16.67,
1106 active_streams: 1,
1107 total_stream_events: 500,
1108 };
1109 let json = serde_json::to_string(&response).unwrap();
1110 assert!(json.contains("total_entries_generated"));
1111 assert!(json.contains("session_entries_per_second"));
1112 }
1113
1114 #[test]
1115 fn test_metrics_response_deserialization() {
1116 let json = r#"{
1117 "total_entries_generated": 5000,
1118 "total_anomalies_injected": 50,
1119 "uptime_seconds": 300,
1120 "session_entries": 5000,
1121 "session_entries_per_second": 16.67,
1122 "active_streams": 2,
1123 "total_stream_events": 10000
1124 }"#;
1125 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1126 assert_eq!(response.total_entries_generated, 5000);
1127 assert_eq!(response.active_streams, 2);
1128 }
1129
1130 #[test]
1131 fn test_config_response_serialization() {
1132 let response = ConfigResponse {
1133 success: true,
1134 message: "Configuration loaded".to_string(),
1135 config: Some(GenerationConfigDto {
1136 industry: "manufacturing".to_string(),
1137 start_date: "2024-01-01".to_string(),
1138 period_months: 12,
1139 seed: Some(42),
1140 coa_complexity: "medium".to_string(),
1141 companies: vec![],
1142 fraud_enabled: false,
1143 fraud_rate: 0.0,
1144 }),
1145 };
1146 let json = serde_json::to_string(&response).unwrap();
1147 assert!(json.contains("success"));
1148 assert!(json.contains("config"));
1149 }
1150
1151 #[test]
1152 fn test_config_response_without_config() {
1153 let response = ConfigResponse {
1154 success: false,
1155 message: "No configuration available".to_string(),
1156 config: None,
1157 };
1158 let json = serde_json::to_string(&response).unwrap();
1159 assert!(json.contains("null") || json.contains("config\":null"));
1160 }
1161
1162 #[test]
1163 fn test_generation_config_dto_roundtrip() {
1164 let original = GenerationConfigDto {
1165 industry: "retail".to_string(),
1166 start_date: "2024-06-01".to_string(),
1167 period_months: 6,
1168 seed: Some(12345),
1169 coa_complexity: "large".to_string(),
1170 companies: vec![CompanyConfigDto {
1171 code: "1000".to_string(),
1172 name: "Test Corp".to_string(),
1173 currency: "USD".to_string(),
1174 country: "US".to_string(),
1175 annual_transaction_volume: 100000,
1176 volume_weight: 1.0,
1177 }],
1178 fraud_enabled: true,
1179 fraud_rate: 0.05,
1180 };
1181
1182 let json = serde_json::to_string(&original).unwrap();
1183 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1184
1185 assert_eq!(original.industry, deserialized.industry);
1186 assert_eq!(original.seed, deserialized.seed);
1187 assert_eq!(original.companies.len(), deserialized.companies.len());
1188 }
1189
1190 #[test]
1191 fn test_company_config_dto_serialization() {
1192 let company = CompanyConfigDto {
1193 code: "2000".to_string(),
1194 name: "European Subsidiary".to_string(),
1195 currency: "EUR".to_string(),
1196 country: "DE".to_string(),
1197 annual_transaction_volume: 50000,
1198 volume_weight: 0.5,
1199 };
1200 let json = serde_json::to_string(&company).unwrap();
1201 assert!(json.contains("2000"));
1202 assert!(json.contains("EUR"));
1203 assert!(json.contains("DE"));
1204 }
1205
1206 #[test]
1207 fn test_bulk_generate_request_deserialization() {
1208 let json = r#"{
1209 "entry_count": 5000,
1210 "include_master_data": true,
1211 "inject_anomalies": true
1212 }"#;
1213 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1214 assert_eq!(request.entry_count, Some(5000));
1215 assert_eq!(request.include_master_data, Some(true));
1216 assert_eq!(request.inject_anomalies, Some(true));
1217 }
1218
1219 #[test]
1220 fn test_bulk_generate_request_with_defaults() {
1221 let json = r#"{}"#;
1222 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1223 assert_eq!(request.entry_count, None);
1224 assert_eq!(request.include_master_data, None);
1225 assert_eq!(request.inject_anomalies, None);
1226 }
1227
1228 #[test]
1229 fn test_bulk_generate_response_serialization() {
1230 let response = BulkGenerateResponse {
1231 success: true,
1232 entries_generated: 1000,
1233 duration_ms: 250,
1234 anomaly_count: 20,
1235 };
1236 let json = serde_json::to_string(&response).unwrap();
1237 assert!(json.contains("entries_generated"));
1238 assert!(json.contains("1000"));
1239 assert!(json.contains("duration_ms"));
1240 }
1241
1242 #[test]
1243 fn test_stream_response_serialization() {
1244 let response = StreamResponse {
1245 success: true,
1246 message: "Stream started successfully".to_string(),
1247 };
1248 let json = serde_json::to_string(&response).unwrap();
1249 assert!(json.contains("success"));
1250 assert!(json.contains("Stream started"));
1251 }
1252
1253 #[test]
1254 fn test_stream_response_failure() {
1255 let response = StreamResponse {
1256 success: false,
1257 message: "Stream failed to start".to_string(),
1258 };
1259 let json = serde_json::to_string(&response).unwrap();
1260 assert!(json.contains("false"));
1261 assert!(json.contains("failed"));
1262 }
1263
1264 #[test]
1269 fn test_cors_config_default() {
1270 let config = CorsConfig::default();
1271 assert!(!config.allow_any_origin);
1272 assert!(!config.allowed_origins.is_empty());
1273 assert!(config
1274 .allowed_origins
1275 .contains(&"http://localhost:5173".to_string()));
1276 assert!(config
1277 .allowed_origins
1278 .contains(&"tauri://localhost".to_string()));
1279 }
1280
1281 #[test]
1282 fn test_cors_config_custom_origins() {
1283 let config = CorsConfig {
1284 allowed_origins: vec![
1285 "https://example.com".to_string(),
1286 "https://app.example.com".to_string(),
1287 ],
1288 allow_any_origin: false,
1289 };
1290 assert_eq!(config.allowed_origins.len(), 2);
1291 assert!(config
1292 .allowed_origins
1293 .contains(&"https://example.com".to_string()));
1294 }
1295
1296 #[test]
1297 fn test_cors_config_permissive() {
1298 let config = CorsConfig {
1299 allowed_origins: vec![],
1300 allow_any_origin: true,
1301 };
1302 assert!(config.allow_any_origin);
1303 }
1304
1305 #[test]
1310 fn test_bulk_generate_request_partial() {
1311 let json = r#"{"entry_count": 100}"#;
1312 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1313 assert_eq!(request.entry_count, Some(100));
1314 assert!(request.include_master_data.is_none());
1315 }
1316
1317 #[test]
1318 fn test_generation_config_no_seed() {
1319 let config = GenerationConfigDto {
1320 industry: "technology".to_string(),
1321 start_date: "2024-01-01".to_string(),
1322 period_months: 3,
1323 seed: None,
1324 coa_complexity: "small".to_string(),
1325 companies: vec![],
1326 fraud_enabled: false,
1327 fraud_rate: 0.0,
1328 };
1329 let json = serde_json::to_string(&config).unwrap();
1330 assert!(json.contains("seed"));
1331 }
1332
1333 #[test]
1334 fn test_generation_config_multiple_companies() {
1335 let config = GenerationConfigDto {
1336 industry: "manufacturing".to_string(),
1337 start_date: "2024-01-01".to_string(),
1338 period_months: 12,
1339 seed: Some(42),
1340 coa_complexity: "large".to_string(),
1341 companies: vec![
1342 CompanyConfigDto {
1343 code: "1000".to_string(),
1344 name: "Headquarters".to_string(),
1345 currency: "USD".to_string(),
1346 country: "US".to_string(),
1347 annual_transaction_volume: 100000,
1348 volume_weight: 1.0,
1349 },
1350 CompanyConfigDto {
1351 code: "2000".to_string(),
1352 name: "European Sub".to_string(),
1353 currency: "EUR".to_string(),
1354 country: "DE".to_string(),
1355 annual_transaction_volume: 50000,
1356 volume_weight: 0.5,
1357 },
1358 CompanyConfigDto {
1359 code: "3000".to_string(),
1360 name: "APAC Sub".to_string(),
1361 currency: "JPY".to_string(),
1362 country: "JP".to_string(),
1363 annual_transaction_volume: 30000,
1364 volume_weight: 0.3,
1365 },
1366 ],
1367 fraud_enabled: true,
1368 fraud_rate: 0.02,
1369 };
1370 assert_eq!(config.companies.len(), 3);
1371 }
1372
1373 #[test]
1378 fn test_metrics_entries_per_second_calculation() {
1379 let total_entries: u64 = 1000;
1381 let uptime: u64 = 60;
1382 let eps = if uptime > 0 {
1383 total_entries as f64 / uptime as f64
1384 } else {
1385 0.0
1386 };
1387 assert!((eps - 16.67).abs() < 0.1);
1388 }
1389
1390 #[test]
1391 fn test_metrics_entries_per_second_zero_uptime() {
1392 let total_entries: u64 = 1000;
1393 let uptime: u64 = 0;
1394 let eps = if uptime > 0 {
1395 total_entries as f64 / uptime as f64
1396 } else {
1397 0.0
1398 };
1399 assert_eq!(eps, 0.0);
1400 }
1401}