1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 pub server_state: Arc<ServerState>,
28 pub job_queue: Option<Arc<JobQueue>>,
29}
30
31#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34 pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39 fn default() -> Self {
40 Self {
41 request_timeout_secs: 300,
43 }
44 }
45}
46
47impl TimeoutConfig {
48 pub fn new(timeout_secs: u64) -> Self {
50 Self {
51 request_timeout_secs: timeout_secs,
52 }
53 }
54}
55
56#[derive(Clone)]
58pub struct CorsConfig {
59 pub allowed_origins: Vec<String>,
61 pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66 fn default() -> Self {
67 Self {
68 allowed_origins: vec![
69 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
72 "http://127.0.0.1:3000".to_string(),
73 "tauri://localhost".to_string(), ],
75 allow_any_origin: false,
76 }
77 }
78}
79
80async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82 let (mut parts, body) = response.into_parts();
83 parts.headers.insert(
84 axum::http::HeaderName::from_static("x-api-version"),
85 axum::http::HeaderValue::from_static("v1"),
86 );
87 axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97pub fn create_router(service: SynthService) -> Router {
99 create_router_with_cors(service, CorsConfig::default())
100}
101
102pub fn create_router_full(
107 service: SynthService,
108 cors_config: CorsConfig,
109 auth_config: AuthConfig,
110 rate_limit_config: RateLimitConfig,
111 timeout_config: TimeoutConfig,
112) -> Router {
113 let backend = RateLimitBackend::in_memory(rate_limit_config);
114 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117pub fn create_router_full_with_backend(
133 service: SynthService,
134 cors_config: CorsConfig,
135 auth_config: AuthConfig,
136 rate_limit_backend: RateLimitBackend,
137 timeout_config: TimeoutConfig,
138) -> Router {
139 let server_state = service.state.clone();
140 let state = AppState {
141 server_state,
142 job_queue: None,
143 };
144
145 let cors = if cors_config.allow_any_origin {
146 CorsLayer::permissive()
147 } else {
148 let origins: Vec<_> = cors_config
149 .allowed_origins
150 .iter()
151 .filter_map(|o| o.parse().ok())
152 .collect();
153
154 CorsLayer::new()
155 .allow_origin(AllowOrigin::list(origins))
156 .allow_methods([
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::DELETE,
161 Method::OPTIONS,
162 ])
163 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164 };
165
166 Router::new()
167 .route("/health", get(health_check))
169 .route("/ready", get(readiness_check))
170 .route("/live", get(liveness_check))
171 .route("/api/metrics", get(get_metrics))
172 .route("/metrics", get(prometheus_metrics))
173 .route("/api/config", get(get_config))
175 .route("/api/config", post(set_config))
176 .route("/api/config/reload", post(reload_config))
177 .route("/api/generate/bulk", post(bulk_generate))
179 .route("/api/stream/start", post(start_stream))
180 .route("/api/stream/stop", post(stop_stream))
181 .route("/api/stream/pause", post(pause_stream))
182 .route("/api/stream/resume", post(resume_stream))
183 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184 .route("/api/jobs/submit", post(submit_job))
186 .route("/api/jobs", get(list_jobs))
187 .route("/api/jobs/{id}", get(get_job))
188 .route("/api/jobs/{id}/cancel", post(cancel_job))
189 .route("/ws/metrics", get(websocket_metrics))
191 .route("/ws/events", get(websocket_events))
192 .layer(axum::middleware::from_fn(security_headers_middleware))
195 .layer(axum::middleware::map_response(api_version_header))
196 .layer(cors)
197 .layer(axum::middleware::from_fn(request_id_middleware))
198 .layer(axum::middleware::from_fn(auth_middleware))
199 .layer(axum::Extension(auth_config))
200 .layer(axum::middleware::from_fn(request_validation_middleware))
201 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202 .layer(axum::Extension(rate_limit_backend))
203 .layer(TimeoutLayer::with_status_code(
204 StatusCode::REQUEST_TIMEOUT,
205 Duration::from_secs(timeout_config.request_timeout_secs),
206 ))
207 .with_state(state)
208}
209
210pub fn create_router_with_auth(
212 service: SynthService,
213 cors_config: CorsConfig,
214 auth_config: AuthConfig,
215) -> Router {
216 let server_state = service.state.clone();
217 let state = AppState {
218 server_state,
219 job_queue: None,
220 };
221
222 let cors = if cors_config.allow_any_origin {
223 CorsLayer::permissive()
224 } else {
225 let origins: Vec<_> = cors_config
226 .allowed_origins
227 .iter()
228 .filter_map(|o| o.parse().ok())
229 .collect();
230
231 CorsLayer::new()
232 .allow_origin(AllowOrigin::list(origins))
233 .allow_methods([
234 Method::GET,
235 Method::POST,
236 Method::PUT,
237 Method::DELETE,
238 Method::OPTIONS,
239 ])
240 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
241 };
242
243 Router::new()
244 .route("/health", get(health_check))
246 .route("/ready", get(readiness_check))
247 .route("/live", get(liveness_check))
248 .route("/api/metrics", get(get_metrics))
249 .route("/metrics", get(prometheus_metrics))
250 .route("/api/config", get(get_config))
252 .route("/api/config", post(set_config))
253 .route("/api/config/reload", post(reload_config))
254 .route("/api/generate/bulk", post(bulk_generate))
256 .route("/api/stream/start", post(start_stream))
257 .route("/api/stream/stop", post(stop_stream))
258 .route("/api/stream/pause", post(pause_stream))
259 .route("/api/stream/resume", post(resume_stream))
260 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
261 .route("/api/jobs/submit", post(submit_job))
263 .route("/api/jobs", get(list_jobs))
264 .route("/api/jobs/{id}", get(get_job))
265 .route("/api/jobs/{id}/cancel", post(cancel_job))
266 .route("/ws/metrics", get(websocket_metrics))
268 .route("/ws/events", get(websocket_events))
269 .layer(axum::middleware::from_fn(auth_middleware))
270 .layer(axum::Extension(auth_config))
271 .layer(cors)
272 .with_state(state)
273}
274
275pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
277 let server_state = service.state.clone();
278 let state = AppState {
279 server_state,
280 job_queue: None,
281 };
282
283 let cors = if cors_config.allow_any_origin {
284 CorsLayer::permissive()
286 } else {
287 let origins: Vec<_> = cors_config
289 .allowed_origins
290 .iter()
291 .filter_map(|o| o.parse().ok())
292 .collect();
293
294 CorsLayer::new()
295 .allow_origin(AllowOrigin::list(origins))
296 .allow_methods([
297 Method::GET,
298 Method::POST,
299 Method::PUT,
300 Method::DELETE,
301 Method::OPTIONS,
302 ])
303 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
304 };
305
306 Router::new()
307 .route("/health", get(health_check))
309 .route("/ready", get(readiness_check))
310 .route("/live", get(liveness_check))
311 .route("/api/metrics", get(get_metrics))
312 .route("/metrics", get(prometheus_metrics))
313 .route("/api/config", get(get_config))
315 .route("/api/config", post(set_config))
316 .route("/api/config/reload", post(reload_config))
317 .route("/api/generate/bulk", post(bulk_generate))
319 .route("/api/stream/start", post(start_stream))
320 .route("/api/stream/stop", post(stop_stream))
321 .route("/api/stream/pause", post(pause_stream))
322 .route("/api/stream/resume", post(resume_stream))
323 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
324 .route("/api/jobs/submit", post(submit_job))
326 .route("/api/jobs", get(list_jobs))
327 .route("/api/jobs/{id}", get(get_job))
328 .route("/api/jobs/{id}/cancel", post(cancel_job))
329 .route("/ws/metrics", get(websocket_metrics))
331 .route("/ws/events", get(websocket_events))
332 .layer(cors)
333 .with_state(state)
334}
335
336#[derive(Debug, Serialize, Deserialize)]
341pub struct HealthResponse {
342 pub healthy: bool,
343 pub version: String,
344 pub uptime_seconds: u64,
345}
346
347#[derive(Debug, Serialize, Deserialize)]
349pub struct ReadinessResponse {
350 pub ready: bool,
351 pub message: String,
352 pub checks: Vec<HealthCheck>,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
357pub struct HealthCheck {
358 pub name: String,
359 pub status: String,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
364pub struct LivenessResponse {
365 pub alive: bool,
366 pub timestamp: String,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370pub struct MetricsResponse {
371 pub total_entries_generated: u64,
372 pub total_anomalies_injected: u64,
373 pub uptime_seconds: u64,
374 pub session_entries: u64,
375 pub session_entries_per_second: f64,
376 pub active_streams: u32,
377 pub total_stream_events: u64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ConfigResponse {
382 pub success: bool,
383 pub message: String,
384 pub config: Option<GenerationConfigDto>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct GenerationConfigDto {
389 pub industry: String,
390 pub start_date: String,
391 pub period_months: u32,
392 pub seed: Option<u64>,
393 pub coa_complexity: String,
394 pub companies: Vec<CompanyConfigDto>,
395 pub fraud_enabled: bool,
396 pub fraud_rate: f32,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct CompanyConfigDto {
401 pub code: String,
402 pub name: String,
403 pub currency: String,
404 pub country: String,
405 pub annual_transaction_volume: u64,
406 pub volume_weight: f32,
407}
408
409#[derive(Debug, Deserialize)]
410pub struct BulkGenerateRequest {
411 pub entry_count: Option<u64>,
412 pub include_master_data: Option<bool>,
413 pub inject_anomalies: Option<bool>,
414}
415
416#[derive(Debug, Serialize)]
417pub struct BulkGenerateResponse {
418 pub success: bool,
419 pub entries_generated: u64,
420 pub duration_ms: u64,
421 pub anomaly_count: u64,
422}
423
424#[derive(Debug, Deserialize)]
425#[allow(dead_code)] pub struct StreamRequest {
427 pub events_per_second: Option<u32>,
428 pub max_events: Option<u64>,
429 pub inject_anomalies: Option<bool>,
430}
431
432#[derive(Debug, Serialize)]
433pub struct StreamResponse {
434 pub success: bool,
435 pub message: String,
436}
437
438async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
444 Json(HealthResponse {
445 healthy: true,
446 version: env!("CARGO_PKG_VERSION").to_string(),
447 uptime_seconds: state.server_state.uptime_seconds(),
448 })
449}
450
451async fn readiness_check(
454 State(state): State<AppState>,
455) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
456 let mut checks = Vec::new();
457 let mut any_fail = false;
458
459 let config = state.server_state.config.read().await;
461 let config_valid = !config.companies.is_empty();
462 checks.push(HealthCheck {
463 name: "config".to_string(),
464 status: if config_valid { "ok" } else { "fail" }.to_string(),
465 });
466 if !config_valid {
467 any_fail = true;
468 }
469 drop(config);
470
471 let resource_status = state.server_state.resource_status();
473 let memory_status = if resource_status.degradation_level == "Emergency" {
474 any_fail = true;
475 "fail"
476 } else if resource_status.degradation_level != "Normal" {
477 "degraded"
478 } else {
479 "ok"
480 };
481 checks.push(HealthCheck {
482 name: "memory".to_string(),
483 status: memory_status.to_string(),
484 });
485
486 let disk_ok = resource_status.disk_available_mb > 100;
488 checks.push(HealthCheck {
489 name: "disk".to_string(),
490 status: if disk_ok { "ok" } else { "fail" }.to_string(),
491 });
492 if !disk_ok {
493 any_fail = true;
494 }
495
496 let response = ReadinessResponse {
497 ready: !any_fail,
498 message: if any_fail {
499 "Service is not ready".to_string()
500 } else {
501 "Service is ready".to_string()
502 },
503 checks,
504 };
505
506 if any_fail {
507 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
508 } else {
509 Ok(Json(response))
510 }
511}
512
513async fn liveness_check() -> Json<LivenessResponse> {
516 Json(LivenessResponse {
517 alive: true,
518 timestamp: chrono::Utc::now().to_rfc3339(),
519 })
520}
521
522async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
525 use std::sync::atomic::Ordering;
526
527 let uptime = state.server_state.uptime_seconds();
528 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
529 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
530 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
531 let total_stream_events = state
532 .server_state
533 .total_stream_events
534 .load(Ordering::Relaxed);
535
536 let entries_per_second = if uptime > 0 {
537 total_entries as f64 / uptime as f64
538 } else {
539 0.0
540 };
541
542 let metrics = format!(
543 r#"# HELP synth_entries_generated_total Total number of journal entries generated
544# TYPE synth_entries_generated_total counter
545synth_entries_generated_total {}
546
547# HELP synth_anomalies_injected_total Total number of anomalies injected
548# TYPE synth_anomalies_injected_total counter
549synth_anomalies_injected_total {}
550
551# HELP synth_uptime_seconds Server uptime in seconds
552# TYPE synth_uptime_seconds gauge
553synth_uptime_seconds {}
554
555# HELP synth_entries_per_second Rate of entry generation
556# TYPE synth_entries_per_second gauge
557synth_entries_per_second {:.2}
558
559# HELP synth_active_streams Number of active streaming connections
560# TYPE synth_active_streams gauge
561synth_active_streams {}
562
563# HELP synth_stream_events_total Total events sent through streams
564# TYPE synth_stream_events_total counter
565synth_stream_events_total {}
566
567# HELP synth_info Server version information
568# TYPE synth_info gauge
569synth_info{{version="{}"}} 1
570"#,
571 total_entries,
572 total_anomalies,
573 uptime,
574 entries_per_second,
575 active_streams,
576 total_stream_events,
577 env!("CARGO_PKG_VERSION")
578 );
579
580 (
581 StatusCode::OK,
582 [(
583 header::CONTENT_TYPE,
584 "text/plain; version=0.0.4; charset=utf-8",
585 )],
586 metrics,
587 )
588}
589
590async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
592 let uptime = state.server_state.uptime_seconds();
593 let total_entries = state
594 .server_state
595 .total_entries
596 .load(std::sync::atomic::Ordering::Relaxed);
597
598 let entries_per_second = if uptime > 0 {
599 total_entries as f64 / uptime as f64
600 } else {
601 0.0
602 };
603
604 Json(MetricsResponse {
605 total_entries_generated: total_entries,
606 total_anomalies_injected: state
607 .server_state
608 .total_anomalies
609 .load(std::sync::atomic::Ordering::Relaxed),
610 uptime_seconds: uptime,
611 session_entries: total_entries,
612 session_entries_per_second: entries_per_second,
613 active_streams: state
614 .server_state
615 .active_streams
616 .load(std::sync::atomic::Ordering::Relaxed) as u32,
617 total_stream_events: state
618 .server_state
619 .total_stream_events
620 .load(std::sync::atomic::Ordering::Relaxed),
621 })
622}
623
624async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
626 let config = state.server_state.config.read().await;
627
628 Json(ConfigResponse {
629 success: true,
630 message: "Current configuration".to_string(),
631 config: Some(GenerationConfigDto {
632 industry: format!("{:?}", config.global.industry),
633 start_date: config.global.start_date.clone(),
634 period_months: config.global.period_months,
635 seed: config.global.seed,
636 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
637 companies: config
638 .companies
639 .iter()
640 .map(|c| CompanyConfigDto {
641 code: c.code.clone(),
642 name: c.name.clone(),
643 currency: c.currency.clone(),
644 country: c.country.clone(),
645 annual_transaction_volume: c.annual_transaction_volume.count(),
646 volume_weight: c.volume_weight as f32,
647 })
648 .collect(),
649 fraud_enabled: config.fraud.enabled,
650 fraud_rate: config.fraud.fraud_rate as f32,
651 }),
652 })
653}
654
655async fn set_config(
657 State(state): State<AppState>,
658 Json(new_config): Json<GenerationConfigDto>,
659) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
660 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
661 use datasynth_core::models::{CoAComplexity, IndustrySector};
662
663 info!(
664 "Configuration update requested: industry={}, period_months={}",
665 new_config.industry, new_config.period_months
666 );
667
668 let industry = match new_config.industry.to_lowercase().as_str() {
670 "manufacturing" => IndustrySector::Manufacturing,
671 "retail" => IndustrySector::Retail,
672 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
673 "healthcare" => IndustrySector::Healthcare,
674 "technology" => IndustrySector::Technology,
675 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
676 "energy" => IndustrySector::Energy,
677 "transportation" => IndustrySector::Transportation,
678 "real_estate" | "realestate" => IndustrySector::RealEstate,
679 "telecommunications" => IndustrySector::Telecommunications,
680 _ => {
681 return Err((
682 StatusCode::BAD_REQUEST,
683 Json(ConfigResponse {
684 success: false,
685 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
686 config: None,
687 }),
688 ));
689 }
690 };
691
692 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
694 "small" => CoAComplexity::Small,
695 "medium" => CoAComplexity::Medium,
696 "large" => CoAComplexity::Large,
697 _ => {
698 return Err((
699 StatusCode::BAD_REQUEST,
700 Json(ConfigResponse {
701 success: false,
702 message: format!(
703 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
704 new_config.coa_complexity
705 ),
706 config: None,
707 }),
708 ));
709 }
710 };
711
712 let companies: Vec<CompanyConfig> = new_config
714 .companies
715 .iter()
716 .map(|c| CompanyConfig {
717 code: c.code.clone(),
718 name: c.name.clone(),
719 currency: c.currency.clone(),
720 country: c.country.clone(),
721 fiscal_year_variant: "K4".to_string(),
722 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
723 volume_weight: c.volume_weight as f64,
724 })
725 .collect();
726
727 let mut config = state.server_state.config.write().await;
729 config.global.industry = industry;
730 config.global.start_date = new_config.start_date.clone();
731 config.global.period_months = new_config.period_months;
732 config.global.seed = new_config.seed;
733 config.chart_of_accounts.complexity = complexity;
734 config.fraud.enabled = new_config.fraud_enabled;
735 config.fraud.fraud_rate = new_config.fraud_rate as f64;
736
737 if !companies.is_empty() {
739 config.companies = companies;
740 }
741
742 info!("Configuration updated successfully");
743
744 Ok(Json(ConfigResponse {
745 success: true,
746 message: "Configuration updated and applied".to_string(),
747 config: Some(new_config),
748 }))
749}
750
751async fn bulk_generate(
753 State(state): State<AppState>,
754 Json(req): Json<BulkGenerateRequest>,
755) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
756 const MAX_ENTRY_COUNT: u64 = 1_000_000;
758 if let Some(count) = req.entry_count {
759 if count > MAX_ENTRY_COUNT {
760 return Err((
761 StatusCode::BAD_REQUEST,
762 format!("entry_count ({count}) exceeds maximum allowed value ({MAX_ENTRY_COUNT})"),
763 ));
764 }
765 }
766
767 let config = state.server_state.config.read().await.clone();
768 let start_time = std::time::Instant::now();
769
770 let phase_config = PhaseConfig {
771 generate_master_data: req.include_master_data.unwrap_or(false),
772 generate_document_flows: false,
773 generate_journal_entries: true,
774 inject_anomalies: req.inject_anomalies.unwrap_or(false),
775 show_progress: false,
776 ..Default::default()
777 };
778
779 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
780 (
781 StatusCode::INTERNAL_SERVER_ERROR,
782 format!("Failed to create orchestrator: {e}"),
783 )
784 })?;
785
786 let result = orchestrator.generate().map_err(|e| {
787 (
788 StatusCode::INTERNAL_SERVER_ERROR,
789 format!("Generation failed: {e}"),
790 )
791 })?;
792
793 let duration_ms = start_time.elapsed().as_millis() as u64;
794 let entries_count = result.journal_entries.len() as u64;
795 let anomaly_count = result.anomaly_labels.labels.len() as u64;
796
797 state
799 .server_state
800 .total_entries
801 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
802 state
803 .server_state
804 .total_anomalies
805 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
806
807 Ok(Json(BulkGenerateResponse {
808 success: true,
809 entries_generated: entries_count,
810 duration_ms,
811 anomaly_count,
812 }))
813}
814
815async fn start_stream(
817 State(state): State<AppState>,
818 Json(req): Json<StreamRequest>,
819) -> Json<StreamResponse> {
820 if let Some(eps) = req.events_per_second {
822 info!("Stream configured: events_per_second={}", eps);
823 state
824 .server_state
825 .stream_events_per_second
826 .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
827 }
828 if let Some(max) = req.max_events {
829 info!("Stream configured: max_events={}", max);
830 state
831 .server_state
832 .stream_max_events
833 .store(max, std::sync::atomic::Ordering::Relaxed);
834 }
835 if let Some(inject) = req.inject_anomalies {
836 info!("Stream configured: inject_anomalies={}", inject);
837 state
838 .server_state
839 .stream_inject_anomalies
840 .store(inject, std::sync::atomic::Ordering::Relaxed);
841 }
842
843 state
844 .server_state
845 .stream_stopped
846 .store(false, std::sync::atomic::Ordering::Relaxed);
847 state
848 .server_state
849 .stream_paused
850 .store(false, std::sync::atomic::Ordering::Relaxed);
851
852 Json(StreamResponse {
853 success: true,
854 message: "Stream started".to_string(),
855 })
856}
857
858async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
860 state
861 .server_state
862 .stream_stopped
863 .store(true, std::sync::atomic::Ordering::Relaxed);
864
865 Json(StreamResponse {
866 success: true,
867 message: "Stream stopped".to_string(),
868 })
869}
870
871async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
873 state
874 .server_state
875 .stream_paused
876 .store(true, std::sync::atomic::Ordering::Relaxed);
877
878 Json(StreamResponse {
879 success: true,
880 message: "Stream paused".to_string(),
881 })
882}
883
884async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
886 state
887 .server_state
888 .stream_paused
889 .store(false, std::sync::atomic::Ordering::Relaxed);
890
891 Json(StreamResponse {
892 success: true,
893 message: "Stream resumed".to_string(),
894 })
895}
896
897async fn trigger_pattern(
902 State(state): State<AppState>,
903 axum::extract::Path(pattern): axum::extract::Path<String>,
904) -> Json<StreamResponse> {
905 info!("Pattern trigger requested: {}", pattern);
906
907 let valid_patterns = [
909 "year_end_spike",
910 "period_end_spike",
911 "holiday_cluster",
912 "fraud_cluster",
913 "error_cluster",
914 "uniform",
915 ];
916
917 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
918
919 if !is_valid {
920 return Json(StreamResponse {
921 success: false,
922 message: format!(
923 "Unknown pattern '{pattern}'. Valid patterns: {valid_patterns:?}, or use 'custom:name' for custom patterns"
924 ),
925 });
926 }
927
928 match state.server_state.triggered_pattern.try_write() {
930 Ok(mut triggered) => {
931 *triggered = Some(pattern.clone());
932 Json(StreamResponse {
933 success: true,
934 message: format!("Pattern '{pattern}' will be applied to upcoming entries"),
935 })
936 }
937 Err(_) => Json(StreamResponse {
938 success: false,
939 message: "Failed to acquire lock for pattern trigger".to_string(),
940 }),
941 }
942}
943
944async fn websocket_metrics(
946 ws: WebSocketUpgrade,
947 State(state): State<AppState>,
948) -> impl IntoResponse {
949 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
950}
951
952async fn websocket_events(
954 ws: WebSocketUpgrade,
955 State(state): State<AppState>,
956) -> impl IntoResponse {
957 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
958}
959
960async fn submit_job(
966 State(state): State<AppState>,
967 Json(request): Json<JobRequest>,
968) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
969 let queue = state.job_queue.as_ref().ok_or_else(|| {
970 (
971 StatusCode::SERVICE_UNAVAILABLE,
972 Json(serde_json::json!({"error": "Job queue not enabled"})),
973 )
974 })?;
975
976 let job_id = queue.submit(request).await;
977 info!("Job submitted: {}", job_id);
978
979 Ok((
980 StatusCode::CREATED,
981 Json(serde_json::json!({
982 "id": job_id.to_string(),
983 "status": "queued"
984 })),
985 ))
986}
987
988async fn get_job(
990 State(state): State<AppState>,
991 axum::extract::Path(id): axum::extract::Path<String>,
992) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
993 let queue = state.job_queue.as_ref().ok_or_else(|| {
994 (
995 StatusCode::SERVICE_UNAVAILABLE,
996 Json(serde_json::json!({"error": "Job queue not enabled"})),
997 )
998 })?;
999
1000 match queue.get(&id).await {
1001 Some(entry) => Ok(Json(serde_json::json!({
1002 "id": entry.id,
1003 "status": format!("{:?}", entry.status).to_lowercase(),
1004 "submitted_at": entry.submitted_at.to_rfc3339(),
1005 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1006 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1007 "result": entry.result,
1008 }))),
1009 None => Err((
1010 StatusCode::NOT_FOUND,
1011 Json(serde_json::json!({"error": "Job not found"})),
1012 )),
1013 }
1014}
1015
1016async fn list_jobs(
1018 State(state): State<AppState>,
1019) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1020 let queue = state.job_queue.as_ref().ok_or_else(|| {
1021 (
1022 StatusCode::SERVICE_UNAVAILABLE,
1023 Json(serde_json::json!({"error": "Job queue not enabled"})),
1024 )
1025 })?;
1026
1027 let summaries: Vec<_> = queue
1028 .list()
1029 .await
1030 .into_iter()
1031 .map(|s| {
1032 serde_json::json!({
1033 "id": s.id,
1034 "status": format!("{:?}", s.status).to_lowercase(),
1035 "submitted_at": s.submitted_at.to_rfc3339(),
1036 })
1037 })
1038 .collect();
1039
1040 Ok(Json(serde_json::json!({ "jobs": summaries })))
1041}
1042
1043async fn cancel_job(
1045 State(state): State<AppState>,
1046 axum::extract::Path(id): axum::extract::Path<String>,
1047) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1048 let queue = state.job_queue.as_ref().ok_or_else(|| {
1049 (
1050 StatusCode::SERVICE_UNAVAILABLE,
1051 Json(serde_json::json!({"error": "Job queue not enabled"})),
1052 )
1053 })?;
1054
1055 if queue.cancel(&id).await {
1056 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1057 } else {
1058 Err((
1059 StatusCode::CONFLICT,
1060 Json(
1061 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1062 ),
1063 ))
1064 }
1065}
1066
1067async fn reload_config(
1073 State(state): State<AppState>,
1074) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1075 let source = state.server_state.config_source.read().await.clone();
1076 match crate::config_loader::load_config(&source).await {
1077 Ok(new_config) => {
1078 let mut config = state.server_state.config.write().await;
1079 *config = new_config;
1080 info!("Configuration reloaded via REST API from {:?}", source);
1081 Ok(Json(serde_json::json!({
1082 "success": true,
1083 "message": "Configuration reloaded"
1084 })))
1085 }
1086 Err(e) => {
1087 error!("Failed to reload configuration: {}", e);
1088 Err((
1089 StatusCode::INTERNAL_SERVER_ERROR,
1090 Json(serde_json::json!({
1091 "success": false,
1092 "message": format!("Failed to reload configuration: {}", e)
1093 })),
1094 ))
1095 }
1096 }
1097}
1098
1099#[cfg(test)]
1100#[allow(clippy::unwrap_used)]
1101mod tests {
1102 use super::*;
1103
1104 #[test]
1109 fn test_health_response_serialization() {
1110 let response = HealthResponse {
1111 healthy: true,
1112 version: "0.1.0".to_string(),
1113 uptime_seconds: 100,
1114 };
1115 let json = serde_json::to_string(&response).unwrap();
1116 assert!(json.contains("healthy"));
1117 assert!(json.contains("version"));
1118 assert!(json.contains("uptime_seconds"));
1119 }
1120
1121 #[test]
1122 fn test_health_response_deserialization() {
1123 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1124 let response: HealthResponse = serde_json::from_str(json).unwrap();
1125 assert!(response.healthy);
1126 assert_eq!(response.version, "0.1.0");
1127 assert_eq!(response.uptime_seconds, 100);
1128 }
1129
1130 #[test]
1131 fn test_metrics_response_serialization() {
1132 let response = MetricsResponse {
1133 total_entries_generated: 1000,
1134 total_anomalies_injected: 10,
1135 uptime_seconds: 60,
1136 session_entries: 1000,
1137 session_entries_per_second: 16.67,
1138 active_streams: 1,
1139 total_stream_events: 500,
1140 };
1141 let json = serde_json::to_string(&response).unwrap();
1142 assert!(json.contains("total_entries_generated"));
1143 assert!(json.contains("session_entries_per_second"));
1144 }
1145
1146 #[test]
1147 fn test_metrics_response_deserialization() {
1148 let json = r#"{
1149 "total_entries_generated": 5000,
1150 "total_anomalies_injected": 50,
1151 "uptime_seconds": 300,
1152 "session_entries": 5000,
1153 "session_entries_per_second": 16.67,
1154 "active_streams": 2,
1155 "total_stream_events": 10000
1156 }"#;
1157 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1158 assert_eq!(response.total_entries_generated, 5000);
1159 assert_eq!(response.active_streams, 2);
1160 }
1161
1162 #[test]
1163 fn test_config_response_serialization() {
1164 let response = ConfigResponse {
1165 success: true,
1166 message: "Configuration loaded".to_string(),
1167 config: Some(GenerationConfigDto {
1168 industry: "manufacturing".to_string(),
1169 start_date: "2024-01-01".to_string(),
1170 period_months: 12,
1171 seed: Some(42),
1172 coa_complexity: "medium".to_string(),
1173 companies: vec![],
1174 fraud_enabled: false,
1175 fraud_rate: 0.0,
1176 }),
1177 };
1178 let json = serde_json::to_string(&response).unwrap();
1179 assert!(json.contains("success"));
1180 assert!(json.contains("config"));
1181 }
1182
1183 #[test]
1184 fn test_config_response_without_config() {
1185 let response = ConfigResponse {
1186 success: false,
1187 message: "No configuration available".to_string(),
1188 config: None,
1189 };
1190 let json = serde_json::to_string(&response).unwrap();
1191 assert!(json.contains("null") || json.contains("config\":null"));
1192 }
1193
1194 #[test]
1195 fn test_generation_config_dto_roundtrip() {
1196 let original = GenerationConfigDto {
1197 industry: "retail".to_string(),
1198 start_date: "2024-06-01".to_string(),
1199 period_months: 6,
1200 seed: Some(12345),
1201 coa_complexity: "large".to_string(),
1202 companies: vec![CompanyConfigDto {
1203 code: "1000".to_string(),
1204 name: "Test Corp".to_string(),
1205 currency: "USD".to_string(),
1206 country: "US".to_string(),
1207 annual_transaction_volume: 100000,
1208 volume_weight: 1.0,
1209 }],
1210 fraud_enabled: true,
1211 fraud_rate: 0.05,
1212 };
1213
1214 let json = serde_json::to_string(&original).unwrap();
1215 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1216
1217 assert_eq!(original.industry, deserialized.industry);
1218 assert_eq!(original.seed, deserialized.seed);
1219 assert_eq!(original.companies.len(), deserialized.companies.len());
1220 }
1221
1222 #[test]
1223 fn test_company_config_dto_serialization() {
1224 let company = CompanyConfigDto {
1225 code: "2000".to_string(),
1226 name: "European Subsidiary".to_string(),
1227 currency: "EUR".to_string(),
1228 country: "DE".to_string(),
1229 annual_transaction_volume: 50000,
1230 volume_weight: 0.5,
1231 };
1232 let json = serde_json::to_string(&company).unwrap();
1233 assert!(json.contains("2000"));
1234 assert!(json.contains("EUR"));
1235 assert!(json.contains("DE"));
1236 }
1237
1238 #[test]
1239 fn test_bulk_generate_request_deserialization() {
1240 let json = r#"{
1241 "entry_count": 5000,
1242 "include_master_data": true,
1243 "inject_anomalies": true
1244 }"#;
1245 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1246 assert_eq!(request.entry_count, Some(5000));
1247 assert_eq!(request.include_master_data, Some(true));
1248 assert_eq!(request.inject_anomalies, Some(true));
1249 }
1250
1251 #[test]
1252 fn test_bulk_generate_request_with_defaults() {
1253 let json = r#"{}"#;
1254 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1255 assert_eq!(request.entry_count, None);
1256 assert_eq!(request.include_master_data, None);
1257 assert_eq!(request.inject_anomalies, None);
1258 }
1259
1260 #[test]
1261 fn test_bulk_generate_response_serialization() {
1262 let response = BulkGenerateResponse {
1263 success: true,
1264 entries_generated: 1000,
1265 duration_ms: 250,
1266 anomaly_count: 20,
1267 };
1268 let json = serde_json::to_string(&response).unwrap();
1269 assert!(json.contains("entries_generated"));
1270 assert!(json.contains("1000"));
1271 assert!(json.contains("duration_ms"));
1272 }
1273
1274 #[test]
1275 fn test_stream_response_serialization() {
1276 let response = StreamResponse {
1277 success: true,
1278 message: "Stream started successfully".to_string(),
1279 };
1280 let json = serde_json::to_string(&response).unwrap();
1281 assert!(json.contains("success"));
1282 assert!(json.contains("Stream started"));
1283 }
1284
1285 #[test]
1286 fn test_stream_response_failure() {
1287 let response = StreamResponse {
1288 success: false,
1289 message: "Stream failed to start".to_string(),
1290 };
1291 let json = serde_json::to_string(&response).unwrap();
1292 assert!(json.contains("false"));
1293 assert!(json.contains("failed"));
1294 }
1295
1296 #[test]
1301 fn test_cors_config_default() {
1302 let config = CorsConfig::default();
1303 assert!(!config.allow_any_origin);
1304 assert!(!config.allowed_origins.is_empty());
1305 assert!(config
1306 .allowed_origins
1307 .contains(&"http://localhost:5173".to_string()));
1308 assert!(config
1309 .allowed_origins
1310 .contains(&"tauri://localhost".to_string()));
1311 }
1312
1313 #[test]
1314 fn test_cors_config_custom_origins() {
1315 let config = CorsConfig {
1316 allowed_origins: vec![
1317 "https://example.com".to_string(),
1318 "https://app.example.com".to_string(),
1319 ],
1320 allow_any_origin: false,
1321 };
1322 assert_eq!(config.allowed_origins.len(), 2);
1323 assert!(config
1324 .allowed_origins
1325 .contains(&"https://example.com".to_string()));
1326 }
1327
1328 #[test]
1329 fn test_cors_config_permissive() {
1330 let config = CorsConfig {
1331 allowed_origins: vec![],
1332 allow_any_origin: true,
1333 };
1334 assert!(config.allow_any_origin);
1335 }
1336
1337 #[test]
1342 fn test_bulk_generate_request_partial() {
1343 let json = r#"{"entry_count": 100}"#;
1344 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1345 assert_eq!(request.entry_count, Some(100));
1346 assert!(request.include_master_data.is_none());
1347 }
1348
1349 #[test]
1350 fn test_generation_config_no_seed() {
1351 let config = GenerationConfigDto {
1352 industry: "technology".to_string(),
1353 start_date: "2024-01-01".to_string(),
1354 period_months: 3,
1355 seed: None,
1356 coa_complexity: "small".to_string(),
1357 companies: vec![],
1358 fraud_enabled: false,
1359 fraud_rate: 0.0,
1360 };
1361 let json = serde_json::to_string(&config).unwrap();
1362 assert!(json.contains("seed"));
1363 }
1364
1365 #[test]
1366 fn test_generation_config_multiple_companies() {
1367 let config = GenerationConfigDto {
1368 industry: "manufacturing".to_string(),
1369 start_date: "2024-01-01".to_string(),
1370 period_months: 12,
1371 seed: Some(42),
1372 coa_complexity: "large".to_string(),
1373 companies: vec![
1374 CompanyConfigDto {
1375 code: "1000".to_string(),
1376 name: "Headquarters".to_string(),
1377 currency: "USD".to_string(),
1378 country: "US".to_string(),
1379 annual_transaction_volume: 100000,
1380 volume_weight: 1.0,
1381 },
1382 CompanyConfigDto {
1383 code: "2000".to_string(),
1384 name: "European Sub".to_string(),
1385 currency: "EUR".to_string(),
1386 country: "DE".to_string(),
1387 annual_transaction_volume: 50000,
1388 volume_weight: 0.5,
1389 },
1390 CompanyConfigDto {
1391 code: "3000".to_string(),
1392 name: "APAC Sub".to_string(),
1393 currency: "JPY".to_string(),
1394 country: "JP".to_string(),
1395 annual_transaction_volume: 30000,
1396 volume_weight: 0.3,
1397 },
1398 ],
1399 fraud_enabled: true,
1400 fraud_rate: 0.02,
1401 };
1402 assert_eq!(config.companies.len(), 3);
1403 }
1404
1405 #[test]
1410 fn test_metrics_entries_per_second_calculation() {
1411 let total_entries: u64 = 1000;
1413 let uptime: u64 = 60;
1414 let eps = if uptime > 0 {
1415 total_entries as f64 / uptime as f64
1416 } else {
1417 0.0
1418 };
1419 assert!((eps - 16.67).abs() < 0.1);
1420 }
1421
1422 #[test]
1423 fn test_metrics_entries_per_second_zero_uptime() {
1424 let total_entries: u64 = 1000;
1425 let uptime: u64 = 0;
1426 let eps = if uptime > 0 {
1427 total_entries as f64 / uptime as f64
1428 } else {
1429 0.0
1430 };
1431 assert_eq!(eps, 0.0);
1432 }
1433}