1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::info;
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 #[allow(dead_code)] pub service: Arc<SynthService>,
29 pub server_state: Arc<ServerState>,
30 pub job_queue: Option<Arc<JobQueue>>,
31}
32
33#[derive(Clone, Debug)]
35pub struct TimeoutConfig {
36 pub request_timeout_secs: u64,
38}
39
40impl Default for TimeoutConfig {
41 fn default() -> Self {
42 Self {
43 request_timeout_secs: 300,
45 }
46 }
47}
48
49impl TimeoutConfig {
50 pub fn new(timeout_secs: u64) -> Self {
52 Self {
53 request_timeout_secs: timeout_secs,
54 }
55 }
56}
57
58#[derive(Clone)]
60pub struct CorsConfig {
61 pub allowed_origins: Vec<String>,
63 pub allow_any_origin: bool,
65}
66
67impl Default for CorsConfig {
68 fn default() -> Self {
69 Self {
70 allowed_origins: vec![
71 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
74 "http://127.0.0.1:3000".to_string(),
75 "tauri://localhost".to_string(), ],
77 allow_any_origin: false,
78 }
79 }
80}
81
82async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
84 let (mut parts, body) = response.into_parts();
85 parts.headers.insert(
86 axum::http::HeaderName::from_static("x-api-version"),
87 axum::http::HeaderValue::from_static("v1"),
88 );
89 axum::response::Response::from_parts(parts, body)
90}
91
92use super::auth::{auth_middleware, AuthConfig};
93use super::rate_limit::RateLimitConfig;
94use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
95use super::request_id::request_id_middleware;
96use super::request_validation::request_validation_middleware;
97use super::security_headers::security_headers_middleware;
98
99pub fn create_router(service: SynthService) -> Router {
101 create_router_with_cors(service, CorsConfig::default())
102}
103
104pub fn create_router_full(
109 service: SynthService,
110 cors_config: CorsConfig,
111 auth_config: AuthConfig,
112 rate_limit_config: RateLimitConfig,
113 timeout_config: TimeoutConfig,
114) -> Router {
115 let backend = RateLimitBackend::in_memory(rate_limit_config);
116 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
117}
118
119pub fn create_router_full_with_backend(
135 service: SynthService,
136 cors_config: CorsConfig,
137 auth_config: AuthConfig,
138 rate_limit_backend: RateLimitBackend,
139 timeout_config: TimeoutConfig,
140) -> Router {
141 let server_state = service.state.clone();
142 let state = AppState {
143 service: Arc::new(service),
144 server_state,
145 job_queue: None,
146 };
147
148 let cors = if cors_config.allow_any_origin {
149 CorsLayer::permissive()
150 } else {
151 let origins: Vec<_> = cors_config
152 .allowed_origins
153 .iter()
154 .filter_map(|o| o.parse().ok())
155 .collect();
156
157 CorsLayer::new()
158 .allow_origin(AllowOrigin::list(origins))
159 .allow_methods([
160 Method::GET,
161 Method::POST,
162 Method::PUT,
163 Method::DELETE,
164 Method::OPTIONS,
165 ])
166 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
167 };
168
169 Router::new()
170 .route("/health", get(health_check))
172 .route("/ready", get(readiness_check))
173 .route("/live", get(liveness_check))
174 .route("/api/metrics", get(get_metrics))
175 .route("/metrics", get(prometheus_metrics))
176 .route("/api/config", get(get_config))
178 .route("/api/config", post(set_config))
179 .route("/api/config/reload", post(reload_config))
180 .route("/api/generate/bulk", post(bulk_generate))
182 .route("/api/stream/start", post(start_stream))
183 .route("/api/stream/stop", post(stop_stream))
184 .route("/api/stream/pause", post(pause_stream))
185 .route("/api/stream/resume", post(resume_stream))
186 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
187 .route("/api/jobs/submit", post(submit_job))
189 .route("/api/jobs", get(list_jobs))
190 .route("/api/jobs/{id}", get(get_job))
191 .route("/api/jobs/{id}/cancel", post(cancel_job))
192 .route("/ws/metrics", get(websocket_metrics))
194 .route("/ws/events", get(websocket_events))
195 .layer(axum::middleware::from_fn(security_headers_middleware))
198 .layer(axum::middleware::map_response(api_version_header))
199 .layer(cors)
200 .layer(axum::middleware::from_fn(request_id_middleware))
201 .layer(axum::middleware::from_fn(auth_middleware))
202 .layer(axum::Extension(auth_config))
203 .layer(axum::middleware::from_fn(request_validation_middleware))
204 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
205 .layer(axum::Extension(rate_limit_backend))
206 .layer(TimeoutLayer::new(Duration::from_secs(
207 timeout_config.request_timeout_secs,
208 )))
209 .with_state(state)
210}
211
212pub fn create_router_with_auth(
214 service: SynthService,
215 cors_config: CorsConfig,
216 auth_config: AuthConfig,
217) -> Router {
218 let server_state = service.state.clone();
219 let state = AppState {
220 service: Arc::new(service),
221 server_state,
222 job_queue: None,
223 };
224
225 let cors = if cors_config.allow_any_origin {
226 CorsLayer::permissive()
227 } else {
228 let origins: Vec<_> = cors_config
229 .allowed_origins
230 .iter()
231 .filter_map(|o| o.parse().ok())
232 .collect();
233
234 CorsLayer::new()
235 .allow_origin(AllowOrigin::list(origins))
236 .allow_methods([
237 Method::GET,
238 Method::POST,
239 Method::PUT,
240 Method::DELETE,
241 Method::OPTIONS,
242 ])
243 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
244 };
245
246 Router::new()
247 .route("/health", get(health_check))
249 .route("/ready", get(readiness_check))
250 .route("/live", get(liveness_check))
251 .route("/api/metrics", get(get_metrics))
252 .route("/metrics", get(prometheus_metrics))
253 .route("/api/config", get(get_config))
255 .route("/api/config", post(set_config))
256 .route("/api/config/reload", post(reload_config))
257 .route("/api/generate/bulk", post(bulk_generate))
259 .route("/api/stream/start", post(start_stream))
260 .route("/api/stream/stop", post(stop_stream))
261 .route("/api/stream/pause", post(pause_stream))
262 .route("/api/stream/resume", post(resume_stream))
263 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
264 .route("/api/jobs/submit", post(submit_job))
266 .route("/api/jobs", get(list_jobs))
267 .route("/api/jobs/{id}", get(get_job))
268 .route("/api/jobs/{id}/cancel", post(cancel_job))
269 .route("/ws/metrics", get(websocket_metrics))
271 .route("/ws/events", get(websocket_events))
272 .layer(axum::middleware::from_fn(auth_middleware))
273 .layer(axum::Extension(auth_config))
274 .layer(cors)
275 .with_state(state)
276}
277
278pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
280 let server_state = service.state.clone();
281 let state = AppState {
282 service: Arc::new(service),
283 server_state,
284 job_queue: None,
285 };
286
287 let cors = if cors_config.allow_any_origin {
288 CorsLayer::permissive()
290 } else {
291 let origins: Vec<_> = cors_config
293 .allowed_origins
294 .iter()
295 .filter_map(|o| o.parse().ok())
296 .collect();
297
298 CorsLayer::new()
299 .allow_origin(AllowOrigin::list(origins))
300 .allow_methods([
301 Method::GET,
302 Method::POST,
303 Method::PUT,
304 Method::DELETE,
305 Method::OPTIONS,
306 ])
307 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
308 };
309
310 Router::new()
311 .route("/health", get(health_check))
313 .route("/ready", get(readiness_check))
314 .route("/live", get(liveness_check))
315 .route("/api/metrics", get(get_metrics))
316 .route("/metrics", get(prometheus_metrics))
317 .route("/api/config", get(get_config))
319 .route("/api/config", post(set_config))
320 .route("/api/config/reload", post(reload_config))
321 .route("/api/generate/bulk", post(bulk_generate))
323 .route("/api/stream/start", post(start_stream))
324 .route("/api/stream/stop", post(stop_stream))
325 .route("/api/stream/pause", post(pause_stream))
326 .route("/api/stream/resume", post(resume_stream))
327 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
328 .route("/api/jobs/submit", post(submit_job))
330 .route("/api/jobs", get(list_jobs))
331 .route("/api/jobs/{id}", get(get_job))
332 .route("/api/jobs/{id}/cancel", post(cancel_job))
333 .route("/ws/metrics", get(websocket_metrics))
335 .route("/ws/events", get(websocket_events))
336 .layer(cors)
337 .with_state(state)
338}
339
340#[derive(Debug, Serialize, Deserialize)]
345pub struct HealthResponse {
346 pub healthy: bool,
347 pub version: String,
348 pub uptime_seconds: u64,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
353pub struct ReadinessResponse {
354 pub ready: bool,
355 pub message: String,
356 pub checks: Vec<HealthCheck>,
357}
358
359#[derive(Debug, Serialize, Deserialize)]
361pub struct HealthCheck {
362 pub name: String,
363 pub status: String,
364}
365
366#[derive(Debug, Serialize, Deserialize)]
368pub struct LivenessResponse {
369 pub alive: bool,
370 pub timestamp: String,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374pub struct MetricsResponse {
375 pub total_entries_generated: u64,
376 pub total_anomalies_injected: u64,
377 pub uptime_seconds: u64,
378 pub session_entries: u64,
379 pub session_entries_per_second: f64,
380 pub active_streams: u32,
381 pub total_stream_events: u64,
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct ConfigResponse {
386 pub success: bool,
387 pub message: String,
388 pub config: Option<GenerationConfigDto>,
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct GenerationConfigDto {
393 pub industry: String,
394 pub start_date: String,
395 pub period_months: u32,
396 pub seed: Option<u64>,
397 pub coa_complexity: String,
398 pub companies: Vec<CompanyConfigDto>,
399 pub fraud_enabled: bool,
400 pub fraud_rate: f32,
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct CompanyConfigDto {
405 pub code: String,
406 pub name: String,
407 pub currency: String,
408 pub country: String,
409 pub annual_transaction_volume: u64,
410 pub volume_weight: f32,
411}
412
413#[derive(Debug, Deserialize)]
414pub struct BulkGenerateRequest {
415 pub entry_count: Option<u64>,
416 pub include_master_data: Option<bool>,
417 pub inject_anomalies: Option<bool>,
418}
419
420#[derive(Debug, Serialize)]
421pub struct BulkGenerateResponse {
422 pub success: bool,
423 pub entries_generated: u64,
424 pub duration_ms: u64,
425 pub anomaly_count: u64,
426}
427
428#[derive(Debug, Deserialize)]
429#[allow(dead_code)] pub struct StreamRequest {
431 pub events_per_second: Option<u32>,
432 pub max_events: Option<u64>,
433 pub inject_anomalies: Option<bool>,
434}
435
436#[derive(Debug, Serialize)]
437pub struct StreamResponse {
438 pub success: bool,
439 pub message: String,
440}
441
442async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
448 Json(HealthResponse {
449 healthy: true,
450 version: env!("CARGO_PKG_VERSION").to_string(),
451 uptime_seconds: state.server_state.uptime_seconds(),
452 })
453}
454
455async fn readiness_check(
458 State(state): State<AppState>,
459) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
460 let mut checks = Vec::new();
461 let mut any_fail = false;
462
463 let config = state.server_state.config.read().await;
465 let config_valid = !config.companies.is_empty();
466 checks.push(HealthCheck {
467 name: "config".to_string(),
468 status: if config_valid { "ok" } else { "fail" }.to_string(),
469 });
470 if !config_valid {
471 any_fail = true;
472 }
473 drop(config);
474
475 let resource_status = state.server_state.resource_status();
477 let memory_status = if resource_status.degradation_level == "Emergency" {
478 any_fail = true;
479 "fail"
480 } else if resource_status.degradation_level != "Normal" {
481 "degraded"
482 } else {
483 "ok"
484 };
485 checks.push(HealthCheck {
486 name: "memory".to_string(),
487 status: memory_status.to_string(),
488 });
489
490 let disk_ok = resource_status.disk_available_mb > 100;
492 checks.push(HealthCheck {
493 name: "disk".to_string(),
494 status: if disk_ok { "ok" } else { "fail" }.to_string(),
495 });
496 if !disk_ok {
497 any_fail = true;
498 }
499
500 let response = ReadinessResponse {
501 ready: !any_fail,
502 message: if any_fail {
503 "Service is not ready".to_string()
504 } else {
505 "Service is ready".to_string()
506 },
507 checks,
508 };
509
510 if any_fail {
511 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
512 } else {
513 Ok(Json(response))
514 }
515}
516
517async fn liveness_check() -> Json<LivenessResponse> {
520 Json(LivenessResponse {
521 alive: true,
522 timestamp: chrono::Utc::now().to_rfc3339(),
523 })
524}
525
526async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
529 use std::sync::atomic::Ordering;
530
531 let uptime = state.server_state.uptime_seconds();
532 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
533 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
534 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
535 let total_stream_events = state
536 .server_state
537 .total_stream_events
538 .load(Ordering::Relaxed);
539
540 let entries_per_second = if uptime > 0 {
541 total_entries as f64 / uptime as f64
542 } else {
543 0.0
544 };
545
546 let metrics = format!(
547 r#"# HELP synth_entries_generated_total Total number of journal entries generated
548# TYPE synth_entries_generated_total counter
549synth_entries_generated_total {}
550
551# HELP synth_anomalies_injected_total Total number of anomalies injected
552# TYPE synth_anomalies_injected_total counter
553synth_anomalies_injected_total {}
554
555# HELP synth_uptime_seconds Server uptime in seconds
556# TYPE synth_uptime_seconds gauge
557synth_uptime_seconds {}
558
559# HELP synth_entries_per_second Rate of entry generation
560# TYPE synth_entries_per_second gauge
561synth_entries_per_second {:.2}
562
563# HELP synth_active_streams Number of active streaming connections
564# TYPE synth_active_streams gauge
565synth_active_streams {}
566
567# HELP synth_stream_events_total Total events sent through streams
568# TYPE synth_stream_events_total counter
569synth_stream_events_total {}
570
571# HELP synth_info Server version information
572# TYPE synth_info gauge
573synth_info{{version="{}"}} 1
574"#,
575 total_entries,
576 total_anomalies,
577 uptime,
578 entries_per_second,
579 active_streams,
580 total_stream_events,
581 env!("CARGO_PKG_VERSION")
582 );
583
584 (
585 StatusCode::OK,
586 [(
587 header::CONTENT_TYPE,
588 "text/plain; version=0.0.4; charset=utf-8",
589 )],
590 metrics,
591 )
592}
593
594async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
596 let uptime = state.server_state.uptime_seconds();
597 let total_entries = state
598 .server_state
599 .total_entries
600 .load(std::sync::atomic::Ordering::Relaxed);
601
602 let entries_per_second = if uptime > 0 {
603 total_entries as f64 / uptime as f64
604 } else {
605 0.0
606 };
607
608 Json(MetricsResponse {
609 total_entries_generated: total_entries,
610 total_anomalies_injected: state
611 .server_state
612 .total_anomalies
613 .load(std::sync::atomic::Ordering::Relaxed),
614 uptime_seconds: uptime,
615 session_entries: total_entries,
616 session_entries_per_second: entries_per_second,
617 active_streams: state
618 .server_state
619 .active_streams
620 .load(std::sync::atomic::Ordering::Relaxed) as u32,
621 total_stream_events: state
622 .server_state
623 .total_stream_events
624 .load(std::sync::atomic::Ordering::Relaxed),
625 })
626}
627
628async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
630 let config = state.server_state.config.read().await;
631
632 Json(ConfigResponse {
633 success: true,
634 message: "Current configuration".to_string(),
635 config: Some(GenerationConfigDto {
636 industry: format!("{:?}", config.global.industry),
637 start_date: config.global.start_date.clone(),
638 period_months: config.global.period_months,
639 seed: config.global.seed,
640 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
641 companies: config
642 .companies
643 .iter()
644 .map(|c| CompanyConfigDto {
645 code: c.code.clone(),
646 name: c.name.clone(),
647 currency: c.currency.clone(),
648 country: c.country.clone(),
649 annual_transaction_volume: c.annual_transaction_volume.count(),
650 volume_weight: c.volume_weight as f32,
651 })
652 .collect(),
653 fraud_enabled: config.fraud.enabled,
654 fraud_rate: config.fraud.fraud_rate as f32,
655 }),
656 })
657}
658
659async fn set_config(
661 State(state): State<AppState>,
662 Json(new_config): Json<GenerationConfigDto>,
663) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
664 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
665 use datasynth_core::models::{CoAComplexity, IndustrySector};
666
667 info!(
668 "Configuration update requested: industry={}, period_months={}",
669 new_config.industry, new_config.period_months
670 );
671
672 let industry = match new_config.industry.to_lowercase().as_str() {
674 "manufacturing" => IndustrySector::Manufacturing,
675 "retail" => IndustrySector::Retail,
676 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
677 "healthcare" => IndustrySector::Healthcare,
678 "technology" => IndustrySector::Technology,
679 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
680 "energy" => IndustrySector::Energy,
681 "transportation" => IndustrySector::Transportation,
682 "real_estate" | "realestate" => IndustrySector::RealEstate,
683 "telecommunications" => IndustrySector::Telecommunications,
684 _ => {
685 return Err((
686 StatusCode::BAD_REQUEST,
687 Json(ConfigResponse {
688 success: false,
689 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
690 config: None,
691 }),
692 ));
693 }
694 };
695
696 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
698 "small" => CoAComplexity::Small,
699 "medium" => CoAComplexity::Medium,
700 "large" => CoAComplexity::Large,
701 _ => {
702 return Err((
703 StatusCode::BAD_REQUEST,
704 Json(ConfigResponse {
705 success: false,
706 message: format!(
707 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
708 new_config.coa_complexity
709 ),
710 config: None,
711 }),
712 ));
713 }
714 };
715
716 let companies: Vec<CompanyConfig> = new_config
718 .companies
719 .iter()
720 .map(|c| CompanyConfig {
721 code: c.code.clone(),
722 name: c.name.clone(),
723 currency: c.currency.clone(),
724 country: c.country.clone(),
725 fiscal_year_variant: "K4".to_string(),
726 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
727 volume_weight: c.volume_weight as f64,
728 })
729 .collect();
730
731 let mut config = state.server_state.config.write().await;
733 config.global.industry = industry;
734 config.global.start_date = new_config.start_date.clone();
735 config.global.period_months = new_config.period_months;
736 config.global.seed = new_config.seed;
737 config.chart_of_accounts.complexity = complexity;
738 config.fraud.enabled = new_config.fraud_enabled;
739 config.fraud.fraud_rate = new_config.fraud_rate as f64;
740
741 if !companies.is_empty() {
743 config.companies = companies;
744 }
745
746 info!("Configuration updated successfully");
747
748 Ok(Json(ConfigResponse {
749 success: true,
750 message: "Configuration updated and applied".to_string(),
751 config: Some(new_config),
752 }))
753}
754
755async fn bulk_generate(
757 State(state): State<AppState>,
758 Json(req): Json<BulkGenerateRequest>,
759) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
760 const MAX_ENTRY_COUNT: u64 = 1_000_000;
762 if let Some(count) = req.entry_count {
763 if count > MAX_ENTRY_COUNT {
764 return Err((
765 StatusCode::BAD_REQUEST,
766 format!(
767 "entry_count ({}) exceeds maximum allowed value ({})",
768 count, MAX_ENTRY_COUNT
769 ),
770 ));
771 }
772 }
773
774 let config = state.server_state.config.read().await.clone();
775 let start_time = std::time::Instant::now();
776
777 let phase_config = PhaseConfig {
778 generate_master_data: req.include_master_data.unwrap_or(false),
779 generate_document_flows: false,
780 generate_journal_entries: true,
781 inject_anomalies: req.inject_anomalies.unwrap_or(false),
782 show_progress: false,
783 ..Default::default()
784 };
785
786 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
787 (
788 StatusCode::INTERNAL_SERVER_ERROR,
789 format!("Failed to create orchestrator: {}", e),
790 )
791 })?;
792
793 let result = orchestrator.generate().map_err(|e| {
794 (
795 StatusCode::INTERNAL_SERVER_ERROR,
796 format!("Generation failed: {}", e),
797 )
798 })?;
799
800 let duration_ms = start_time.elapsed().as_millis() as u64;
801 let entries_count = result.journal_entries.len() as u64;
802 let anomaly_count = result.anomaly_labels.labels.len() as u64;
803
804 state
806 .server_state
807 .total_entries
808 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
809 state
810 .server_state
811 .total_anomalies
812 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
813
814 Ok(Json(BulkGenerateResponse {
815 success: true,
816 entries_generated: entries_count,
817 duration_ms,
818 anomaly_count,
819 }))
820}
821
822async fn start_stream(
824 State(state): State<AppState>,
825 Json(_req): Json<StreamRequest>,
826) -> Json<StreamResponse> {
827 state
828 .server_state
829 .stream_stopped
830 .store(false, std::sync::atomic::Ordering::Relaxed);
831 state
832 .server_state
833 .stream_paused
834 .store(false, std::sync::atomic::Ordering::Relaxed);
835
836 Json(StreamResponse {
837 success: true,
838 message: "Stream started".to_string(),
839 })
840}
841
842async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
844 state
845 .server_state
846 .stream_stopped
847 .store(true, std::sync::atomic::Ordering::Relaxed);
848
849 Json(StreamResponse {
850 success: true,
851 message: "Stream stopped".to_string(),
852 })
853}
854
855async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
857 state
858 .server_state
859 .stream_paused
860 .store(true, std::sync::atomic::Ordering::Relaxed);
861
862 Json(StreamResponse {
863 success: true,
864 message: "Stream paused".to_string(),
865 })
866}
867
868async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
870 state
871 .server_state
872 .stream_paused
873 .store(false, std::sync::atomic::Ordering::Relaxed);
874
875 Json(StreamResponse {
876 success: true,
877 message: "Stream resumed".to_string(),
878 })
879}
880
881async fn trigger_pattern(
886 State(state): State<AppState>,
887 axum::extract::Path(pattern): axum::extract::Path<String>,
888) -> Json<StreamResponse> {
889 info!("Pattern trigger requested: {}", pattern);
890
891 let valid_patterns = [
893 "year_end_spike",
894 "period_end_spike",
895 "holiday_cluster",
896 "fraud_cluster",
897 "error_cluster",
898 "uniform",
899 ];
900
901 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
902
903 if !is_valid {
904 return Json(StreamResponse {
905 success: false,
906 message: format!(
907 "Unknown pattern '{}'. Valid patterns: {:?}, or use 'custom:name' for custom patterns",
908 pattern, valid_patterns
909 ),
910 });
911 }
912
913 match state.server_state.triggered_pattern.try_write() {
915 Ok(mut triggered) => {
916 *triggered = Some(pattern.clone());
917 Json(StreamResponse {
918 success: true,
919 message: format!("Pattern '{}' will be applied to upcoming entries", pattern),
920 })
921 }
922 Err(_) => Json(StreamResponse {
923 success: false,
924 message: "Failed to acquire lock for pattern trigger".to_string(),
925 }),
926 }
927}
928
929async fn websocket_metrics(
931 ws: WebSocketUpgrade,
932 State(state): State<AppState>,
933) -> impl IntoResponse {
934 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
935}
936
937async fn websocket_events(
939 ws: WebSocketUpgrade,
940 State(state): State<AppState>,
941) -> impl IntoResponse {
942 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
943}
944
945async fn submit_job(
951 State(state): State<AppState>,
952 Json(request): Json<JobRequest>,
953) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
954 let queue = state.job_queue.as_ref().ok_or_else(|| {
955 (
956 StatusCode::SERVICE_UNAVAILABLE,
957 Json(serde_json::json!({"error": "Job queue not enabled"})),
958 )
959 })?;
960
961 let job_id = queue.submit(request).await;
962 info!("Job submitted: {}", job_id);
963
964 Ok((
965 StatusCode::CREATED,
966 Json(serde_json::json!({
967 "id": job_id.to_string(),
968 "status": "queued"
969 })),
970 ))
971}
972
973async fn get_job(
975 State(state): State<AppState>,
976 axum::extract::Path(id): axum::extract::Path<String>,
977) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
978 let queue = state.job_queue.as_ref().ok_or_else(|| {
979 (
980 StatusCode::SERVICE_UNAVAILABLE,
981 Json(serde_json::json!({"error": "Job queue not enabled"})),
982 )
983 })?;
984
985 match queue.get(&id).await {
986 Some(entry) => Ok(Json(serde_json::json!({
987 "id": entry.id,
988 "status": format!("{:?}", entry.status).to_lowercase(),
989 "submitted_at": entry.submitted_at.to_rfc3339(),
990 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
991 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
992 "result": entry.result,
993 }))),
994 None => Err((
995 StatusCode::NOT_FOUND,
996 Json(serde_json::json!({"error": "Job not found"})),
997 )),
998 }
999}
1000
1001async fn list_jobs(
1003 State(state): State<AppState>,
1004) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1005 let queue = state.job_queue.as_ref().ok_or_else(|| {
1006 (
1007 StatusCode::SERVICE_UNAVAILABLE,
1008 Json(serde_json::json!({"error": "Job queue not enabled"})),
1009 )
1010 })?;
1011
1012 let summaries: Vec<_> = queue
1013 .list()
1014 .await
1015 .into_iter()
1016 .map(|s| {
1017 serde_json::json!({
1018 "id": s.id,
1019 "status": format!("{:?}", s.status).to_lowercase(),
1020 "submitted_at": s.submitted_at.to_rfc3339(),
1021 })
1022 })
1023 .collect();
1024
1025 Ok(Json(serde_json::json!({ "jobs": summaries })))
1026}
1027
1028async fn cancel_job(
1030 State(state): State<AppState>,
1031 axum::extract::Path(id): axum::extract::Path<String>,
1032) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1033 let queue = state.job_queue.as_ref().ok_or_else(|| {
1034 (
1035 StatusCode::SERVICE_UNAVAILABLE,
1036 Json(serde_json::json!({"error": "Job queue not enabled"})),
1037 )
1038 })?;
1039
1040 if queue.cancel(&id).await {
1041 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1042 } else {
1043 Err((
1044 StatusCode::CONFLICT,
1045 Json(
1046 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1047 ),
1048 ))
1049 }
1050}
1051
1052async fn reload_config(
1058 State(state): State<AppState>,
1059) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1060 let new_config = crate::grpc::service::default_generator_config();
1062 let mut config = state.server_state.config.write().await;
1063 *config = new_config;
1064 info!("Configuration reloaded via REST API");
1065
1066 Ok(Json(serde_json::json!({
1067 "success": true,
1068 "message": "Configuration reloaded"
1069 })))
1070}
1071
1072#[cfg(test)]
1073#[allow(clippy::unwrap_used)]
1074mod tests {
1075 use super::*;
1076
1077 #[test]
1082 fn test_health_response_serialization() {
1083 let response = HealthResponse {
1084 healthy: true,
1085 version: "0.1.0".to_string(),
1086 uptime_seconds: 100,
1087 };
1088 let json = serde_json::to_string(&response).unwrap();
1089 assert!(json.contains("healthy"));
1090 assert!(json.contains("version"));
1091 assert!(json.contains("uptime_seconds"));
1092 }
1093
1094 #[test]
1095 fn test_health_response_deserialization() {
1096 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1097 let response: HealthResponse = serde_json::from_str(json).unwrap();
1098 assert!(response.healthy);
1099 assert_eq!(response.version, "0.1.0");
1100 assert_eq!(response.uptime_seconds, 100);
1101 }
1102
1103 #[test]
1104 fn test_metrics_response_serialization() {
1105 let response = MetricsResponse {
1106 total_entries_generated: 1000,
1107 total_anomalies_injected: 10,
1108 uptime_seconds: 60,
1109 session_entries: 1000,
1110 session_entries_per_second: 16.67,
1111 active_streams: 1,
1112 total_stream_events: 500,
1113 };
1114 let json = serde_json::to_string(&response).unwrap();
1115 assert!(json.contains("total_entries_generated"));
1116 assert!(json.contains("session_entries_per_second"));
1117 }
1118
1119 #[test]
1120 fn test_metrics_response_deserialization() {
1121 let json = r#"{
1122 "total_entries_generated": 5000,
1123 "total_anomalies_injected": 50,
1124 "uptime_seconds": 300,
1125 "session_entries": 5000,
1126 "session_entries_per_second": 16.67,
1127 "active_streams": 2,
1128 "total_stream_events": 10000
1129 }"#;
1130 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1131 assert_eq!(response.total_entries_generated, 5000);
1132 assert_eq!(response.active_streams, 2);
1133 }
1134
1135 #[test]
1136 fn test_config_response_serialization() {
1137 let response = ConfigResponse {
1138 success: true,
1139 message: "Configuration loaded".to_string(),
1140 config: Some(GenerationConfigDto {
1141 industry: "manufacturing".to_string(),
1142 start_date: "2024-01-01".to_string(),
1143 period_months: 12,
1144 seed: Some(42),
1145 coa_complexity: "medium".to_string(),
1146 companies: vec![],
1147 fraud_enabled: false,
1148 fraud_rate: 0.0,
1149 }),
1150 };
1151 let json = serde_json::to_string(&response).unwrap();
1152 assert!(json.contains("success"));
1153 assert!(json.contains("config"));
1154 }
1155
1156 #[test]
1157 fn test_config_response_without_config() {
1158 let response = ConfigResponse {
1159 success: false,
1160 message: "No configuration available".to_string(),
1161 config: None,
1162 };
1163 let json = serde_json::to_string(&response).unwrap();
1164 assert!(json.contains("null") || json.contains("config\":null"));
1165 }
1166
1167 #[test]
1168 fn test_generation_config_dto_roundtrip() {
1169 let original = GenerationConfigDto {
1170 industry: "retail".to_string(),
1171 start_date: "2024-06-01".to_string(),
1172 period_months: 6,
1173 seed: Some(12345),
1174 coa_complexity: "large".to_string(),
1175 companies: vec![CompanyConfigDto {
1176 code: "1000".to_string(),
1177 name: "Test Corp".to_string(),
1178 currency: "USD".to_string(),
1179 country: "US".to_string(),
1180 annual_transaction_volume: 100000,
1181 volume_weight: 1.0,
1182 }],
1183 fraud_enabled: true,
1184 fraud_rate: 0.05,
1185 };
1186
1187 let json = serde_json::to_string(&original).unwrap();
1188 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1189
1190 assert_eq!(original.industry, deserialized.industry);
1191 assert_eq!(original.seed, deserialized.seed);
1192 assert_eq!(original.companies.len(), deserialized.companies.len());
1193 }
1194
1195 #[test]
1196 fn test_company_config_dto_serialization() {
1197 let company = CompanyConfigDto {
1198 code: "2000".to_string(),
1199 name: "European Subsidiary".to_string(),
1200 currency: "EUR".to_string(),
1201 country: "DE".to_string(),
1202 annual_transaction_volume: 50000,
1203 volume_weight: 0.5,
1204 };
1205 let json = serde_json::to_string(&company).unwrap();
1206 assert!(json.contains("2000"));
1207 assert!(json.contains("EUR"));
1208 assert!(json.contains("DE"));
1209 }
1210
1211 #[test]
1212 fn test_bulk_generate_request_deserialization() {
1213 let json = r#"{
1214 "entry_count": 5000,
1215 "include_master_data": true,
1216 "inject_anomalies": true
1217 }"#;
1218 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1219 assert_eq!(request.entry_count, Some(5000));
1220 assert_eq!(request.include_master_data, Some(true));
1221 assert_eq!(request.inject_anomalies, Some(true));
1222 }
1223
1224 #[test]
1225 fn test_bulk_generate_request_with_defaults() {
1226 let json = r#"{}"#;
1227 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1228 assert_eq!(request.entry_count, None);
1229 assert_eq!(request.include_master_data, None);
1230 assert_eq!(request.inject_anomalies, None);
1231 }
1232
1233 #[test]
1234 fn test_bulk_generate_response_serialization() {
1235 let response = BulkGenerateResponse {
1236 success: true,
1237 entries_generated: 1000,
1238 duration_ms: 250,
1239 anomaly_count: 20,
1240 };
1241 let json = serde_json::to_string(&response).unwrap();
1242 assert!(json.contains("entries_generated"));
1243 assert!(json.contains("1000"));
1244 assert!(json.contains("duration_ms"));
1245 }
1246
1247 #[test]
1248 fn test_stream_response_serialization() {
1249 let response = StreamResponse {
1250 success: true,
1251 message: "Stream started successfully".to_string(),
1252 };
1253 let json = serde_json::to_string(&response).unwrap();
1254 assert!(json.contains("success"));
1255 assert!(json.contains("Stream started"));
1256 }
1257
1258 #[test]
1259 fn test_stream_response_failure() {
1260 let response = StreamResponse {
1261 success: false,
1262 message: "Stream failed to start".to_string(),
1263 };
1264 let json = serde_json::to_string(&response).unwrap();
1265 assert!(json.contains("false"));
1266 assert!(json.contains("failed"));
1267 }
1268
1269 #[test]
1274 fn test_cors_config_default() {
1275 let config = CorsConfig::default();
1276 assert!(!config.allow_any_origin);
1277 assert!(!config.allowed_origins.is_empty());
1278 assert!(config
1279 .allowed_origins
1280 .contains(&"http://localhost:5173".to_string()));
1281 assert!(config
1282 .allowed_origins
1283 .contains(&"tauri://localhost".to_string()));
1284 }
1285
1286 #[test]
1287 fn test_cors_config_custom_origins() {
1288 let config = CorsConfig {
1289 allowed_origins: vec![
1290 "https://example.com".to_string(),
1291 "https://app.example.com".to_string(),
1292 ],
1293 allow_any_origin: false,
1294 };
1295 assert_eq!(config.allowed_origins.len(), 2);
1296 assert!(config
1297 .allowed_origins
1298 .contains(&"https://example.com".to_string()));
1299 }
1300
1301 #[test]
1302 fn test_cors_config_permissive() {
1303 let config = CorsConfig {
1304 allowed_origins: vec![],
1305 allow_any_origin: true,
1306 };
1307 assert!(config.allow_any_origin);
1308 }
1309
1310 #[test]
1315 fn test_bulk_generate_request_partial() {
1316 let json = r#"{"entry_count": 100}"#;
1317 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1318 assert_eq!(request.entry_count, Some(100));
1319 assert!(request.include_master_data.is_none());
1320 }
1321
1322 #[test]
1323 fn test_generation_config_no_seed() {
1324 let config = GenerationConfigDto {
1325 industry: "technology".to_string(),
1326 start_date: "2024-01-01".to_string(),
1327 period_months: 3,
1328 seed: None,
1329 coa_complexity: "small".to_string(),
1330 companies: vec![],
1331 fraud_enabled: false,
1332 fraud_rate: 0.0,
1333 };
1334 let json = serde_json::to_string(&config).unwrap();
1335 assert!(json.contains("seed"));
1336 }
1337
1338 #[test]
1339 fn test_generation_config_multiple_companies() {
1340 let config = GenerationConfigDto {
1341 industry: "manufacturing".to_string(),
1342 start_date: "2024-01-01".to_string(),
1343 period_months: 12,
1344 seed: Some(42),
1345 coa_complexity: "large".to_string(),
1346 companies: vec![
1347 CompanyConfigDto {
1348 code: "1000".to_string(),
1349 name: "Headquarters".to_string(),
1350 currency: "USD".to_string(),
1351 country: "US".to_string(),
1352 annual_transaction_volume: 100000,
1353 volume_weight: 1.0,
1354 },
1355 CompanyConfigDto {
1356 code: "2000".to_string(),
1357 name: "European Sub".to_string(),
1358 currency: "EUR".to_string(),
1359 country: "DE".to_string(),
1360 annual_transaction_volume: 50000,
1361 volume_weight: 0.5,
1362 },
1363 CompanyConfigDto {
1364 code: "3000".to_string(),
1365 name: "APAC Sub".to_string(),
1366 currency: "JPY".to_string(),
1367 country: "JP".to_string(),
1368 annual_transaction_volume: 30000,
1369 volume_weight: 0.3,
1370 },
1371 ],
1372 fraud_enabled: true,
1373 fraud_rate: 0.02,
1374 };
1375 assert_eq!(config.companies.len(), 3);
1376 }
1377
1378 #[test]
1383 fn test_metrics_entries_per_second_calculation() {
1384 let total_entries: u64 = 1000;
1386 let uptime: u64 = 60;
1387 let eps = if uptime > 0 {
1388 total_entries as f64 / uptime as f64
1389 } else {
1390 0.0
1391 };
1392 assert!((eps - 16.67).abs() < 0.1);
1393 }
1394
1395 #[test]
1396 fn test_metrics_entries_per_second_zero_uptime() {
1397 let total_entries: u64 = 1000;
1398 let uptime: u64 = 0;
1399 let eps = if uptime > 0 {
1400 total_entries as f64 / uptime as f64
1401 } else {
1402 0.0
1403 };
1404 assert_eq!(eps, 0.0);
1405 }
1406}