use std::sync::Arc;
use std::time::Duration;
use axum::{
extract::{State, WebSocketUpgrade},
http::{header, Method, StatusCode},
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::timeout::TimeoutLayer;
use tracing::{error, info};
use crate::grpc::service::{ServerState, SynthService};
use crate::jobs::{JobQueue, JobRequest};
use datasynth_runtime::{EnhancedOrchestrator, PhaseConfig};
use super::websocket;
#[derive(Clone)]
pub struct AppState {
pub server_state: Arc<ServerState>,
pub job_queue: Option<Arc<JobQueue>>,
}
#[derive(Clone, Debug)]
pub struct TimeoutConfig {
pub request_timeout_secs: u64,
}
impl Default for TimeoutConfig {
fn default() -> Self {
Self {
request_timeout_secs: 300,
}
}
}
impl TimeoutConfig {
pub fn new(timeout_secs: u64) -> Self {
Self {
request_timeout_secs: timeout_secs,
}
}
}
#[derive(Clone)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allow_any_origin: bool,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allowed_origins: vec![
"http://localhost:5173".to_string(), "http://localhost:3000".to_string(), "http://127.0.0.1:5173".to_string(),
"http://127.0.0.1:3000".to_string(),
"tauri://localhost".to_string(), ],
allow_any_origin: false,
}
}
}
async fn api_version_header(response: axum::response::Response) -> axum::response::Response {
let (mut parts, body) = response.into_parts();
parts.headers.insert(
axum::http::HeaderName::from_static("x-api-version"),
axum::http::HeaderValue::from_static("v1"),
);
axum::response::Response::from_parts(parts, body)
}
use super::auth::{auth_middleware, AuthConfig};
use super::rate_limit::RateLimitConfig;
use super::rate_limit_backend::{backend_rate_limit_middleware, RateLimitBackend};
use super::request_id::request_id_middleware;
use super::request_validation::request_validation_middleware;
use super::security_headers::security_headers_middleware;
pub fn create_router(service: SynthService) -> Router {
create_router_with_cors(service, CorsConfig::default())
}
pub fn create_router_full(
service: SynthService,
cors_config: CorsConfig,
auth_config: AuthConfig,
rate_limit_config: RateLimitConfig,
timeout_config: TimeoutConfig,
) -> Router {
let backend = RateLimitBackend::in_memory(rate_limit_config);
create_router_full_with_backend(service, cors_config, auth_config, backend, timeout_config)
}
pub fn create_router_full_with_backend(
service: SynthService,
cors_config: CorsConfig,
auth_config: AuthConfig,
rate_limit_backend: RateLimitBackend,
timeout_config: TimeoutConfig,
) -> Router {
let server_state = service.state.clone();
let state = AppState {
server_state,
job_queue: None,
};
let cors = if cors_config.allow_any_origin {
CorsLayer::permissive()
} else {
let origins: Vec<_> = cors_config
.allowed_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
};
Router::new()
.route("/health", get(health_check))
.route("/ready", get(readiness_check))
.route("/live", get(liveness_check))
.route("/api/metrics", get(get_metrics))
.route("/metrics", get(prometheus_metrics))
.route("/api/config", get(get_config))
.route("/api/config", post(set_config))
.route("/api/config/reload", post(reload_config))
.route("/api/generate/bulk", post(bulk_generate))
.route("/api/stream/start", post(start_stream))
.route("/api/stream/stop", post(stop_stream))
.route("/api/stream/pause", post(pause_stream))
.route("/api/stream/resume", post(resume_stream))
.route("/api/stream/trigger/{pattern}", post(trigger_pattern))
.route("/api/stream/ndjson", get(stream_ndjson))
.route("/api/jobs/submit", post(submit_job))
.route("/api/jobs", get(list_jobs))
.route("/api/jobs/{id}", get(get_job))
.route("/api/jobs/{id}/cancel", post(cancel_job))
.route("/ws/metrics", get(websocket_metrics))
.route("/ws/events", get(websocket_events))
.layer(axum::middleware::from_fn(security_headers_middleware))
.layer(axum::middleware::map_response(api_version_header))
.layer(cors)
.layer(axum::middleware::from_fn(request_id_middleware))
.layer(axum::middleware::from_fn(auth_middleware))
.layer(axum::Extension(auth_config))
.layer(axum::middleware::from_fn(request_validation_middleware))
.layer(axum::middleware::from_fn(backend_rate_limit_middleware))
.layer(axum::Extension(rate_limit_backend))
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(timeout_config.request_timeout_secs),
))
.with_state(state)
}
pub fn create_router_with_auth(
service: SynthService,
cors_config: CorsConfig,
auth_config: AuthConfig,
) -> Router {
let server_state = service.state.clone();
let state = AppState {
server_state,
job_queue: None,
};
let cors = if cors_config.allow_any_origin {
CorsLayer::permissive()
} else {
let origins: Vec<_> = cors_config
.allowed_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
};
Router::new()
.route("/health", get(health_check))
.route("/ready", get(readiness_check))
.route("/live", get(liveness_check))
.route("/api/metrics", get(get_metrics))
.route("/metrics", get(prometheus_metrics))
.route("/api/config", get(get_config))
.route("/api/config", post(set_config))
.route("/api/config/reload", post(reload_config))
.route("/api/generate/bulk", post(bulk_generate))
.route("/api/stream/start", post(start_stream))
.route("/api/stream/stop", post(stop_stream))
.route("/api/stream/pause", post(pause_stream))
.route("/api/stream/resume", post(resume_stream))
.route("/api/stream/trigger/{pattern}", post(trigger_pattern))
.route("/api/stream/ndjson", get(stream_ndjson))
.route("/api/jobs/submit", post(submit_job))
.route("/api/jobs", get(list_jobs))
.route("/api/jobs/{id}", get(get_job))
.route("/api/jobs/{id}/cancel", post(cancel_job))
.route("/ws/metrics", get(websocket_metrics))
.route("/ws/events", get(websocket_events))
.layer(axum::middleware::from_fn(auth_middleware))
.layer(axum::Extension(auth_config))
.layer(cors)
.with_state(state)
}
pub fn create_router_with_cors(service: SynthService, cors_config: CorsConfig) -> Router {
let server_state = service.state.clone();
let state = AppState {
server_state,
job_queue: None,
};
let cors = if cors_config.allow_any_origin {
CorsLayer::permissive()
} else {
let origins: Vec<_> = cors_config
.allowed_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT])
};
Router::new()
.route("/health", get(health_check))
.route("/ready", get(readiness_check))
.route("/live", get(liveness_check))
.route("/api/metrics", get(get_metrics))
.route("/metrics", get(prometheus_metrics))
.route("/api/config", get(get_config))
.route("/api/config", post(set_config))
.route("/api/config/reload", post(reload_config))
.route("/api/generate/bulk", post(bulk_generate))
.route("/api/stream/start", post(start_stream))
.route("/api/stream/stop", post(stop_stream))
.route("/api/stream/pause", post(pause_stream))
.route("/api/stream/resume", post(resume_stream))
.route("/api/stream/trigger/{pattern}", post(trigger_pattern))
.route("/api/stream/ndjson", get(stream_ndjson))
.route("/api/jobs/submit", post(submit_job))
.route("/api/jobs", get(list_jobs))
.route("/api/jobs/{id}", get(get_job))
.route("/api/jobs/{id}/cancel", post(cancel_job))
.route("/ws/metrics", get(websocket_metrics))
.route("/ws/events", get(websocket_events))
.layer(cors)
.with_state(state)
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HealthResponse {
pub healthy: bool,
pub version: String,
pub uptime_seconds: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ReadinessResponse {
pub ready: bool,
pub message: String,
pub checks: Vec<HealthCheck>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HealthCheck {
pub name: String,
pub status: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LivenessResponse {
pub alive: bool,
pub timestamp: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct MetricsResponse {
pub total_entries_generated: u64,
pub total_anomalies_injected: u64,
pub uptime_seconds: u64,
pub session_entries: u64,
pub session_entries_per_second: f64,
pub active_streams: u32,
pub total_stream_events: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigResponse {
pub success: bool,
pub message: String,
pub config: Option<GenerationConfigDto>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationConfigDto {
pub industry: String,
pub start_date: String,
pub period_months: u32,
pub seed: Option<u64>,
pub coa_complexity: String,
pub companies: Vec<CompanyConfigDto>,
pub fraud_enabled: bool,
pub fraud_rate: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompanyConfigDto {
pub code: String,
pub name: String,
pub currency: String,
pub country: String,
pub annual_transaction_volume: u64,
pub volume_weight: f32,
}
#[derive(Debug, Deserialize)]
pub struct BulkGenerateRequest {
pub entry_count: Option<u64>,
pub include_master_data: Option<bool>,
pub inject_anomalies: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct BulkGenerateResponse {
pub success: bool,
pub entries_generated: u64,
pub duration_ms: u64,
pub anomaly_count: u64,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] pub struct StreamRequest {
pub events_per_second: Option<u32>,
pub max_events: Option<u64>,
pub inject_anomalies: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct StreamResponse {
pub success: bool,
pub message: String,
}
async fn health_check(State(state): State<AppState>) -> Json<HealthResponse> {
Json(HealthResponse {
healthy: true,
version: env!("CARGO_PKG_VERSION").to_string(),
uptime_seconds: state.server_state.uptime_seconds(),
})
}
async fn readiness_check(
State(state): State<AppState>,
) -> Result<Json<ReadinessResponse>, (StatusCode, Json<ReadinessResponse>)> {
let mut checks = Vec::new();
let mut any_fail = false;
let config = state.server_state.config.read().await;
let config_valid = !config.companies.is_empty();
checks.push(HealthCheck {
name: "config".to_string(),
status: if config_valid { "ok" } else { "fail" }.to_string(),
});
if !config_valid {
any_fail = true;
}
drop(config);
let resource_status = state.server_state.resource_status();
let memory_status = if resource_status.degradation_level == "Emergency" {
any_fail = true;
"fail"
} else if resource_status.degradation_level != "Normal" {
"degraded"
} else {
"ok"
};
checks.push(HealthCheck {
name: "memory".to_string(),
status: memory_status.to_string(),
});
let disk_ok = resource_status.disk_available_mb > 100;
checks.push(HealthCheck {
name: "disk".to_string(),
status: if disk_ok { "ok" } else { "fail" }.to_string(),
});
if !disk_ok {
any_fail = true;
}
let response = ReadinessResponse {
ready: !any_fail,
message: if any_fail {
"Service is not ready".to_string()
} else {
"Service is ready".to_string()
},
checks,
};
if any_fail {
Err((StatusCode::SERVICE_UNAVAILABLE, Json(response)))
} else {
Ok(Json(response))
}
}
async fn liveness_check() -> Json<LivenessResponse> {
Json(LivenessResponse {
alive: true,
timestamp: chrono::Utc::now().to_rfc3339(),
})
}
async fn prometheus_metrics(State(state): State<AppState>) -> impl IntoResponse {
use std::sync::atomic::Ordering;
let uptime = state.server_state.uptime_seconds();
let total_entries = state.server_state.total_entries.load(Ordering::Relaxed);
let total_anomalies = state.server_state.total_anomalies.load(Ordering::Relaxed);
let active_streams = state.server_state.active_streams.load(Ordering::Relaxed);
let total_stream_events = state
.server_state
.total_stream_events
.load(Ordering::Relaxed);
let entries_per_second = if uptime > 0 {
total_entries as f64 / uptime as f64
} else {
0.0
};
let metrics = format!(
r#"# HELP synth_entries_generated_total Total number of journal entries generated
# TYPE synth_entries_generated_total counter
synth_entries_generated_total {}
# HELP synth_anomalies_injected_total Total number of anomalies injected
# TYPE synth_anomalies_injected_total counter
synth_anomalies_injected_total {}
# HELP synth_uptime_seconds Server uptime in seconds
# TYPE synth_uptime_seconds gauge
synth_uptime_seconds {}
# HELP synth_entries_per_second Rate of entry generation
# TYPE synth_entries_per_second gauge
synth_entries_per_second {:.2}
# HELP synth_active_streams Number of active streaming connections
# TYPE synth_active_streams gauge
synth_active_streams {}
# HELP synth_stream_events_total Total events sent through streams
# TYPE synth_stream_events_total counter
synth_stream_events_total {}
# HELP synth_info Server version information
# TYPE synth_info gauge
synth_info{{version="{}"}} 1
"#,
total_entries,
total_anomalies,
uptime,
entries_per_second,
active_streams,
total_stream_events,
env!("CARGO_PKG_VERSION")
);
(
StatusCode::OK,
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
metrics,
)
}
async fn get_metrics(State(state): State<AppState>) -> Json<MetricsResponse> {
let uptime = state.server_state.uptime_seconds();
let total_entries = state
.server_state
.total_entries
.load(std::sync::atomic::Ordering::Relaxed);
let entries_per_second = if uptime > 0 {
total_entries as f64 / uptime as f64
} else {
0.0
};
Json(MetricsResponse {
total_entries_generated: total_entries,
total_anomalies_injected: state
.server_state
.total_anomalies
.load(std::sync::atomic::Ordering::Relaxed),
uptime_seconds: uptime,
session_entries: total_entries,
session_entries_per_second: entries_per_second,
active_streams: state
.server_state
.active_streams
.load(std::sync::atomic::Ordering::Relaxed) as u32,
total_stream_events: state
.server_state
.total_stream_events
.load(std::sync::atomic::Ordering::Relaxed),
})
}
async fn get_config(State(state): State<AppState>) -> Json<ConfigResponse> {
let config = state.server_state.config.read().await;
Json(ConfigResponse {
success: true,
message: "Current configuration".to_string(),
config: Some(GenerationConfigDto {
industry: format!("{:?}", config.global.industry),
start_date: config.global.start_date.clone(),
period_months: config.global.period_months,
seed: config.global.seed,
coa_complexity: format!("{:?}", config.chart_of_accounts.complexity),
companies: config
.companies
.iter()
.map(|c| CompanyConfigDto {
code: c.code.clone(),
name: c.name.clone(),
currency: c.currency.clone(),
country: c.country.clone(),
annual_transaction_volume: c.annual_transaction_volume.count(),
volume_weight: c.volume_weight as f32,
})
.collect(),
fraud_enabled: config.fraud.enabled,
fraud_rate: config.fraud.fraud_rate as f32,
}),
})
}
async fn set_config(
State(state): State<AppState>,
Json(new_config): Json<GenerationConfigDto>,
) -> Result<Json<ConfigResponse>, (StatusCode, Json<ConfigResponse>)> {
use datasynth_config::schema::{CompanyConfig, TransactionVolume};
use datasynth_core::models::{CoAComplexity, IndustrySector};
info!(
"Configuration update requested: industry={}, period_months={}",
new_config.industry, new_config.period_months
);
let industry = match new_config.industry.to_lowercase().as_str() {
"manufacturing" => IndustrySector::Manufacturing,
"retail" => IndustrySector::Retail,
"financial_services" | "financialservices" => IndustrySector::FinancialServices,
"healthcare" => IndustrySector::Healthcare,
"technology" => IndustrySector::Technology,
"professional_services" | "professionalservices" => IndustrySector::ProfessionalServices,
"energy" => IndustrySector::Energy,
"transportation" => IndustrySector::Transportation,
"real_estate" | "realestate" => IndustrySector::RealEstate,
"telecommunications" => IndustrySector::Telecommunications,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(ConfigResponse {
success: false,
message: format!("Unknown industry: '{}'. Valid values: manufacturing, retail, financial_services, healthcare, technology, professional_services, energy, transportation, real_estate, telecommunications", new_config.industry),
config: None,
}),
));
}
};
let complexity = match new_config.coa_complexity.to_lowercase().as_str() {
"small" => CoAComplexity::Small,
"medium" => CoAComplexity::Medium,
"large" => CoAComplexity::Large,
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(ConfigResponse {
success: false,
message: format!(
"Unknown CoA complexity: '{}'. Valid values: small, medium, large",
new_config.coa_complexity
),
config: None,
}),
));
}
};
let companies: Vec<CompanyConfig> = new_config
.companies
.iter()
.map(|c| CompanyConfig {
code: c.code.clone(),
name: c.name.clone(),
currency: c.currency.clone(),
functional_currency: None,
country: c.country.clone(),
fiscal_year_variant: "K4".to_string(),
annual_transaction_volume: TransactionVolume::Custom(c.annual_transaction_volume),
volume_weight: c.volume_weight as f64,
})
.collect();
let mut config = state.server_state.config.write().await;
config.global.industry = industry;
config.global.start_date = new_config.start_date.clone();
config.global.period_months = new_config.period_months;
config.global.seed = new_config.seed;
config.chart_of_accounts.complexity = complexity;
config.fraud.enabled = new_config.fraud_enabled;
config.fraud.fraud_rate = new_config.fraud_rate as f64;
if !companies.is_empty() {
config.companies = companies;
}
info!("Configuration updated successfully");
Ok(Json(ConfigResponse {
success: true,
message: "Configuration updated and applied".to_string(),
config: Some(new_config),
}))
}
async fn bulk_generate(
State(state): State<AppState>,
Json(req): Json<BulkGenerateRequest>,
) -> Result<Json<BulkGenerateResponse>, (StatusCode, String)> {
const MAX_ENTRY_COUNT: u64 = 1_000_000;
if let Some(count) = req.entry_count {
if count > MAX_ENTRY_COUNT {
return Err((
StatusCode::BAD_REQUEST,
format!("entry_count ({count}) exceeds maximum allowed value ({MAX_ENTRY_COUNT})"),
));
}
}
let config = state.server_state.config.read().await.clone();
let start_time = std::time::Instant::now();
let phase_config = {
let mut pc = PhaseConfig::from_config(&config);
pc.generate_master_data = req.include_master_data.unwrap_or(false);
pc.generate_document_flows = false;
pc.generate_journal_entries = true;
pc.inject_anomalies = req.inject_anomalies.unwrap_or(false);
pc.show_progress = false;
pc
};
let mut orchestrator = EnhancedOrchestrator::new(config, phase_config).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create orchestrator: {e}"),
)
})?;
let result = orchestrator.generate().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Generation failed: {e}"),
)
})?;
let duration_ms = start_time.elapsed().as_millis() as u64;
let entries_count = result.journal_entries.len() as u64;
let anomaly_count = result.anomaly_labels.labels.len() as u64;
state
.server_state
.total_entries
.fetch_add(entries_count, std::sync::atomic::Ordering::Relaxed);
state
.server_state
.total_anomalies
.fetch_add(anomaly_count, std::sync::atomic::Ordering::Relaxed);
Ok(Json(BulkGenerateResponse {
success: true,
entries_generated: entries_count,
duration_ms,
anomaly_count,
}))
}
async fn start_stream(
State(state): State<AppState>,
Json(req): Json<StreamRequest>,
) -> Json<StreamResponse> {
if let Some(eps) = req.events_per_second {
info!("Stream configured: events_per_second={}", eps);
state
.server_state
.stream_events_per_second
.store(eps as u64, std::sync::atomic::Ordering::Relaxed);
}
if let Some(max) = req.max_events {
info!("Stream configured: max_events={}", max);
state
.server_state
.stream_max_events
.store(max, std::sync::atomic::Ordering::Relaxed);
}
if let Some(inject) = req.inject_anomalies {
info!("Stream configured: inject_anomalies={}", inject);
state
.server_state
.stream_inject_anomalies
.store(inject, std::sync::atomic::Ordering::Relaxed);
}
state
.server_state
.stream_stopped
.store(false, std::sync::atomic::Ordering::Relaxed);
state
.server_state
.stream_paused
.store(false, std::sync::atomic::Ordering::Relaxed);
Json(StreamResponse {
success: true,
message: "Stream started".to_string(),
})
}
async fn stop_stream(State(state): State<AppState>) -> Json<StreamResponse> {
state
.server_state
.stream_stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
Json(StreamResponse {
success: true,
message: "Stream stopped".to_string(),
})
}
async fn pause_stream(State(state): State<AppState>) -> Json<StreamResponse> {
state
.server_state
.stream_paused
.store(true, std::sync::atomic::Ordering::Relaxed);
Json(StreamResponse {
success: true,
message: "Stream paused".to_string(),
})
}
async fn resume_stream(State(state): State<AppState>) -> Json<StreamResponse> {
state
.server_state
.stream_paused
.store(false, std::sync::atomic::Ordering::Relaxed);
Json(StreamResponse {
success: true,
message: "Stream resumed".to_string(),
})
}
async fn trigger_pattern(
State(state): State<AppState>,
axum::extract::Path(pattern): axum::extract::Path<String>,
) -> Json<StreamResponse> {
info!("Pattern trigger requested: {}", pattern);
let valid_patterns = [
"year_end_spike",
"period_end_spike",
"holiday_cluster",
"fraud_cluster",
"error_cluster",
"uniform",
];
let is_valid = valid_patterns.contains(&pattern.as_str()) || pattern.starts_with("custom:");
if !is_valid {
return Json(StreamResponse {
success: false,
message: format!(
"Unknown pattern '{pattern}'. Valid patterns: {valid_patterns:?}, or use 'custom:name' for custom patterns"
),
});
}
match state.server_state.triggered_pattern.try_write() {
Ok(mut triggered) => {
*triggered = Some(pattern.clone());
Json(StreamResponse {
success: true,
message: format!("Pattern '{pattern}' will be applied to upcoming entries"),
})
}
Err(_) => Json(StreamResponse {
success: false,
message: "Failed to acquire lock for pattern trigger".to_string(),
}),
}
}
struct ChannelPhaseSink {
tx: tokio::sync::mpsc::Sender<String>,
stats: Arc<std::sync::Mutex<datasynth_runtime::stream_pipeline::StreamStats>>,
}
impl ChannelPhaseSink {
fn new(tx: tokio::sync::mpsc::Sender<String>) -> Self {
Self {
tx,
stats: Arc::new(std::sync::Mutex::new(
datasynth_runtime::stream_pipeline::StreamStats::default(),
)),
}
}
}
impl datasynth_runtime::stream_pipeline::PhaseSink for ChannelPhaseSink {
fn emit(
&self,
phase: &str,
item_type: &str,
item: &serde_json::Value,
) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
let envelope = serde_json::json!({
"phase": phase,
"item_type": item_type,
"data": item,
});
let json = serde_json::to_string(&envelope).map_err(|e| {
datasynth_runtime::stream_pipeline::StreamError::Serialization(e.to_string())
})?;
self.tx.blocking_send(json).map_err(|_| {
datasynth_runtime::stream_pipeline::StreamError::Connection(
"channel closed".to_string(),
)
})?;
if let Ok(mut stats) = self.stats.lock() {
stats.items_emitted += 1;
}
Ok(())
}
fn phase_complete(
&self,
_phase: &str,
) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
if let Ok(mut stats) = self.stats.lock() {
stats.phases_completed += 1;
}
Ok(())
}
fn flush(&self) -> Result<(), datasynth_runtime::stream_pipeline::StreamError> {
Ok(())
}
fn stats(&self) -> datasynth_runtime::stream_pipeline::StreamStats {
self.stats.lock().map(|s| s.clone()).unwrap_or_default()
}
}
#[derive(Debug, Deserialize)]
struct NdjsonStreamQuery {
#[serde(default)]
rate: Option<f64>,
#[serde(default)]
burst: Option<u32>,
#[serde(default)]
progress_interval: Option<u64>,
}
async fn stream_ndjson(
State(state): State<AppState>,
axum::extract::Query(params): axum::extract::Query<NdjsonStreamQuery>,
) -> impl IntoResponse {
let config = state.server_state.config.read().await.clone();
let rate = params.rate.unwrap_or(0.0);
let burst = params.burst.unwrap_or(100);
let progress_interval = params.progress_interval.unwrap_or(100);
let (tx, rx) = tokio::sync::mpsc::channel::<String>(1024);
tokio::task::spawn_blocking(move || {
use datasynth_runtime::stream_pipeline::*;
let channel_sink = ChannelPhaseSink::new(tx.clone());
let pipeline: Box<dyn PhaseSink> = Box::new(RateLimitedPipeline::new(
Box::new(channel_sink),
rate,
burst,
progress_interval,
));
let mut phase_config = PhaseConfig::from_config(&config);
phase_config.show_progress = false;
match EnhancedOrchestrator::new(config, phase_config) {
Ok(mut orchestrator) => {
orchestrator.set_phase_sink(pipeline);
match orchestrator.generate() {
Ok(result) => {
let summary = serde_json::json!({
"type": "_complete",
"summary": {
"total_entries": result.statistics.total_entries,
"total_line_items": result.statistics.total_line_items,
"anomaly_count": result.anomaly_labels.labels.len(),
}
});
let _ =
tx.blocking_send(serde_json::to_string(&summary).unwrap_or_default());
}
Err(e) => {
let err = serde_json::json!({
"type": "_error",
"message": format!("Generation failed: {e}"),
});
let _ = tx.blocking_send(serde_json::to_string(&err).unwrap_or_default());
}
}
}
Err(e) => {
let err = serde_json::json!({
"type": "_error",
"message": format!("Failed to create orchestrator: {e}"),
});
let _ = tx.blocking_send(serde_json::to_string(&err).unwrap_or_default());
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let body = axum::body::Body::from_stream(tokio_stream::StreamExt::map(stream, |mut line| {
line.push('\n');
Ok::<_, std::convert::Infallible>(line)
}));
axum::response::Response::builder()
.header("Content-Type", "application/x-ndjson")
.header("Transfer-Encoding", "chunked")
.header("Cache-Control", "no-cache")
.header("X-Content-Type-Options", "nosniff")
.body(body)
.unwrap_or_else(|_| {
axum::response::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(axum::body::Body::empty())
.expect("fallback response")
})
}
async fn websocket_metrics(
ws: WebSocketUpgrade,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| websocket::handle_metrics_socket(socket, state))
}
async fn websocket_events(
ws: WebSocketUpgrade,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| websocket::handle_events_socket(socket, state))
}
async fn submit_job(
State(state): State<AppState>,
Json(request): Json<JobRequest>,
) -> Result<(StatusCode, Json<serde_json::Value>), (StatusCode, Json<serde_json::Value>)> {
let queue = state.job_queue.as_ref().ok_or_else(|| {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"error": "Job queue not enabled"})),
)
})?;
let job_id = queue.submit(request).await;
info!("Job submitted: {}", job_id);
Ok((
StatusCode::CREATED,
Json(serde_json::json!({
"id": job_id.to_string(),
"status": "queued"
})),
))
}
async fn get_job(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let queue = state.job_queue.as_ref().ok_or_else(|| {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"error": "Job queue not enabled"})),
)
})?;
match queue.get(&id).await {
Some(entry) => Ok(Json(serde_json::json!({
"id": entry.id,
"status": format!("{:?}", entry.status).to_lowercase(),
"submitted_at": entry.submitted_at.to_rfc3339(),
"started_at": entry.started_at.map(|t| t.to_rfc3339()),
"completed_at": entry.completed_at.map(|t| t.to_rfc3339()),
"result": entry.result,
}))),
None => Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "Job not found"})),
)),
}
}
async fn list_jobs(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let queue = state.job_queue.as_ref().ok_or_else(|| {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"error": "Job queue not enabled"})),
)
})?;
let summaries: Vec<_> = queue
.list()
.await
.into_iter()
.map(|s| {
serde_json::json!({
"id": s.id,
"status": format!("{:?}", s.status).to_lowercase(),
"submitted_at": s.submitted_at.to_rfc3339(),
})
})
.collect();
Ok(Json(serde_json::json!({ "jobs": summaries })))
}
async fn cancel_job(
State(state): State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let queue = state.job_queue.as_ref().ok_or_else(|| {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"error": "Job queue not enabled"})),
)
})?;
if queue.cancel(&id).await {
Ok(Json(serde_json::json!({"id": id, "status": "cancelled"})))
} else {
Err((
StatusCode::CONFLICT,
Json(
serde_json::json!({"error": "Job cannot be cancelled (not in queued state or not found)"}),
),
))
}
}
async fn reload_config(
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let source = state.server_state.config_source.read().await.clone();
match crate::config_loader::load_config(&source).await {
Ok(new_config) => {
let mut config = state.server_state.config.write().await;
*config = new_config;
info!("Configuration reloaded via REST API from {:?}", source);
Ok(Json(serde_json::json!({
"success": true,
"message": "Configuration reloaded"
})))
}
Err(e) => {
error!("Failed to reload configuration: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({
"success": false,
"message": format!("Failed to reload configuration: {}", e)
})),
))
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_health_response_serialization() {
let response = HealthResponse {
healthy: true,
version: "0.1.0".to_string(),
uptime_seconds: 100,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("healthy"));
assert!(json.contains("version"));
assert!(json.contains("uptime_seconds"));
}
#[test]
fn test_health_response_deserialization() {
let json = r#"{"healthy":true,"version":"0.1.0","uptime_seconds":100}"#;
let response: HealthResponse = serde_json::from_str(json).unwrap();
assert!(response.healthy);
assert_eq!(response.version, "0.1.0");
assert_eq!(response.uptime_seconds, 100);
}
#[test]
fn test_metrics_response_serialization() {
let response = MetricsResponse {
total_entries_generated: 1000,
total_anomalies_injected: 10,
uptime_seconds: 60,
session_entries: 1000,
session_entries_per_second: 16.67,
active_streams: 1,
total_stream_events: 500,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("total_entries_generated"));
assert!(json.contains("session_entries_per_second"));
}
#[test]
fn test_metrics_response_deserialization() {
let json = r#"{
"total_entries_generated": 5000,
"total_anomalies_injected": 50,
"uptime_seconds": 300,
"session_entries": 5000,
"session_entries_per_second": 16.67,
"active_streams": 2,
"total_stream_events": 10000
}"#;
let response: MetricsResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.total_entries_generated, 5000);
assert_eq!(response.active_streams, 2);
}
#[test]
fn test_config_response_serialization() {
let response = ConfigResponse {
success: true,
message: "Configuration loaded".to_string(),
config: Some(GenerationConfigDto {
industry: "manufacturing".to_string(),
start_date: "2024-01-01".to_string(),
period_months: 12,
seed: Some(42),
coa_complexity: "medium".to_string(),
companies: vec![],
fraud_enabled: false,
fraud_rate: 0.0,
}),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("success"));
assert!(json.contains("config"));
}
#[test]
fn test_config_response_without_config() {
let response = ConfigResponse {
success: false,
message: "No configuration available".to_string(),
config: None,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("null") || json.contains("config\":null"));
}
#[test]
fn test_generation_config_dto_roundtrip() {
let original = GenerationConfigDto {
industry: "retail".to_string(),
start_date: "2024-06-01".to_string(),
period_months: 6,
seed: Some(12345),
coa_complexity: "large".to_string(),
companies: vec![CompanyConfigDto {
code: "1000".to_string(),
name: "Test Corp".to_string(),
currency: "USD".to_string(),
country: "US".to_string(),
annual_transaction_volume: 100000,
volume_weight: 1.0,
}],
fraud_enabled: true,
fraud_rate: 0.05,
};
let json = serde_json::to_string(&original).unwrap();
let deserialized: GenerationConfigDto = serde_json::from_str(&json).unwrap();
assert_eq!(original.industry, deserialized.industry);
assert_eq!(original.seed, deserialized.seed);
assert_eq!(original.companies.len(), deserialized.companies.len());
}
#[test]
fn test_company_config_dto_serialization() {
let company = CompanyConfigDto {
code: "2000".to_string(),
name: "European Subsidiary".to_string(),
currency: "EUR".to_string(),
country: "DE".to_string(),
annual_transaction_volume: 50000,
volume_weight: 0.5,
};
let json = serde_json::to_string(&company).unwrap();
assert!(json.contains("2000"));
assert!(json.contains("EUR"));
assert!(json.contains("DE"));
}
#[test]
fn test_bulk_generate_request_deserialization() {
let json = r#"{
"entry_count": 5000,
"include_master_data": true,
"inject_anomalies": true
}"#;
let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.entry_count, Some(5000));
assert_eq!(request.include_master_data, Some(true));
assert_eq!(request.inject_anomalies, Some(true));
}
#[test]
fn test_bulk_generate_request_with_defaults() {
let json = r#"{}"#;
let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.entry_count, None);
assert_eq!(request.include_master_data, None);
assert_eq!(request.inject_anomalies, None);
}
#[test]
fn test_bulk_generate_response_serialization() {
let response = BulkGenerateResponse {
success: true,
entries_generated: 1000,
duration_ms: 250,
anomaly_count: 20,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("entries_generated"));
assert!(json.contains("1000"));
assert!(json.contains("duration_ms"));
}
#[test]
fn test_stream_response_serialization() {
let response = StreamResponse {
success: true,
message: "Stream started successfully".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("success"));
assert!(json.contains("Stream started"));
}
#[test]
fn test_stream_response_failure() {
let response = StreamResponse {
success: false,
message: "Stream failed to start".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("false"));
assert!(json.contains("failed"));
}
#[test]
fn test_cors_config_default() {
let config = CorsConfig::default();
assert!(!config.allow_any_origin);
assert!(!config.allowed_origins.is_empty());
assert!(config
.allowed_origins
.contains(&"http://localhost:5173".to_string()));
assert!(config
.allowed_origins
.contains(&"tauri://localhost".to_string()));
}
#[test]
fn test_cors_config_custom_origins() {
let config = CorsConfig {
allowed_origins: vec![
"https://example.com".to_string(),
"https://app.example.com".to_string(),
],
allow_any_origin: false,
};
assert_eq!(config.allowed_origins.len(), 2);
assert!(config
.allowed_origins
.contains(&"https://example.com".to_string()));
}
#[test]
fn test_cors_config_permissive() {
let config = CorsConfig {
allowed_origins: vec![],
allow_any_origin: true,
};
assert!(config.allow_any_origin);
}
#[test]
fn test_bulk_generate_request_partial() {
let json = r#"{"entry_count": 100}"#;
let request: BulkGenerateRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.entry_count, Some(100));
assert!(request.include_master_data.is_none());
}
#[test]
fn test_generation_config_no_seed() {
let config = GenerationConfigDto {
industry: "technology".to_string(),
start_date: "2024-01-01".to_string(),
period_months: 3,
seed: None,
coa_complexity: "small".to_string(),
companies: vec![],
fraud_enabled: false,
fraud_rate: 0.0,
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("seed"));
}
#[test]
fn test_generation_config_multiple_companies() {
let config = GenerationConfigDto {
industry: "manufacturing".to_string(),
start_date: "2024-01-01".to_string(),
period_months: 12,
seed: Some(42),
coa_complexity: "large".to_string(),
companies: vec![
CompanyConfigDto {
code: "1000".to_string(),
name: "Headquarters".to_string(),
currency: "USD".to_string(),
country: "US".to_string(),
annual_transaction_volume: 100000,
volume_weight: 1.0,
},
CompanyConfigDto {
code: "2000".to_string(),
name: "European Sub".to_string(),
currency: "EUR".to_string(),
country: "DE".to_string(),
annual_transaction_volume: 50000,
volume_weight: 0.5,
},
CompanyConfigDto {
code: "3000".to_string(),
name: "APAC Sub".to_string(),
currency: "JPY".to_string(),
country: "JP".to_string(),
annual_transaction_volume: 30000,
volume_weight: 0.3,
},
],
fraud_enabled: true,
fraud_rate: 0.02,
};
assert_eq!(config.companies.len(), 3);
}
#[test]
fn test_metrics_entries_per_second_calculation() {
let total_entries: u64 = 1000;
let uptime: u64 = 60;
let eps = if uptime > 0 {
total_entries as f64 / uptime as f64
} else {
0.0
};
assert!((eps - 16.67).abs() < 0.1);
}
#[test]
fn test_metrics_entries_per_second_zero_uptime() {
let total_entries: u64 = 1000;
let uptime: u64 = 0;
let eps = if uptime > 0 {
total_entries as f64 / uptime as f64
} else {
0.0
};
assert_eq!(eps, 0.0);
}
}