use anyhow::Result;
use axum::extract::DefaultBodyLimit;
use axum::Router;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::signal;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use mockforge_registry_server::circuit_breaker::{self, CircuitBreaker, CircuitBreakerRegistry};
use mockforge_registry_server::config::Config;
use mockforge_registry_server::database::Database;
use mockforge_registry_server::middleware::csrf::csrf_middleware;
use mockforge_registry_server::middleware::rate_limit::RateLimiterState;
use mockforge_registry_server::middleware::request_id::request_id_middleware;
use mockforge_registry_server::redis::RedisPool;
use mockforge_registry_server::storage::PluginStorage;
use mockforge_registry_server::store::PgRegistryStore;
use mockforge_registry_server::{deployment, pillar_tracking_init, routes, workers, AppState};
use axum::response::IntoResponse;
use mockforge_observability::get_global_registry;
use std::sync::Arc;
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "mockforge_registry_server=info,tower_http=debug".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let config = Config::load()?;
tracing::info!("Configuration loaded");
let db = Database::connect(&config.database_url).await?;
tracing::info!("Database connected");
if config.skip_migrations {
tracing::info!("Skipping database migrations (SKIP_MIGRATIONS=true)");
} else {
db.migrate().await?;
tracing::info!("Database migrations complete");
}
let storage = PluginStorage::new(&config).await?;
tracing::info!("Storage initialized");
let metrics = Arc::new(get_global_registry().clone());
let analytics_db = if let Some(analytics_db_path) = &config.analytics_db_path {
match mockforge_analytics::AnalyticsDatabase::new(std::path::Path::new(analytics_db_path))
.await
{
Ok(analytics_db) => {
if let Err(e) = analytics_db.run_migrations().await {
tracing::warn!("Failed to run analytics database migrations: {}", e);
None
} else {
tracing::info!("Analytics database initialized at: {}", analytics_db_path);
Some(analytics_db)
}
}
Err(e) => {
tracing::warn!("Failed to initialize analytics database: {}", e);
None
}
}
} else {
let default_path = std::path::Path::new("mockforge-analytics.db");
match mockforge_analytics::AnalyticsDatabase::new(default_path).await {
Ok(analytics_db) => {
if let Err(e) = analytics_db.run_migrations().await {
tracing::warn!("Failed to run analytics database migrations: {}", e);
None
} else {
tracing::info!(
"Analytics database initialized at default path: mockforge-analytics.db"
);
Some(analytics_db)
}
}
Err(e) => {
tracing::debug!("Analytics database not available (optional): {}", e);
None
}
}
};
if let Some(ref analytics_db) = analytics_db {
let db_arc = std::sync::Arc::new(analytics_db.clone());
pillar_tracking_init::init_pillar_tracking(Some(db_arc)).await;
}
let redis = if let Some(redis_url) = &config.redis_url {
match RedisPool::connect(redis_url).await {
Ok(pool) => {
tracing::info!("Redis connected");
Some(pool)
}
Err(e) => {
tracing::warn!(
"Failed to connect to Redis (2FA setup will require alternative flow): {}",
e
);
None
}
}
} else {
tracing::info!("Redis not configured (REDIS_URL not set)");
None
};
let rate_limiter = RateLimiterState::new(config.rate_limit_per_minute);
tracing::info!("Rate limiter initialized: {} requests/minute", config.rate_limit_per_minute);
let circuit_breakers = CircuitBreakerRegistry::new();
circuit_breakers
.register("redis", CircuitBreaker::new(circuit_breaker::presets::redis()))
.await;
circuit_breakers
.register("s3", CircuitBreaker::new(circuit_breaker::presets::s3()))
.await;
circuit_breakers
.register("email", CircuitBreaker::new(circuit_breaker::presets::email()))
.await;
circuit_breakers
.register("database", CircuitBreaker::new(circuit_breaker::presets::database()))
.await;
tracing::info!("Circuit breakers initialized for external services");
let store: Arc<dyn mockforge_registry_server::store::RegistryStore> =
Arc::new(PgRegistryStore::new(db.pool().clone()));
let state = AppState {
db: db.clone(),
storage,
config: config.clone(),
metrics: metrics.clone(),
analytics_db,
redis,
circuit_breakers,
store,
};
workers::saml_cleanup::start_saml_cleanup_worker(db.pool().clone());
workers::plugin_scanner::start_plugin_scanner_worker(state.clone());
workers::osv_sync::start_osv_sync_worker(state.clone());
workers::runtime_logs_retention::start_runtime_logs_retention_worker(db.pool().clone());
workers::runtime_observability_retention::start_runtime_observability_retention_worker(
db.pool().clone(),
);
workers::usage_threshold_checker::start_usage_threshold_checker(state.clone());
workers::token_rotation_reminders::start_token_rotation_reminders_worker(db.pool().clone());
workers::test_schedule_runner::start_test_schedule_worker(
db.pool().clone(),
state.redis.clone(),
);
workers::incident_dispatcher::start_incident_dispatcher_worker(db.pool().clone());
workers::snapshot_retention::start_snapshot_retention_worker(
db.pool().clone(),
state.storage.clone(),
);
let flyio_token = std::env::var("FLYIO_API_TOKEN").ok();
let flyio_org_slug = std::env::var("FLYIO_ORG_SLUG").ok();
let orchestrator = Arc::new(deployment::orchestrator::DeploymentOrchestrator::new(
Arc::new(db.pool().clone()),
flyio_token,
flyio_org_slug,
));
let _orchestrator_handle = orchestrator.start();
tracing::info!("Deployment orchestrator started");
let health_checker =
Arc::new(deployment::health_check::HealthCheckWorker::new(Arc::new(db.pool().clone())));
let _health_check_handle = health_checker.start();
tracing::info!("Health check worker started");
let metrics_collector =
Arc::new(deployment::metrics::MetricsCollector::new(Arc::new(db.pool().clone())));
let _metrics_collector_handle = metrics_collector.start();
tracing::info!("Metrics collector started");
let cleanup_flyio_client =
std::env::var("FLYIO_API_TOKEN").ok().map(deployment::flyio::FlyioClient::new);
let cleanup_worker = Arc::new(deployment::cleanup::DeploymentCleanup::new(
Arc::new(db.pool().clone()),
cleanup_flyio_client,
));
let _cleanup_handle = cleanup_worker.start();
tracing::info!("Deployment cleanup worker started");
let app = create_app(state, rate_limiter);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
let shutdown_timeout = Duration::from_secs(config.shutdown_timeout_secs);
tracing::info!("Starting server on {}", addr);
tracing::info!("Graceful shutdown timeout: {} seconds", config.shutdown_timeout_secs);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal(shutdown_timeout))
.await?;
tracing::info!("Server shutdown complete");
Ok(())
}
async fn shutdown_signal(timeout: Duration) {
let ctrl_c = async {
match signal::ctrl_c().await {
Ok(()) => {}
Err(e) => {
tracing::error!("Failed to install Ctrl+C handler: {}", e);
std::future::pending::<()>().await;
}
}
};
#[cfg(unix)]
let terminate = async {
match signal::unix::signal(signal::unix::SignalKind::terminate()) {
Ok(mut signal) => {
signal.recv().await;
}
Err(e) => {
tracing::error!("Failed to install SIGTERM handler: {}", e);
std::future::pending::<()>().await;
}
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received Ctrl+C, initiating graceful shutdown");
}
_ = terminate => {
tracing::info!("Received SIGTERM, initiating graceful shutdown");
}
}
tracing::info!(
"Stopping new connections, waiting up to {} seconds for active requests to complete",
timeout.as_secs()
);
}
fn create_app(state: AppState, rate_limiter: RateLimiterState) -> Router {
let cors = match std::env::var("CORS_ALLOWED_ORIGINS") {
Ok(origins) if !origins.is_empty() => {
let allowed_origins: Vec<_> =
origins.split(',').filter_map(|s| s.trim().parse().ok()).collect();
tracing::info!("CORS configured with {} allowed origins", allowed_origins.len());
CorsLayer::new()
.allow_origin(AllowOrigin::list(allowed_origins))
.allow_methods(Any)
.allow_headers(Any)
}
_ => {
tracing::info!(
"CORS configured with strict same-origin policy (no CORS_ALLOWED_ORIGINS set)"
);
CorsLayer::new()
.allow_origin(AllowOrigin::exact(
"null".parse().expect("'null' is a valid header value"),
))
.allow_methods(Any)
.allow_headers(Any)
}
};
let max_body_size: usize = std::env::var("MAX_REQUEST_BODY_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10 * 1024 * 1024); tracing::info!("Request body size limit: {} bytes", max_body_size);
let metrics_router = Router::new()
.route("/metrics", axum::routing::get(metrics_handler))
.route("/metrics/health", axum::routing::get(|| async { "OK" }));
Router::new()
.merge(routes::create_router(state.clone()))
.merge(deployment::router::MultitenantRouter::create_router())
.merge(metrics_router)
.fallback(deployment::router::custom_domain_fallback)
.layer(DefaultBodyLimit::max(max_body_size))
.layer(cors)
.layer(TraceLayer::new_for_http())
.layer(axum::middleware::from_fn(request_id_middleware))
.layer(axum::middleware::from_fn(csrf_middleware))
.layer(axum::Extension(rate_limiter))
.with_state(state)
}
async fn metrics_handler() -> impl axum::response::IntoResponse {
use mockforge_observability::get_global_registry;
use prometheus::{Encoder, TextEncoder};
let encoder = TextEncoder::new();
let metric_families = get_global_registry().registry().gather();
let mut buffer = Vec::new();
if let Err(e) = encoder.encode(&metric_families, &mut buffer) {
tracing::error!("Failed to encode metrics: {}", e);
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "Failed to encode metrics")
.into_response();
}
let body = match String::from_utf8(buffer) {
Ok(body) => body,
Err(e) => {
tracing::error!("Failed to convert metrics to UTF-8: {}", e);
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "Failed to convert metrics")
.into_response();
}
};
(
axum::http::StatusCode::OK,
[("content-type", "text/plain; version=0.0.4; charset=utf-8")],
body,
)
.into_response()
}
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test_health_check() {
}
}