use crate::connection::{Connection, ConnectionManager};
use crate::control_plane::{control_plane_router, ControlPlaneState};
use crate::dashboard::dashboard_router;
use crate::enterprise_readiness::{
build_enterprise_readiness_report, EnterpriseReadinessInput, EnterpriseReadinessReport,
};
use crate::graceful_shutdown::{ShutdownManager, ShutdownPhase};
use crate::middleware::{
auth_middleware, per_key_rate_limit_middleware, rate_limit_middleware, AuthConfig,
MiddlewareState,
};
use crate::observability::{
request_tracing_middleware, ObservabilityMiddlewareState, RequestMetrics,
};
use crate::playground::playground_router;
use crate::pricing_page::pricing_router;
use crate::proxy_management::{proxy_management_router, ProxyManagementState};
use crate::rate_limit_per_key::PerKeyRateLimiter;
use crate::rest_api::{api_router, audit_prometheus_export, RestApiState};
use crate::router::{InboundMessage, MessageRouter};
use crate::streaming::{streaming_router, StreamingState};
use crate::webhook::{webhook_handler, WebhookConfig, WebhookState};
use argentor_a2a::server::{A2AServer, A2AServerState};
use argentor_agent::AgentRunner;
use argentor_security::observability::AgentMetricsCollector;
use argentor_security::RateLimiter;
use argentor_session::SessionStore;
use axum::{
extract::{
ws::{Message, WebSocket},
State, WebSocketUpgrade,
},
http::{header, StatusCode},
middleware as axum_mw,
response::IntoResponse,
routing::{get, post},
Router,
};
use chrono::Utc;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use uuid::Uuid;
pub struct AppState {
pub router: Arc<MessageRouter>,
pub connections: Arc<ConnectionManager>,
pub webhooks: Option<WebhookState>,
pub metrics: Option<AgentMetricsCollector>,
pub control_plane: Option<Arc<ControlPlaneState>>,
pub rest_api: Option<Arc<RestApiState>>,
pub shutdown_manager: Option<ShutdownManager>,
pub per_key_rate_limiter: Option<Arc<PerKeyRateLimiter>>,
pub request_metrics: Option<Arc<RequestMetrics>>,
pub auth_enabled: bool,
pub proxy_management_mounted: bool,
pub a2a_mounted: bool,
}
#[derive(Debug, Clone)]
pub struct GracefulShutdownConfig {
pub total_timeout: Duration,
pub drain_timeout: Duration,
pub cleanup_timeout: Duration,
pub final_timeout: Duration,
}
impl Default for GracefulShutdownConfig {
fn default() -> Self {
Self {
total_timeout: Duration::from_secs(30),
drain_timeout: Duration::from_secs(15),
cleanup_timeout: Duration::from_secs(10),
final_timeout: Duration::from_secs(5),
}
}
}
pub struct GatewayServer;
impl GatewayServer {
pub fn build(agent: Arc<AgentRunner>, sessions: Arc<dyn SessionStore>) -> Router {
Self::build_with_middleware(agent, sessions, None, AuthConfig::new(vec![]), None, None)
}
pub fn build_with_middleware(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
rate_limiter: Option<Arc<RateLimiter>>,
auth_config: AuthConfig,
webhooks: Option<Vec<WebhookConfig>>,
metrics: Option<AgentMetricsCollector>,
) -> Router {
Self::build_full(
agent,
sessions,
rate_limiter,
auth_config,
webhooks,
metrics,
None,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn build_full(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
rate_limiter: Option<Arc<RateLimiter>>,
auth_config: AuthConfig,
webhooks: Option<Vec<WebhookConfig>>,
metrics: Option<AgentMetricsCollector>,
control_plane: Option<Arc<ControlPlaneState>>,
rest_api: Option<Arc<RestApiState>>,
) -> Router {
Self::build_complete(
agent,
sessions,
rate_limiter,
auth_config,
webhooks,
metrics,
control_plane,
rest_api,
None,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn build_complete(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
rate_limiter: Option<Arc<RateLimiter>>,
auth_config: AuthConfig,
webhooks: Option<Vec<WebhookConfig>>,
metrics: Option<AgentMetricsCollector>,
control_plane: Option<Arc<ControlPlaneState>>,
rest_api: Option<Arc<RestApiState>>,
proxy_management: Option<Arc<ProxyManagementState>>,
a2a: Option<Arc<A2AServerState>>,
) -> Router {
Self::build_complete_with_shutdown(
agent,
sessions,
rate_limiter,
auth_config,
webhooks,
metrics,
control_plane,
rest_api,
proxy_management,
a2a,
None,
)
}
pub fn with_rate_limiter(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
per_key_limiter: PerKeyRateLimiter,
auth_config: AuthConfig,
) -> Router {
Self::build_complete_with_per_key(
agent,
sessions,
None,
auth_config,
None,
None,
None,
None,
None,
None,
None,
Some(Arc::new(per_key_limiter)),
)
}
#[allow(clippy::too_many_arguments)]
pub fn build_complete_with_per_key(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
rate_limiter: Option<Arc<RateLimiter>>,
auth_config: AuthConfig,
webhooks: Option<Vec<WebhookConfig>>,
metrics: Option<AgentMetricsCollector>,
control_plane: Option<Arc<ControlPlaneState>>,
rest_api: Option<Arc<RestApiState>>,
proxy_management: Option<Arc<ProxyManagementState>>,
a2a: Option<Arc<A2AServerState>>,
shutdown_manager: Option<ShutdownManager>,
per_key_rate_limiter: Option<Arc<PerKeyRateLimiter>>,
) -> Router {
let connections = ConnectionManager::new();
let sessions_for_streaming = sessions.clone();
let router = Arc::new(MessageRouter::new(agent, sessions, connections.clone()));
let webhook_state = webhooks.map(|configs| WebhookState { webhooks: configs });
let auth_enabled = auth_config.is_enabled();
let proxy_management_mounted = proxy_management.is_some();
let a2a_mounted = a2a.is_some();
let request_metrics = Arc::new(RequestMetrics::new());
let streaming_state = Arc::new(StreamingState {
router: router.clone(),
connections: connections.clone(),
sessions: sessions_for_streaming,
session_broadcast: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
});
let state = Arc::new(AppState {
router,
connections,
webhooks: webhook_state,
metrics,
control_plane: control_plane.clone(),
rest_api: rest_api.clone(),
shutdown_manager,
per_key_rate_limiter: per_key_rate_limiter.clone(),
request_metrics: Some(Arc::clone(&request_metrics)),
auth_enabled,
proxy_management_mounted,
a2a_mounted,
});
let mut app = Router::new()
.route("/ws", get(ws_handler))
.route("/health", get(health_handler))
.route("/health/live", get(health_live_handler))
.route("/health/ready", get(health_ready_handler))
.route("/metrics", get(prometheus_metrics_handler))
.route("/openapi.json", get(openapi_handler))
.route(
"/api/v1/enterprise/readiness",
get(enterprise_readiness_handler),
);
if state.webhooks.is_some() {
app = app.route("/webhook/{name}", post(webhook_handler));
}
let mut app = app.with_state(state);
app = app.merge(streaming_router(streaming_state));
app = app.merge(dashboard_router());
app = app.merge(playground_router());
app = app.merge(pricing_router());
if let Some(cp_state) = control_plane {
app = app.merge(control_plane_router(cp_state));
}
if let Some(ra_state) = rest_api {
app = app.merge(api_router(ra_state));
}
if let Some(pm_state) = proxy_management {
app = app.merge(proxy_management_router(pm_state));
}
if let Some(a2a_state) = a2a {
app = app.merge(A2AServer::router(a2a_state));
}
let obs_state = Arc::new(ObservabilityMiddlewareState { request_metrics });
app = app.layer(axum_mw::from_fn_with_state(
obs_state,
request_tracing_middleware,
));
let has_middleware =
rate_limiter.is_some() || auth_config.is_enabled() || per_key_rate_limiter.is_some();
if has_middleware {
let mw_state = Arc::new(MiddlewareState {
rate_limiter: rate_limiter
.unwrap_or_else(|| Arc::new(RateLimiter::new(1000.0, 1000.0))),
auth: auth_config,
per_key_rate_limiter,
});
app = app
.layer(axum_mw::from_fn_with_state(
mw_state.clone(),
rate_limit_middleware,
))
.layer(axum_mw::from_fn_with_state(
mw_state.clone(),
per_key_rate_limit_middleware,
))
.layer(axum_mw::from_fn_with_state(mw_state, auth_middleware));
}
app
}
#[allow(clippy::too_many_arguments)]
pub fn build_complete_with_shutdown(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
rate_limiter: Option<Arc<RateLimiter>>,
auth_config: AuthConfig,
webhooks: Option<Vec<WebhookConfig>>,
metrics: Option<AgentMetricsCollector>,
control_plane: Option<Arc<ControlPlaneState>>,
rest_api: Option<Arc<RestApiState>>,
proxy_management: Option<Arc<ProxyManagementState>>,
a2a: Option<Arc<A2AServerState>>,
shutdown_manager: Option<ShutdownManager>,
) -> Router {
let connections = ConnectionManager::new();
let sessions_for_streaming = sessions.clone();
let router = Arc::new(MessageRouter::new(agent, sessions, connections.clone()));
let webhook_state = webhooks.map(|configs| WebhookState { webhooks: configs });
let auth_enabled = auth_config.is_enabled();
let proxy_management_mounted = proxy_management.is_some();
let a2a_mounted = a2a.is_some();
let request_metrics = Arc::new(RequestMetrics::new());
let streaming_state = Arc::new(StreamingState {
router: router.clone(),
connections: connections.clone(),
sessions: sessions_for_streaming,
session_broadcast: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
});
let state = Arc::new(AppState {
router,
connections,
webhooks: webhook_state,
metrics,
control_plane: control_plane.clone(),
rest_api: rest_api.clone(),
shutdown_manager,
per_key_rate_limiter: None,
request_metrics: Some(Arc::clone(&request_metrics)),
auth_enabled,
proxy_management_mounted,
a2a_mounted,
});
let mut app = Router::new()
.route("/ws", get(ws_handler))
.route("/health", get(health_handler))
.route("/health/live", get(health_live_handler))
.route("/health/ready", get(health_ready_handler))
.route("/metrics", get(prometheus_metrics_handler))
.route("/openapi.json", get(openapi_handler))
.route(
"/api/v1/enterprise/readiness",
get(enterprise_readiness_handler),
);
if state.webhooks.is_some() {
app = app.route("/webhook/{name}", post(webhook_handler));
}
let mut app = app.with_state(state);
app = app.merge(streaming_router(streaming_state));
app = app.merge(dashboard_router());
app = app.merge(playground_router());
app = app.merge(pricing_router());
if let Some(cp_state) = control_plane {
app = app.merge(control_plane_router(cp_state));
}
if let Some(ra_state) = rest_api {
app = app.merge(api_router(ra_state));
}
if let Some(pm_state) = proxy_management {
app = app.merge(proxy_management_router(pm_state));
}
if let Some(a2a_state) = a2a {
app = app.merge(A2AServer::router(a2a_state));
}
let obs_state = Arc::new(ObservabilityMiddlewareState { request_metrics });
app = app.layer(axum_mw::from_fn_with_state(
obs_state,
request_tracing_middleware,
));
if rate_limiter.is_some() || auth_config.is_enabled() {
let mw_state = Arc::new(MiddlewareState {
rate_limiter: rate_limiter
.unwrap_or_else(|| Arc::new(RateLimiter::new(1000.0, 1000.0))),
auth: auth_config,
per_key_rate_limiter: None,
});
app.layer(axum_mw::from_fn_with_state(
mw_state.clone(),
rate_limit_middleware,
))
.layer(axum_mw::from_fn_with_state(mw_state, auth_middleware))
} else {
app
}
}
pub fn build_with_metrics(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
metrics: AgentMetricsCollector,
) -> Router {
Self::build_with_middleware(
agent,
sessions,
None,
AuthConfig::new(vec![]),
None,
Some(metrics),
)
}
pub fn with_webhooks(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
configs: Vec<WebhookConfig>,
) -> Router {
Self::build_with_middleware(
agent,
sessions,
None,
AuthConfig::new(vec![]),
Some(configs),
None,
)
}
pub async fn with_graceful_shutdown(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
config: GracefulShutdownConfig,
) -> (Router, ShutdownManager) {
let shutdown_mgr = ShutdownManager::new(config.total_timeout);
shutdown_mgr
.on_shutdown("stop-accepting", ShutdownPhase::PreDrain, || {
info!("PreDrain: no longer accepting new connections");
Ok(())
})
.await;
let drain_timeout = config.drain_timeout;
shutdown_mgr
.on_shutdown("drain-connections", ShutdownPhase::Drain, move || {
info!(
timeout_ms = drain_timeout.as_millis() as u64,
"Drain: waiting for in-flight requests to complete"
);
Ok(())
})
.await;
let cleanup_timeout = config.cleanup_timeout;
shutdown_mgr
.on_shutdown("flush-resources", ShutdownPhase::Cleanup, move || {
info!(
timeout_ms = cleanup_timeout.as_millis() as u64,
"Cleanup: flushing audit logs and closing resources"
);
Ok(())
})
.await;
shutdown_mgr
.on_shutdown("final-export", ShutdownPhase::Final, || {
info!("Final: shutdown complete");
Ok(())
})
.await;
let shutdown_for_signal = shutdown_mgr.clone();
tokio::spawn(async move {
match tokio::signal::ctrl_c().await {
Ok(()) => {
warn!("Received SIGINT (Ctrl+C), initiating graceful shutdown");
let report = shutdown_for_signal.shutdown().await;
info!(
succeeded = report.hooks_succeeded,
failed = report.hooks_failed,
total_ms = report.total_duration_ms,
"Graceful shutdown complete"
);
}
Err(e) => {
error!(error = %e, "Failed to listen for ctrl_c signal");
}
}
});
let router = Self::build_complete_with_shutdown(
agent,
sessions,
None,
AuthConfig::new(vec![]),
None,
None,
None,
None,
None,
None,
Some(shutdown_mgr.clone()),
);
(router, shutdown_mgr)
}
pub fn with_sso(
agent: Arc<AgentRunner>,
sessions: Arc<dyn SessionStore>,
sso_config: crate::sso::SsoConfig,
protect_routes: bool,
) -> Router {
let sso_manager = Arc::new(crate::sso::SsoManager::new(sso_config));
let mut app = Self::build(agent, sessions);
app = app.merge(crate::sso::sso_router(sso_manager.clone()));
if protect_routes {
let mw_state = Arc::new(crate::sso::SsoMiddlewareState {
manager: sso_manager,
});
app = app.layer(axum_mw::from_fn_with_state(
mw_state,
crate::sso::sso_auth_middleware,
));
}
app
}
}
async fn health_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
if let Some(ref mgr) = state.shutdown_manager {
if mgr.is_shutting_down().await {
return (
StatusCode::SERVICE_UNAVAILABLE,
[(header::CONTENT_TYPE, "application/json")],
serde_json::json!({
"status": "shutting_down",
"service": "argentor"
})
.to_string(),
)
.into_response();
}
}
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
serde_json::json!({"status": "ok", "service": "argentor"}).to_string(),
)
.into_response()
}
async fn health_live_handler() -> impl IntoResponse {
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
serde_json::json!({
"status": "alive",
"service": "argentor"
})
.to_string(),
)
.into_response()
}
async fn health_ready_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let mut checks: Vec<(&str, bool)> = Vec::new();
let not_shutting_down = if let Some(ref mgr) = state.shutdown_manager {
!mgr.is_shutting_down().await
} else {
true
};
checks.push(("not_shutting_down", not_shutting_down));
let rest_api_ready = state.rest_api.is_some();
checks.push(("rest_api", rest_api_ready));
let control_plane_ready = state.control_plane.is_some();
checks.push(("control_plane", control_plane_ready));
checks.push(("connections", true));
let all_ready = checks.iter().all(|(_, ok)| *ok);
let check_map: serde_json::Value = checks
.iter()
.map(|(name, ok)| {
(
name.to_string(),
serde_json::json!(if *ok { "ok" } else { "not_ready" }),
)
})
.collect::<serde_json::Map<String, serde_json::Value>>()
.into();
let status_code = if all_ready {
StatusCode::OK
} else {
StatusCode::SERVICE_UNAVAILABLE
};
(
status_code,
[(header::CONTENT_TYPE, "application/json")],
serde_json::json!({
"status": if all_ready { "ready" } else { "not_ready" },
"service": "argentor",
"checks": check_map,
})
.to_string(),
)
.into_response()
}
async fn openapi_handler() -> impl IntoResponse {
let spec = crate::openapi::argentor_openapi_spec();
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
serde_json::to_string_pretty(&spec).unwrap_or_default(),
)
.into_response()
}
async fn enterprise_readiness_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let active_connections = state.connections.connection_count().await;
let active_sessions = state.connections.session_ids().await.len();
let rest_api_mounted = state.rest_api.is_some();
let (skills_registered, uptime_seconds, sessions_reachable) =
if let Some(rest_api) = &state.rest_api {
let uptime = Utc::now().signed_duration_since(rest_api.started_at);
(
rest_api.skills.skill_count(),
uptime.num_seconds(),
rest_api.sessions.list().await.is_ok(),
)
} else {
(0, 0, false)
};
let report: EnterpriseReadinessReport =
build_enterprise_readiness_report(EnterpriseReadinessInput {
skills_registered,
active_connections,
active_sessions,
uptime_seconds,
sessions_reachable,
rest_api_mounted,
auth_configured: state.auth_enabled,
per_key_rate_limit_configured: state.per_key_rate_limiter.is_some(),
control_plane_mounted: state.control_plane.is_some(),
proxy_management_mounted: state.proxy_management_mounted,
a2a_mounted: state.a2a_mounted,
metrics_configured: state.metrics.is_some() || state.request_metrics.is_some(),
});
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
serde_json::to_string(&report).unwrap_or_default(),
)
.into_response()
}
async fn prometheus_metrics_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let mut body = String::new();
let mut has_metrics = false;
if let Some(collector) = &state.metrics {
body.push_str(&collector.prometheus_export());
has_metrics = true;
}
if let Some(req_metrics) = &state.request_metrics {
if has_metrics {
body.push('\n');
}
body.push_str(&req_metrics.prometheus_export());
has_metrics = true;
}
if let Some(rest_api) = &state.rest_api {
if has_metrics {
body.push('\n');
}
body.push_str(&audit_prometheus_export(rest_api).await);
has_metrics = true;
}
if has_metrics {
(
StatusCode::OK,
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
body,
)
.into_response()
} else {
let err = serde_json::json!({
"error": "Metrics collector not configured",
"hint": "Use GatewayServer::build_with_metrics() to enable Prometheus metrics"
})
.to_string();
(
StatusCode::SERVICE_UNAVAILABLE,
[(header::CONTENT_TYPE, "application/json")],
err,
)
.into_response()
}
}
#[tracing::instrument(skip(ws, state), fields(method = "GET", path = "/ws", status_code = tracing::field::Empty))]
async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> impl IntoResponse {
tracing::Span::current().record("status_code", 101u16);
ws.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: Arc<AppState>) {
let connection_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let (mut ws_sender, mut ws_receiver) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
let conn = Connection {
id: connection_id,
session_id,
tx,
};
state.connections.add(conn).await;
info!(
connection_id = %connection_id,
session_id = %session_id,
"WebSocket connected"
);
let welcome = serde_json::json!({
"type": "connected",
"session_id": session_id,
"connection_id": connection_id,
});
let _ = state
.connections
.send_to_session(session_id, &welcome.to_string())
.await;
use axum::extract::ws::Message as WsMessage;
use futures_util::SinkExt;
let send_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if ws_sender.send(WsMessage::Text(msg.into())).await.is_err() {
break;
}
}
});
use futures_util::StreamExt;
let router = state.router.clone();
let recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = ws_receiver.next().await {
match msg {
Message::Text(text) => {
let inbound: InboundMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(_) => InboundMessage {
session_id: Some(session_id),
content: text.to_string(),
},
};
if let Err(e) = router
.handle_message_streaming(inbound, connection_id)
.await
{
error!(error = %e, "Failed to handle message");
}
}
Message::Close(_) => break,
_ => {}
}
}
});
tokio::select! {
_ = send_task => {},
_ = recv_task => {},
}
state.connections.remove(connection_id).await;
info!(connection_id = %connection_id, "WebSocket disconnected");
}