1use std::sync::Arc;
4use std::time::Duration;
5
6use axum::{
7 extract::{State, WebSocketUpgrade},
8 http::{header, Method, StatusCode},
9 response::IntoResponse,
10 routing::{get, post},
11 Json, Router,
12};
13use serde::{Deserialize, Serialize};
14use tower_http::cors::{AllowOrigin, CorsLayer};
15use tower_http::timeout::TimeoutLayer;
16use tracing::{error, info};
17
18use crate::grpc::service::{ServerState, SynthService};
19use crate::jobs::{JobQueue, JobRequest};
20use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
21
22use super::websocket;
23
24#[derive(Clone)]
26pub struct AppState {
27 pub server_state: Arc<ServerState>,
28 pub job_queue: Option<Arc<JobQueue>>,
29}
30
31#[derive(Clone, Debug)]
33pub struct TimeoutConfig {
34 pub request_timeout_secs: u64,
36}
37
38impl Default for TimeoutConfig {
39 fn default() -> Self {
40 Self {
41 request_timeout_secs: 300,
43 }
44 }
45}
46
47impl TimeoutConfig {
48 pub fn new(timeout_secs: u64) -> Self {
50 Self {
51 request_timeout_secs: timeout_secs,
52 }
53 }
54}
55
56#[derive(Clone)]
58pub struct CorsConfig {
59 pub allowed_origins: Vec<String>,
61 pub allow_any_origin: bool,
63}
64
65impl Default for CorsConfig {
66 fn default() -> Self {
67 Self {
68 allowed_origins: vec![
69 "http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
72 "http://127.0.0.1:3000".to_string(),
73 "tauri://localhost".to_string(), ],
75 allow_any_origin: false,
76 }
77 }
78}
79
80async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
82 let (mut parts, body) = response.into_parts();
83 parts.headers.insert(
84 axum::http::HeaderName::from_static("x-api-version"),
85 axum::http::HeaderValue::from_static("v1"),
86 );
87 axum::response::Response::from_parts(parts, body)
88}
89
90use super::auth::{auth_middleware, AuthConfig};
91use super::rate_limit::RateLimitConfig;
92use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
93use super::request_id::request_id_middleware;
94use super::request_validation::request_validation_middleware;
95use super::security_headers::security_headers_middleware;
96
97pub fn create_router(service: SynthService) -> Router {
99 create_router_with_cors(service, CorsConfig::default())
100}
101
102pub fn create_router_full(
107 service: SynthService,
108 cors_config: CorsConfig,
109 auth_config: AuthConfig,
110 rate_limit_config: RateLimitConfig,
111 timeout_config: TimeoutConfig,
112) -> Router {
113 let backend = RateLimitBackend::in_memory(rate_limit_config);
114 create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
115}
116
117pub fn create_router_full_with_backend(
133 service: SynthService,
134 cors_config: CorsConfig,
135 auth_config: AuthConfig,
136 rate_limit_backend: RateLimitBackend,
137 timeout_config: TimeoutConfig,
138) -> Router {
139 let server_state = service.state.clone();
140 let state = AppState {
141 server_state,
142 job_queue: None,
143 };
144
145 let cors = if cors_config.allow_any_origin {
146 CorsLayer::permissive()
147 } else {
148 let origins: Vec<_> = cors_config
149 .allowed_origins
150 .iter()
151 .filter_map(|o| o.parse().ok())
152 .collect();
153
154 CorsLayer::new()
155 .allow_origin(AllowOrigin::list(origins))
156 .allow_methods([
157 Method::GET,
158 Method::POST,
159 Method::PUT,
160 Method::DELETE,
161 Method::OPTIONS,
162 ])
163 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
164 };
165
166 Router::new()
167 .route("/health", get(health_check))
169 .route("/ready", get(readiness_check))
170 .route("/live", get(liveness_check))
171 .route("/api/metrics", get(get_metrics))
172 .route("/metrics", get(prometheus_metrics))
173 .route("/api/config", get(get_config))
175 .route("/api/config", post(set_config))
176 .route("/api/config/reload", post(reload_config))
177 .route("/api/generate/bulk", post(bulk_generate))
179 .route("/api/stream/start", post(start_stream))
180 .route("/api/stream/stop", post(stop_stream))
181 .route("/api/stream/pause", post(pause_stream))
182 .route("/api/stream/resume", post(resume_stream))
183 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
184 .route("/api/stream/ndjson", get(stream_ndjson))
185 .route("/api/jobs/submit", post(submit_job))
187 .route("/api/jobs", get(list_jobs))
188 .route("/api/jobs/{id}", get(get_job))
189 .route("/api/jobs/{id}/cancel", post(cancel_job))
190 .route("/ws/metrics", get(websocket_metrics))
192 .route("/ws/events", get(websocket_events))
193 .layer(axum::middleware::from_fn(security_headers_middleware))
196 .layer(axum::middleware::map_response(api_version_header))
197 .layer(cors)
198 .layer(axum::middleware::from_fn(request_id_middleware))
199 .layer(axum::middleware::from_fn(auth_middleware))
200 .layer(axum::Extension(auth_config))
201 .layer(axum::middleware::from_fn(request_validation_middleware))
202 .layer(axum::middleware::from_fn(backend_rate_limit_middleware))
203 .layer(axum::Extension(rate_limit_backend))
204 .layer(TimeoutLayer::with_status_code(
205 StatusCode::REQUEST_TIMEOUT,
206 Duration::from_secs(timeout_config.request_timeout_secs),
207 ))
208 .with_state(state)
209}
210
211pub fn create_router_with_auth(
213 service: SynthService,
214 cors_config: CorsConfig,
215 auth_config: AuthConfig,
216) -> Router {
217 let server_state = service.state.clone();
218 let state = AppState {
219 server_state,
220 job_queue: None,
221 };
222
223 let cors = if cors_config.allow_any_origin {
224 CorsLayer::permissive()
225 } else {
226 let origins: Vec<_> = cors_config
227 .allowed_origins
228 .iter()
229 .filter_map(|o| o.parse().ok())
230 .collect();
231
232 CorsLayer::new()
233 .allow_origin(AllowOrigin::list(origins))
234 .allow_methods([
235 Method::GET,
236 Method::POST,
237 Method::PUT,
238 Method::DELETE,
239 Method::OPTIONS,
240 ])
241 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
242 };
243
244 Router::new()
245 .route("/health", get(health_check))
247 .route("/ready", get(readiness_check))
248 .route("/live", get(liveness_check))
249 .route("/api/metrics", get(get_metrics))
250 .route("/metrics", get(prometheus_metrics))
251 .route("/api/config", get(get_config))
253 .route("/api/config", post(set_config))
254 .route("/api/config/reload", post(reload_config))
255 .route("/api/generate/bulk", post(bulk_generate))
257 .route("/api/stream/start", post(start_stream))
258 .route("/api/stream/stop", post(stop_stream))
259 .route("/api/stream/pause", post(pause_stream))
260 .route("/api/stream/resume", post(resume_stream))
261 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
262 .route("/api/stream/ndjson", get(stream_ndjson))
263 .route("/api/jobs/submit", post(submit_job))
265 .route("/api/jobs", get(list_jobs))
266 .route("/api/jobs/{id}", get(get_job))
267 .route("/api/jobs/{id}/cancel", post(cancel_job))
268 .route("/ws/metrics", get(websocket_metrics))
270 .route("/ws/events", get(websocket_events))
271 .layer(axum::middleware::from_fn(auth_middleware))
272 .layer(axum::Extension(auth_config))
273 .layer(cors)
274 .with_state(state)
275}
276
277pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
279 let server_state = service.state.clone();
280 let state = AppState {
281 server_state,
282 job_queue: None,
283 };
284
285 let cors = if cors_config.allow_any_origin {
286 CorsLayer::permissive()
288 } else {
289 let origins: Vec<_> = cors_config
291 .allowed_origins
292 .iter()
293 .filter_map(|o| o.parse().ok())
294 .collect();
295
296 CorsLayer::new()
297 .allow_origin(AllowOrigin::list(origins))
298 .allow_methods([
299 Method::GET,
300 Method::POST,
301 Method::PUT,
302 Method::DELETE,
303 Method::OPTIONS,
304 ])
305 .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
306 };
307
308 Router::new()
309 .route("/health", get(health_check))
311 .route("/ready", get(readiness_check))
312 .route("/live", get(liveness_check))
313 .route("/api/metrics", get(get_metrics))
314 .route("/metrics", get(prometheus_metrics))
315 .route("/api/config", get(get_config))
317 .route("/api/config", post(set_config))
318 .route("/api/config/reload", post(reload_config))
319 .route("/api/generate/bulk", post(bulk_generate))
321 .route("/api/stream/start", post(start_stream))
322 .route("/api/stream/stop", post(stop_stream))
323 .route("/api/stream/pause", post(pause_stream))
324 .route("/api/stream/resume", post(resume_stream))
325 .route("/api/stream/trigger/{pattern}", post(trigger_pattern))
326 .route("/api/stream/ndjson", get(stream_ndjson))
327 .route("/api/jobs/submit", post(submit_job))
329 .route("/api/jobs", get(list_jobs))
330 .route("/api/jobs/{id}", get(get_job))
331 .route("/api/jobs/{id}/cancel", post(cancel_job))
332 .route("/ws/metrics", get(websocket_metrics))
334 .route("/ws/events", get(websocket_events))
335 .layer(cors)
336 .with_state(state)
337}
338
339#[derive(Debug, Serialize, Deserialize)]
344pub struct HealthResponse {
345 pub healthy: bool,
346 pub version: String,
347 pub uptime_seconds: u64,
348}
349
350#[derive(Debug, Serialize, Deserialize)]
352pub struct ReadinessResponse {
353 pub ready: bool,
354 pub message: String,
355 pub checks: Vec<HealthCheck>,
356}
357
358#[derive(Debug, Serialize, Deserialize)]
360pub struct HealthCheck {
361 pub name: String,
362 pub status: String,
363}
364
365#[derive(Debug, Serialize, Deserialize)]
367pub struct LivenessResponse {
368 pub alive: bool,
369 pub timestamp: String,
370}
371
372#[derive(Debug, Serialize, Deserialize)]
373pub struct MetricsResponse {
374 pub total_entries_generated: u64,
375 pub total_anomalies_injected: u64,
376 pub uptime_seconds: u64,
377 pub session_entries: u64,
378 pub session_entries_per_second: f64,
379 pub active_streams: u32,
380 pub total_stream_events: u64,
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct ConfigResponse {
385 pub success: bool,
386 pub message: String,
387 pub config: Option<GenerationConfigDto>,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct GenerationConfigDto {
392 pub industry: String,
393 pub start_date: String,
394 pub period_months: u32,
395 pub seed: Option<u64>,
396 pub coa_complexity: String,
397 pub companies: Vec<CompanyConfigDto>,
398 pub fraud_enabled: bool,
399 pub fraud_rate: f32,
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct CompanyConfigDto {
404 pub code: String,
405 pub name: String,
406 pub currency: String,
407 pub country: String,
408 pub annual_transaction_volume: u64,
409 pub volume_weight: f32,
410}
411
412#[derive(Debug, Deserialize)]
413pub struct BulkGenerateRequest {
414 pub entry_count: Option<u64>,
415 pub include_master_data: Option<bool>,
416 pub inject_anomalies: Option<bool>,
417}
418
419#[derive(Debug, Serialize)]
420pub struct BulkGenerateResponse {
421 pub success: bool,
422 pub entries_generated: u64,
423 pub duration_ms: u64,
424 pub anomaly_count: u64,
425}
426
427#[derive(Debug, Deserialize)]
428#[allow(dead_code)] pub struct StreamRequest {
430 pub events_per_second: Option<u32>,
431 pub max_events: Option<u64>,
432 pub inject_anomalies: Option<bool>,
433}
434
435#[derive(Debug, Serialize)]
436pub struct StreamResponse {
437 pub success: bool,
438 pub message: String,
439}
440
441async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
447 Json(HealthResponse {
448 healthy: true,
449 version: env!("CARGO_PKG_VERSION").to_string(),
450 uptime_seconds: state.server_state.uptime_seconds(),
451 })
452}
453
454async fn readiness_check(
457 State(state): State<AppState>,
458) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
459 let mut checks = Vec::new();
460 let mut any_fail = false;
461
462 let config = state.server_state.config.read().await;
464 let config_valid = !config.companies.is_empty();
465 checks.push(HealthCheck {
466 name: "config".to_string(),
467 status: if config_valid { "ok" } else { "fail" }.to_string(),
468 });
469 if !config_valid {
470 any_fail = true;
471 }
472 drop(config);
473
474 let resource_status = state.server_state.resource_status();
476 let memory_status = if resource_status.degradation_level == "Emergency" {
477 any_fail = true;
478 "fail"
479 } else if resource_status.degradation_level != "Normal" {
480 "degraded"
481 } else {
482 "ok"
483 };
484 checks.push(HealthCheck {
485 name: "memory".to_string(),
486 status: memory_status.to_string(),
487 });
488
489 let disk_ok = resource_status.disk_available_mb > 100;
491 checks.push(HealthCheck {
492 name: "disk".to_string(),
493 status: if disk_ok { "ok" } else { "fail" }.to_string(),
494 });
495 if !disk_ok {
496 any_fail = true;
497 }
498
499 let response = ReadinessResponse {
500 ready: !any_fail,
501 message: if any_fail {
502 "Service is not ready".to_string()
503 } else {
504 "Service is ready".to_string()
505 },
506 checks,
507 };
508
509 if any_fail {
510 Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
511 } else {
512 Ok(Json(response))
513 }
514}
515
516async fn liveness_check() -> Json<LivenessResponse> {
519 Json(LivenessResponse {
520 alive: true,
521 timestamp: chrono::Utc::now().to_rfc3339(),
522 })
523}
524
525async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
528 use std::sync::atomic::Ordering;
529
530 let uptime = state.server_state.uptime_seconds();
531 let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
532 let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
533 let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
534 let total_stream_events = state
535 .server_state
536 .total_stream_events
537 .load(Ordering::Relaxed);
538
539 let entries_per_second = if uptime > 0 {
540 total_entries as f64 / uptime as f64
541 } else {
542 0.0
543 };
544
545 let metrics = format!(
546 r#"# HELP synth_entries_generated_total Total number of journal entries generated
547# TYPE synth_entries_generated_total counter
548synth_entries_generated_total {}
549
550# HELP synth_anomalies_injected_total Total number of anomalies injected
551# TYPE synth_anomalies_injected_total counter
552synth_anomalies_injected_total {}
553
554# HELP synth_uptime_seconds Server uptime in seconds
555# TYPE synth_uptime_seconds gauge
556synth_uptime_seconds {}
557
558# HELP synth_entries_per_second Rate of entry generation
559# TYPE synth_entries_per_second gauge
560synth_entries_per_second {:.2}
561
562# HELP synth_active_streams Number of active streaming connections
563# TYPE synth_active_streams gauge
564synth_active_streams {}
565
566# HELP synth_stream_events_total Total events sent through streams
567# TYPE synth_stream_events_total counter
568synth_stream_events_total {}
569
570# HELP synth_info Server version information
571# TYPE synth_info gauge
572synth_info{{version="{}"}} 1
573"#,
574 total_entries,
575 total_anomalies,
576 uptime,
577 entries_per_second,
578 active_streams,
579 total_stream_events,
580 env!("CARGO_PKG_VERSION")
581 );
582
583 (
584 StatusCode::OK,
585 [(
586 header::CONTENT_TYPE,
587 "text/plain; version=0.0.4; charset=utf-8",
588 )],
589 metrics,
590 )
591}
592
593async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
595 let uptime = state.server_state.uptime_seconds();
596 let total_entries = state
597 .server_state
598 .total_entries
599 .load(std::sync::atomic::Ordering::Relaxed);
600
601 let entries_per_second = if uptime > 0 {
602 total_entries as f64 / uptime as f64
603 } else {
604 0.0
605 };
606
607 Json(MetricsResponse {
608 total_entries_generated: total_entries,
609 total_anomalies_injected: state
610 .server_state
611 .total_anomalies
612 .load(std::sync::atomic::Ordering::Relaxed),
613 uptime_seconds: uptime,
614 session_entries: total_entries,
615 session_entries_per_second: entries_per_second,
616 active_streams: state
617 .server_state
618 .active_streams
619 .load(std::sync::atomic::Ordering::Relaxed) as u32,
620 total_stream_events: state
621 .server_state
622 .total_stream_events
623 .load(std::sync::atomic::Ordering::Relaxed),
624 })
625}
626
627async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
629 let config = state.server_state.config.read().await;
630
631 Json(ConfigResponse {
632 success: true,
633 message: "Current configuration".to_string(),
634 config: Some(GenerationConfigDto {
635 industry: format!("{:?}", config.global.industry),
636 start_date: config.global.start_date.clone(),
637 period_months: config.global.period_months,
638 seed: config.global.seed,
639 coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
640 companies: config
641 .companies
642 .iter()
643 .map(|c| CompanyConfigDto {
644 code: c.code.clone(),
645 name: c.name.clone(),
646 currency: c.currency.clone(),
647 country: c.country.clone(),
648 annual_transaction_volume: c.annual_transaction_volume.count(),
649 volume_weight: c.volume_weight as f32,
650 })
651 .collect(),
652 fraud_enabled: config.fraud.enabled,
653 fraud_rate: config.fraud.fraud_rate as f32,
654 }),
655 })
656}
657
658async fn set_config(
660 State(state): State<AppState>,
661 Json(new_config): Json<GenerationConfigDto>,
662) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
663 use datasynth_config::schema::{CompanyConfig, TransactionVolume};
664 use datasynth_core::models::{CoAComplexity, IndustrySector};
665
666 info!(
667 "Configuration update requested: industry={}, period_months={}",
668 new_config.industry, new_config.period_months
669 );
670
671 let industry = match new_config.industry.to_lowercase().as_str() {
673 "manufacturing" => IndustrySector::Manufacturing,
674 "retail" => IndustrySector::Retail,
675 "financial_services" | "financialservices" => IndustrySector::FinancialServices,
676 "healthcare" => IndustrySector::Healthcare,
677 "technology" => IndustrySector::Technology,
678 "professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
679 "energy" => IndustrySector::Energy,
680 "transportation" => IndustrySector::Transportation,
681 "real_estate" | "realestate" => IndustrySector::RealEstate,
682 "telecommunications" => IndustrySector::Telecommunications,
683 _ => {
684 return Err((
685 StatusCode::BAD_REQUEST,
686 Json(ConfigResponse {
687 success: false,
688 message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
689 config: None,
690 }),
691 ));
692 }
693 };
694
695 let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
697 "small" => CoAComplexity::Small,
698 "medium" => CoAComplexity::Medium,
699 "large" => CoAComplexity::Large,
700 _ => {
701 return Err((
702 StatusCode::BAD_REQUEST,
703 Json(ConfigResponse {
704 success: false,
705 message: format!(
706 "Unknown CoA complexity: '{}'. Valid values: small, medium, large",
707 new_config.coa_complexity
708 ),
709 config: None,
710 }),
711 ));
712 }
713 };
714
715 let companies: Vec<CompanyConfig> = new_config
717 .companies
718 .iter()
719 .map(|c| CompanyConfig {
720 code: c.code.clone(),
721 name: c.name.clone(),
722 currency: c.currency.clone(),
723 functional_currency: None,
724 country: c.country.clone(),
725 fiscal_year_variant: "K4".to_string(),
726 annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
727 volume_weight: c.volume_weight as f64,
728 })
729 .collect();
730
731 let mut config = state.server_state.config.write().await;
733 config.global.industry = industry;
734 config.global.start_date = new_config.start_date.clone();
735 config.global.period_months = new_config.period_months;
736 config.global.seed = new_config.seed;
737 config.chart_of_accounts.complexity = complexity;
738 config.fraud.enabled = new_config.fraud_enabled;
739 config.fraud.fraud_rate = new_config.fraud_rate as f64;
740
741 if !companies.is_empty() {
743 config.companies = companies;
744 }
745
746 info!("Configuration updated successfully");
747
748 Ok(Json(ConfigResponse {
749 success: true,
750 message: "Configuration updated and applied".to_string(),
751 config: Some(new_config),
752 }))
753}
754
755async fn bulk_generate(
757 State(state): State<AppState>,
758 Json(req): Json<BulkGenerateRequest>,
759) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
760 const MAX_ENTRY_COUNT: u64 = 1_000_000;
762 if let Some(count) = req.entry_count {
763 if count > MAX_ENTRY_COUNT {
764 return Err((
765 StatusCode::BAD_REQUEST,
766 format!("entry_count ({count}) exceeds maximum allowed value ({MAX_ENTRY_COUNT})"),
767 ));
768 }
769 }
770
771 let config = state.server_state.config.read().await.clone();
772 let start_time = std::time::Instant::now();
773
774 let phase_config = {
775 let mut pc = PhaseConfig::from_config(&config);
776 pc.generate_master_data = req.include_master_data.unwrap_or(false);
777 pc.generate_document_flows = false;
778 pc.generate_journal_entries = true;
779 pc.inject_anomalies = req.inject_anomalies.unwrap_or(false);
780 pc.show_progress = false;
781 pc
782 };
783
784 let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
785 (
786 StatusCode::INTERNAL_SERVER_ERROR,
787 format!("Failed to create orchestrator: {e}"),
788 )
789 })?;
790
791 let result = orchestrator.generate().map_err(|e| {
792 (
793 StatusCode::INTERNAL_SERVER_ERROR,
794 format!("Generation failed: {e}"),
795 )
796 })?;
797
798 let duration_ms = start_time.elapsed().as_millis() as u64;
799 let entries_count = result.journal_entries.len() as u64;
800 let anomaly_count = result.anomaly_labels.labels.len() as u64;
801
802 state
804 .server_state
805 .total_entries
806 .fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
807 state
808 .server_state
809 .total_anomalies
810 .fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
811
812 Ok(Json(BulkGenerateResponse {
813 success: true,
814 entries_generated: entries_count,
815 duration_ms,
816 anomaly_count,
817 }))
818}
819
820async fn start_stream(
822 State(state): State<AppState>,
823 Json(req): Json<StreamRequest>,
824) -> Json<StreamResponse> {
825 if let Some(eps) = req.events_per_second {
827 info!("Stream configured: events_per_second={}", eps);
828 state
829 .server_state
830 .stream_events_per_second
831 .store(eps as u64, std::sync::atomic::Ordering::Relaxed);
832 }
833 if let Some(max) = req.max_events {
834 info!("Stream configured: max_events={}", max);
835 state
836 .server_state
837 .stream_max_events
838 .store(max, std::sync::atomic::Ordering::Relaxed);
839 }
840 if let Some(inject) = req.inject_anomalies {
841 info!("Stream configured: inject_anomalies={}", inject);
842 state
843 .server_state
844 .stream_inject_anomalies
845 .store(inject, std::sync::atomic::Ordering::Relaxed);
846 }
847
848 state
849 .server_state
850 .stream_stopped
851 .store(false, std::sync::atomic::Ordering::Relaxed);
852 state
853 .server_state
854 .stream_paused
855 .store(false, std::sync::atomic::Ordering::Relaxed);
856
857 Json(StreamResponse {
858 success: true,
859 message: "Stream started".to_string(),
860 })
861}
862
863async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
865 state
866 .server_state
867 .stream_stopped
868 .store(true, std::sync::atomic::Ordering::Relaxed);
869
870 Json(StreamResponse {
871 success: true,
872 message: "Stream stopped".to_string(),
873 })
874}
875
876async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
878 state
879 .server_state
880 .stream_paused
881 .store(true, std::sync::atomic::Ordering::Relaxed);
882
883 Json(StreamResponse {
884 success: true,
885 message: "Stream paused".to_string(),
886 })
887}
888
889async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
891 state
892 .server_state
893 .stream_paused
894 .store(false, std::sync::atomic::Ordering::Relaxed);
895
896 Json(StreamResponse {
897 success: true,
898 message: "Stream resumed".to_string(),
899 })
900}
901
902async fn trigger_pattern(
907 State(state): State<AppState>,
908 axum::extract::Path(pattern): axum::extract::Path<String>,
909) -> Json<StreamResponse> {
910 info!("Pattern trigger requested: {}", pattern);
911
912 let valid_patterns = [
914 "year_end_spike",
915 "period_end_spike",
916 "holiday_cluster",
917 "fraud_cluster",
918 "error_cluster",
919 "uniform",
920 ];
921
922 let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
923
924 if !is_valid {
925 return Json(StreamResponse {
926 success: false,
927 message: format!(
928 "Unknown pattern '{pattern}'. Valid patterns: {valid_patterns:?}, or use 'custom:name' for custom patterns"
929 ),
930 });
931 }
932
933 match state.server_state.triggered_pattern.try_write() {
935 Ok(mut triggered) => {
936 *triggered = Some(pattern.clone());
937 Json(StreamResponse {
938 success: true,
939 message: format!("Pattern '{pattern}' will be applied to upcoming entries"),
940 })
941 }
942 Err(_) => Json(StreamResponse {
943 success: false,
944 message: "Failed to acquire lock for pattern trigger".to_string(),
945 }),
946 }
947}
948
949struct ChannelPhaseSink {
953 tx: tokio::sync::mpsc::Sender<String>,
954 stats: Arc<std::sync::Mutex<datasynth_runtime::stream_pipeline::StreamStats>>,
955}
956
957impl ChannelPhaseSink {
958 fn new(tx: tokio::sync::mpsc::Sender<String>) -> Self {
959 Self {
960 tx,
961 stats: Arc::new(std::sync::Mutex::new(
962 datasynth_runtime::stream_pipeline::StreamStats::default(),
963 )),
964 }
965 }
966}
967
968impl datasynth_runtime::stream_pipeline::PhaseSink for ChannelPhaseSink {
969 fn emit(
970 &self,
971 phase: &str,
972 item_type: &str,
973 item: &serde_json::Value,
974 ) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
975 let envelope = serde_json::json!({
976 "phase": phase,
977 "item_type": item_type,
978 "data": item,
979 });
980 let json = serde_json::to_string(&envelope).map_err(|e| {
981 datasynth_runtime::stream_pipeline::StreamError::Serialization(e.to_string())
982 })?;
983
984 self.tx.blocking_send(json).map_err(|_| {
986 datasynth_runtime::stream_pipeline::StreamError::Connection(
987 "channel closed".to_string(),
988 )
989 })?;
990
991 if let Ok(mut stats) = self.stats.lock() {
992 stats.items_emitted += 1;
993 }
994 Ok(())
995 }
996
997 fn phase_complete(
998 &self,
999 _phase: &str,
1000 ) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
1001 if let Ok(mut stats) = self.stats.lock() {
1002 stats.phases_completed += 1;
1003 }
1004 Ok(())
1005 }
1006
1007 fn flush(&self) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
1008 Ok(())
1009 }
1010
1011 fn stats(&self) -> datasynth_runtime::stream_pipeline::StreamStats {
1012 self.stats.lock().map(|s| s.clone()).unwrap_or_default()
1013 }
1014}
1015
1016#[derive(Debug, Deserialize)]
1018struct NdjsonStreamQuery {
1019 #[serde(default)]
1021 rate: Option<f64>,
1022 #[serde(default)]
1024 burst: Option<u32>,
1025 #[serde(default)]
1027 progress_interval: Option<u64>,
1028}
1029
1030async fn stream_ndjson(
1046 State(state): State<AppState>,
1047 axum::extract::Query(params): axum::extract::Query<NdjsonStreamQuery>,
1048) -> impl IntoResponse {
1049 let config = state.server_state.config.read().await.clone();
1050 let rate = params.rate.unwrap_or(0.0);
1051 let burst = params.burst.unwrap_or(100);
1052 let progress_interval = params.progress_interval.unwrap_or(100);
1053
1054 let (tx, rx) = tokio::sync::mpsc::channel::<String>(1024);
1056
1057 tokio::task::spawn_blocking(move || {
1059 use datasynth_runtime::stream_pipeline::*;
1060
1061 let channel_sink = ChannelPhaseSink::new(tx.clone());
1063
1064 let pipeline: Box<dyn PhaseSink> = Box::new(RateLimitedPipeline::new(
1066 Box::new(channel_sink),
1067 rate,
1068 burst,
1069 progress_interval,
1070 ));
1071
1072 let mut phase_config = PhaseConfig::from_config(&config);
1074 phase_config.show_progress = false;
1075
1076 match EnhancedOrchestrator::new(config, phase_config) {
1077 Ok(mut orchestrator) => {
1078 orchestrator.set_phase_sink(pipeline);
1079 match orchestrator.generate() {
1080 Ok(result) => {
1081 let summary = serde_json::json!({
1083 "type": "_complete",
1084 "summary": {
1085 "total_entries": result.statistics.total_entries,
1086 "total_line_items": result.statistics.total_line_items,
1087 "anomaly_count": result.anomaly_labels.labels.len(),
1088 }
1089 });
1090 let _ =
1091 tx.blocking_send(serde_json::to_string(&summary).unwrap_or_default());
1092 }
1093 Err(e) => {
1094 let err = serde_json::json!({
1095 "type": "_error",
1096 "message": format!("Generation failed: {e}"),
1097 });
1098 let _ = tx.blocking_send(serde_json::to_string(&err).unwrap_or_default());
1099 }
1100 }
1101 }
1102 Err(e) => {
1103 let err = serde_json::json!({
1104 "type": "_error",
1105 "message": format!("Failed to create orchestrator: {e}"),
1106 });
1107 let _ = tx.blocking_send(serde_json::to_string(&err).unwrap_or_default());
1108 }
1109 }
1110 });
1112
1113 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
1115 let body = axum::body::Body::from_stream(tokio_stream::StreamExt::map(stream, |mut line| {
1116 line.push('\n');
1117 Ok::<_, std::convert::Infallible>(line)
1118 }));
1119
1120 axum::response::Response::builder()
1121 .header("Content-Type", "application/x-ndjson")
1122 .header("Transfer-Encoding", "chunked")
1123 .header("Cache-Control", "no-cache")
1124 .header("X-Content-Type-Options", "nosniff")
1125 .body(body)
1126 .unwrap_or_else(|_| {
1127 axum::response::Response::builder()
1128 .status(StatusCode::INTERNAL_SERVER_ERROR)
1129 .body(axum::body::Body::empty())
1130 .expect("fallback response")
1131 })
1132}
1133
1134async fn websocket_metrics(
1136 ws: WebSocketUpgrade,
1137 State(state): State<AppState>,
1138) -> impl IntoResponse {
1139 ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
1140}
1141
1142async fn websocket_events(
1144 ws: WebSocketUpgrade,
1145 State(state): State<AppState>,
1146) -> impl IntoResponse {
1147 ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
1148}
1149
1150async fn submit_job(
1156 State(state): State<AppState>,
1157 Json(request): Json<JobRequest>,
1158) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
1159 let queue = state.job_queue.as_ref().ok_or_else(|| {
1160 (
1161 StatusCode::SERVICE_UNAVAILABLE,
1162 Json(serde_json::json!({"error": "Job queue not enabled"})),
1163 )
1164 })?;
1165
1166 let job_id = queue.submit(request).await;
1167 info!("Job submitted: {}", job_id);
1168
1169 Ok((
1170 StatusCode::CREATED,
1171 Json(serde_json::json!({
1172 "id": job_id.to_string(),
1173 "status": "queued"
1174 })),
1175 ))
1176}
1177
1178async fn get_job(
1180 State(state): State<AppState>,
1181 axum::extract::Path(id): axum::extract::Path<String>,
1182) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1183 let queue = state.job_queue.as_ref().ok_or_else(|| {
1184 (
1185 StatusCode::SERVICE_UNAVAILABLE,
1186 Json(serde_json::json!({"error": "Job queue not enabled"})),
1187 )
1188 })?;
1189
1190 match queue.get(&id).await {
1191 Some(entry) => Ok(Json(serde_json::json!({
1192 "id": entry.id,
1193 "status": format!("{:?}", entry.status).to_lowercase(),
1194 "submitted_at": entry.submitted_at.to_rfc3339(),
1195 "started_at": entry.started_at.map(|t| t.to_rfc3339()),
1196 "completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
1197 "result": entry.result,
1198 }))),
1199 None => Err((
1200 StatusCode::NOT_FOUND,
1201 Json(serde_json::json!({"error": "Job not found"})),
1202 )),
1203 }
1204}
1205
1206async fn list_jobs(
1208 State(state): State<AppState>,
1209) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1210 let queue = state.job_queue.as_ref().ok_or_else(|| {
1211 (
1212 StatusCode::SERVICE_UNAVAILABLE,
1213 Json(serde_json::json!({"error": "Job queue not enabled"})),
1214 )
1215 })?;
1216
1217 let summaries: Vec<_> = queue
1218 .list()
1219 .await
1220 .into_iter()
1221 .map(|s| {
1222 serde_json::json!({
1223 "id": s.id,
1224 "status": format!("{:?}", s.status).to_lowercase(),
1225 "submitted_at": s.submitted_at.to_rfc3339(),
1226 })
1227 })
1228 .collect();
1229
1230 Ok(Json(serde_json::json!({ "jobs": summaries })))
1231}
1232
1233async fn cancel_job(
1235 State(state): State<AppState>,
1236 axum::extract::Path(id): axum::extract::Path<String>,
1237) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1238 let queue = state.job_queue.as_ref().ok_or_else(|| {
1239 (
1240 StatusCode::SERVICE_UNAVAILABLE,
1241 Json(serde_json::json!({"error": "Job queue not enabled"})),
1242 )
1243 })?;
1244
1245 if queue.cancel(&id).await {
1246 Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
1247 } else {
1248 Err((
1249 StatusCode::CONFLICT,
1250 Json(
1251 serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
1252 ),
1253 ))
1254 }
1255}
1256
1257async fn reload_config(
1263 State(state): State<AppState>,
1264) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
1265 let source = state.server_state.config_source.read().await.clone();
1266 match crate::config_loader::load_config(&source).await {
1267 Ok(new_config) => {
1268 let mut config = state.server_state.config.write().await;
1269 *config = new_config;
1270 info!("Configuration reloaded via REST API from {:?}", source);
1271 Ok(Json(serde_json::json!({
1272 "success": true,
1273 "message": "Configuration reloaded"
1274 })))
1275 }
1276 Err(e) => {
1277 error!("Failed to reload configuration: {}", e);
1278 Err((
1279 StatusCode::INTERNAL_SERVER_ERROR,
1280 Json(serde_json::json!({
1281 "success": false,
1282 "message": format!("Failed to reload configuration: {}", e)
1283 })),
1284 ))
1285 }
1286 }
1287}
1288
1289#[cfg(test)]
1290#[allow(clippy::unwrap_used)]
1291mod tests {
1292 use super::*;
1293
1294 #[test]
1299 fn test_health_response_serialization() {
1300 let response = HealthResponse {
1301 healthy: true,
1302 version: "0.1.0".to_string(),
1303 uptime_seconds: 100,
1304 };
1305 let json = serde_json::to_string(&response).unwrap();
1306 assert!(json.contains("healthy"));
1307 assert!(json.contains("version"));
1308 assert!(json.contains("uptime_seconds"));
1309 }
1310
1311 #[test]
1312 fn test_health_response_deserialization() {
1313 let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
1314 let response: HealthResponse = serde_json::from_str(json).unwrap();
1315 assert!(response.healthy);
1316 assert_eq!(response.version, "0.1.0");
1317 assert_eq!(response.uptime_seconds, 100);
1318 }
1319
1320 #[test]
1321 fn test_metrics_response_serialization() {
1322 let response = MetricsResponse {
1323 total_entries_generated: 1000,
1324 total_anomalies_injected: 10,
1325 uptime_seconds: 60,
1326 session_entries: 1000,
1327 session_entries_per_second: 16.67,
1328 active_streams: 1,
1329 total_stream_events: 500,
1330 };
1331 let json = serde_json::to_string(&response).unwrap();
1332 assert!(json.contains("total_entries_generated"));
1333 assert!(json.contains("session_entries_per_second"));
1334 }
1335
1336 #[test]
1337 fn test_metrics_response_deserialization() {
1338 let json = r#"{
1339 "total_entries_generated": 5000,
1340 "total_anomalies_injected": 50,
1341 "uptime_seconds": 300,
1342 "session_entries": 5000,
1343 "session_entries_per_second": 16.67,
1344 "active_streams": 2,
1345 "total_stream_events": 10000
1346 }"#;
1347 let response: MetricsResponse = serde_json::from_str(json).unwrap();
1348 assert_eq!(response.total_entries_generated, 5000);
1349 assert_eq!(response.active_streams, 2);
1350 }
1351
1352 #[test]
1353 fn test_config_response_serialization() {
1354 let response = ConfigResponse {
1355 success: true,
1356 message: "Configuration loaded".to_string(),
1357 config: Some(GenerationConfigDto {
1358 industry: "manufacturing".to_string(),
1359 start_date: "2024-01-01".to_string(),
1360 period_months: 12,
1361 seed: Some(42),
1362 coa_complexity: "medium".to_string(),
1363 companies: vec![],
1364 fraud_enabled: false,
1365 fraud_rate: 0.0,
1366 }),
1367 };
1368 let json = serde_json::to_string(&response).unwrap();
1369 assert!(json.contains("success"));
1370 assert!(json.contains("config"));
1371 }
1372
1373 #[test]
1374 fn test_config_response_without_config() {
1375 let response = ConfigResponse {
1376 success: false,
1377 message: "No configuration available".to_string(),
1378 config: None,
1379 };
1380 let json = serde_json::to_string(&response).unwrap();
1381 assert!(json.contains("null") || json.contains("config\":null"));
1382 }
1383
1384 #[test]
1385 fn test_generation_config_dto_roundtrip() {
1386 let original = GenerationConfigDto {
1387 industry: "retail".to_string(),
1388 start_date: "2024-06-01".to_string(),
1389 period_months: 6,
1390 seed: Some(12345),
1391 coa_complexity: "large".to_string(),
1392 companies: vec![CompanyConfigDto {
1393 code: "1000".to_string(),
1394 name: "Test Corp".to_string(),
1395 currency: "USD".to_string(),
1396 country: "US".to_string(),
1397 annual_transaction_volume: 100000,
1398 volume_weight: 1.0,
1399 }],
1400 fraud_enabled: true,
1401 fraud_rate: 0.05,
1402 };
1403
1404 let json = serde_json::to_string(&original).unwrap();
1405 let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
1406
1407 assert_eq!(original.industry, deserialized.industry);
1408 assert_eq!(original.seed, deserialized.seed);
1409 assert_eq!(original.companies.len(), deserialized.companies.len());
1410 }
1411
1412 #[test]
1413 fn test_company_config_dto_serialization() {
1414 let company = CompanyConfigDto {
1415 code: "2000".to_string(),
1416 name: "European Subsidiary".to_string(),
1417 currency: "EUR".to_string(),
1418 country: "DE".to_string(),
1419 annual_transaction_volume: 50000,
1420 volume_weight: 0.5,
1421 };
1422 let json = serde_json::to_string(&company).unwrap();
1423 assert!(json.contains("2000"));
1424 assert!(json.contains("EUR"));
1425 assert!(json.contains("DE"));
1426 }
1427
1428 #[test]
1429 fn test_bulk_generate_request_deserialization() {
1430 let json = r#"{
1431 "entry_count": 5000,
1432 "include_master_data": true,
1433 "inject_anomalies": true
1434 }"#;
1435 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1436 assert_eq!(request.entry_count, Some(5000));
1437 assert_eq!(request.include_master_data, Some(true));
1438 assert_eq!(request.inject_anomalies, Some(true));
1439 }
1440
1441 #[test]
1442 fn test_bulk_generate_request_with_defaults() {
1443 let json = r#"{}"#;
1444 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1445 assert_eq!(request.entry_count, None);
1446 assert_eq!(request.include_master_data, None);
1447 assert_eq!(request.inject_anomalies, None);
1448 }
1449
1450 #[test]
1451 fn test_bulk_generate_response_serialization() {
1452 let response = BulkGenerateResponse {
1453 success: true,
1454 entries_generated: 1000,
1455 duration_ms: 250,
1456 anomaly_count: 20,
1457 };
1458 let json = serde_json::to_string(&response).unwrap();
1459 assert!(json.contains("entries_generated"));
1460 assert!(json.contains("1000"));
1461 assert!(json.contains("duration_ms"));
1462 }
1463
1464 #[test]
1465 fn test_stream_response_serialization() {
1466 let response = StreamResponse {
1467 success: true,
1468 message: "Stream started successfully".to_string(),
1469 };
1470 let json = serde_json::to_string(&response).unwrap();
1471 assert!(json.contains("success"));
1472 assert!(json.contains("Stream started"));
1473 }
1474
1475 #[test]
1476 fn test_stream_response_failure() {
1477 let response = StreamResponse {
1478 success: false,
1479 message: "Stream failed to start".to_string(),
1480 };
1481 let json = serde_json::to_string(&response).unwrap();
1482 assert!(json.contains("false"));
1483 assert!(json.contains("failed"));
1484 }
1485
1486 #[test]
1491 fn test_cors_config_default() {
1492 let config = CorsConfig::default();
1493 assert!(!config.allow_any_origin);
1494 assert!(!config.allowed_origins.is_empty());
1495 assert!(config
1496 .allowed_origins
1497 .contains(&"http://localhost:5173".to_string()));
1498 assert!(config
1499 .allowed_origins
1500 .contains(&"tauri://localhost".to_string()));
1501 }
1502
1503 #[test]
1504 fn test_cors_config_custom_origins() {
1505 let config = CorsConfig {
1506 allowed_origins: vec![
1507 "https://example.com".to_string(),
1508 "https://app.example.com".to_string(),
1509 ],
1510 allow_any_origin: false,
1511 };
1512 assert_eq!(config.allowed_origins.len(), 2);
1513 assert!(config
1514 .allowed_origins
1515 .contains(&"https://example.com".to_string()));
1516 }
1517
1518 #[test]
1519 fn test_cors_config_permissive() {
1520 let config = CorsConfig {
1521 allowed_origins: vec![],
1522 allow_any_origin: true,
1523 };
1524 assert!(config.allow_any_origin);
1525 }
1526
1527 #[test]
1532 fn test_bulk_generate_request_partial() {
1533 let json = r#"{"entry_count": 100}"#;
1534 let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
1535 assert_eq!(request.entry_count, Some(100));
1536 assert!(request.include_master_data.is_none());
1537 }
1538
1539 #[test]
1540 fn test_generation_config_no_seed() {
1541 let config = GenerationConfigDto {
1542 industry: "technology".to_string(),
1543 start_date: "2024-01-01".to_string(),
1544 period_months: 3,
1545 seed: None,
1546 coa_complexity: "small".to_string(),
1547 companies: vec![],
1548 fraud_enabled: false,
1549 fraud_rate: 0.0,
1550 };
1551 let json = serde_json::to_string(&config).unwrap();
1552 assert!(json.contains("seed"));
1553 }
1554
1555 #[test]
1556 fn test_generation_config_multiple_companies() {
1557 let config = GenerationConfigDto {
1558 industry: "manufacturing".to_string(),
1559 start_date: "2024-01-01".to_string(),
1560 period_months: 12,
1561 seed: Some(42),
1562 coa_complexity: "large".to_string(),
1563 companies: vec![
1564 CompanyConfigDto {
1565 code: "1000".to_string(),
1566 name: "Headquarters".to_string(),
1567 currency: "USD".to_string(),
1568 country: "US".to_string(),
1569 annual_transaction_volume: 100000,
1570 volume_weight: 1.0,
1571 },
1572 CompanyConfigDto {
1573 code: "2000".to_string(),
1574 name: "European Sub".to_string(),
1575 currency: "EUR".to_string(),
1576 country: "DE".to_string(),
1577 annual_transaction_volume: 50000,
1578 volume_weight: 0.5,
1579 },
1580 CompanyConfigDto {
1581 code: "3000".to_string(),
1582 name: "APAC Sub".to_string(),
1583 currency: "JPY".to_string(),
1584 country: "JP".to_string(),
1585 annual_transaction_volume: 30000,
1586 volume_weight: 0.3,
1587 },
1588 ],
1589 fraud_enabled: true,
1590 fraud_rate: 0.02,
1591 };
1592 assert_eq!(config.companies.len(), 3);
1593 }
1594
1595 #[test]
1600 fn test_metrics_entries_per_second_calculation() {
1601 let total_entries: u64 = 1000;
1603 let uptime: u64 = 60;
1604 let eps = if uptime > 0 {
1605 total_entries as f64 / uptime as f64
1606 } else {
1607 0.0
1608 };
1609 assert!((eps - 16.67).abs() < 0.1);
1610 }
1611
1612 #[test]
1613 fn test_metrics_entries_per_second_zero_uptime() {
1614 let total_entries: u64 = 1000;
1615 let uptime: u64 = 0;
1616 let eps = if uptime > 0 {
1617 total_entries as f64 / uptime as f64
1618 } else {
1619 0.0
1620 };
1621 assert_eq!(eps, 0.0);
1622 }
1623}