1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 pub server_state: Arc<ServerState>,
28 pub job_queue: Option<Arc<JobQueue>>,
29}
30
31#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34 pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39 fn default() -> Self {
40 Self {
41 request_timeout_secs: 300,
43 }
44 }
45}
46
47impl TimeoutConfig {
48 pub fn new(timeout_secs: u64) -> Self {
50 Self {
51 request_timeout_secs: timeout_secs,
52 }
53 }
54}
55
56#[derive(Clone)]
58pub struct CorsConfig {
59 pub allowed_origins: Vec<String>,
61 pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66 fn default() -> Self {
67 Self {
68 allowed_origins: vec![
69 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
72 "http://127.0.0.1:3000".to_string(),
73 "tauri://localhost".to_string(), ],
75 allow_any_origin: false,
76 }
77 }
78}
79
80async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82 let (mut parts, body) = response.into_parts();
83 parts.headers.insert(
84 axum::http::HeaderName::from_static("x-api-version"),
85 axum::http::HeaderValue::from_static("v1"),
86 );
87 axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97pub fn create_router(service: SynthService) -> Router {
99 create_router_with_cors(service, CorsConfig::default())
100}
101
102pub fn create_router_full(
107 service: SynthService,
108 cors_config: CorsConfig,
109 auth_config: AuthConfig,
110 rate_limit_config: RateLimitConfig,
111 timeout_config: TimeoutConfig,
112) -> Router {
113 let backend = RateLimitBackend::in_memory(rate_limit_config);
114 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117pub fn create_router_full_with_backend(
133 service: SynthService,
134 cors_config: CorsConfig,
135 auth_config: AuthConfig,
136 rate_limit_backend: RateLimitBackend,
137 timeout_config: TimeoutConfig,
138) -> Router {
139 let server_state = service.state.clone();
140 let state = AppState {
141 server_state,
142 job_queue: None,
143 };
144
145 let cors = if cors_config.allow_any_origin {
146 CorsLayer::permissive()
147 } else {
148 let origins: Vec<_> = cors_config
149 .allowed_origins
150 .iter()
151 .filter_map(|o| o.parse().ok())
152 .collect();
153
154 CorsLayer::new()
155 .allow_origin(AllowOrigin::list(origins))
156 .allow_methods([
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::DELETE,
161 Method::OPTIONS,
162 ])
163 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164 };
165
166 Router::new()
167 .route("/health", get(health_check))
169 .route("/ready", get(readiness_check))
170 .route("/live", get(liveness_check))
171 .route("/api/metrics", get(get_metrics))
172 .route("/metrics", get(prometheus_metrics))
173 .route("/api/config", get(get_config))
175 .route("/api/config", post(set_config))
176 .route("/api/config/reload", post(reload_config))
177 .route("/api/generate/bulk", post(bulk_generate))
179 .route("/api/stream/start", post(start_stream))
180 .route("/api/stream/stop", post(stop_stream))
181 .route("/api/stream/pause", post(pause_stream))
182 .route("/api/stream/resume", post(resume_stream))
183 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184 .route("/api/jobs/submit", post(submit_job))
186 .route("/api/jobs", get(list_jobs))
187 .route("/api/jobs/{id}", get(get_job))
188 .route("/api/jobs/{id}/cancel", post(cancel_job))
189 .route("/ws/metrics", get(websocket_metrics))
191 .route("/ws/events", get(websocket_events))
192 .layer(axum::middleware::from_fn(security_headers_middleware))
195 .layer(axum::middleware::map_response(api_version_header))
196 .layer(cors)
197 .layer(axum::middleware::from_fn(request_id_middleware))
198 .layer(axum::middleware::from_fn(auth_middleware))
199 .layer(axum::Extension(auth_config))
200 .layer(axum::middleware::from_fn(request_validation_middleware))
201 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202 .layer(axum::Extension(rate_limit_backend))
203 .layer(TimeoutLayer::with_status_code(
204 StatusCode::REQUEST_TIMEOUT,
205 Duration::from_secs(timeout_config.request_timeout_secs),
206 ))
207 .with_state(state)
208}
209
210pub fn create_router_with_auth(
212 service: SynthService,
213 cors_config: CorsConfig,
214 auth_config: AuthConfig,
215) -> Router {
216 let server_state = service.state.clone();
217 let state = AppState {
218 server_state,
219 job_queue: None,
220 };
221
222 let cors = if cors_config.allow_any_origin {
223 CorsLayer::permissive()
224 } else {
225 let origins: Vec<_> = cors_config
226 .allowed_origins
227 .iter()
228 .filter_map(|o| o.parse().ok())
229 .collect();
230
231 CorsLayer::new()
232 .allow_origin(AllowOrigin::list(origins))
233 .allow_methods([
234 Method::GET,
235 Method::POST,
236 Method::PUT,
237 Method::DELETE,
238 Method::OPTIONS,
239 ])
240 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
241 };
242
243 Router::new()
244 .route("/health", get(health_check))
246 .route("/ready", get(readiness_check))
247 .route("/live", get(liveness_check))
248 .route("/api/metrics", get(get_metrics))
249 .route("/metrics", get(prometheus_metrics))
250 .route("/api/config", get(get_config))
252 .route("/api/config", post(set_config))
253 .route("/api/config/reload", post(reload_config))
254 .route("/api/generate/bulk", post(bulk_generate))
256 .route("/api/stream/start", post(start_stream))
257 .route("/api/stream/stop", post(stop_stream))
258 .route("/api/stream/pause", post(pause_stream))
259 .route("/api/stream/resume", post(resume_stream))
260 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
261 .route("/api/jobs/submit", post(submit_job))
263 .route("/api/jobs", get(list_jobs))
264 .route("/api/jobs/{id}", get(get_job))
265 .route("/api/jobs/{id}/cancel", post(cancel_job))
266 .route("/ws/metrics", get(websocket_metrics))
268 .route("/ws/events", get(websocket_events))
269 .layer(axum::middleware::from_fn(auth_middleware))
270 .layer(axum::Extension(auth_config))
271 .layer(cors)
272 .with_state(state)
273}
274
275pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
277 let server_state = service.state.clone();
278 let state = AppState {
279 server_state,
280 job_queue: None,
281 };
282
283 let cors = if cors_config.allow_any_origin {
284 CorsLayer::permissive()
286 } else {
287 let origins: Vec<_> = cors_config
289 .allowed_origins
290 .iter()
291 .filter_map(|o| o.parse().ok())
292 .collect();
293
294 CorsLayer::new()
295 .allow_origin(AllowOrigin::list(origins))
296 .allow_methods([
297 Method::GET,
298 Method::POST,
299 Method::PUT,
300 Method::DELETE,
301 Method::OPTIONS,
302 ])
303 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
304 };
305
306 Router::new()
307 .route("/health", get(health_check))
309 .route("/ready", get(readiness_check))
310 .route("/live", get(liveness_check))
311 .route("/api/metrics", get(get_metrics))
312 .route("/metrics", get(prometheus_metrics))
313 .route("/api/config", get(get_config))
315 .route("/api/config", post(set_config))
316 .route("/api/config/reload", post(reload_config))
317 .route("/api/generate/bulk", post(bulk_generate))
319 .route("/api/stream/start", post(start_stream))
320 .route("/api/stream/stop", post(stop_stream))
321 .route("/api/stream/pause", post(pause_stream))
322 .route("/api/stream/resume", post(resume_stream))
323 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
324 .route("/api/jobs/submit", post(submit_job))
326 .route("/api/jobs", get(list_jobs))
327 .route("/api/jobs/{id}", get(get_job))
328 .route("/api/jobs/{id}/cancel", post(cancel_job))
329 .route("/ws/metrics", get(websocket_metrics))
331 .route("/ws/events", get(websocket_events))
332 .layer(cors)
333 .with_state(state)
334}
335
336#[derive(Debug, Serialize, Deserialize)]
341pub struct HealthResponse {
342 pub healthy: bool,
343 pub version: String,
344 pub uptime_seconds: u64,
345}
346
347#[derive(Debug, Serialize, Deserialize)]
349pub struct ReadinessResponse {
350 pub ready: bool,
351 pub message: String,
352 pub checks: Vec<HealthCheck>,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
357pub struct HealthCheck {
358 pub name: String,
359 pub status: String,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
364pub struct LivenessResponse {
365 pub alive: bool,
366 pub timestamp: String,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370pub struct MetricsResponse {
371 pub total_entries_generated: u64,
372 pub total_anomalies_injected: u64,
373 pub uptime_seconds: u64,
374 pub session_entries: u64,
375 pub session_entries_per_second: f64,
376 pub active_streams: u32,
377 pub total_stream_events: u64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ConfigResponse {
382 pub success: bool,
383 pub message: String,
384 pub config: Option<GenerationConfigDto>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct GenerationConfigDto {
389 pub industry: String,
390 pub start_date: String,
391 pub period_months: u32,
392 pub seed: Option<u64>,
393 pub coa_complexity: String,
394 pub companies: Vec<CompanyConfigDto>,
395 pub fraud_enabled: bool,
396 pub fraud_rate: f32,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct CompanyConfigDto {
401 pub code: String,
402 pub name: String,
403 pub currency: String,
404 pub country: String,
405 pub annual_transaction_volume: u64,
406 pub volume_weight: f32,
407}
408
409#[derive(Debug, Deserialize)]
410pub struct BulkGenerateRequest {
411 pub entry_count: Option<u64>,
412 pub include_master_data: Option<bool>,
413 pub inject_anomalies: Option<bool>,
414}
415
416#[derive(Debug, Serialize)]
417pub struct BulkGenerateResponse {
418 pub success: bool,
419 pub entries_generated: u64,
420 pub duration_ms: u64,
421 pub anomaly_count: u64,
422}
423
424#[derive(Debug, Deserialize)]
425#[allow(dead_code)] pub struct StreamRequest {
427 pub events_per_second: Option<u32>,
428 pub max_events: Option<u64>,
429 pub inject_anomalies: Option<bool>,
430}
431
432#[derive(Debug, Serialize)]
433pub struct StreamResponse {
434 pub success: bool,
435 pub message: String,
436}
437
438async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
444 Json(HealthResponse {
445 healthy: true,
446 version: env!("CARGO_PKG_VERSION").to_string(),
447 uptime_seconds: state.server_state.uptime_seconds(),
448 })
449}
450
451async fn readiness_check(
454 State(state): State<AppState>,
455) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
456 let mut checks = Vec::new();
457 let mut any_fail = false;
458
459 let config = state.server_state.config.read().await;
461 let config_valid = !config.companies.is_empty();
462 checks.push(HealthCheck {
463 name: "config".to_string(),
464 status: if config_valid { "ok" } else { "fail" }.to_string(),
465 });
466 if !config_valid {
467 any_fail = true;
468 }
469 drop(config);
470
471 let resource_status = state.server_state.resource_status();
473 let memory_status = if resource_status.degradation_level == "Emergency" {
474 any_fail = true;
475 "fail"
476 } else if resource_status.degradation_level != "Normal" {
477 "degraded"
478 } else {
479 "ok"
480 };
481 checks.push(HealthCheck {
482 name: "memory".to_string(),
483 status: memory_status.to_string(),
484 });
485
486 let disk_ok = resource_status.disk_available_mb > 100;
488 checks.push(HealthCheck {
489 name: "disk".to_string(),
490 status: if disk_ok { "ok" } else { "fail" }.to_string(),
491 });
492 if !disk_ok {
493 any_fail = true;
494 }
495
496 let response = ReadinessResponse {
497 ready: !any_fail,
498 message: if any_fail {
499 "Service is not ready".to_string()
500 } else {
501 "Service is ready".to_string()
502 },
503 checks,
504 };
505
506 if any_fail {
507 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
508 } else {
509 Ok(Json(response))
510 }
511}
512
513async fn liveness_check() -> Json<LivenessResponse> {
516 Json(LivenessResponse {
517 alive: true,
518 timestamp: chrono::Utc::now().to_rfc3339(),
519 })
520}
521
522async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
525 use std::sync::atomic::Ordering;
526
527 let uptime = state.server_state.uptime_seconds();
528 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
529 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
530 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
531 let total_stream_events = state
532 .server_state
533 .total_stream_events
534 .load(Ordering::Relaxed);
535
536 let entries_per_second = if uptime > 0 {
537 total_entries as f64 / uptime as f64
538 } else {
539 0.0
540 };
541
542 let metrics = format!(
543 r#"# HELP synth_entries_generated_total Total number of journal entries generated
544# TYPE synth_entries_generated_total counter
545synth_entries_generated_total {}
546
547# HELP synth_anomalies_injected_total Total number of anomalies injected
548# TYPE synth_anomalies_injected_total counter
549synth_anomalies_injected_total {}
550
551# HELP synth_uptime_seconds Server uptime in seconds
552# TYPE synth_uptime_seconds gauge
553synth_uptime_seconds {}
554
555# HELP synth_entries_per_second Rate of entry generation
556# TYPE synth_entries_per_second gauge
557synth_entries_per_second {:.2}
558
559# HELP synth_active_streams Number of active streaming connections
560# TYPE synth_active_streams gauge
561synth_active_streams {}
562
563# HELP synth_stream_events_total Total events sent through streams
564# TYPE synth_stream_events_total counter
565synth_stream_events_total {}
566
567# HELP synth_info Server version information
568# TYPE synth_info gauge
569synth_info{{version="{}"}} 1
570"#,
571 total_entries,
572 total_anomalies,
573 uptime,
574 entries_per_second,
575 active_streams,
576 total_stream_events,
577 env!("CARGO_PKG_VERSION")
578 );
579
580 (
581 StatusCode::OK,
582 [(
583 header::CONTENT_TYPE,
584 "text/plain; version=0.0.4; charset=utf-8",
585 )],
586 metrics,
587 )
588}
589
590async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
592 let uptime = state.server_state.uptime_seconds();
593 let total_entries = state
594 .server_state
595 .total_entries
596 .load(std::sync::atomic::Ordering::Relaxed);
597
598 let entries_per_second = if uptime > 0 {
599 total_entries as f64 / uptime as f64
600 } else {
601 0.0
602 };
603
604 Json(MetricsResponse {
605 total_entries_generated: total_entries,
606 total_anomalies_injected: state
607 .server_state
608 .total_anomalies
609 .load(std::sync::atomic::Ordering::Relaxed),
610 uptime_seconds: uptime,
611 session_entries: total_entries,
612 session_entries_per_second: entries_per_second,
613 active_streams: state
614 .server_state
615 .active_streams
616 .load(std::sync::atomic::Ordering::Relaxed) as u32,
617 total_stream_events: state
618 .server_state
619 .total_stream_events
620 .load(std::sync::atomic::Ordering::Relaxed),
621 })
622}
623
624async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
626 let config = state.server_state.config.read().await;
627
628 Json(ConfigResponse {
629 success: true,
630 message: "Current configuration".to_string(),
631 config: Some(GenerationConfigDto {
632 industry: format!("{:?}", config.global.industry),
633 start_date: config.global.start_date.clone(),
634 period_months: config.global.period_months,
635 seed: config.global.seed,
636 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
637 companies: config
638 .companies
639 .iter()
640 .map(|c| CompanyConfigDto {
641 code: c.code.clone(),
642 name: c.name.clone(),
643 currency: c.currency.clone(),
644 country: c.country.clone(),
645 annual_transaction_volume: c.annual_transaction_volume.count(),
646 volume_weight: c.volume_weight as f32,
647 })
648 .collect(),
649 fraud_enabled: config.fraud.enabled,
650 fraud_rate: config.fraud.fraud_rate as f32,
651 }),
652 })
653}
654
655async fn set_config(
657 State(state): State<AppState>,
658 Json(new_config): Json<GenerationConfigDto>,
659) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
660 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
661 use datasynth_core::models::{CoAComplexity, IndustrySector};
662
663 info!(
664 "Configuration update requested: industry={}, period_months={}",
665 new_config.industry, new_config.period_months
666 );
667
668 let industry = match new_config.industry.to_lowercase().as_str() {
670 "manufacturing" => IndustrySector::Manufacturing,
671 "retail" => IndustrySector::Retail,
672 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
673 "healthcare" => IndustrySector::Healthcare,
674 "technology" => IndustrySector::Technology,
675 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
676 "energy" => IndustrySector::Energy,
677 "transportation" => IndustrySector::Transportation,
678 "real_estate" | "realestate" => IndustrySector::RealEstate,
679 "telecommunications" => IndustrySector::Telecommunications,
680 _ => {
681 return Err((
682 StatusCode::BAD_REQUEST,
683 Json(ConfigResponse {
684 success: false,
685 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
686 config: None,
687 }),
688 ));
689 }
690 };
691
692 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
694 "small" => CoAComplexity::Small,
695 "medium" => CoAComplexity::Medium,
696 "large" => CoAComplexity::Large,
697 _ => {
698 return Err((
699 StatusCode::BAD_REQUEST,
700 Json(ConfigResponse {
701 success: false,
702 message: format!(
703 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
704 new_config.coa_complexity
705 ),
706 config: None,
707 }),
708 ));
709 }
710 };
711
712 let companies: Vec<CompanyConfig> = new_config
714 .companies
715 .iter()
716 .map(|c| CompanyConfig {
717 code: c.code.clone(),
718 name: c.name.clone(),
719 currency: c.currency.clone(),
720 country: c.country.clone(),
721 fiscal_year_variant: "K4".to_string(),
722 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
723 volume_weight: c.volume_weight as f64,
724 })
725 .collect();
726
727 let mut config = state.server_state.config.write().await;
729 config.global.industry = industry;
730 config.global.start_date = new_config.start_date.clone();
731 config.global.period_months = new_config.period_months;
732 config.global.seed = new_config.seed;
733 config.chart_of_accounts.complexity = complexity;
734 config.fraud.enabled = new_config.fraud_enabled;
735 config.fraud.fraud_rate = new_config.fraud_rate as f64;
736
737 if !companies.is_empty() {
739 config.companies = companies;
740 }
741
742 info!("Configuration updated successfully");
743
744 Ok(Json(ConfigResponse {
745 success: true,
746 message: "Configuration updated and applied".to_string(),
747 config: Some(new_config),
748 }))
749}
750
751async fn bulk_generate(
753 State(state): State<AppState>,
754 Json(req): Json<BulkGenerateRequest>,
755) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
756 const MAX_ENTRY_COUNT: u64 = 1_000_000;
758 if let Some(count) = req.entry_count {
759 if count > MAX_ENTRY_COUNT {
760 return Err((
761 StatusCode::BAD_REQUEST,
762 format!(
763 "entry_count ({}) exceeds maximum allowed value ({})",
764 count, MAX_ENTRY_COUNT
765 ),
766 ));
767 }
768 }
769
770 let config = state.server_state.config.read().await.clone();
771 let start_time = std::time::Instant::now();
772
773 let phase_config = PhaseConfig {
774 generate_master_data: req.include_master_data.unwrap_or(false),
775 generate_document_flows: false,
776 generate_journal_entries: true,
777 inject_anomalies: req.inject_anomalies.unwrap_or(false),
778 show_progress: false,
779 ..Default::default()
780 };
781
782 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
783 (
784 StatusCode::INTERNAL_SERVER_ERROR,
785 format!("Failed to create orchestrator: {}", e),
786 )
787 })?;
788
789 let result = orchestrator.generate().map_err(|e| {
790 (
791 StatusCode::INTERNAL_SERVER_ERROR,
792 format!("Generation failed: {}", e),
793 )
794 })?;
795
796 let duration_ms = start_time.elapsed().as_millis() as u64;
797 let entries_count = result.journal_entries.len() as u64;
798 let anomaly_count = result.anomaly_labels.labels.len() as u64;
799
800 state
802 .server_state
803 .total_entries
804 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
805 state
806 .server_state
807 .total_anomalies
808 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
809
810 Ok(Json(BulkGenerateResponse {
811 success: true,
812 entries_generated: entries_count,
813 duration_ms,
814 anomaly_count,
815 }))
816}
817
818async fn start_stream(
820 State(state): State<AppState>,
821 Json(req): Json<StreamRequest>,
822) -> Json<StreamResponse> {
823 if let Some(eps) = req.events_per_second {
825 info!("Stream configured: events_per_second={}", eps);
826 state
827 .server_state
828 .stream_events_per_second
829 .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
830 }
831 if let Some(max) = req.max_events {
832 info!("Stream configured: max_events={}", max);
833 state
834 .server_state
835 .stream_max_events
836 .store(max, std::sync::atomic::Ordering::Relaxed);
837 }
838 if let Some(inject) = req.inject_anomalies {
839 info!("Stream configured: inject_anomalies={}", inject);
840 state
841 .server_state
842 .stream_inject_anomalies
843 .store(inject, std::sync::atomic::Ordering::Relaxed);
844 }
845
846 state
847 .server_state
848 .stream_stopped
849 .store(false, std::sync::atomic::Ordering::Relaxed);
850 state
851 .server_state
852 .stream_paused
853 .store(false, std::sync::atomic::Ordering::Relaxed);
854
855 Json(StreamResponse {
856 success: true,
857 message: "Stream started".to_string(),
858 })
859}
860
861async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
863 state
864 .server_state
865 .stream_stopped
866 .store(true, std::sync::atomic::Ordering::Relaxed);
867
868 Json(StreamResponse {
869 success: true,
870 message: "Stream stopped".to_string(),
871 })
872}
873
874async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
876 state
877 .server_state
878 .stream_paused
879 .store(true, std::sync::atomic::Ordering::Relaxed);
880
881 Json(StreamResponse {
882 success: true,
883 message: "Stream paused".to_string(),
884 })
885}
886
887async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
889 state
890 .server_state
891 .stream_paused
892 .store(false, std::sync::atomic::Ordering::Relaxed);
893
894 Json(StreamResponse {
895 success: true,
896 message: "Stream resumed".to_string(),
897 })
898}
899
900async fn trigger_pattern(
905 State(state): State<AppState>,
906 axum::extract::Path(pattern): axum::extract::Path<String>,
907) -> Json<StreamResponse> {
908 info!("Pattern trigger requested: {}", pattern);
909
910 let valid_patterns = [
912 "year_end_spike",
913 "period_end_spike",
914 "holiday_cluster",
915 "fraud_cluster",
916 "error_cluster",
917 "uniform",
918 ];
919
920 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
921
922 if !is_valid {
923 return Json(StreamResponse {
924 success: false,
925 message: format!(
926 "Unknown pattern '{}'. Valid patterns: {:?}, or use 'custom:name' for custom patterns",
927 pattern, valid_patterns
928 ),
929 });
930 }
931
932 match state.server_state.triggered_pattern.try_write() {
934 Ok(mut triggered) => {
935 *triggered = Some(pattern.clone());
936 Json(StreamResponse {
937 success: true,
938 message: format!("Pattern '{}' will be applied to upcoming entries", pattern),
939 })
940 }
941 Err(_) => Json(StreamResponse {
942 success: false,
943 message: "Failed to acquire lock for pattern trigger".to_string(),
944 }),
945 }
946}
947
948async fn websocket_metrics(
950 ws: WebSocketUpgrade,
951 State(state): State<AppState>,
952) -> impl IntoResponse {
953 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
954}
955
956async fn websocket_events(
958 ws: WebSocketUpgrade,
959 State(state): State<AppState>,
960) -> impl IntoResponse {
961 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
962}
963
964async fn submit_job(
970 State(state): State<AppState>,
971 Json(request): Json<JobRequest>,
972) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
973 let queue = state.job_queue.as_ref().ok_or_else(|| {
974 (
975 StatusCode::SERVICE_UNAVAILABLE,
976 Json(serde_json::json!({"error": "Job queue not enabled"})),
977 )
978 })?;
979
980 let job_id = queue.submit(request).await;
981 info!("Job submitted: {}", job_id);
982
983 Ok((
984 StatusCode::CREATED,
985 Json(serde_json::json!({
986 "id": job_id.to_string(),
987 "status": "queued"
988 })),
989 ))
990}
991
992async fn get_job(
994 State(state): State<AppState>,
995 axum::extract::Path(id): axum::extract::Path<String>,
996) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
997 let queue = state.job_queue.as_ref().ok_or_else(|| {
998 (
999 StatusCode::SERVICE_UNAVAILABLE,
1000 Json(serde_json::json!({"error": "Job queue not enabled"})),
1001 )
1002 })?;
1003
1004 match queue.get(&id).await {
1005 Some(entry) => Ok(Json(serde_json::json!({
1006 "id": entry.id,
1007 "status": format!("{:?}", entry.status).to_lowercase(),
1008 "submitted_at": entry.submitted_at.to_rfc3339(),
1009 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1010 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1011 "result": entry.result,
1012 }))),
1013 None => Err((
1014 StatusCode::NOT_FOUND,
1015 Json(serde_json::json!({"error": "Job not found"})),
1016 )),
1017 }
1018}
1019
1020async fn list_jobs(
1022 State(state): State<AppState>,
1023) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1024 let queue = state.job_queue.as_ref().ok_or_else(|| {
1025 (
1026 StatusCode::SERVICE_UNAVAILABLE,
1027 Json(serde_json::json!({"error": "Job queue not enabled"})),
1028 )
1029 })?;
1030
1031 let summaries: Vec<_> = queue
1032 .list()
1033 .await
1034 .into_iter()
1035 .map(|s| {
1036 serde_json::json!({
1037 "id": s.id,
1038 "status": format!("{:?}", s.status).to_lowercase(),
1039 "submitted_at": s.submitted_at.to_rfc3339(),
1040 })
1041 })
1042 .collect();
1043
1044 Ok(Json(serde_json::json!({ "jobs": summaries })))
1045}
1046
1047async fn cancel_job(
1049 State(state): State<AppState>,
1050 axum::extract::Path(id): axum::extract::Path<String>,
1051) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1052 let queue = state.job_queue.as_ref().ok_or_else(|| {
1053 (
1054 StatusCode::SERVICE_UNAVAILABLE,
1055 Json(serde_json::json!({"error": "Job queue not enabled"})),
1056 )
1057 })?;
1058
1059 if queue.cancel(&id).await {
1060 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1061 } else {
1062 Err((
1063 StatusCode::CONFLICT,
1064 Json(
1065 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1066 ),
1067 ))
1068 }
1069}
1070
1071async fn reload_config(
1077 State(state): State<AppState>,
1078) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1079 let source = state.server_state.config_source.read().await.clone();
1080 match crate::config_loader::load_config(&source).await {
1081 Ok(new_config) => {
1082 let mut config = state.server_state.config.write().await;
1083 *config = new_config;
1084 info!("Configuration reloaded via REST API from {:?}", source);
1085 Ok(Json(serde_json::json!({
1086 "success": true,
1087 "message": "Configuration reloaded"
1088 })))
1089 }
1090 Err(e) => {
1091 error!("Failed to reload configuration: {}", e);
1092 Err((
1093 StatusCode::INTERNAL_SERVER_ERROR,
1094 Json(serde_json::json!({
1095 "success": false,
1096 "message": format!("Failed to reload configuration: {}", e)
1097 })),
1098 ))
1099 }
1100 }
1101}
1102
1103#[cfg(test)]
1104#[allow(clippy::unwrap_used)]
1105mod tests {
1106 use super::*;
1107
1108 #[test]
1113 fn test_health_response_serialization() {
1114 let response = HealthResponse {
1115 healthy: true,
1116 version: "0.1.0".to_string(),
1117 uptime_seconds: 100,
1118 };
1119 let json = serde_json::to_string(&response).unwrap();
1120 assert!(json.contains("healthy"));
1121 assert!(json.contains("version"));
1122 assert!(json.contains("uptime_seconds"));
1123 }
1124
1125 #[test]
1126 fn test_health_response_deserialization() {
1127 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1128 let response: HealthResponse = serde_json::from_str(json).unwrap();
1129 assert!(response.healthy);
1130 assert_eq!(response.version, "0.1.0");
1131 assert_eq!(response.uptime_seconds, 100);
1132 }
1133
1134 #[test]
1135 fn test_metrics_response_serialization() {
1136 let response = MetricsResponse {
1137 total_entries_generated: 1000,
1138 total_anomalies_injected: 10,
1139 uptime_seconds: 60,
1140 session_entries: 1000,
1141 session_entries_per_second: 16.67,
1142 active_streams: 1,
1143 total_stream_events: 500,
1144 };
1145 let json = serde_json::to_string(&response).unwrap();
1146 assert!(json.contains("total_entries_generated"));
1147 assert!(json.contains("session_entries_per_second"));
1148 }
1149
1150 #[test]
1151 fn test_metrics_response_deserialization() {
1152 let json = r#"{
1153 "total_entries_generated": 5000,
1154 "total_anomalies_injected": 50,
1155 "uptime_seconds": 300,
1156 "session_entries": 5000,
1157 "session_entries_per_second": 16.67,
1158 "active_streams": 2,
1159 "total_stream_events": 10000
1160 }"#;
1161 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1162 assert_eq!(response.total_entries_generated, 5000);
1163 assert_eq!(response.active_streams, 2);
1164 }
1165
1166 #[test]
1167 fn test_config_response_serialization() {
1168 let response = ConfigResponse {
1169 success: true,
1170 message: "Configuration loaded".to_string(),
1171 config: Some(GenerationConfigDto {
1172 industry: "manufacturing".to_string(),
1173 start_date: "2024-01-01".to_string(),
1174 period_months: 12,
1175 seed: Some(42),
1176 coa_complexity: "medium".to_string(),
1177 companies: vec![],
1178 fraud_enabled: false,
1179 fraud_rate: 0.0,
1180 }),
1181 };
1182 let json = serde_json::to_string(&response).unwrap();
1183 assert!(json.contains("success"));
1184 assert!(json.contains("config"));
1185 }
1186
1187 #[test]
1188 fn test_config_response_without_config() {
1189 let response = ConfigResponse {
1190 success: false,
1191 message: "No configuration available".to_string(),
1192 config: None,
1193 };
1194 let json = serde_json::to_string(&response).unwrap();
1195 assert!(json.contains("null") || json.contains("config\":null"));
1196 }
1197
1198 #[test]
1199 fn test_generation_config_dto_roundtrip() {
1200 let original = GenerationConfigDto {
1201 industry: "retail".to_string(),
1202 start_date: "2024-06-01".to_string(),
1203 period_months: 6,
1204 seed: Some(12345),
1205 coa_complexity: "large".to_string(),
1206 companies: vec![CompanyConfigDto {
1207 code: "1000".to_string(),
1208 name: "Test Corp".to_string(),
1209 currency: "USD".to_string(),
1210 country: "US".to_string(),
1211 annual_transaction_volume: 100000,
1212 volume_weight: 1.0,
1213 }],
1214 fraud_enabled: true,
1215 fraud_rate: 0.05,
1216 };
1217
1218 let json = serde_json::to_string(&original).unwrap();
1219 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1220
1221 assert_eq!(original.industry, deserialized.industry);
1222 assert_eq!(original.seed, deserialized.seed);
1223 assert_eq!(original.companies.len(), deserialized.companies.len());
1224 }
1225
1226 #[test]
1227 fn test_company_config_dto_serialization() {
1228 let company = CompanyConfigDto {
1229 code: "2000".to_string(),
1230 name: "European Subsidiary".to_string(),
1231 currency: "EUR".to_string(),
1232 country: "DE".to_string(),
1233 annual_transaction_volume: 50000,
1234 volume_weight: 0.5,
1235 };
1236 let json = serde_json::to_string(&company).unwrap();
1237 assert!(json.contains("2000"));
1238 assert!(json.contains("EUR"));
1239 assert!(json.contains("DE"));
1240 }
1241
1242 #[test]
1243 fn test_bulk_generate_request_deserialization() {
1244 let json = r#"{
1245 "entry_count": 5000,
1246 "include_master_data": true,
1247 "inject_anomalies": true
1248 }"#;
1249 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1250 assert_eq!(request.entry_count, Some(5000));
1251 assert_eq!(request.include_master_data, Some(true));
1252 assert_eq!(request.inject_anomalies, Some(true));
1253 }
1254
1255 #[test]
1256 fn test_bulk_generate_request_with_defaults() {
1257 let json = r#"{}"#;
1258 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1259 assert_eq!(request.entry_count, None);
1260 assert_eq!(request.include_master_data, None);
1261 assert_eq!(request.inject_anomalies, None);
1262 }
1263
1264 #[test]
1265 fn test_bulk_generate_response_serialization() {
1266 let response = BulkGenerateResponse {
1267 success: true,
1268 entries_generated: 1000,
1269 duration_ms: 250,
1270 anomaly_count: 20,
1271 };
1272 let json = serde_json::to_string(&response).unwrap();
1273 assert!(json.contains("entries_generated"));
1274 assert!(json.contains("1000"));
1275 assert!(json.contains("duration_ms"));
1276 }
1277
1278 #[test]
1279 fn test_stream_response_serialization() {
1280 let response = StreamResponse {
1281 success: true,
1282 message: "Stream started successfully".to_string(),
1283 };
1284 let json = serde_json::to_string(&response).unwrap();
1285 assert!(json.contains("success"));
1286 assert!(json.contains("Stream started"));
1287 }
1288
1289 #[test]
1290 fn test_stream_response_failure() {
1291 let response = StreamResponse {
1292 success: false,
1293 message: "Stream failed to start".to_string(),
1294 };
1295 let json = serde_json::to_string(&response).unwrap();
1296 assert!(json.contains("false"));
1297 assert!(json.contains("failed"));
1298 }
1299
1300 #[test]
1305 fn test_cors_config_default() {
1306 let config = CorsConfig::default();
1307 assert!(!config.allow_any_origin);
1308 assert!(!config.allowed_origins.is_empty());
1309 assert!(config
1310 .allowed_origins
1311 .contains(&"http://localhost:5173".to_string()));
1312 assert!(config
1313 .allowed_origins
1314 .contains(&"tauri://localhost".to_string()));
1315 }
1316
1317 #[test]
1318 fn test_cors_config_custom_origins() {
1319 let config = CorsConfig {
1320 allowed_origins: vec![
1321 "https://example.com".to_string(),
1322 "https://app.example.com".to_string(),
1323 ],
1324 allow_any_origin: false,
1325 };
1326 assert_eq!(config.allowed_origins.len(), 2);
1327 assert!(config
1328 .allowed_origins
1329 .contains(&"https://example.com".to_string()));
1330 }
1331
1332 #[test]
1333 fn test_cors_config_permissive() {
1334 let config = CorsConfig {
1335 allowed_origins: vec![],
1336 allow_any_origin: true,
1337 };
1338 assert!(config.allow_any_origin);
1339 }
1340
1341 #[test]
1346 fn test_bulk_generate_request_partial() {
1347 let json = r#"{"entry_count": 100}"#;
1348 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1349 assert_eq!(request.entry_count, Some(100));
1350 assert!(request.include_master_data.is_none());
1351 }
1352
1353 #[test]
1354 fn test_generation_config_no_seed() {
1355 let config = GenerationConfigDto {
1356 industry: "technology".to_string(),
1357 start_date: "2024-01-01".to_string(),
1358 period_months: 3,
1359 seed: None,
1360 coa_complexity: "small".to_string(),
1361 companies: vec![],
1362 fraud_enabled: false,
1363 fraud_rate: 0.0,
1364 };
1365 let json = serde_json::to_string(&config).unwrap();
1366 assert!(json.contains("seed"));
1367 }
1368
1369 #[test]
1370 fn test_generation_config_multiple_companies() {
1371 let config = GenerationConfigDto {
1372 industry: "manufacturing".to_string(),
1373 start_date: "2024-01-01".to_string(),
1374 period_months: 12,
1375 seed: Some(42),
1376 coa_complexity: "large".to_string(),
1377 companies: vec![
1378 CompanyConfigDto {
1379 code: "1000".to_string(),
1380 name: "Headquarters".to_string(),
1381 currency: "USD".to_string(),
1382 country: "US".to_string(),
1383 annual_transaction_volume: 100000,
1384 volume_weight: 1.0,
1385 },
1386 CompanyConfigDto {
1387 code: "2000".to_string(),
1388 name: "European Sub".to_string(),
1389 currency: "EUR".to_string(),
1390 country: "DE".to_string(),
1391 annual_transaction_volume: 50000,
1392 volume_weight: 0.5,
1393 },
1394 CompanyConfigDto {
1395 code: "3000".to_string(),
1396 name: "APAC Sub".to_string(),
1397 currency: "JPY".to_string(),
1398 country: "JP".to_string(),
1399 annual_transaction_volume: 30000,
1400 volume_weight: 0.3,
1401 },
1402 ],
1403 fraud_enabled: true,
1404 fraud_rate: 0.02,
1405 };
1406 assert_eq!(config.companies.len(), 3);
1407 }
1408
1409 #[test]
1414 fn test_metrics_entries_per_second_calculation() {
1415 let total_entries: u64 = 1000;
1417 let uptime: u64 = 60;
1418 let eps = if uptime > 0 {
1419 total_entries as f64 / uptime as f64
1420 } else {
1421 0.0
1422 };
1423 assert!((eps - 16.67).abs() < 0.1);
1424 }
1425
1426 #[test]
1427 fn test_metrics_entries_per_second_zero_uptime() {
1428 let total_entries: u64 = 1000;
1429 let uptime: u64 = 0;
1430 let eps = if uptime > 0 {
1431 total_entries as f64 / uptime as f64
1432 } else {
1433 0.0
1434 };
1435 assert_eq!(eps, 0.0);
1436 }
1437}