use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use slotbus::transport::{Request as ShmRequest, SlotWorker};
use slotbus::SlotBusConfig;
use crate::types::*;
pub const WORKER_PING_ROUTE: &str = "/__ping__";
pub struct HandlerResponse {
pub status: u16,
pub body: Vec<u8>,
pub content_type: String,
pub headers: Vec<(String, String)>,
}
impl HandlerResponse {
pub fn json(status: u16, body: &impl serde::Serialize) -> Self {
Self {
status,
body: serde_json::to_vec(body).unwrap_or_default(),
content_type: "application/json".into(),
headers: Vec::new(),
}
}
pub fn ok_json(body: &impl serde::Serialize) -> Self {
Self::json(200, body)
}
pub fn error(status: u16, msg: &str) -> Self {
Self::json(status, &serde_json::json!({ "error": msg }))
}
pub fn bytes(status: u16, body: Vec<u8>, content_type: &str) -> Self {
Self {
status,
body,
content_type: content_type.into(),
headers: Vec::new(),
}
}
pub fn bytes_with_headers(
status: u16,
body: Vec<u8>,
content_type: &str,
headers: Vec<(String, String)>,
) -> Self {
Self {
status,
body,
content_type: content_type.into(),
headers,
}
}
pub fn ok() -> Self {
Self::json(200, &serde_json::json!({}))
}
pub fn status(code: u16) -> Self {
Self {
status: code,
body: Vec::new(),
content_type: "application/json".into(),
headers: Vec::new(),
}
}
}
type BoxHandler<S> = Box<
dyn Fn(Arc<S>, HashMap<String, String>, Option<String>, Vec<u8>)
-> Pin<Box<dyn Future<Output = HandlerResponse> + Send>>
+ Send
+ Sync,
>;
struct RouteEntry<S> {
method: String,
path: String,
handler: BoxHandler<S>,
sse: bool,
}
#[derive(Clone)]
pub struct HubEmitter {
hub_url: String,
source: String,
client: reqwest::Client,
}
impl HubEmitter {
pub async fn emit(&self, event_type: &str, data: &str) -> Result<(), String> {
let event = WorkerEvent {
source: self.source.clone(),
event_type: event_type.to_string(),
data: data.to_string(),
};
self.client
.post(format!("{}/internal/emit", self.hub_url))
.json(&event)
.send()
.await
.map_err(|e| format!("Failed to emit event: {e}"))?;
Ok(())
}
pub async fn sse_push(&self, path: &str, event_type: &str, data: &str) -> Result<(), String> {
let req = SsePushRequest {
path: Some(path.to_string()),
pattern: None,
params: None,
event_type: event_type.to_string(),
data: data.to_string(),
};
let resp = self
.client
.post(format!("{}/internal/sse-push", self.hub_url))
.json(&req)
.send()
.await
.map_err(|e| format!("Failed to push SSE event: {e}"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("SSE push failed ({status}): {body}"));
}
Ok(())
}
pub async fn sse_push_by_pattern(
&self,
pattern: &str,
params: HashMap<String, String>,
event_type: &str,
data: &str,
) -> Result<(), String> {
let req = SsePushRequest {
path: None,
pattern: Some(pattern.to_string()),
params: Some(params),
event_type: event_type.to_string(),
data: data.to_string(),
};
let resp = self
.client
.post(format!("{}/internal/sse-push", self.hub_url))
.json(&req)
.send()
.await
.map_err(|e| format!("Failed to push SSE event: {e}"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("SSE push failed ({status}): {body}"));
}
Ok(())
}
pub fn source(&self) -> &str {
&self.source
}
}
pub struct HubWorker<S: Send + Sync + 'static> {
hub_url: String,
name: String,
state: Arc<S>,
routes: Vec<RouteEntry<S>>,
client: reqwest::Client,
}
impl<S: Send + Sync + 'static> HubWorker<S> {
pub fn new(hub_url: &str, name: &str, state: S) -> Self {
Self {
hub_url: hub_url.trim_end_matches('/').to_string(),
name: name.to_string(),
state: Arc::new(state),
routes: Vec::new(),
client: reqwest::Client::new(),
}
}
pub fn route<F, Fut>(mut self, method: &str, path: &str, handler: F) -> Self
where
F: Fn(Arc<S>, HashMap<String, String>, Option<String>, Vec<u8>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = HandlerResponse> + Send + 'static,
{
let boxed: BoxHandler<S> = Box::new(move |state, params, query, body| {
Box::pin(handler(state, params, query, body))
});
self.routes.push(RouteEntry {
method: method.to_string(),
path: path.to_string(),
handler: boxed,
sse: false,
});
self
}
pub fn sse_route<F, Fut>(mut self, path: &str, handler: F) -> Self
where
F: Fn(Arc<S>, HashMap<String, String>, Option<String>, Vec<u8>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = HandlerResponse> + Send + 'static,
{
let boxed: BoxHandler<S> = Box::new(move |state, params, query, body| {
Box::pin(handler(state, params, query, body))
});
self.routes.push(RouteEntry {
method: "GET".to_string(),
path: path.to_string(),
handler: boxed,
sse: true,
});
self
}
pub fn emitter(&self) -> HubEmitter {
HubEmitter {
hub_url: self.hub_url.clone(),
source: self.name.clone(),
client: self.client.clone(),
}
}
pub fn state(&self) -> Arc<S> {
self.state.clone()
}
pub async fn run(self) -> Result<(), String> {
let Self {
hub_url,
name,
state,
routes,
client,
} = self;
let routes: Arc<Vec<RouteEntry<S>>> = Arc::new(routes);
const MAX_BACKOFF_SECS: u64 = 30;
let mut attempt: u32 = 0;
loop {
match run_once(&hub_url, &name, &client, &routes, &state).await {
Ok(()) => {
return Ok(());
}
Err(e) => {
let delay_secs = (1u64 << attempt.min(5)).min(MAX_BACKOFF_SECS);
eprintln!(
"[{}] Hub connection lost: {e}. Reconnecting in {delay_secs}s...",
name.to_uppercase()
);
tokio::time::sleep(std::time::Duration::from_secs(delay_secs)).await;
attempt = attempt.saturating_add(1);
}
}
}
}
}
const HUB_HEALTH_POLL: std::time::Duration = std::time::Duration::from_secs(5);
async fn run_once<S: Send + Sync + 'static>(
hub_url: &str,
name: &str,
client: &reqwest::Client,
routes: &Arc<Vec<RouteEntry<S>>>,
state: &Arc<S>,
) -> Result<(), String> {
let mut route_regs: Vec<RouteRegistration> = routes
.iter()
.map(|r| RouteRegistration {
method: r.method.clone(),
path: r.path.clone(),
sse: r.sse,
})
.collect();
route_regs.push(RouteRegistration {
method: "GET".to_string(),
path: WORKER_PING_ROUTE.to_string(),
sse: false,
});
let reg_req = RegisterRequest {
name: name.to_string(),
routes: route_regs,
};
let resp = client
.post(format!("{hub_url}/internal/register"))
.json(®_req)
.send()
.await
.map_err(|e| format!("Failed to register with hub: {e}"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("Hub registration failed ({status}): {body}"));
}
let reg_resp: RegisterResponse = resp
.json()
.await
.map_err(|e| format!("Invalid registration response: {e}"))?;
let worker_id = reg_resp.worker_id.clone();
eprintln!(
"[{}] Registered with hub (worker_id={}, shm={}, {} routes)",
name.to_uppercase(),
worker_id,
reg_resp.shm_name,
routes.len()
);
let worker_config = SlotBusConfig::builder()
.name(name)
.prefix("hub")
.build();
let transport = Arc::new(
SlotWorker::open(worker_config)
.map_err(|e| format!("Failed to open SHM: {e}"))?,
);
let rt_handle = tokio::runtime::Handle::current();
let transport_for_loop = Arc::clone(&transport);
let routes_for_loop = Arc::clone(routes);
let state_for_loop = Arc::clone(state);
let name_for_loop = name.to_string();
let receive_join = transport_for_loop.start_receive_loop(move |worker, slot_index, request| {
let routes = Arc::clone(&routes_for_loop);
let state = Arc::clone(&state_for_loop);
let worker_name = name_for_loop.clone();
let worker = Arc::clone(&worker);
rt_handle.spawn(async move {
let response = dispatch_request(&worker_name, &routes, state, request).await;
if let Err(e) = worker.send_response(
slot_index,
response.status,
response.body,
&response.content_type,
response.headers,
) {
eprintln!(
"[{}] Failed to write response to slot {slot_index}: {e}",
worker_name.to_uppercase()
);
}
});
});
let health_url = format!("{hub_url}/health");
let reason = loop {
tokio::time::sleep(HUB_HEALTH_POLL).await;
match check_worker_registered(client, &health_url, &worker_id).await {
Ok(true) => continue,
Ok(false) => {
break format!("worker_id {worker_id} missing from /health (hub restarted?)")
}
Err(e) => break format!("health check error: {e}"),
}
};
transport.stop();
let _ = receive_join.join();
Err(reason)
}
async fn check_worker_registered(
client: &reqwest::Client,
health_url: &str,
worker_id: &str,
) -> Result<bool, String> {
let resp = client
.get(health_url)
.send()
.await
.map_err(|e| format!("GET {health_url}: {e}"))?;
if !resp.status().is_success() {
return Err(format!("health returned status {}", resp.status()));
}
let body: serde_json::Value = resp
.json()
.await
.map_err(|e| format!("invalid health JSON: {e}"))?;
let Some(workers) = body.get("workers").and_then(|w| w.as_array()) else {
return Err("health response missing 'workers' array".to_string());
};
Ok(workers
.iter()
.any(|w| w.get("worker_id").and_then(|id| id.as_str()) == Some(worker_id)))
}
async fn dispatch_request<S: Send + Sync + 'static>(
_worker_name: &str,
routes: &[RouteEntry<S>],
state: Arc<S>,
request: ShmRequest,
) -> HandlerResponse {
if request.method == "GET" && request.route_pattern == WORKER_PING_ROUTE {
return HandlerResponse::ok();
}
let handler = routes
.iter()
.find(|r| r.method == request.method && r.path == request.route_pattern);
if let Some(route) = handler {
(route.handler)(state, request.path_params, request.query, request.body).await
} else {
HandlerResponse::error(
404,
&format!(
"No handler for {} {}",
request.method, request.route_pattern
),
)
}
}