use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use axum::extract::{Request, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tower_http::cors::CorsLayer;
use slotbus::types::RequestMeta;
use slotbus::{SlotBus, SlotBusConfig};
use crate::events;
use slotbus_hub::types::*;
pub const DEFAULT_SSE_SWEEP_INTERVAL_MS: u64 = 5_000;
const WORKER_PING_ROUTE: &str = "/__ping__";
const LIVENESS_REAP_INTERVAL: Duration = Duration::from_secs(10);
const LIVENESS_PING_TIMEOUT: Duration = Duration::from_secs(2);
const LIVENESS_MAX_MISSES: u32 = 3;
pub struct HubConfig {
pub timeout_secs: u64,
pub num_slots: usize,
pub region_size: usize,
pub instrumentation: bool,
pub sse_sweep_interval_ms: u64,
}
pub struct SseEvent {
pub event_type: String,
pub data: String,
}
pub struct SseConnectionInfo {
pub connection_id: String,
pub worker_id: String,
pub path_pattern: String,
pub params: HashMap<String, String>,
pub sender: mpsc::Sender<SseEvent>,
pub connected_at: String,
}
pub struct HubState {
workers: RwLock<HashMap<String, WorkerRecord>>,
route_table: RwLock<Vec<RouteEntry>>,
pub event_tx: broadcast::Sender<HubEvent>,
pub sse_connections: RwLock<HashMap<String, SseConnectionInfo>>,
config: HubConfig,
}
struct WorkerRecord {
name: String,
routes: Vec<RouteRegistration>,
bus: Arc<SlotBus>,
}
struct RouteEntry {
method: String,
pattern: String,
worker_id: String,
sse: bool,
}
pub fn build_router(config: HubConfig) -> Router {
let (event_tx, _) = broadcast::channel::<HubEvent>(256);
let sweep_interval_ms = config.sse_sweep_interval_ms;
let state = Arc::new(HubState {
workers: RwLock::new(HashMap::new()),
route_table: RwLock::new(Vec::new()),
event_tx,
sse_connections: RwLock::new(HashMap::new()),
config,
});
if sweep_interval_ms > 0 {
spawn_sse_sweeper(Arc::clone(&state), Duration::from_millis(sweep_interval_ms));
}
spawn_liveness_reaper(Arc::clone(&state));
Router::new()
.route("/internal/register", post(register_worker))
.route("/internal/emit", post(events::emit_event))
.route("/internal/sse-push", post(events::sse_push))
.route("/internal/slots", get(slot_diagnostics))
.route("/events", get(events::unified_sse))
.route("/events/{channel}", get(events::scoped_sse))
.route("/health", get(health))
.fallback(proxy_handler)
.layer(CorsLayer::permissive())
.with_state(state)
}
fn spawn_sse_sweeper(state: Arc<HubState>, interval: Duration) {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
ticker.tick().await;
tracing::info!(
interval_ms = interval.as_millis() as u64,
"SSE sweeper started"
);
loop {
ticker.tick().await;
sweep_once(&state).await;
}
});
}
async fn sweep_once(state: &Arc<HubState>) {
let probes: Vec<(String, String, mpsc::Sender<SseEvent>)> = {
let sse = state.sse_connections.read().await;
sse.iter()
.map(|(path, info)| {
(
path.clone(),
info.connection_id.clone(),
info.sender.clone(),
)
})
.collect()
};
if probes.is_empty() {
return;
}
let mut closed: Vec<(String, String)> = Vec::new();
for (path, connection_id, sender) in &probes {
match sender.try_send(SseEvent {
event_type: "sse_ping".to_string(),
data: String::new(),
}) {
Ok(()) => {
}
Err(mpsc::error::TrySendError::Full(_)) => {
}
Err(mpsc::error::TrySendError::Closed(_)) => {
closed.push((path.clone(), connection_id.clone()));
}
}
}
if closed.is_empty() {
return;
}
tracing::debug!(count = closed.len(), "SSE sweeper found closed channels");
for (path, connection_id) in closed {
let state = Arc::clone(state);
tokio::spawn(async move {
cleanup_sse_connection(&state, &path, "sweeper", &connection_id).await;
});
}
}
fn spawn_liveness_reaper(state: Arc<HubState>) {
tokio::spawn(async move {
let mut misses: HashMap<String, u32> = HashMap::new();
let mut ticker = tokio::time::interval(LIVENESS_REAP_INTERVAL);
ticker.tick().await;
loop {
ticker.tick().await;
let snapshot: Vec<(String, String, Arc<SlotBus>)> = {
let workers = state.workers.read().await;
workers
.iter()
.map(|(id, r)| (id.clone(), r.name.clone(), Arc::clone(&r.bus)))
.collect()
};
misses.retain(|id, _| snapshot.iter().any(|(wid, _, _)| wid == id));
for (worker_id, name, bus) in snapshot {
if ping_worker(&bus).await {
misses.remove(&worker_id);
continue;
}
let count = misses.entry(worker_id.clone()).or_insert(0);
*count += 1;
if *count >= LIVENESS_MAX_MISSES {
tracing::warn!(
name,
worker_id,
misses = *count,
"evicting unresponsive worker (SHM liveness probe failed)"
);
evict_worker(&state, &worker_id).await;
misses.remove(&worker_id);
} else {
tracing::debug!(name, worker_id, misses = *count, "worker liveness probe miss");
}
}
}
});
}
async fn ping_worker(bus: &Arc<SlotBus>) -> bool {
let meta = RequestMeta {
path: WORKER_PING_ROUTE.to_string(),
route_pattern: WORKER_PING_ROUTE.to_string(),
path_params: Vec::new(),
query: None,
headers: Vec::new(),
};
let req_id = uuid::Uuid::new_v4().to_string();
let rx = match bus.dispatch(&req_id, "GET", &meta, &[]) {
Ok(rx) => rx,
Err(_) => return false,
};
matches!(tokio::time::timeout(LIVENESS_PING_TIMEOUT, rx).await, Ok(Ok(_)))
}
async fn evict_worker(state: &Arc<HubState>, worker_id: &str) {
{
let mut workers = state.workers.write().await;
workers.remove(worker_id);
}
{
let mut table = state.route_table.write().await;
table.retain(|entry| entry.worker_id != worker_id);
}
}
async fn register_worker(
State(state): State<Arc<HubState>>,
Json(req): Json<RegisterRequest>,
) -> impl IntoResponse {
let worker_id = uuid::Uuid::new_v4().to_string();
let route_count = req.routes.len();
let stale_ids: Vec<String>;
{
let mut workers = state.workers.write().await;
stale_ids = workers
.iter()
.filter(|(_, record)| record.name == req.name)
.map(|(id, _)| id.clone())
.collect();
if !stale_ids.is_empty() {
let mut table = state.route_table.write().await;
for stale_id in &stale_ids {
workers.remove(stale_id);
table.retain(|entry| entry.worker_id != *stale_id);
}
tracing::info!(
name = req.name,
count = stale_ids.len(),
"removed stale worker(s)"
);
}
}
let config = SlotBusConfig::builder()
.name(&req.name)
.prefix("hub")
.num_slots(state.config.num_slots)
.region_size(state.config.region_size)
.instrumentation(state.config.instrumentation)
.build();
let bus = match SlotBus::create(config) {
Ok(bus) => Arc::new(bus),
Err(e) => {
tracing::error!(name = req.name, error = %e, "failed to create SlotBus");
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create SHM: {e}"),
)
.into_response();
}
};
bus.start_response_watcher();
let shm_name = bus.region_name();
{
let mut table = state.route_table.write().await;
for route in &req.routes {
table.push(RouteEntry {
method: route.method.clone(),
pattern: route.path.clone(),
worker_id: worker_id.clone(),
sse: route.sse,
});
}
}
{
let mut workers = state.workers.write().await;
workers.insert(
worker_id.clone(),
WorkerRecord {
name: req.name.clone(),
routes: req.routes,
bus,
},
);
}
if !stale_ids.is_empty() {
let replay = {
let mut sse = state.sse_connections.write().await;
let mut to_replay = Vec::new();
for (path, info) in sse.iter_mut() {
if stale_ids.contains(&info.worker_id) {
to_replay.push((
path.clone(),
info.params.clone(),
info.path_pattern.clone(),
));
info.worker_id = worker_id.clone();
}
}
to_replay
};
if !replay.is_empty() {
let replay_count = replay.len();
let replay_state = Arc::clone(&state);
let replay_worker_id = worker_id.clone();
let replay_name = req.name.clone();
tracing::info!(
name = req.name,
worker_id,
count = replay_count,
"replaying SSE connections to new worker (background)"
);
tokio::spawn(async move {
let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
for (path, params, pattern) in replay {
if tokio::time::Instant::now() >= deadline {
tracing::warn!(
name = replay_name,
worker_id = replay_worker_id,
"SSE replay timed out (30s overall), skipping remaining"
);
break;
}
dispatch_sse_lifecycle(
&replay_state,
&replay_worker_id,
"connect",
&path,
&pattern,
¶ms,
)
.await;
}
});
}
}
tracing::info!(
name = req.name,
worker_id,
route_count,
shm_name,
"registered worker"
);
Json(RegisterResponse {
worker_id,
shm_name,
})
.into_response()
}
async fn health(State(state): State<Arc<HubState>>) -> impl IntoResponse {
let workers = state.workers.read().await;
let worker_info: Vec<WorkerInfo> = workers
.iter()
.map(|(id, record)| WorkerInfo {
name: record.name.clone(),
worker_id: id.clone(),
route_count: record.routes.len(),
transport: format!("shm:{}", record.bus.region_name()),
})
.collect();
Json(HealthResponse {
ok: true,
workers: worker_info,
})
}
async fn slot_diagnostics(State(state): State<Arc<HubState>>) -> impl IntoResponse {
let workers = state.workers.read().await;
let mut entries = Vec::new();
for record in workers.values() {
let diag = record.bus.slot_diagnostics();
let total = diag.len();
let mut free = 0usize;
let mut in_use = 0usize;
let mut slots = Vec::with_capacity(total);
for (index, raw) in &diag {
let label = match *raw {
0 => "free",
1 => "ready",
2 => "claimed",
3 => "done",
4 => "writing",
_ => "unknown",
};
if *raw == 0 {
free += 1;
} else {
in_use += 1;
}
slots.push(serde_json::json!({ "index": index, "state": label }));
}
entries.push(serde_json::json!({
"name": record.name,
"total_slots": total,
"free": free,
"in_use": in_use,
"slots": slots,
}));
}
Json(serde_json::json!({ "workers": entries }))
}
async fn proxy_handler(State(state): State<Arc<HubState>>, request: Request) -> impl IntoResponse {
let method = request.method().to_string();
let path = request.uri().path().to_string();
let query = request.uri().query().map(|q| q.to_string());
let headers: Vec<(String, String)> = request
.headers()
.iter()
.filter_map(|(k, v)| {
let key = k.as_str().to_string();
let val = v.to_str().ok()?.to_string();
Some((key, val))
})
.collect();
let body_bytes = match axum::body::to_bytes(request.into_body(), 10 * 1024 * 1024).await {
Ok(b) => b,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Failed to read request body: {e}"),
)
.into_response();
}
};
let route_table = state.route_table.read().await;
let matched = route_table
.iter()
.find(|entry| entry.method == method && match_route(&entry.pattern, &path).is_some());
let (route_pattern, worker_id, is_sse) = match matched {
Some(entry) => (entry.pattern.clone(), entry.worker_id.clone(), entry.sse),
None => {
return (
StatusCode::BAD_GATEWAY,
format!("No worker registered for {method} {path}"),
)
.into_response();
}
};
drop(route_table);
let path_params = match_route(&route_pattern, &path).unwrap_or_default();
if is_sse {
return handle_sse_delegation(state, worker_id, path, route_pattern, path_params)
.await
.into_response();
}
let workers = state.workers.read().await;
let bus = match workers.get(&worker_id) {
Some(record) => Arc::clone(&record.bus),
None => {
return (
StatusCode::BAD_GATEWAY,
format!("Worker {worker_id} not found"),
)
.into_response();
}
};
drop(workers);
let meta = RequestMeta {
path: path.clone(),
route_pattern,
path_params: path_params.into_iter().collect(),
query,
headers,
};
let req_id = uuid::Uuid::new_v4().to_string();
let rx = match bus.dispatch(&req_id, &method, &meta, &body_bytes) {
Ok(rx) => rx,
Err(e) => {
return (StatusCode::BAD_GATEWAY, format!("SHM dispatch failed: {e}")).into_response();
}
};
let timeout = Duration::from_secs(state.config.timeout_secs);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(resp)) => {
let status =
StatusCode::from_u16(resp.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut response = axum::response::Response::builder().status(status);
response = response.header("content-type", &resp.content_type);
for (key, value) in &resp.headers {
response = response.header(key.as_str(), value.as_str());
}
response
.body(axum::body::Body::from(resp.body))
.unwrap()
.into_response()
}
Ok(Err(_)) => (
StatusCode::BAD_GATEWAY,
"Worker dropped the request".to_string(),
)
.into_response(),
Err(_) => (
StatusCode::GATEWAY_TIMEOUT,
format!(
"Request to {path} timed out after {}s",
state.config.timeout_secs
),
)
.into_response(),
}
}
struct SseDropGuard {
connection_id: String,
path: String,
state: Arc<HubState>,
}
impl Drop for SseDropGuard {
fn drop(&mut self) {
let connection_id = self.connection_id.clone();
let path = self.path.clone();
let state = Arc::clone(&self.state);
tokio::spawn(async move {
cleanup_sse_connection(&state, &path, "drop_guard", &connection_id).await;
});
}
}
async fn handle_sse_delegation(
state: Arc<HubState>,
worker_id: String,
path: String,
route_pattern: String,
path_params: HashMap<String, String>,
) -> impl IntoResponse {
let (tx, rx) = mpsc::channel::<SseEvent>(256);
let connection_id = uuid::Uuid::new_v4().to_string();
let connected_at = {
use std::time::SystemTime;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
let secs = now.as_secs();
format!("{}Z", secs)
};
{
let mut sse = state.sse_connections.write().await;
sse.insert(
path.clone(),
SseConnectionInfo {
connection_id: connection_id.clone(),
worker_id: worker_id.clone(),
path_pattern: route_pattern.clone(),
params: path_params.clone(),
sender: tx,
connected_at,
},
);
}
tracing::info!(path, worker_id, connection_id, "SSE client connected (delegated)");
dispatch_sse_lifecycle(&state, &worker_id, "connect", &path, &route_pattern, &path_params)
.await;
let drop_guard = SseDropGuard {
connection_id,
path: path.clone(),
state: Arc::clone(&state),
};
let stream = ReceiverStream::new(rx).map(move |sse_event| {
let _ = &drop_guard;
Ok::<_, std::convert::Infallible>(
Event::default()
.event(&sse_event.event_type)
.data(sse_event.data),
)
});
Sse::new(stream).keep_alive(KeepAlive::default())
}
pub(crate) async fn cleanup_sse_connection(
state: &Arc<HubState>,
path: &str,
reason: &'static str,
expected_connection_id: &str,
) {
let (worker_id, path_pattern, params) = {
let mut sse = state.sse_connections.write().await;
let info = match sse.get(path) {
Some(info) => info,
None => return,
};
if info.connection_id != expected_connection_id {
tracing::debug!(
path,
old_conn = expected_connection_id,
current_conn = %info.connection_id,
reason,
"SSE cleanup skipped: entry belongs to a newer connection"
);
return;
}
let snapshot = (
info.worker_id.clone(),
info.path_pattern.clone(),
info.params.clone(),
);
sse.remove(path);
snapshot
};
tracing::info!(path, worker_id, reason, "SSE client disconnected");
dispatch_sse_lifecycle(
state,
&worker_id,
"disconnect",
path,
&path_pattern,
¶ms,
)
.await;
}
async fn dispatch_sse_lifecycle(
state: &HubState,
worker_id: &str,
lifecycle: &str,
path: &str,
route_pattern: &str,
params: &HashMap<String, String>,
) {
let workers = state.workers.read().await;
let bus = match workers.get(worker_id) {
Some(record) => Arc::clone(&record.bus),
None => {
tracing::warn!(
worker_id,
lifecycle,
path,
"cannot send SSE lifecycle: worker not found"
);
return;
}
};
drop(workers);
let body = SseLifecycle {
sse_lifecycle: lifecycle.to_string(),
params: params.clone(),
};
let body_bytes = serde_json::to_vec(&body).unwrap_or_default();
let meta = RequestMeta {
path: path.to_string(),
route_pattern: route_pattern.to_string(),
path_params: params.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
query: None,
headers: Vec::new(),
};
let req_id = uuid::Uuid::new_v4().to_string();
match bus.dispatch(&req_id, "GET", &meta, &body_bytes) {
Ok(rx) => {
if lifecycle == "connect" {
let timeout = Duration::from_secs(5);
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(resp)) => {
tracing::debug!(
path,
status = resp.status,
"SSE connect acknowledged by worker"
);
}
Ok(Err(_)) => {
tracing::warn!(path, "Worker dropped SSE connect notification");
}
Err(_) => {
tracing::warn!(path, "SSE connect notification timed out (5s)");
}
}
} else {
drop(rx);
}
}
Err(e) => {
tracing::warn!(
path,
lifecycle,
error = %e,
"failed to dispatch SSE lifecycle via SHM"
);
}
}
}
fn match_route(pattern: &str, path: &str) -> Option<HashMap<String, String>> {
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let path_parts: Vec<&str> = path.split('/').collect();
if pattern_parts.len() != path_parts.len() {
return None;
}
let mut params = HashMap::new();
for (pat, actual) in pattern_parts.iter().zip(path_parts.iter()) {
if pat.starts_with('{') && pat.ends_with('}') {
params.insert(pat[1..pat.len() - 1].to_string(), actual.to_string());
} else if pat != actual {
return None;
}
}
Some(params)
}
pub fn resolve_sse_path(
path: Option<&str>,
pattern: Option<&str>,
params: Option<&HashMap<String, String>>,
) -> Option<String> {
if let Some(p) = path {
return Some(p.to_string());
}
if let (Some(pat), Some(par)) = (pattern, params) {
let mut resolved = pat.to_string();
for (key, value) in par {
resolved = resolved.replace(&format!("{{{key}}}"), value);
}
if !resolved.contains('{') {
return Some(resolved);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn make_state() -> Arc<HubState> {
let (event_tx, _rx) = broadcast::channel::<HubEvent>(16);
Arc::new(HubState {
workers: RwLock::new(HashMap::new()),
route_table: RwLock::new(Vec::new()),
event_tx,
sse_connections: RwLock::new(HashMap::new()),
config: HubConfig {
timeout_secs: 30,
num_slots: 32,
region_size: 1024 * 1024,
instrumentation: false,
sse_sweep_interval_ms: 0,
},
})
}
async fn insert_conn(
state: &HubState,
path: &str,
connection_id: &str,
) -> mpsc::Receiver<SseEvent> {
let (tx, rx) = mpsc::channel::<SseEvent>(16);
let mut sse = state.sse_connections.write().await;
sse.insert(
path.to_string(),
SseConnectionInfo {
connection_id: connection_id.to_string(),
worker_id: "test-worker".to_string(),
path_pattern: path.to_string(),
params: HashMap::new(),
sender: tx,
connected_at: "0Z".to_string(),
},
);
rx
}
async fn assert_present(state: &HubState, path: &str) {
let sse = state.sse_connections.read().await;
assert!(
sse.contains_key(path),
"expected entry present for {path}, but was missing"
);
}
async fn assert_missing(state: &HubState, path: &str) {
let sse = state.sse_connections.read().await;
assert!(
!sse.contains_key(path),
"expected entry absent for {path}, but it was present"
);
}
async fn drain_spawned_cleanup() {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
}
#[tokio::test]
async fn cleanup_with_matching_connection_id_removes_entry() {
let state = make_state();
let path = "/agent/events/session-a";
let _rx = insert_conn(&state, path, "conn-1").await;
cleanup_sse_connection(&state, path, "test", "conn-1").await;
assert_missing(&state, path).await;
}
#[tokio::test]
async fn cleanup_with_mismatched_connection_id_preserves_entry() {
let state = make_state();
let path = "/agent/events/session-collision";
let _rx = insert_conn(&state, path, "conn-B").await;
cleanup_sse_connection(&state, path, "test", "conn-A").await;
assert_present(&state, path).await;
let sse = state.sse_connections.read().await;
assert_eq!(sse.get(path).unwrap().connection_id, "conn-B");
}
#[tokio::test]
async fn cleanup_on_missing_path_is_noop() {
let state = make_state();
cleanup_sse_connection(&state, "/nothing/here", "test", "conn-X").await;
let sse = state.sse_connections.read().await;
assert!(sse.is_empty());
}
#[tokio::test]
async fn sweep_removes_closed_channel() {
let state = make_state();
let path = "/agent/events/dead-client";
let rx = insert_conn(&state, path, "conn-dead").await;
drop(rx);
sweep_once(&state).await;
drain_spawned_cleanup().await;
assert_missing(&state, path).await;
}
#[tokio::test]
async fn sweep_keeps_live_channel() {
let state = make_state();
let path = "/agent/events/live-client";
let _rx = insert_conn(&state, path, "conn-live").await;
sweep_once(&state).await;
drain_spawned_cleanup().await;
assert_present(&state, path).await;
}
#[tokio::test]
async fn sweep_ignores_full_buffer_as_alive() {
let state = make_state();
let path = "/agent/events/slow-client";
let _rx = insert_conn(&state, path, "conn-slow").await;
{
let sse = state.sse_connections.read().await;
let sender = sse.get(path).unwrap().sender.clone();
for i in 0..16 {
sender
.try_send(SseEvent {
event_type: "filler".into(),
data: format!("{i}"),
})
.expect("pre-fill should succeed");
}
}
sweep_once(&state).await;
drain_spawned_cleanup().await;
assert_present(&state, path).await;
}
#[tokio::test]
async fn sweep_cleans_only_closed_entries() {
let state = make_state();
let _alive_rx = insert_conn(&state, "/agent/events/alive", "conn-alive").await;
let dead_rx = insert_conn(&state, "/agent/events/dead", "conn-dead").await;
let _slow_rx = insert_conn(&state, "/agent/events/slow", "conn-slow").await;
drop(dead_rx);
sweep_once(&state).await;
drain_spawned_cleanup().await;
assert_present(&state, "/agent/events/alive").await;
assert_missing(&state, "/agent/events/dead").await;
assert_present(&state, "/agent/events/slow").await;
}
}