1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 pub server_state: Arc<ServerState>,
28 pub job_queue: Option<Arc<JobQueue>>,
29}
30
31#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34 pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39 fn default() -> Self {
40 Self {
41 request_timeout_secs: 300,
43 }
44 }
45}
46
47impl TimeoutConfig {
48 pub fn new(timeout_secs: u64) -> Self {
50 Self {
51 request_timeout_secs: timeout_secs,
52 }
53 }
54}
55
56#[derive(Clone)]
58pub struct CorsConfig {
59 pub allowed_origins: Vec<String>,
61 pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66 fn default() -> Self {
67 Self {
68 allowed_origins: vec![
69 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
72 "http://127.0.0.1:3000".to_string(),
73 "tauri://localhost".to_string(), ],
75 allow_any_origin: false,
76 }
77 }
78}
79
80async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82 let (mut parts, body) = response.into_parts();
83 parts.headers.insert(
84 axum::http::HeaderName::from_static("x-api-version"),
85 axum::http::HeaderValue::from_static("v1"),
86 );
87 axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97pub fn create_router(service: SynthService) -> Router {
99 create_router_with_cors(service, CorsConfig::default())
100}
101
102pub fn create_router_full(
107 service: SynthService,
108 cors_config: CorsConfig,
109 auth_config: AuthConfig,
110 rate_limit_config: RateLimitConfig,
111 timeout_config: TimeoutConfig,
112) -> Router {
113 let backend = RateLimitBackend::in_memory(rate_limit_config);
114 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117pub fn create_router_full_with_backend(
133 service: SynthService,
134 cors_config: CorsConfig,
135 auth_config: AuthConfig,
136 rate_limit_backend: RateLimitBackend,
137 timeout_config: TimeoutConfig,
138) -> Router {
139 let server_state = service.state.clone();
140 let state = AppState {
141 server_state,
142 job_queue: None,
143 };
144
145 let cors = if cors_config.allow_any_origin {
146 CorsLayer::permissive()
147 } else {
148 let origins: Vec<_> = cors_config
149 .allowed_origins
150 .iter()
151 .filter_map(|o| o.parse().ok())
152 .collect();
153
154 CorsLayer::new()
155 .allow_origin(AllowOrigin::list(origins))
156 .allow_methods([
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::DELETE,
161 Method::OPTIONS,
162 ])
163 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164 };
165
166 Router::new()
167 .route("/health", get(health_check))
169 .route("/ready", get(readiness_check))
170 .route("/live", get(liveness_check))
171 .route("/api/metrics", get(get_metrics))
172 .route("/metrics", get(prometheus_metrics))
173 .route("/api/config", get(get_config))
175 .route("/api/config", post(set_config))
176 .route("/api/config/reload", post(reload_config))
177 .route("/api/generate/bulk", post(bulk_generate))
179 .route("/api/stream/start", post(start_stream))
180 .route("/api/stream/stop", post(stop_stream))
181 .route("/api/stream/pause", post(pause_stream))
182 .route("/api/stream/resume", post(resume_stream))
183 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184 .route("/api/jobs/submit", post(submit_job))
186 .route("/api/jobs", get(list_jobs))
187 .route("/api/jobs/{id}", get(get_job))
188 .route("/api/jobs/{id}/cancel", post(cancel_job))
189 .route("/ws/metrics", get(websocket_metrics))
191 .route("/ws/events", get(websocket_events))
192 .layer(axum::middleware::from_fn(security_headers_middleware))
195 .layer(axum::middleware::map_response(api_version_header))
196 .layer(cors)
197 .layer(axum::middleware::from_fn(request_id_middleware))
198 .layer(axum::middleware::from_fn(auth_middleware))
199 .layer(axum::Extension(auth_config))
200 .layer(axum::middleware::from_fn(request_validation_middleware))
201 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202 .layer(axum::Extension(rate_limit_backend))
203 .layer(TimeoutLayer::with_status_code(
204 StatusCode::REQUEST_TIMEOUT,
205 Duration::from_secs(timeout_config.request_timeout_secs),
206 ))
207 .with_state(state)
208}
209
210pub fn create_router_with_auth(
212 service: SynthService,
213 cors_config: CorsConfig,
214 auth_config: AuthConfig,
215) -> Router {
216 let server_state = service.state.clone();
217 let state = AppState {
218 server_state,
219 job_queue: None,
220 };
221
222 let cors = if cors_config.allow_any_origin {
223 CorsLayer::permissive()
224 } else {
225 let origins: Vec<_> = cors_config
226 .allowed_origins
227 .iter()
228 .filter_map(|o| o.parse().ok())
229 .collect();
230
231 CorsLayer::new()
232 .allow_origin(AllowOrigin::list(origins))
233 .allow_methods([
234 Method::GET,
235 Method::POST,
236 Method::PUT,
237 Method::DELETE,
238 Method::OPTIONS,
239 ])
240 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
241 };
242
243 Router::new()
244 .route("/health", get(health_check))
246 .route("/ready", get(readiness_check))
247 .route("/live", get(liveness_check))
248 .route("/api/metrics", get(get_metrics))
249 .route("/metrics", get(prometheus_metrics))
250 .route("/api/config", get(get_config))
252 .route("/api/config", post(set_config))
253 .route("/api/config/reload", post(reload_config))
254 .route("/api/generate/bulk", post(bulk_generate))
256 .route("/api/stream/start", post(start_stream))
257 .route("/api/stream/stop", post(stop_stream))
258 .route("/api/stream/pause", post(pause_stream))
259 .route("/api/stream/resume", post(resume_stream))
260 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
261 .route("/api/jobs/submit", post(submit_job))
263 .route("/api/jobs", get(list_jobs))
264 .route("/api/jobs/{id}", get(get_job))
265 .route("/api/jobs/{id}/cancel", post(cancel_job))
266 .route("/ws/metrics", get(websocket_metrics))
268 .route("/ws/events", get(websocket_events))
269 .layer(axum::middleware::from_fn(auth_middleware))
270 .layer(axum::Extension(auth_config))
271 .layer(cors)
272 .with_state(state)
273}
274
275pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
277 let server_state = service.state.clone();
278 let state = AppState {
279 server_state,
280 job_queue: None,
281 };
282
283 let cors = if cors_config.allow_any_origin {
284 CorsLayer::permissive()
286 } else {
287 let origins: Vec<_> = cors_config
289 .allowed_origins
290 .iter()
291 .filter_map(|o| o.parse().ok())
292 .collect();
293
294 CorsLayer::new()
295 .allow_origin(AllowOrigin::list(origins))
296 .allow_methods([
297 Method::GET,
298 Method::POST,
299 Method::PUT,
300 Method::DELETE,
301 Method::OPTIONS,
302 ])
303 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
304 };
305
306 Router::new()
307 .route("/health", get(health_check))
309 .route("/ready", get(readiness_check))
310 .route("/live", get(liveness_check))
311 .route("/api/metrics", get(get_metrics))
312 .route("/metrics", get(prometheus_metrics))
313 .route("/api/config", get(get_config))
315 .route("/api/config", post(set_config))
316 .route("/api/config/reload", post(reload_config))
317 .route("/api/generate/bulk", post(bulk_generate))
319 .route("/api/stream/start", post(start_stream))
320 .route("/api/stream/stop", post(stop_stream))
321 .route("/api/stream/pause", post(pause_stream))
322 .route("/api/stream/resume", post(resume_stream))
323 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
324 .route("/api/jobs/submit", post(submit_job))
326 .route("/api/jobs", get(list_jobs))
327 .route("/api/jobs/{id}", get(get_job))
328 .route("/api/jobs/{id}/cancel", post(cancel_job))
329 .route("/ws/metrics", get(websocket_metrics))
331 .route("/ws/events", get(websocket_events))
332 .layer(cors)
333 .with_state(state)
334}
335
336#[derive(Debug, Serialize, Deserialize)]
341pub struct HealthResponse {
342 pub healthy: bool,
343 pub version: String,
344 pub uptime_seconds: u64,
345}
346
347#[derive(Debug, Serialize, Deserialize)]
349pub struct ReadinessResponse {
350 pub ready: bool,
351 pub message: String,
352 pub checks: Vec<HealthCheck>,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
357pub struct HealthCheck {
358 pub name: String,
359 pub status: String,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
364pub struct LivenessResponse {
365 pub alive: bool,
366 pub timestamp: String,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370pub struct MetricsResponse {
371 pub total_entries_generated: u64,
372 pub total_anomalies_injected: u64,
373 pub uptime_seconds: u64,
374 pub session_entries: u64,
375 pub session_entries_per_second: f64,
376 pub active_streams: u32,
377 pub total_stream_events: u64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ConfigResponse {
382 pub success: bool,
383 pub message: String,
384 pub config: Option<GenerationConfigDto>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct GenerationConfigDto {
389 pub industry: String,
390 pub start_date: String,
391 pub period_months: u32,
392 pub seed: Option<u64>,
393 pub coa_complexity: String,
394 pub companies: Vec<CompanyConfigDto>,
395 pub fraud_enabled: bool,
396 pub fraud_rate: f32,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct CompanyConfigDto {
401 pub code: String,
402 pub name: String,
403 pub currency: String,
404 pub country: String,
405 pub annual_transaction_volume: u64,
406 pub volume_weight: f32,
407}
408
409#[derive(Debug, Deserialize)]
410pub struct BulkGenerateRequest {
411 pub entry_count: Option<u64>,
412 pub include_master_data: Option<bool>,
413 pub inject_anomalies: Option<bool>,
414}
415
416#[derive(Debug, Serialize)]
417pub struct BulkGenerateResponse {
418 pub success: bool,
419 pub entries_generated: u64,
420 pub duration_ms: u64,
421 pub anomaly_count: u64,
422}
423
424#[derive(Debug, Deserialize)]
425#[allow(dead_code)] pub struct StreamRequest {
427 pub events_per_second: Option<u32>,
428 pub max_events: Option<u64>,
429 pub inject_anomalies: Option<bool>,
430}
431
432#[derive(Debug, Serialize)]
433pub struct StreamResponse {
434 pub success: bool,
435 pub message: String,
436}
437
438async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
444 Json(HealthResponse {
445 healthy: true,
446 version: env!("CARGO_PKG_VERSION").to_string(),
447 uptime_seconds: state.server_state.uptime_seconds(),
448 })
449}
450
451async fn readiness_check(
454 State(state): State<AppState>,
455) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
456 let mut checks = Vec::new();
457 let mut any_fail = false;
458
459 let config = state.server_state.config.read().await;
461 let config_valid = !config.companies.is_empty();
462 checks.push(HealthCheck {
463 name: "config".to_string(),
464 status: if config_valid { "ok" } else { "fail" }.to_string(),
465 });
466 if !config_valid {
467 any_fail = true;
468 }
469 drop(config);
470
471 let resource_status = state.server_state.resource_status();
473 let memory_status = if resource_status.degradation_level == "Emergency" {
474 any_fail = true;
475 "fail"
476 } else if resource_status.degradation_level != "Normal" {
477 "degraded"
478 } else {
479 "ok"
480 };
481 checks.push(HealthCheck {
482 name: "memory".to_string(),
483 status: memory_status.to_string(),
484 });
485
486 let disk_ok = resource_status.disk_available_mb > 100;
488 checks.push(HealthCheck {
489 name: "disk".to_string(),
490 status: if disk_ok { "ok" } else { "fail" }.to_string(),
491 });
492 if !disk_ok {
493 any_fail = true;
494 }
495
496 let response = ReadinessResponse {
497 ready: !any_fail,
498 message: if any_fail {
499 "Service is not ready".to_string()
500 } else {
501 "Service is ready".to_string()
502 },
503 checks,
504 };
505
506 if any_fail {
507 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
508 } else {
509 Ok(Json(response))
510 }
511}
512
513async fn liveness_check() -> Json<LivenessResponse> {
516 Json(LivenessResponse {
517 alive: true,
518 timestamp: chrono::Utc::now().to_rfc3339(),
519 })
520}
521
522async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
525 use std::sync::atomic::Ordering;
526
527 let uptime = state.server_state.uptime_seconds();
528 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
529 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
530 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
531 let total_stream_events = state
532 .server_state
533 .total_stream_events
534 .load(Ordering::Relaxed);
535
536 let entries_per_second = if uptime > 0 {
537 total_entries as f64 / uptime as f64
538 } else {
539 0.0
540 };
541
542 let metrics = format!(
543 r#"# HELP synth_entries_generated_total Total number of journal entries generated
544# TYPE synth_entries_generated_total counter
545synth_entries_generated_total {}
546
547# HELP synth_anomalies_injected_total Total number of anomalies injected
548# TYPE synth_anomalies_injected_total counter
549synth_anomalies_injected_total {}
550
551# HELP synth_uptime_seconds Server uptime in seconds
552# TYPE synth_uptime_seconds gauge
553synth_uptime_seconds {}
554
555# HELP synth_entries_per_second Rate of entry generation
556# TYPE synth_entries_per_second gauge
557synth_entries_per_second {:.2}
558
559# HELP synth_active_streams Number of active streaming connections
560# TYPE synth_active_streams gauge
561synth_active_streams {}
562
563# HELP synth_stream_events_total Total events sent through streams
564# TYPE synth_stream_events_total counter
565synth_stream_events_total {}
566
567# HELP synth_info Server version information
568# TYPE synth_info gauge
569synth_info{{version="{}"}} 1
570"#,
571 total_entries,
572 total_anomalies,
573 uptime,
574 entries_per_second,
575 active_streams,
576 total_stream_events,
577 env!("CARGO_PKG_VERSION")
578 );
579
580 (
581 StatusCode::OK,
582 [(
583 header::CONTENT_TYPE,
584 "text/plain; version=0.0.4; charset=utf-8",
585 )],
586 metrics,
587 )
588}
589
590async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
592 let uptime = state.server_state.uptime_seconds();
593 let total_entries = state
594 .server_state
595 .total_entries
596 .load(std::sync::atomic::Ordering::Relaxed);
597
598 let entries_per_second = if uptime > 0 {
599 total_entries as f64 / uptime as f64
600 } else {
601 0.0
602 };
603
604 Json(MetricsResponse {
605 total_entries_generated: total_entries,
606 total_anomalies_injected: state
607 .server_state
608 .total_anomalies
609 .load(std::sync::atomic::Ordering::Relaxed),
610 uptime_seconds: uptime,
611 session_entries: total_entries,
612 session_entries_per_second: entries_per_second,
613 active_streams: state
614 .server_state
615 .active_streams
616 .load(std::sync::atomic::Ordering::Relaxed) as u32,
617 total_stream_events: state
618 .server_state
619 .total_stream_events
620 .load(std::sync::atomic::Ordering::Relaxed),
621 })
622}
623
624async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
626 let config = state.server_state.config.read().await;
627
628 Json(ConfigResponse {
629 success: true,
630 message: "Current configuration".to_string(),
631 config: Some(GenerationConfigDto {
632 industry: format!("{:?}", config.global.industry),
633 start_date: config.global.start_date.clone(),
634 period_months: config.global.period_months,
635 seed: config.global.seed,
636 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
637 companies: config
638 .companies
639 .iter()
640 .map(|c| CompanyConfigDto {
641 code: c.code.clone(),
642 name: c.name.clone(),
643 currency: c.currency.clone(),
644 country: c.country.clone(),
645 annual_transaction_volume: c.annual_transaction_volume.count(),
646 volume_weight: c.volume_weight as f32,
647 })
648 .collect(),
649 fraud_enabled: config.fraud.enabled,
650 fraud_rate: config.fraud.fraud_rate as f32,
651 }),
652 })
653}
654
655async fn set_config(
657 State(state): State<AppState>,
658 Json(new_config): Json<GenerationConfigDto>,
659) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
660 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
661 use datasynth_core::models::{CoAComplexity, IndustrySector};
662
663 info!(
664 "Configuration update requested: industry={}, period_months={}",
665 new_config.industry, new_config.period_months
666 );
667
668 let industry = match new_config.industry.to_lowercase().as_str() {
670 "manufacturing" => IndustrySector::Manufacturing,
671 "retail" => IndustrySector::Retail,
672 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
673 "healthcare" => IndustrySector::Healthcare,
674 "technology" => IndustrySector::Technology,
675 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
676 "energy" => IndustrySector::Energy,
677 "transportation" => IndustrySector::Transportation,
678 "real_estate" | "realestate" => IndustrySector::RealEstate,
679 "telecommunications" => IndustrySector::Telecommunications,
680 _ => {
681 return Err((
682 StatusCode::BAD_REQUEST,
683 Json(ConfigResponse {
684 success: false,
685 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
686 config: None,
687 }),
688 ));
689 }
690 };
691
692 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
694 "small" => CoAComplexity::Small,
695 "medium" => CoAComplexity::Medium,
696 "large" => CoAComplexity::Large,
697 _ => {
698 return Err((
699 StatusCode::BAD_REQUEST,
700 Json(ConfigResponse {
701 success: false,
702 message: format!(
703 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
704 new_config.coa_complexity
705 ),
706 config: None,
707 }),
708 ));
709 }
710 };
711
712 let companies: Vec<CompanyConfig> = new_config
714 .companies
715 .iter()
716 .map(|c| CompanyConfig {
717 code: c.code.clone(),
718 name: c.name.clone(),
719 currency: c.currency.clone(),
720 functional_currency: None,
721 country: c.country.clone(),
722 fiscal_year_variant: "K4".to_string(),
723 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
724 volume_weight: c.volume_weight as f64,
725 })
726 .collect();
727
728 let mut config = state.server_state.config.write().await;
730 config.global.industry = industry;
731 config.global.start_date = new_config.start_date.clone();
732 config.global.period_months = new_config.period_months;
733 config.global.seed = new_config.seed;
734 config.chart_of_accounts.complexity = complexity;
735 config.fraud.enabled = new_config.fraud_enabled;
736 config.fraud.fraud_rate = new_config.fraud_rate as f64;
737
738 if !companies.is_empty() {
740 config.companies = companies;
741 }
742
743 info!("Configuration updated successfully");
744
745 Ok(Json(ConfigResponse {
746 success: true,
747 message: "Configuration updated and applied".to_string(),
748 config: Some(new_config),
749 }))
750}
751
752async fn bulk_generate(
754 State(state): State<AppState>,
755 Json(req): Json<BulkGenerateRequest>,
756) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
757 const MAX_ENTRY_COUNT: u64 = 1_000_000;
759 if let Some(count) = req.entry_count {
760 if count > MAX_ENTRY_COUNT {
761 return Err((
762 StatusCode::BAD_REQUEST,
763 format!("entry_count ({count}) exceeds maximum allowed value ({MAX_ENTRY_COUNT})"),
764 ));
765 }
766 }
767
768 let config = state.server_state.config.read().await.clone();
769 let start_time = std::time::Instant::now();
770
771 let phase_config = {
772 let mut pc = PhaseConfig::from_config(&config);
773 pc.generate_master_data = req.include_master_data.unwrap_or(false);
774 pc.generate_document_flows = false;
775 pc.generate_journal_entries = true;
776 pc.inject_anomalies = req.inject_anomalies.unwrap_or(false);
777 pc.show_progress = false;
778 pc
779 };
780
781 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
782 (
783 StatusCode::INTERNAL_SERVER_ERROR,
784 format!("Failed to create orchestrator: {e}"),
785 )
786 })?;
787
788 let result = orchestrator.generate().map_err(|e| {
789 (
790 StatusCode::INTERNAL_SERVER_ERROR,
791 format!("Generation failed: {e}"),
792 )
793 })?;
794
795 let duration_ms = start_time.elapsed().as_millis() as u64;
796 let entries_count = result.journal_entries.len() as u64;
797 let anomaly_count = result.anomaly_labels.labels.len() as u64;
798
799 state
801 .server_state
802 .total_entries
803 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
804 state
805 .server_state
806 .total_anomalies
807 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
808
809 Ok(Json(BulkGenerateResponse {
810 success: true,
811 entries_generated: entries_count,
812 duration_ms,
813 anomaly_count,
814 }))
815}
816
817async fn start_stream(
819 State(state): State<AppState>,
820 Json(req): Json<StreamRequest>,
821) -> Json<StreamResponse> {
822 if let Some(eps) = req.events_per_second {
824 info!("Stream configured: events_per_second={}", eps);
825 state
826 .server_state
827 .stream_events_per_second
828 .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
829 }
830 if let Some(max) = req.max_events {
831 info!("Stream configured: max_events={}", max);
832 state
833 .server_state
834 .stream_max_events
835 .store(max, std::sync::atomic::Ordering::Relaxed);
836 }
837 if let Some(inject) = req.inject_anomalies {
838 info!("Stream configured: inject_anomalies={}", inject);
839 state
840 .server_state
841 .stream_inject_anomalies
842 .store(inject, std::sync::atomic::Ordering::Relaxed);
843 }
844
845 state
846 .server_state
847 .stream_stopped
848 .store(false, std::sync::atomic::Ordering::Relaxed);
849 state
850 .server_state
851 .stream_paused
852 .store(false, std::sync::atomic::Ordering::Relaxed);
853
854 Json(StreamResponse {
855 success: true,
856 message: "Stream started".to_string(),
857 })
858}
859
860async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
862 state
863 .server_state
864 .stream_stopped
865 .store(true, std::sync::atomic::Ordering::Relaxed);
866
867 Json(StreamResponse {
868 success: true,
869 message: "Stream stopped".to_string(),
870 })
871}
872
873async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
875 state
876 .server_state
877 .stream_paused
878 .store(true, std::sync::atomic::Ordering::Relaxed);
879
880 Json(StreamResponse {
881 success: true,
882 message: "Stream paused".to_string(),
883 })
884}
885
886async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
888 state
889 .server_state
890 .stream_paused
891 .store(false, std::sync::atomic::Ordering::Relaxed);
892
893 Json(StreamResponse {
894 success: true,
895 message: "Stream resumed".to_string(),
896 })
897}
898
899async fn trigger_pattern(
904 State(state): State<AppState>,
905 axum::extract::Path(pattern): axum::extract::Path<String>,
906) -> Json<StreamResponse> {
907 info!("Pattern trigger requested: {}", pattern);
908
909 let valid_patterns = [
911 "year_end_spike",
912 "period_end_spike",
913 "holiday_cluster",
914 "fraud_cluster",
915 "error_cluster",
916 "uniform",
917 ];
918
919 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
920
921 if !is_valid {
922 return Json(StreamResponse {
923 success: false,
924 message: format!(
925 "Unknown pattern '{pattern}'. Valid patterns: {valid_patterns:?}, or use 'custom:name' for custom patterns"
926 ),
927 });
928 }
929
930 match state.server_state.triggered_pattern.try_write() {
932 Ok(mut triggered) => {
933 *triggered = Some(pattern.clone());
934 Json(StreamResponse {
935 success: true,
936 message: format!("Pattern '{pattern}' will be applied to upcoming entries"),
937 })
938 }
939 Err(_) => Json(StreamResponse {
940 success: false,
941 message: "Failed to acquire lock for pattern trigger".to_string(),
942 }),
943 }
944}
945
946async fn websocket_metrics(
948 ws: WebSocketUpgrade,
949 State(state): State<AppState>,
950) -> impl IntoResponse {
951 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
952}
953
954async fn websocket_events(
956 ws: WebSocketUpgrade,
957 State(state): State<AppState>,
958) -> impl IntoResponse {
959 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
960}
961
962async fn submit_job(
968 State(state): State<AppState>,
969 Json(request): Json<JobRequest>,
970) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
971 let queue = state.job_queue.as_ref().ok_or_else(|| {
972 (
973 StatusCode::SERVICE_UNAVAILABLE,
974 Json(serde_json::json!({"error": "Job queue not enabled"})),
975 )
976 })?;
977
978 let job_id = queue.submit(request).await;
979 info!("Job submitted: {}", job_id);
980
981 Ok((
982 StatusCode::CREATED,
983 Json(serde_json::json!({
984 "id": job_id.to_string(),
985 "status": "queued"
986 })),
987 ))
988}
989
990async fn get_job(
992 State(state): State<AppState>,
993 axum::extract::Path(id): axum::extract::Path<String>,
994) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
995 let queue = state.job_queue.as_ref().ok_or_else(|| {
996 (
997 StatusCode::SERVICE_UNAVAILABLE,
998 Json(serde_json::json!({"error": "Job queue not enabled"})),
999 )
1000 })?;
1001
1002 match queue.get(&id).await {
1003 Some(entry) => Ok(Json(serde_json::json!({
1004 "id": entry.id,
1005 "status": format!("{:?}", entry.status).to_lowercase(),
1006 "submitted_at": entry.submitted_at.to_rfc3339(),
1007 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1008 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1009 "result": entry.result,
1010 }))),
1011 None => Err((
1012 StatusCode::NOT_FOUND,
1013 Json(serde_json::json!({"error": "Job not found"})),
1014 )),
1015 }
1016}
1017
1018async fn list_jobs(
1020 State(state): State<AppState>,
1021) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1022 let queue = state.job_queue.as_ref().ok_or_else(|| {
1023 (
1024 StatusCode::SERVICE_UNAVAILABLE,
1025 Json(serde_json::json!({"error": "Job queue not enabled"})),
1026 )
1027 })?;
1028
1029 let summaries: Vec<_> = queue
1030 .list()
1031 .await
1032 .into_iter()
1033 .map(|s| {
1034 serde_json::json!({
1035 "id": s.id,
1036 "status": format!("{:?}", s.status).to_lowercase(),
1037 "submitted_at": s.submitted_at.to_rfc3339(),
1038 })
1039 })
1040 .collect();
1041
1042 Ok(Json(serde_json::json!({ "jobs": summaries })))
1043}
1044
1045async fn cancel_job(
1047 State(state): State<AppState>,
1048 axum::extract::Path(id): axum::extract::Path<String>,
1049) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1050 let queue = state.job_queue.as_ref().ok_or_else(|| {
1051 (
1052 StatusCode::SERVICE_UNAVAILABLE,
1053 Json(serde_json::json!({"error": "Job queue not enabled"})),
1054 )
1055 })?;
1056
1057 if queue.cancel(&id).await {
1058 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1059 } else {
1060 Err((
1061 StatusCode::CONFLICT,
1062 Json(
1063 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1064 ),
1065 ))
1066 }
1067}
1068
1069async fn reload_config(
1075 State(state): State<AppState>,
1076) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1077 let source = state.server_state.config_source.read().await.clone();
1078 match crate::config_loader::load_config(&source).await {
1079 Ok(new_config) => {
1080 let mut config = state.server_state.config.write().await;
1081 *config = new_config;
1082 info!("Configuration reloaded via REST API from {:?}", source);
1083 Ok(Json(serde_json::json!({
1084 "success": true,
1085 "message": "Configuration reloaded"
1086 })))
1087 }
1088 Err(e) => {
1089 error!("Failed to reload configuration: {}", e);
1090 Err((
1091 StatusCode::INTERNAL_SERVER_ERROR,
1092 Json(serde_json::json!({
1093 "success": false,
1094 "message": format!("Failed to reload configuration: {}", e)
1095 })),
1096 ))
1097 }
1098 }
1099}
1100
1101#[cfg(test)]
1102#[allow(clippy::unwrap_used)]
1103mod tests {
1104 use super::*;
1105
1106 #[test]
1111 fn test_health_response_serialization() {
1112 let response = HealthResponse {
1113 healthy: true,
1114 version: "0.1.0".to_string(),
1115 uptime_seconds: 100,
1116 };
1117 let json = serde_json::to_string(&response).unwrap();
1118 assert!(json.contains("healthy"));
1119 assert!(json.contains("version"));
1120 assert!(json.contains("uptime_seconds"));
1121 }
1122
1123 #[test]
1124 fn test_health_response_deserialization() {
1125 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1126 let response: HealthResponse = serde_json::from_str(json).unwrap();
1127 assert!(response.healthy);
1128 assert_eq!(response.version, "0.1.0");
1129 assert_eq!(response.uptime_seconds, 100);
1130 }
1131
1132 #[test]
1133 fn test_metrics_response_serialization() {
1134 let response = MetricsResponse {
1135 total_entries_generated: 1000,
1136 total_anomalies_injected: 10,
1137 uptime_seconds: 60,
1138 session_entries: 1000,
1139 session_entries_per_second: 16.67,
1140 active_streams: 1,
1141 total_stream_events: 500,
1142 };
1143 let json = serde_json::to_string(&response).unwrap();
1144 assert!(json.contains("total_entries_generated"));
1145 assert!(json.contains("session_entries_per_second"));
1146 }
1147
1148 #[test]
1149 fn test_metrics_response_deserialization() {
1150 let json = r#"{
1151 "total_entries_generated": 5000,
1152 "total_anomalies_injected": 50,
1153 "uptime_seconds": 300,
1154 "session_entries": 5000,
1155 "session_entries_per_second": 16.67,
1156 "active_streams": 2,
1157 "total_stream_events": 10000
1158 }"#;
1159 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1160 assert_eq!(response.total_entries_generated, 5000);
1161 assert_eq!(response.active_streams, 2);
1162 }
1163
1164 #[test]
1165 fn test_config_response_serialization() {
1166 let response = ConfigResponse {
1167 success: true,
1168 message: "Configuration loaded".to_string(),
1169 config: Some(GenerationConfigDto {
1170 industry: "manufacturing".to_string(),
1171 start_date: "2024-01-01".to_string(),
1172 period_months: 12,
1173 seed: Some(42),
1174 coa_complexity: "medium".to_string(),
1175 companies: vec![],
1176 fraud_enabled: false,
1177 fraud_rate: 0.0,
1178 }),
1179 };
1180 let json = serde_json::to_string(&response).unwrap();
1181 assert!(json.contains("success"));
1182 assert!(json.contains("config"));
1183 }
1184
1185 #[test]
1186 fn test_config_response_without_config() {
1187 let response = ConfigResponse {
1188 success: false,
1189 message: "No configuration available".to_string(),
1190 config: None,
1191 };
1192 let json = serde_json::to_string(&response).unwrap();
1193 assert!(json.contains("null") || json.contains("config\":null"));
1194 }
1195
1196 #[test]
1197 fn test_generation_config_dto_roundtrip() {
1198 let original = GenerationConfigDto {
1199 industry: "retail".to_string(),
1200 start_date: "2024-06-01".to_string(),
1201 period_months: 6,
1202 seed: Some(12345),
1203 coa_complexity: "large".to_string(),
1204 companies: vec![CompanyConfigDto {
1205 code: "1000".to_string(),
1206 name: "Test Corp".to_string(),
1207 currency: "USD".to_string(),
1208 country: "US".to_string(),
1209 annual_transaction_volume: 100000,
1210 volume_weight: 1.0,
1211 }],
1212 fraud_enabled: true,
1213 fraud_rate: 0.05,
1214 };
1215
1216 let json = serde_json::to_string(&original).unwrap();
1217 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1218
1219 assert_eq!(original.industry, deserialized.industry);
1220 assert_eq!(original.seed, deserialized.seed);
1221 assert_eq!(original.companies.len(), deserialized.companies.len());
1222 }
1223
1224 #[test]
1225 fn test_company_config_dto_serialization() {
1226 let company = CompanyConfigDto {
1227 code: "2000".to_string(),
1228 name: "European Subsidiary".to_string(),
1229 currency: "EUR".to_string(),
1230 country: "DE".to_string(),
1231 annual_transaction_volume: 50000,
1232 volume_weight: 0.5,
1233 };
1234 let json = serde_json::to_string(&company).unwrap();
1235 assert!(json.contains("2000"));
1236 assert!(json.contains("EUR"));
1237 assert!(json.contains("DE"));
1238 }
1239
1240 #[test]
1241 fn test_bulk_generate_request_deserialization() {
1242 let json = r#"{
1243 "entry_count": 5000,
1244 "include_master_data": true,
1245 "inject_anomalies": true
1246 }"#;
1247 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1248 assert_eq!(request.entry_count, Some(5000));
1249 assert_eq!(request.include_master_data, Some(true));
1250 assert_eq!(request.inject_anomalies, Some(true));
1251 }
1252
1253 #[test]
1254 fn test_bulk_generate_request_with_defaults() {
1255 let json = r#"{}"#;
1256 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1257 assert_eq!(request.entry_count, None);
1258 assert_eq!(request.include_master_data, None);
1259 assert_eq!(request.inject_anomalies, None);
1260 }
1261
1262 #[test]
1263 fn test_bulk_generate_response_serialization() {
1264 let response = BulkGenerateResponse {
1265 success: true,
1266 entries_generated: 1000,
1267 duration_ms: 250,
1268 anomaly_count: 20,
1269 };
1270 let json = serde_json::to_string(&response).unwrap();
1271 assert!(json.contains("entries_generated"));
1272 assert!(json.contains("1000"));
1273 assert!(json.contains("duration_ms"));
1274 }
1275
1276 #[test]
1277 fn test_stream_response_serialization() {
1278 let response = StreamResponse {
1279 success: true,
1280 message: "Stream started successfully".to_string(),
1281 };
1282 let json = serde_json::to_string(&response).unwrap();
1283 assert!(json.contains("success"));
1284 assert!(json.contains("Stream started"));
1285 }
1286
1287 #[test]
1288 fn test_stream_response_failure() {
1289 let response = StreamResponse {
1290 success: false,
1291 message: "Stream failed to start".to_string(),
1292 };
1293 let json = serde_json::to_string(&response).unwrap();
1294 assert!(json.contains("false"));
1295 assert!(json.contains("failed"));
1296 }
1297
1298 #[test]
1303 fn test_cors_config_default() {
1304 let config = CorsConfig::default();
1305 assert!(!config.allow_any_origin);
1306 assert!(!config.allowed_origins.is_empty());
1307 assert!(config
1308 .allowed_origins
1309 .contains(&"http://localhost:5173".to_string()));
1310 assert!(config
1311 .allowed_origins
1312 .contains(&"tauri://localhost".to_string()));
1313 }
1314
1315 #[test]
1316 fn test_cors_config_custom_origins() {
1317 let config = CorsConfig {
1318 allowed_origins: vec![
1319 "https://example.com".to_string(),
1320 "https://app.example.com".to_string(),
1321 ],
1322 allow_any_origin: false,
1323 };
1324 assert_eq!(config.allowed_origins.len(), 2);
1325 assert!(config
1326 .allowed_origins
1327 .contains(&"https://example.com".to_string()));
1328 }
1329
1330 #[test]
1331 fn test_cors_config_permissive() {
1332 let config = CorsConfig {
1333 allowed_origins: vec![],
1334 allow_any_origin: true,
1335 };
1336 assert!(config.allow_any_origin);
1337 }
1338
1339 #[test]
1344 fn test_bulk_generate_request_partial() {
1345 let json = r#"{"entry_count": 100}"#;
1346 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1347 assert_eq!(request.entry_count, Some(100));
1348 assert!(request.include_master_data.is_none());
1349 }
1350
1351 #[test]
1352 fn test_generation_config_no_seed() {
1353 let config = GenerationConfigDto {
1354 industry: "technology".to_string(),
1355 start_date: "2024-01-01".to_string(),
1356 period_months: 3,
1357 seed: None,
1358 coa_complexity: "small".to_string(),
1359 companies: vec![],
1360 fraud_enabled: false,
1361 fraud_rate: 0.0,
1362 };
1363 let json = serde_json::to_string(&config).unwrap();
1364 assert!(json.contains("seed"));
1365 }
1366
1367 #[test]
1368 fn test_generation_config_multiple_companies() {
1369 let config = GenerationConfigDto {
1370 industry: "manufacturing".to_string(),
1371 start_date: "2024-01-01".to_string(),
1372 period_months: 12,
1373 seed: Some(42),
1374 coa_complexity: "large".to_string(),
1375 companies: vec![
1376 CompanyConfigDto {
1377 code: "1000".to_string(),
1378 name: "Headquarters".to_string(),
1379 currency: "USD".to_string(),
1380 country: "US".to_string(),
1381 annual_transaction_volume: 100000,
1382 volume_weight: 1.0,
1383 },
1384 CompanyConfigDto {
1385 code: "2000".to_string(),
1386 name: "European Sub".to_string(),
1387 currency: "EUR".to_string(),
1388 country: "DE".to_string(),
1389 annual_transaction_volume: 50000,
1390 volume_weight: 0.5,
1391 },
1392 CompanyConfigDto {
1393 code: "3000".to_string(),
1394 name: "APAC Sub".to_string(),
1395 currency: "JPY".to_string(),
1396 country: "JP".to_string(),
1397 annual_transaction_volume: 30000,
1398 volume_weight: 0.3,
1399 },
1400 ],
1401 fraud_enabled: true,
1402 fraud_rate: 0.02,
1403 };
1404 assert_eq!(config.companies.len(), 3);
1405 }
1406
1407 #[test]
1412 fn test_metrics_entries_per_second_calculation() {
1413 let total_entries: u64 = 1000;
1415 let uptime: u64 = 60;
1416 let eps = if uptime > 0 {
1417 total_entries as f64 / uptime as f64
1418 } else {
1419 0.0
1420 };
1421 assert!((eps - 16.67).abs() < 0.1);
1422 }
1423
1424 #[test]
1425 fn test_metrics_entries_per_second_zero_uptime() {
1426 let total_entries: u64 = 1000;
1427 let uptime: u64 = 0;
1428 let eps = if uptime > 0 {
1429 total_entries as f64 / uptime as f64
1430 } else {
1431 0.0
1432 };
1433 assert_eq!(eps, 0.0);
1434 }
1435}