1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 pub server_state: Arc<ServerState>,
28 pub job_queue: Option<Arc<JobQueue>>,
29}
30
31#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34 pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39 fn default() -> Self {
40 Self {
41 request_timeout_secs: 300,
43 }
44 }
45}
46
47impl TimeoutConfig {
48 pub fn new(timeout_secs: u64) -> Self {
50 Self {
51 request_timeout_secs: timeout_secs,
52 }
53 }
54}
55
56#[derive(Clone)]
58pub struct CorsConfig {
59 pub allowed_origins: Vec<String>,
61 pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66 fn default() -> Self {
67 Self {
68 allowed_origins: vec![
69 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
72 "http://127.0.0.1:3000".to_string(),
73 "tauri://localhost".to_string(), ],
75 allow_any_origin: false,
76 }
77 }
78}
79
80async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82 let (mut parts, body) = response.into_parts();
83 parts.headers.insert(
84 axum::http::HeaderName::from_static("x-api-version"),
85 axum::http::HeaderValue::from_static("v1"),
86 );
87 axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97pub fn create_router(service: SynthService) -> Router {
99 create_router_with_cors(service, CorsConfig::default())
100}
101
102pub fn create_router_full(
107 service: SynthService,
108 cors_config: CorsConfig,
109 auth_config: AuthConfig,
110 rate_limit_config: RateLimitConfig,
111 timeout_config: TimeoutConfig,
112) -> Router {
113 let backend = RateLimitBackend::in_memory(rate_limit_config);
114 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117pub fn create_router_full_with_backend(
133 service: SynthService,
134 cors_config: CorsConfig,
135 auth_config: AuthConfig,
136 rate_limit_backend: RateLimitBackend,
137 timeout_config: TimeoutConfig,
138) -> Router {
139 let server_state = service.state.clone();
140 let state = AppState {
141 server_state,
142 job_queue: None,
143 };
144
145 let cors = if cors_config.allow_any_origin {
146 CorsLayer::permissive()
147 } else {
148 let origins: Vec<_> = cors_config
149 .allowed_origins
150 .iter()
151 .filter_map(|o| o.parse().ok())
152 .collect();
153
154 CorsLayer::new()
155 .allow_origin(AllowOrigin::list(origins))
156 .allow_methods([
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::DELETE,
161 Method::OPTIONS,
162 ])
163 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164 };
165
166 Router::new()
167 .route("/health", get(health_check))
169 .route("/ready", get(readiness_check))
170 .route("/live", get(liveness_check))
171 .route("/api/metrics", get(get_metrics))
172 .route("/metrics", get(prometheus_metrics))
173 .route("/api/config", get(get_config))
175 .route("/api/config", post(set_config))
176 .route("/api/config/reload", post(reload_config))
177 .route("/api/generate/bulk", post(bulk_generate))
179 .route("/api/stream/start", post(start_stream))
180 .route("/api/stream/stop", post(stop_stream))
181 .route("/api/stream/pause", post(pause_stream))
182 .route("/api/stream/resume", post(resume_stream))
183 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184 .route("/api/jobs/submit", post(submit_job))
186 .route("/api/jobs", get(list_jobs))
187 .route("/api/jobs/{id}", get(get_job))
188 .route("/api/jobs/{id}/cancel", post(cancel_job))
189 .route("/ws/metrics", get(websocket_metrics))
191 .route("/ws/events", get(websocket_events))
192 .layer(axum::middleware::from_fn(security_headers_middleware))
195 .layer(axum::middleware::map_response(api_version_header))
196 .layer(cors)
197 .layer(axum::middleware::from_fn(request_id_middleware))
198 .layer(axum::middleware::from_fn(auth_middleware))
199 .layer(axum::Extension(auth_config))
200 .layer(axum::middleware::from_fn(request_validation_middleware))
201 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
202 .layer(axum::Extension(rate_limit_backend))
203 .layer(TimeoutLayer::with_status_code(
204 StatusCode::REQUEST_TIMEOUT,
205 Duration::from_secs(timeout_config.request_timeout_secs),
206 ))
207 .with_state(state)
208}
209
210pub fn create_router_with_auth(
212 service: SynthService,
213 cors_config: CorsConfig,
214 auth_config: AuthConfig,
215) -> Router {
216 let server_state = service.state.clone();
217 let state = AppState {
218 server_state,
219 job_queue: None,
220 };
221
222 let cors = if cors_config.allow_any_origin {
223 CorsLayer::permissive()
224 } else {
225 let origins: Vec<_> = cors_config
226 .allowed_origins
227 .iter()
228 .filter_map(|o| o.parse().ok())
229 .collect();
230
231 CorsLayer::new()
232 .allow_origin(AllowOrigin::list(origins))
233 .allow_methods([
234 Method::GET,
235 Method::POST,
236 Method::PUT,
237 Method::DELETE,
238 Method::OPTIONS,
239 ])
240 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
241 };
242
243 Router::new()
244 .route("/health", get(health_check))
246 .route("/ready", get(readiness_check))
247 .route("/live", get(liveness_check))
248 .route("/api/metrics", get(get_metrics))
249 .route("/metrics", get(prometheus_metrics))
250 .route("/api/config", get(get_config))
252 .route("/api/config", post(set_config))
253 .route("/api/config/reload", post(reload_config))
254 .route("/api/generate/bulk", post(bulk_generate))
256 .route("/api/stream/start", post(start_stream))
257 .route("/api/stream/stop", post(stop_stream))
258 .route("/api/stream/pause", post(pause_stream))
259 .route("/api/stream/resume", post(resume_stream))
260 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
261 .route("/api/jobs/submit", post(submit_job))
263 .route("/api/jobs", get(list_jobs))
264 .route("/api/jobs/{id}", get(get_job))
265 .route("/api/jobs/{id}/cancel", post(cancel_job))
266 .route("/ws/metrics", get(websocket_metrics))
268 .route("/ws/events", get(websocket_events))
269 .layer(axum::middleware::from_fn(auth_middleware))
270 .layer(axum::Extension(auth_config))
271 .layer(cors)
272 .with_state(state)
273}
274
275pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
277 let server_state = service.state.clone();
278 let state = AppState {
279 server_state,
280 job_queue: None,
281 };
282
283 let cors = if cors_config.allow_any_origin {
284 CorsLayer::permissive()
286 } else {
287 let origins: Vec<_> = cors_config
289 .allowed_origins
290 .iter()
291 .filter_map(|o| o.parse().ok())
292 .collect();
293
294 CorsLayer::new()
295 .allow_origin(AllowOrigin::list(origins))
296 .allow_methods([
297 Method::GET,
298 Method::POST,
299 Method::PUT,
300 Method::DELETE,
301 Method::OPTIONS,
302 ])
303 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
304 };
305
306 Router::new()
307 .route("/health", get(health_check))
309 .route("/ready", get(readiness_check))
310 .route("/live", get(liveness_check))
311 .route("/api/metrics", get(get_metrics))
312 .route("/metrics", get(prometheus_metrics))
313 .route("/api/config", get(get_config))
315 .route("/api/config", post(set_config))
316 .route("/api/config/reload", post(reload_config))
317 .route("/api/generate/bulk", post(bulk_generate))
319 .route("/api/stream/start", post(start_stream))
320 .route("/api/stream/stop", post(stop_stream))
321 .route("/api/stream/pause", post(pause_stream))
322 .route("/api/stream/resume", post(resume_stream))
323 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
324 .route("/api/jobs/submit", post(submit_job))
326 .route("/api/jobs", get(list_jobs))
327 .route("/api/jobs/{id}", get(get_job))
328 .route("/api/jobs/{id}/cancel", post(cancel_job))
329 .route("/ws/metrics", get(websocket_metrics))
331 .route("/ws/events", get(websocket_events))
332 .layer(cors)
333 .with_state(state)
334}
335
336#[derive(Debug, Serialize, Deserialize)]
341pub struct HealthResponse {
342 pub healthy: bool,
343 pub version: String,
344 pub uptime_seconds: u64,
345}
346
347#[derive(Debug, Serialize, Deserialize)]
349pub struct ReadinessResponse {
350 pub ready: bool,
351 pub message: String,
352 pub checks: Vec<HealthCheck>,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
357pub struct HealthCheck {
358 pub name: String,
359 pub status: String,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
364pub struct LivenessResponse {
365 pub alive: bool,
366 pub timestamp: String,
367}
368
369#[derive(Debug, Serialize, Deserialize)]
370pub struct MetricsResponse {
371 pub total_entries_generated: u64,
372 pub total_anomalies_injected: u64,
373 pub uptime_seconds: u64,
374 pub session_entries: u64,
375 pub session_entries_per_second: f64,
376 pub active_streams: u32,
377 pub total_stream_events: u64,
378}
379
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct ConfigResponse {
382 pub success: bool,
383 pub message: String,
384 pub config: Option<GenerationConfigDto>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct GenerationConfigDto {
389 pub industry: String,
390 pub start_date: String,
391 pub period_months: u32,
392 pub seed: Option<u64>,
393 pub coa_complexity: String,
394 pub companies: Vec<CompanyConfigDto>,
395 pub fraud_enabled: bool,
396 pub fraud_rate: f32,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct CompanyConfigDto {
401 pub code: String,
402 pub name: String,
403 pub currency: String,
404 pub country: String,
405 pub annual_transaction_volume: u64,
406 pub volume_weight: f32,
407}
408
409#[derive(Debug, Deserialize)]
410pub struct BulkGenerateRequest {
411 pub entry_count: Option<u64>,
412 pub include_master_data: Option<bool>,
413 pub inject_anomalies: Option<bool>,
414}
415
416#[derive(Debug, Serialize)]
417pub struct BulkGenerateResponse {
418 pub success: bool,
419 pub entries_generated: u64,
420 pub duration_ms: u64,
421 pub anomaly_count: u64,
422}
423
424#[derive(Debug, Deserialize)]
425#[allow(dead_code)] pub struct StreamRequest {
427 pub events_per_second: Option<u32>,
428 pub max_events: Option<u64>,
429 pub inject_anomalies: Option<bool>,
430}
431
432#[derive(Debug, Serialize)]
433pub struct StreamResponse {
434 pub success: bool,
435 pub message: String,
436}
437
438async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
444 Json(HealthResponse {
445 healthy: true,
446 version: env!("CARGO_PKG_VERSION").to_string(),
447 uptime_seconds: state.server_state.uptime_seconds(),
448 })
449}
450
451async fn readiness_check(
454 State(state): State<AppState>,
455) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
456 let mut checks = Vec::new();
457 let mut any_fail = false;
458
459 let config = state.server_state.config.read().await;
461 let config_valid = !config.companies.is_empty();
462 checks.push(HealthCheck {
463 name: "config".to_string(),
464 status: if config_valid { "ok" } else { "fail" }.to_string(),
465 });
466 if !config_valid {
467 any_fail = true;
468 }
469 drop(config);
470
471 let resource_status = state.server_state.resource_status();
473 let memory_status = if resource_status.degradation_level == "Emergency" {
474 any_fail = true;
475 "fail"
476 } else if resource_status.degradation_level != "Normal" {
477 "degraded"
478 } else {
479 "ok"
480 };
481 checks.push(HealthCheck {
482 name: "memory".to_string(),
483 status: memory_status.to_string(),
484 });
485
486 let disk_ok = resource_status.disk_available_mb > 100;
488 checks.push(HealthCheck {
489 name: "disk".to_string(),
490 status: if disk_ok { "ok" } else { "fail" }.to_string(),
491 });
492 if !disk_ok {
493 any_fail = true;
494 }
495
496 let response = ReadinessResponse {
497 ready: !any_fail,
498 message: if any_fail {
499 "Service is not ready".to_string()
500 } else {
501 "Service is ready".to_string()
502 },
503 checks,
504 };
505
506 if any_fail {
507 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
508 } else {
509 Ok(Json(response))
510 }
511}
512
513async fn liveness_check() -> Json<LivenessResponse> {
516 Json(LivenessResponse {
517 alive: true,
518 timestamp: chrono::Utc::now().to_rfc3339(),
519 })
520}
521
522async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
525 use std::sync::atomic::Ordering;
526
527 let uptime = state.server_state.uptime_seconds();
528 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
529 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
530 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
531 let total_stream_events = state
532 .server_state
533 .total_stream_events
534 .load(Ordering::Relaxed);
535
536 let entries_per_second = if uptime > 0 {
537 total_entries as f64 / uptime as f64
538 } else {
539 0.0
540 };
541
542 let metrics = format!(
543 r#"# HELP synth_entries_generated_total Total number of journal entries generated
544# TYPE synth_entries_generated_total counter
545synth_entries_generated_total {}
546
547# HELP synth_anomalies_injected_total Total number of anomalies injected
548# TYPE synth_anomalies_injected_total counter
549synth_anomalies_injected_total {}
550
551# HELP synth_uptime_seconds Server uptime in seconds
552# TYPE synth_uptime_seconds gauge
553synth_uptime_seconds {}
554
555# HELP synth_entries_per_second Rate of entry generation
556# TYPE synth_entries_per_second gauge
557synth_entries_per_second {:.2}
558
559# HELP synth_active_streams Number of active streaming connections
560# TYPE synth_active_streams gauge
561synth_active_streams {}
562
563# HELP synth_stream_events_total Total events sent through streams
564# TYPE synth_stream_events_total counter
565synth_stream_events_total {}
566
567# HELP synth_info Server version information
568# TYPE synth_info gauge
569synth_info{{version="{}"}} 1
570"#,
571 total_entries,
572 total_anomalies,
573 uptime,
574 entries_per_second,
575 active_streams,
576 total_stream_events,
577 env!("CARGO_PKG_VERSION")
578 );
579
580 (
581 StatusCode::OK,
582 [(
583 header::CONTENT_TYPE,
584 "text/plain; version=0.0.4; charset=utf-8",
585 )],
586 metrics,
587 )
588}
589
590async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
592 let uptime = state.server_state.uptime_seconds();
593 let total_entries = state
594 .server_state
595 .total_entries
596 .load(std::sync::atomic::Ordering::Relaxed);
597
598 let entries_per_second = if uptime > 0 {
599 total_entries as f64 / uptime as f64
600 } else {
601 0.0
602 };
603
604 Json(MetricsResponse {
605 total_entries_generated: total_entries,
606 total_anomalies_injected: state
607 .server_state
608 .total_anomalies
609 .load(std::sync::atomic::Ordering::Relaxed),
610 uptime_seconds: uptime,
611 session_entries: total_entries,
612 session_entries_per_second: entries_per_second,
613 active_streams: state
614 .server_state
615 .active_streams
616 .load(std::sync::atomic::Ordering::Relaxed) as u32,
617 total_stream_events: state
618 .server_state
619 .total_stream_events
620 .load(std::sync::atomic::Ordering::Relaxed),
621 })
622}
623
624async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
626 let config = state.server_state.config.read().await;
627
628 Json(ConfigResponse {
629 success: true,
630 message: "Current configuration".to_string(),
631 config: Some(GenerationConfigDto {
632 industry: format!("{:?}", config.global.industry),
633 start_date: config.global.start_date.clone(),
634 period_months: config.global.period_months,
635 seed: config.global.seed,
636 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
637 companies: config
638 .companies
639 .iter()
640 .map(|c| CompanyConfigDto {
641 code: c.code.clone(),
642 name: c.name.clone(),
643 currency: c.currency.clone(),
644 country: c.country.clone(),
645 annual_transaction_volume: c.annual_transaction_volume.count(),
646 volume_weight: c.volume_weight as f32,
647 })
648 .collect(),
649 fraud_enabled: config.fraud.enabled,
650 fraud_rate: config.fraud.fraud_rate as f32,
651 }),
652 })
653}
654
655async fn set_config(
657 State(state): State<AppState>,
658 Json(new_config): Json<GenerationConfigDto>,
659) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
660 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
661 use datasynth_core::models::{CoAComplexity, IndustrySector};
662
663 info!(
664 "Configuration update requested: industry={}, period_months={}",
665 new_config.industry, new_config.period_months
666 );
667
668 let industry = match new_config.industry.to_lowercase().as_str() {
670 "manufacturing" => IndustrySector::Manufacturing,
671 "retail" => IndustrySector::Retail,
672 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
673 "healthcare" => IndustrySector::Healthcare,
674 "technology" => IndustrySector::Technology,
675 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
676 "energy" => IndustrySector::Energy,
677 "transportation" => IndustrySector::Transportation,
678 "real_estate" | "realestate" => IndustrySector::RealEstate,
679 "telecommunications" => IndustrySector::Telecommunications,
680 _ => {
681 return Err((
682 StatusCode::BAD_REQUEST,
683 Json(ConfigResponse {
684 success: false,
685 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
686 config: None,
687 }),
688 ));
689 }
690 };
691
692 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
694 "small" => CoAComplexity::Small,
695 "medium" => CoAComplexity::Medium,
696 "large" => CoAComplexity::Large,
697 _ => {
698 return Err((
699 StatusCode::BAD_REQUEST,
700 Json(ConfigResponse {
701 success: false,
702 message: format!(
703 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
704 new_config.coa_complexity
705 ),
706 config: None,
707 }),
708 ));
709 }
710 };
711
712 let companies: Vec<CompanyConfig> = new_config
714 .companies
715 .iter()
716 .map(|c| CompanyConfig {
717 code: c.code.clone(),
718 name: c.name.clone(),
719 currency: c.currency.clone(),
720 functional_currency: None,
721 country: c.country.clone(),
722 fiscal_year_variant: "K4".to_string(),
723 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
724 volume_weight: c.volume_weight as f64,
725 })
726 .collect();
727
728 let mut config = state.server_state.config.write().await;
730 config.global.industry = industry;
731 config.global.start_date = new_config.start_date.clone();
732 config.global.period_months = new_config.period_months;
733 config.global.seed = new_config.seed;
734 config.chart_of_accounts.complexity = complexity;
735 config.fraud.enabled = new_config.fraud_enabled;
736 config.fraud.fraud_rate = new_config.fraud_rate as f64;
737
738 if !companies.is_empty() {
740 config.companies = companies;
741 }
742
743 info!("Configuration updated successfully");
744
745 Ok(Json(ConfigResponse {
746 success: true,
747 message: "Configuration updated and applied".to_string(),
748 config: Some(new_config),
749 }))
750}
751
752async fn bulk_generate(
754 State(state): State<AppState>,
755 Json(req): Json<BulkGenerateRequest>,
756) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
757 const MAX_ENTRY_COUNT: u64 = 1_000_000;
759 if let Some(count) = req.entry_count {
760 if count > MAX_ENTRY_COUNT {
761 return Err((
762 StatusCode::BAD_REQUEST,
763 format!("entry_count ({count}) exceeds maximum allowed value ({MAX_ENTRY_COUNT})"),
764 ));
765 }
766 }
767
768 let config = state.server_state.config.read().await.clone();
769 let start_time = std::time::Instant::now();
770
771 let phase_config = PhaseConfig {
772 generate_master_data: req.include_master_data.unwrap_or(false),
773 generate_document_flows: false,
774 generate_journal_entries: true,
775 inject_anomalies: req.inject_anomalies.unwrap_or(false),
776 show_progress: false,
777 ..Default::default()
778 };
779
780 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
781 (
782 StatusCode::INTERNAL_SERVER_ERROR,
783 format!("Failed to create orchestrator: {e}"),
784 )
785 })?;
786
787 let result = orchestrator.generate().map_err(|e| {
788 (
789 StatusCode::INTERNAL_SERVER_ERROR,
790 format!("Generation failed: {e}"),
791 )
792 })?;
793
794 let duration_ms = start_time.elapsed().as_millis() as u64;
795 let entries_count = result.journal_entries.len() as u64;
796 let anomaly_count = result.anomaly_labels.labels.len() as u64;
797
798 state
800 .server_state
801 .total_entries
802 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
803 state
804 .server_state
805 .total_anomalies
806 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
807
808 Ok(Json(BulkGenerateResponse {
809 success: true,
810 entries_generated: entries_count,
811 duration_ms,
812 anomaly_count,
813 }))
814}
815
816async fn start_stream(
818 State(state): State<AppState>,
819 Json(req): Json<StreamRequest>,
820) -> Json<StreamResponse> {
821 if let Some(eps) = req.events_per_second {
823 info!("Stream configured: events_per_second={}", eps);
824 state
825 .server_state
826 .stream_events_per_second
827 .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
828 }
829 if let Some(max) = req.max_events {
830 info!("Stream configured: max_events={}", max);
831 state
832 .server_state
833 .stream_max_events
834 .store(max, std::sync::atomic::Ordering::Relaxed);
835 }
836 if let Some(inject) = req.inject_anomalies {
837 info!("Stream configured: inject_anomalies={}", inject);
838 state
839 .server_state
840 .stream_inject_anomalies
841 .store(inject, std::sync::atomic::Ordering::Relaxed);
842 }
843
844 state
845 .server_state
846 .stream_stopped
847 .store(false, std::sync::atomic::Ordering::Relaxed);
848 state
849 .server_state
850 .stream_paused
851 .store(false, std::sync::atomic::Ordering::Relaxed);
852
853 Json(StreamResponse {
854 success: true,
855 message: "Stream started".to_string(),
856 })
857}
858
859async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
861 state
862 .server_state
863 .stream_stopped
864 .store(true, std::sync::atomic::Ordering::Relaxed);
865
866 Json(StreamResponse {
867 success: true,
868 message: "Stream stopped".to_string(),
869 })
870}
871
872async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
874 state
875 .server_state
876 .stream_paused
877 .store(true, std::sync::atomic::Ordering::Relaxed);
878
879 Json(StreamResponse {
880 success: true,
881 message: "Stream paused".to_string(),
882 })
883}
884
885async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
887 state
888 .server_state
889 .stream_paused
890 .store(false, std::sync::atomic::Ordering::Relaxed);
891
892 Json(StreamResponse {
893 success: true,
894 message: "Stream resumed".to_string(),
895 })
896}
897
898async fn trigger_pattern(
903 State(state): State<AppState>,
904 axum::extract::Path(pattern): axum::extract::Path<String>,
905) -> Json<StreamResponse> {
906 info!("Pattern trigger requested: {}", pattern);
907
908 let valid_patterns = [
910 "year_end_spike",
911 "period_end_spike",
912 "holiday_cluster",
913 "fraud_cluster",
914 "error_cluster",
915 "uniform",
916 ];
917
918 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
919
920 if !is_valid {
921 return Json(StreamResponse {
922 success: false,
923 message: format!(
924 "Unknown pattern '{pattern}'. Valid patterns: {valid_patterns:?}, or use 'custom:name' for custom patterns"
925 ),
926 });
927 }
928
929 match state.server_state.triggered_pattern.try_write() {
931 Ok(mut triggered) => {
932 *triggered = Some(pattern.clone());
933 Json(StreamResponse {
934 success: true,
935 message: format!("Pattern '{pattern}' will be applied to upcoming entries"),
936 })
937 }
938 Err(_) => Json(StreamResponse {
939 success: false,
940 message: "Failed to acquire lock for pattern trigger".to_string(),
941 }),
942 }
943}
944
945async fn websocket_metrics(
947 ws: WebSocketUpgrade,
948 State(state): State<AppState>,
949) -> impl IntoResponse {
950 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
951}
952
953async fn websocket_events(
955 ws: WebSocketUpgrade,
956 State(state): State<AppState>,
957) -> impl IntoResponse {
958 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
959}
960
961async fn submit_job(
967 State(state): State<AppState>,
968 Json(request): Json<JobRequest>,
969) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
970 let queue = state.job_queue.as_ref().ok_or_else(|| {
971 (
972 StatusCode::SERVICE_UNAVAILABLE,
973 Json(serde_json::json!({"error": "Job queue not enabled"})),
974 )
975 })?;
976
977 let job_id = queue.submit(request).await;
978 info!("Job submitted: {}", job_id);
979
980 Ok((
981 StatusCode::CREATED,
982 Json(serde_json::json!({
983 "id": job_id.to_string(),
984 "status": "queued"
985 })),
986 ))
987}
988
989async fn get_job(
991 State(state): State<AppState>,
992 axum::extract::Path(id): axum::extract::Path<String>,
993) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
994 let queue = state.job_queue.as_ref().ok_or_else(|| {
995 (
996 StatusCode::SERVICE_UNAVAILABLE,
997 Json(serde_json::json!({"error": "Job queue not enabled"})),
998 )
999 })?;
1000
1001 match queue.get(&id).await {
1002 Some(entry) => Ok(Json(serde_json::json!({
1003 "id": entry.id,
1004 "status": format!("{:?}", entry.status).to_lowercase(),
1005 "submitted_at": entry.submitted_at.to_rfc3339(),
1006 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1007 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1008 "result": entry.result,
1009 }))),
1010 None => Err((
1011 StatusCode::NOT_FOUND,
1012 Json(serde_json::json!({"error": "Job not found"})),
1013 )),
1014 }
1015}
1016
1017async fn list_jobs(
1019 State(state): State<AppState>,
1020) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1021 let queue = state.job_queue.as_ref().ok_or_else(|| {
1022 (
1023 StatusCode::SERVICE_UNAVAILABLE,
1024 Json(serde_json::json!({"error": "Job queue not enabled"})),
1025 )
1026 })?;
1027
1028 let summaries: Vec<_> = queue
1029 .list()
1030 .await
1031 .into_iter()
1032 .map(|s| {
1033 serde_json::json!({
1034 "id": s.id,
1035 "status": format!("{:?}", s.status).to_lowercase(),
1036 "submitted_at": s.submitted_at.to_rfc3339(),
1037 })
1038 })
1039 .collect();
1040
1041 Ok(Json(serde_json::json!({ "jobs": summaries })))
1042}
1043
1044async fn cancel_job(
1046 State(state): State<AppState>,
1047 axum::extract::Path(id): axum::extract::Path<String>,
1048) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1049 let queue = state.job_queue.as_ref().ok_or_else(|| {
1050 (
1051 StatusCode::SERVICE_UNAVAILABLE,
1052 Json(serde_json::json!({"error": "Job queue not enabled"})),
1053 )
1054 })?;
1055
1056 if queue.cancel(&id).await {
1057 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1058 } else {
1059 Err((
1060 StatusCode::CONFLICT,
1061 Json(
1062 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1063 ),
1064 ))
1065 }
1066}
1067
1068async fn reload_config(
1074 State(state): State<AppState>,
1075) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1076 let source = state.server_state.config_source.read().await.clone();
1077 match crate::config_loader::load_config(&source).await {
1078 Ok(new_config) => {
1079 let mut config = state.server_state.config.write().await;
1080 *config = new_config;
1081 info!("Configuration reloaded via REST API from {:?}", source);
1082 Ok(Json(serde_json::json!({
1083 "success": true,
1084 "message": "Configuration reloaded"
1085 })))
1086 }
1087 Err(e) => {
1088 error!("Failed to reload configuration: {}", e);
1089 Err((
1090 StatusCode::INTERNAL_SERVER_ERROR,
1091 Json(serde_json::json!({
1092 "success": false,
1093 "message": format!("Failed to reload configuration: {}", e)
1094 })),
1095 ))
1096 }
1097 }
1098}
1099
1100#[cfg(test)]
1101#[allow(clippy::unwrap_used)]
1102mod tests {
1103 use super::*;
1104
1105 #[test]
1110 fn test_health_response_serialization() {
1111 let response = HealthResponse {
1112 healthy: true,
1113 version: "0.1.0".to_string(),
1114 uptime_seconds: 100,
1115 };
1116 let json = serde_json::to_string(&response).unwrap();
1117 assert!(json.contains("healthy"));
1118 assert!(json.contains("version"));
1119 assert!(json.contains("uptime_seconds"));
1120 }
1121
1122 #[test]
1123 fn test_health_response_deserialization() {
1124 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1125 let response: HealthResponse = serde_json::from_str(json).unwrap();
1126 assert!(response.healthy);
1127 assert_eq!(response.version, "0.1.0");
1128 assert_eq!(response.uptime_seconds, 100);
1129 }
1130
1131 #[test]
1132 fn test_metrics_response_serialization() {
1133 let response = MetricsResponse {
1134 total_entries_generated: 1000,
1135 total_anomalies_injected: 10,
1136 uptime_seconds: 60,
1137 session_entries: 1000,
1138 session_entries_per_second: 16.67,
1139 active_streams: 1,
1140 total_stream_events: 500,
1141 };
1142 let json = serde_json::to_string(&response).unwrap();
1143 assert!(json.contains("total_entries_generated"));
1144 assert!(json.contains("session_entries_per_second"));
1145 }
1146
1147 #[test]
1148 fn test_metrics_response_deserialization() {
1149 let json = r#"{
1150 "total_entries_generated": 5000,
1151 "total_anomalies_injected": 50,
1152 "uptime_seconds": 300,
1153 "session_entries": 5000,
1154 "session_entries_per_second": 16.67,
1155 "active_streams": 2,
1156 "total_stream_events": 10000
1157 }"#;
1158 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1159 assert_eq!(response.total_entries_generated, 5000);
1160 assert_eq!(response.active_streams, 2);
1161 }
1162
1163 #[test]
1164 fn test_config_response_serialization() {
1165 let response = ConfigResponse {
1166 success: true,
1167 message: "Configuration loaded".to_string(),
1168 config: Some(GenerationConfigDto {
1169 industry: "manufacturing".to_string(),
1170 start_date: "2024-01-01".to_string(),
1171 period_months: 12,
1172 seed: Some(42),
1173 coa_complexity: "medium".to_string(),
1174 companies: vec![],
1175 fraud_enabled: false,
1176 fraud_rate: 0.0,
1177 }),
1178 };
1179 let json = serde_json::to_string(&response).unwrap();
1180 assert!(json.contains("success"));
1181 assert!(json.contains("config"));
1182 }
1183
1184 #[test]
1185 fn test_config_response_without_config() {
1186 let response = ConfigResponse {
1187 success: false,
1188 message: "No configuration available".to_string(),
1189 config: None,
1190 };
1191 let json = serde_json::to_string(&response).unwrap();
1192 assert!(json.contains("null") || json.contains("config\":null"));
1193 }
1194
1195 #[test]
1196 fn test_generation_config_dto_roundtrip() {
1197 let original = GenerationConfigDto {
1198 industry: "retail".to_string(),
1199 start_date: "2024-06-01".to_string(),
1200 period_months: 6,
1201 seed: Some(12345),
1202 coa_complexity: "large".to_string(),
1203 companies: vec![CompanyConfigDto {
1204 code: "1000".to_string(),
1205 name: "Test Corp".to_string(),
1206 currency: "USD".to_string(),
1207 country: "US".to_string(),
1208 annual_transaction_volume: 100000,
1209 volume_weight: 1.0,
1210 }],
1211 fraud_enabled: true,
1212 fraud_rate: 0.05,
1213 };
1214
1215 let json = serde_json::to_string(&original).unwrap();
1216 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1217
1218 assert_eq!(original.industry, deserialized.industry);
1219 assert_eq!(original.seed, deserialized.seed);
1220 assert_eq!(original.companies.len(), deserialized.companies.len());
1221 }
1222
1223 #[test]
1224 fn test_company_config_dto_serialization() {
1225 let company = CompanyConfigDto {
1226 code: "2000".to_string(),
1227 name: "European Subsidiary".to_string(),
1228 currency: "EUR".to_string(),
1229 country: "DE".to_string(),
1230 annual_transaction_volume: 50000,
1231 volume_weight: 0.5,
1232 };
1233 let json = serde_json::to_string(&company).unwrap();
1234 assert!(json.contains("2000"));
1235 assert!(json.contains("EUR"));
1236 assert!(json.contains("DE"));
1237 }
1238
1239 #[test]
1240 fn test_bulk_generate_request_deserialization() {
1241 let json = r#"{
1242 "entry_count": 5000,
1243 "include_master_data": true,
1244 "inject_anomalies": true
1245 }"#;
1246 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1247 assert_eq!(request.entry_count, Some(5000));
1248 assert_eq!(request.include_master_data, Some(true));
1249 assert_eq!(request.inject_anomalies, Some(true));
1250 }
1251
1252 #[test]
1253 fn test_bulk_generate_request_with_defaults() {
1254 let json = r#"{}"#;
1255 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1256 assert_eq!(request.entry_count, None);
1257 assert_eq!(request.include_master_data, None);
1258 assert_eq!(request.inject_anomalies, None);
1259 }
1260
1261 #[test]
1262 fn test_bulk_generate_response_serialization() {
1263 let response = BulkGenerateResponse {
1264 success: true,
1265 entries_generated: 1000,
1266 duration_ms: 250,
1267 anomaly_count: 20,
1268 };
1269 let json = serde_json::to_string(&response).unwrap();
1270 assert!(json.contains("entries_generated"));
1271 assert!(json.contains("1000"));
1272 assert!(json.contains("duration_ms"));
1273 }
1274
1275 #[test]
1276 fn test_stream_response_serialization() {
1277 let response = StreamResponse {
1278 success: true,
1279 message: "Stream started successfully".to_string(),
1280 };
1281 let json = serde_json::to_string(&response).unwrap();
1282 assert!(json.contains("success"));
1283 assert!(json.contains("Stream started"));
1284 }
1285
1286 #[test]
1287 fn test_stream_response_failure() {
1288 let response = StreamResponse {
1289 success: false,
1290 message: "Stream failed to start".to_string(),
1291 };
1292 let json = serde_json::to_string(&response).unwrap();
1293 assert!(json.contains("false"));
1294 assert!(json.contains("failed"));
1295 }
1296
1297 #[test]
1302 fn test_cors_config_default() {
1303 let config = CorsConfig::default();
1304 assert!(!config.allow_any_origin);
1305 assert!(!config.allowed_origins.is_empty());
1306 assert!(config
1307 .allowed_origins
1308 .contains(&"http://localhost:5173".to_string()));
1309 assert!(config
1310 .allowed_origins
1311 .contains(&"tauri://localhost".to_string()));
1312 }
1313
1314 #[test]
1315 fn test_cors_config_custom_origins() {
1316 let config = CorsConfig {
1317 allowed_origins: vec![
1318 "https://example.com".to_string(),
1319 "https://app.example.com".to_string(),
1320 ],
1321 allow_any_origin: false,
1322 };
1323 assert_eq!(config.allowed_origins.len(), 2);
1324 assert!(config
1325 .allowed_origins
1326 .contains(&"https://example.com".to_string()));
1327 }
1328
1329 #[test]
1330 fn test_cors_config_permissive() {
1331 let config = CorsConfig {
1332 allowed_origins: vec![],
1333 allow_any_origin: true,
1334 };
1335 assert!(config.allow_any_origin);
1336 }
1337
1338 #[test]
1343 fn test_bulk_generate_request_partial() {
1344 let json = r#"{"entry_count": 100}"#;
1345 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1346 assert_eq!(request.entry_count, Some(100));
1347 assert!(request.include_master_data.is_none());
1348 }
1349
1350 #[test]
1351 fn test_generation_config_no_seed() {
1352 let config = GenerationConfigDto {
1353 industry: "technology".to_string(),
1354 start_date: "2024-01-01".to_string(),
1355 period_months: 3,
1356 seed: None,
1357 coa_complexity: "small".to_string(),
1358 companies: vec![],
1359 fraud_enabled: false,
1360 fraud_rate: 0.0,
1361 };
1362 let json = serde_json::to_string(&config).unwrap();
1363 assert!(json.contains("seed"));
1364 }
1365
1366 #[test]
1367 fn test_generation_config_multiple_companies() {
1368 let config = GenerationConfigDto {
1369 industry: "manufacturing".to_string(),
1370 start_date: "2024-01-01".to_string(),
1371 period_months: 12,
1372 seed: Some(42),
1373 coa_complexity: "large".to_string(),
1374 companies: vec![
1375 CompanyConfigDto {
1376 code: "1000".to_string(),
1377 name: "Headquarters".to_string(),
1378 currency: "USD".to_string(),
1379 country: "US".to_string(),
1380 annual_transaction_volume: 100000,
1381 volume_weight: 1.0,
1382 },
1383 CompanyConfigDto {
1384 code: "2000".to_string(),
1385 name: "European Sub".to_string(),
1386 currency: "EUR".to_string(),
1387 country: "DE".to_string(),
1388 annual_transaction_volume: 50000,
1389 volume_weight: 0.5,
1390 },
1391 CompanyConfigDto {
1392 code: "3000".to_string(),
1393 name: "APAC Sub".to_string(),
1394 currency: "JPY".to_string(),
1395 country: "JP".to_string(),
1396 annual_transaction_volume: 30000,
1397 volume_weight: 0.3,
1398 },
1399 ],
1400 fraud_enabled: true,
1401 fraud_rate: 0.02,
1402 };
1403 assert_eq!(config.companies.len(), 3);
1404 }
1405
1406 #[test]
1411 fn test_metrics_entries_per_second_calculation() {
1412 let total_entries: u64 = 1000;
1414 let uptime: u64 = 60;
1415 let eps = if uptime > 0 {
1416 total_entries as f64 / uptime as f64
1417 } else {
1418 0.0
1419 };
1420 assert!((eps - 16.67).abs() < 0.1);
1421 }
1422
1423 #[test]
1424 fn test_metrics_entries_per_second_zero_uptime() {
1425 let total_entries: u64 = 1000;
1426 let uptime: u64 = 0;
1427 let eps = if uptime > 0 {
1428 total_entries as f64 / uptime as f64
1429 } else {
1430 0.0
1431 };
1432 assert_eq!(eps, 0.0);
1433 }
1434}