use axum::{middleware, Router};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::signal;
use tower::Service;
use tower_http::compression::CompressionLayer;
use crate::auth::JwtAuth;
use crate::error::ApiError;
use crate::middleware::{
auth_middleware, body_limit_layer, cors_layer, rate_limit_middleware, request_id_middleware,
timeout_layer, tracing_middleware,
};
use crate::routes::api_router;
use vex_llm::{Metrics, RateLimitConfig};
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
}
impl TlsConfig {
pub fn new(cert_path: &str, key_path: &str) -> Self {
Self {
cert_path: cert_path.to_string(),
key_path: key_path.to_string(),
}
}
pub fn from_env() -> Option<Self> {
let cert = std::env::var("VEX_TLS_CERT").ok()?;
let key = std::env::var("VEX_TLS_KEY").ok()?;
Some(Self::new(&cert, &key))
}
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub addr: SocketAddr,
pub timeout: Duration,
pub max_body_size: usize,
pub compression: bool,
pub rate_limit: RateLimitConfig,
pub tls: Option<TlsConfig>,
pub enforce_https: bool,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
addr: "0.0.0.0:8080".parse().unwrap(),
timeout: Duration::from_secs(30),
max_body_size: 1024 * 1024, compression: true,
rate_limit: RateLimitConfig::default(),
tls: None,
enforce_https: false,
}
}
}
impl ServerConfig {
pub fn from_env() -> Self {
let port: u16 = std::env::var("VEX_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(8080);
let timeout_secs: u64 = std::env::var("VEX_TIMEOUT_SECS")
.ok()
.and_then(|t| t.parse().ok())
.unwrap_or(30);
let enforce_https = std::env::var("VEX_ENFORCE_HTTPS").is_ok()
|| std::env::var("VEX_ENV")
.map(|e| e == "production")
.unwrap_or(false);
Self {
addr: SocketAddr::from(([0, 0, 0, 0], port)),
timeout: Duration::from_secs(timeout_secs),
enforce_https,
..Default::default()
}
}
}
use crate::state::AppState;
pub struct VexServer {
config: ServerConfig,
app_state: AppState,
}
impl VexServer {
pub async fn new(config: ServerConfig) -> Result<Self, ApiError> {
use crate::jobs::agent::{AgentExecutionJob, AgentJobPayload};
use crate::tenant_rate_limiter::{RateLimitTier, TenantRateLimiter};
use vex_llm::{
CachedProvider, DeepSeekProvider, LlmProvider, MockProvider, ResilientProvider,
};
use vex_queue::{QueueBackend, WorkerConfig, WorkerPool};
let jwt_auth = JwtAuth::from_env()?;
let rate_limiter = Arc::new(TenantRateLimiter::new(RateLimitTier::Standard));
let metrics = Arc::new(Metrics::new());
let db_url =
std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:vex.db?mode=rwc".to_string());
let is_postgres = db_url.starts_with("postgres://") || db_url.starts_with("postgresql://");
let (db, evolution_store, queue_backend): (
Arc<dyn vex_persist::StorageBackend>,
Arc<dyn vex_persist::EvolutionStore>,
Arc<dyn QueueBackend>,
) = if is_postgres {
#[cfg(feature = "postgres")]
{
tracing::info!("DATABASE_URL: PostgreSQL backend selected (Railway Managed DB)");
let pg_backend = vex_persist::PostgresBackend::new(&db_url)
.await
.expect("Failed to initialize Postgres backend");
pg_backend
.migrate()
.await
.expect("Failed to migrate Postgres backend");
let pg_pool = pg_backend.pool().clone();
(
Arc::new(pg_backend),
Arc::new(vex_persist::PostgresEvolutionStore::new(pg_pool.clone())),
Arc::new(vex_persist::PostgresQueueBackend::new(pg_pool))
as Arc<dyn QueueBackend>,
)
}
#[cfg(not(feature = "postgres"))]
{
return Err(ApiError::Internal(
"Postgres DATABASE_URL detected but vex-persist was compiled without the 'postgres' feature. \
Rebuild with: cargo build --features vex-persist/postgres".to_string(),
));
}
} else {
tracing::info!("DATABASE_URL: SQLite backend selected");
let sqlite_backend = vex_persist::sqlite::SqliteBackend::new(&db_url)
.await
.expect("Failed to initialize SQLite backend");
sqlite_backend
.migrate()
.await
.expect("Failed to migrate SQLite backend");
let sqlite_pool = sqlite_backend.pool().clone();
(
Arc::new(sqlite_backend),
Arc::new(vex_persist::SqliteEvolutionStore::new(sqlite_pool.clone())),
Arc::new(vex_persist::queue::SqliteQueueBackend::new(sqlite_pool))
as Arc<dyn QueueBackend>,
)
};
let worker_pool = WorkerPool::new_with_arc(queue_backend, WorkerConfig::default());
let _base_llm: Arc<dyn LlmProvider> = if let Ok(key) = std::env::var("DEEPSEEK_API_KEY") {
tracing::info!("Initializing Resilient+Cached DeepSeek Provider");
let base = DeepSeekProvider::chat(&key);
let resilient = ResilientProvider::new(base, vex_llm::LlmCircuitConfig::conservative());
let cached = CachedProvider::wrap(resilient);
Arc::new(cached)
} else {
tracing::warn!("DEEPSEEK_API_KEY not found. Using Mock Provider.");
Arc::new(MockProvider::smart())
};
let router = vex_router::Router::builder()
.strategy(vex_router::RoutingStrategy::Auto)
.build();
let router_arc = Arc::new(router);
let llm: Arc<dyn LlmProvider> = router_arc.clone();
let result_store = crate::jobs::new_result_store();
let hardware_keystore = vex_hardware::api::HardwareKeystore::new()
.await
.map_err(|e| ApiError::Internal(format!("Hardware init failed: {}", e)))?;
let identity = Arc::new(
hardware_keystore
.get_identity(&[])
.await
.map_err(|e| ApiError::Internal(format!("Hardware identity failed: {}", e)))?,
);
let verifier: Arc<dyn vex_core::zk::ZkVerifier> = Arc::new(attest_rs::zk::AuditProver);
let prover = Arc::new(attest_rs::zk::AuditProver);
let gate_url = std::env::var("CHORA_GATE_URL").ok();
let api_key = std::env::var("CHORA_API_KEY").unwrap_or_default();
let (gate, bridge): (Arc<dyn vex_runtime::Gate>, Arc<vex_chora::AuthorityBridge>) =
if let Some(url) = gate_url {
let client = Arc::new(vex_chora::client::HttpChoraClient::new(url, api_key));
let http_gate = vex_runtime::HttpGate::new(client).with_prover(prover);
let bridge = http_gate.inner.bridge.clone();
(Arc::new(http_gate), bridge)
} else {
let mock_gate = Arc::new(vex_runtime::GenericGateMock);
let bridge = Arc::new(vex_chora::AuthorityBridge::new(Arc::new(
vex_chora::client::MockChoraClient,
)));
(mock_gate, bridge)
};
let audit_store = Arc::new(vex_persist::AuditStore::new(db.clone()));
let base_orchestrator = vex_runtime::Orchestrator::new(
llm.clone(),
vex_runtime::OrchestratorConfig::default(),
Some(evolution_store.clone()),
gate.clone(),
);
let orchestrator = Arc::new(
base_orchestrator
.with_identity(identity.clone(), audit_store.clone())
.with_verifier(verifier.clone()),
);
let llm_clone = llm.clone();
let result_store_clone = result_store.clone();
let db_for_factory = db.clone();
let evolution_store_clone = evolution_store.clone();
let gate_clone = gate.clone();
let orchestrator_clone = orchestrator.clone();
worker_pool.register_job_factory("agent_execution", move |payload| {
let job_payload: AgentJobPayload =
serde_json::from_value(payload).unwrap_or_else(|_| AgentJobPayload {
agent_id: "unknown".to_string(),
prompt: "payload error".to_string(),
context_id: None,
enable_adversarial: false,
enable_self_correction: false,
max_debate_rounds: 3,
tenant_id: None,
capabilities: vec![],
});
let job_id = uuid::Uuid::new_v4();
let db_concrete = db_for_factory.clone();
let evo_store = evolution_store_clone.clone();
Box::new(AgentExecutionJob::new(
job_id,
job_payload,
llm_clone.clone(),
result_store_clone.clone(),
db_concrete as Arc<dyn vex_persist::StorageBackend>,
None, evo_store,
gate_clone.clone(),
orchestrator_clone.clone(),
))
});
let a2a_state = Arc::new(crate::a2a::handler::A2aState::default());
let app_state = AppState::new(
jwt_auth,
rate_limiter,
metrics,
db as Arc<dyn vex_persist::StorageBackend>,
evolution_store,
Arc::new(worker_pool),
a2a_state,
llm.clone(),
Some(router_arc),
gate.clone(),
orchestrator.clone(),
bridge,
);
Ok(Self { config, app_state })
}
pub fn router(&self) -> Router {
let mut app = api_router(self.app_state.clone());
app = app
.layer(CompressionLayer::new())
.layer(body_limit_layer(self.config.max_body_size))
.layer(timeout_layer(self.config.timeout))
.layer(cors_layer())
.layer(middleware::from_fn(request_id_middleware))
.layer(middleware::from_fn_with_state(
self.app_state.clone(),
tracing_middleware,
))
.layer(middleware::from_fn_with_state(
self.app_state.clone(),
rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
self.app_state.clone(),
auth_middleware,
));
app
}
pub async fn run(self) -> Result<(), ApiError> {
let app = self.router();
let addr = self.config.addr;
let queue = self.app_state.queue();
tokio::spawn(async move {
queue.start().await;
});
if let Some(tls_config) = &self.config.tls {
tracing::info!("🔒 Starting VEX API server with HTTPS on {}", addr);
use rustls_pki_types::pem::PemObject;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use std::io::Read;
use tokio_rustls::rustls::ServerConfig;
let mut cert_file = std::fs::File::open(&tls_config.cert_path)
.map_err(|e| ApiError::Internal(format!("Failed to open cert file: {}", e)))?;
let mut key_file = std::fs::File::open(&tls_config.key_path)
.map_err(|e| ApiError::Internal(format!("Failed to open key file: {}", e)))?;
let mut cert_pem = Vec::new();
cert_file
.read_to_end(&mut cert_pem)
.map_err(|e| ApiError::Internal(format!("Failed to read cert file: {}", e)))?;
let mut key_pem = Vec::new();
key_file
.read_to_end(&mut key_pem)
.map_err(|e| ApiError::Internal(format!("Failed to read key file: {}", e)))?;
let certs = CertificateDer::pem_slice_iter(&cert_pem)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| ApiError::Internal(format!("Failed to parse certs: {}", e)))?;
let mut keys = PrivateKeyDer::pem_slice_iter(&key_pem)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| ApiError::Internal(format!("Failed to parse key: {}", e)))?;
if keys.is_empty() {
return Err(ApiError::Internal("No private keys found".to_string()));
}
let mut server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, keys.remove(0))
.map_err(|e| ApiError::Internal(format!("Failed to build TLS config: {}", e)))?;
server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config));
let tcp_listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("✅ VEX API listening on https://{}", addr);
loop {
let (tcp_stream, remote_addr) = tcp_listener
.accept()
.await
.map_err(|e| ApiError::Internal(format!("Accept error: {}", e)))?;
let tls_acceptor = tls_acceptor.clone();
let app = app.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(tcp_stream).await {
Ok(s) => s,
Err(e) => {
tracing::error!("TLS handshake failed: {}", e);
return;
}
};
let tower_service = app.clone();
let hyper_service = hyper::service::service_fn(
move |request: hyper::Request<hyper::body::Incoming>| {
tower_service.clone().call(request)
},
);
if let Err(e) = hyper::server::conn::http1::Builder::new()
.serve_connection(hyper_util::rt::TokioIo::new(tls_stream), hyper_service)
.await
{
tracing::error!(
"Error serving HTTPS connection from {}: {}",
remote_addr,
e
);
}
});
}
} else {
if self.config.enforce_https {
tracing::error!("FATAL: HTTPS enforcement is enabled but TLS certificates are missing (VEX_TLS_CERT/VEX_TLS_KEY)");
return Err(ApiError::Internal("HTTPS enforcement error".to_string()));
}
tracing::warn!(
"⚠️ Starting VEX API server WITHOUT HTTPS on {} - NOT for production!",
addr
);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await
.map_err(|e| ApiError::Internal(format!("Server error: {}", e)))?;
}
tracing::info!("Server shutdown complete");
Ok(())
}
pub fn metrics(&self) -> Arc<Metrics> {
self.app_state.metrics()
}
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received Ctrl+C, starting graceful shutdown");
}
_ = terminate => {
tracing::info!("Received SIGTERM, starting graceful shutdown");
}
}
}
pub fn init_tracing() {
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("info,vex_api=debug,tower_http=debug"));
tracing_subscriber::registry()
.with(filter)
.with(tracing_subscriber::fmt::layer().with_target(true))
.init();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_config_default() {
let config = ServerConfig::default();
assert_eq!(config.addr.port(), 8080);
assert_eq!(config.timeout, Duration::from_secs(30));
}
}