use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use breaker_machines::CircuitBreaker;
use rama::{
Layer, Service,
extensions::ExtensionsRef,
http::{
Body, Request, Response, StatusCode, Uri,
client::{EasyHttpWebClient, HttpConnector},
header::{HOST, HeaderName, HeaderValue},
io::upgrade,
layer::{compression::CompressionLayer, trace::TraceLayer},
server::HttpServer,
service::fs::ServeDir,
ws::{
AsyncWebSocket,
handshake::{client::HttpClientWebSocketExt, server::WebSocketMatcher},
protocol::Role,
},
},
matcher::Matcher,
net::{
client::{ConnectorService, EstablishedClientConnection},
fingerprint::Ja4H,
stream::{SocketInfo, layer::http::BodyLimitLayer},
},
proxy::haproxy::server::HaProxyLayer,
rt::Executor,
service::service_fn,
tcp::server::TcpListener,
ua::{UserAgent, layer::classifier::UserAgentClassifierLayer},
unix::client::UnixConnector,
};
use regex::Regex;
use throttle_machines::token_bucket;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, info, warn};
use super::cors_cache::{CorsCache, CorsCacheKey};
use super::sensors::MetricsRegistry;
use crate::charter::{Bind, Manifest, StaticDirConfig, UaFilter};
use crate::docking::{Boarding, Cargo, Disembark, DockingConnector, next_conn_id};
use crate::payload::{ModuleAction, ModuleRuntime, RequestInfo, ResponseInfo};
fn build_block_response(
status: u16,
body: String,
headers: Option<HashMap<String, String>>,
) -> Response {
let mut builder =
Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::FORBIDDEN));
if let Some(hdrs) = headers {
for (key, value) in hdrs {
if let Ok(name) = key.parse::<HeaderName>() {
builder = builder.header(name, value);
}
}
}
builder.body(Body::from(body)).unwrap()
}
#[derive(Clone)]
enum BackendTarget {
Tcp { base_uri: String },
Unix { socket_path: String },
Docked { bay_name: String },
}
impl std::fmt::Display for BackendTarget {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackendTarget::Tcp { base_uri } => write!(f, "{}", base_uri),
BackendTarget::Unix { socket_path } => write!(f, "unix://{}", socket_path),
BackendTarget::Docked { bay_name } => write!(f, "docked://{}", bay_name),
}
}
}
#[derive(Clone)]
struct Route {
pattern: Regex,
backend: BackendTarget,
ship_name: String,
route_key: String,
circuit_breaker: Arc<Mutex<CircuitBreaker>>,
strip_prefix: Option<String>,
ua_filter: Option<UaFilter>,
}
fn create_circuit_breaker(name: &str) -> CircuitBreaker {
CircuitBreaker::builder(name)
.failure_threshold(5) .failure_window_secs(60.0) .half_open_timeout_secs(30.0) .success_threshold(2) .on_open(|name| warn!(backend = %name, "Circuit breaker opened"))
.on_close(|name| info!(backend = %name, "Circuit breaker closed"))
.build()
}
#[derive(Debug, Clone)]
struct RateLimitState {
tokens: f64,
last_refill: f64,
}
impl Default for RateLimitState {
fn default() -> Self {
Self {
tokens: 0.0,
last_refill: 0.0,
}
}
}
struct RateLimiter {
global_state: Arc<Mutex<RateLimitState>>,
per_ip_states: Arc<dashmap::DashMap<String, RateLimitState>>,
global_capacity: f64,
global_refill_rate: f64,
per_ip_capacity: f64,
per_ip_refill_rate: f64,
}
impl RateLimiter {
fn new(global_rps: Option<f64>, per_ip_rpm: Option<f64>) -> Self {
let global_capacity = global_rps.unwrap_or(f64::INFINITY);
let global_refill_rate = global_rps.unwrap_or(f64::INFINITY);
let per_ip_capacity = per_ip_rpm.unwrap_or(f64::INFINITY);
let per_ip_refill_rate = per_ip_rpm.map(|rpm| rpm / 60.0).unwrap_or(f64::INFINITY);
Self {
global_state: Arc::new(Mutex::new(RateLimitState {
tokens: global_capacity, last_refill: 0.0,
})),
per_ip_states: Arc::new(dashmap::DashMap::new()),
global_capacity,
global_refill_rate,
per_ip_capacity,
per_ip_refill_rate,
}
}
async fn check_request(&self, client_ip: &str) -> Result<(), f64> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs_f64();
{
let mut state = self.global_state.lock().await;
let result = token_bucket::check(
state.tokens,
state.last_refill,
now,
self.global_capacity,
self.global_refill_rate,
);
if !result.allowed {
warn!(
retry_after = result.retry_after,
"Global rate limit exceeded"
);
return Err(result.retry_after);
}
state.tokens = result.new_tokens;
state.last_refill = now;
}
let mut ip_state = self
.per_ip_states
.entry(client_ip.to_string())
.or_insert_with(|| RateLimitState {
tokens: self.per_ip_capacity, last_refill: now,
});
let result = token_bucket::check(
ip_state.tokens,
ip_state.last_refill,
now,
self.per_ip_capacity,
self.per_ip_refill_rate,
);
if !result.allowed {
warn!(
client_ip = %client_ip,
retry_after = result.retry_after,
"Per-IP rate limit exceeded"
);
return Err(result.retry_after);
}
ip_state.tokens = result.new_tokens;
ip_state.last_refill = now;
Ok(())
}
#[allow(dead_code)]
async fn cleanup_stale_entries(&self, max_entries: usize) {
if self.per_ip_states.len() > max_entries {
let to_remove = (self.per_ip_states.len() - max_entries).max(max_entries / 10);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs_f64();
let stale_threshold = now - 300.0; let stale_ips: Vec<String> = self
.per_ip_states
.iter()
.filter(|entry| entry.value().last_refill < stale_threshold)
.map(|entry| entry.key().clone())
.take(to_remove)
.collect();
for ip in stale_ips {
self.per_ip_states.remove(&ip);
}
debug!(removed = to_remove, "Cleaned up stale rate limit entries");
}
}
}
#[derive(Clone, Default)]
struct BindRoutes {
routes: Vec<Route>,
}
pub struct HttpExposure {
binds: HashMap<String, Bind>,
routes: HashMap<String, BindRoutes>,
module_runtime: Arc<RwLock<Option<ModuleRuntime>>>,
module_configs: Vec<crate::charter::Module>,
static_dirs: Vec<StaticDirConfig>,
compression: bool,
bay_connectors: Arc<RwLock<HashMap<String, Arc<DockingConnector>>>>,
cors_cache: Option<Arc<CorsCache>>,
rate_limiter: Option<Arc<RateLimiter>>,
}
impl HttpExposure {
pub fn from_manifest(manifest: &Manifest) -> Option<Self> {
if manifest.mothership.bind.is_empty() {
return None;
}
let binds = manifest.mothership.bind.clone();
let mut routes: HashMap<String, BindRoutes> = HashMap::new();
for bind_name in binds.keys() {
routes.insert(bind_name.clone(), BindRoutes::default());
}
for vessel in &manifest.vessels {
match vessel {
crate::charter::Vessel::Ship { config, .. } => {
if config.routes.is_empty() {
continue;
}
let backend = match &config.bind {
Some(Bind::Tcp { host, port, .. }) => {
let h = if host == "0.0.0.0" { "127.0.0.1" } else { host };
BackendTarget::Tcp {
base_uri: format!("http://{}:{}", h, port),
}
}
Some(Bind::Unix { path }) => BackendTarget::Unix {
socket_path: path.clone(),
},
None => {
warn!(
ship = %config.name,
"Ship has routes but no bind address, skipping"
);
continue;
}
};
for route_config in &config.routes {
if !binds.contains_key(&route_config.bind) {
warn!(
ship = %config.name,
bind = %route_config.bind,
"Route references unknown bind, skipping"
);
continue;
}
match Regex::new(&route_config.pattern) {
Ok(regex) => {
let bind_routes =
routes.entry(route_config.bind.clone()).or_default();
let breaker_name = format!("{}:{}", config.name, route_config.bind);
bind_routes.routes.push(Route {
pattern: regex,
backend: backend.clone(),
ship_name: config.name.clone(),
route_key: format!(
"{}:{}",
route_config.bind, route_config.pattern
),
circuit_breaker: Arc::new(Mutex::new(create_circuit_breaker(
&breaker_name,
))),
strip_prefix: route_config.strip_prefix.clone(),
ua_filter: route_config.ua_filter.clone(),
});
info!(
ship = %config.name,
bind = %route_config.bind,
pattern = %route_config.pattern,
backend = %backend,
strip_prefix = ?route_config.strip_prefix,
ua_filter = ?route_config.ua_filter,
"Registered route with circuit breaker"
);
}
Err(e) => {
error!(
ship = %config.name,
bind = %route_config.bind,
pattern = %route_config.pattern,
error = %e,
"Invalid route pattern"
);
}
}
}
}
crate::charter::Vessel::Bay { config, .. } => {
if config.routes.is_empty() {
continue;
}
let backend = BackendTarget::Docked {
bay_name: config.name.clone(),
};
for route_config in &config.routes {
if !binds.contains_key(&route_config.bind) {
warn!(
bay = %config.name,
bind = %route_config.bind,
"Bay route references unknown bind, skipping"
);
continue;
}
match Regex::new(&route_config.pattern) {
Ok(regex) => {
let bind_routes =
routes.entry(route_config.bind.clone()).or_default();
let breaker_name = format!("{}:{}", config.name, route_config.bind);
bind_routes.routes.push(Route {
pattern: regex,
backend: backend.clone(),
ship_name: config.name.clone(),
route_key: format!(
"{}:{}",
route_config.bind, route_config.pattern
),
circuit_breaker: Arc::new(Mutex::new(create_circuit_breaker(
&breaker_name,
))),
strip_prefix: route_config.strip_prefix.clone(),
ua_filter: route_config.ua_filter.clone(),
});
info!(
bay = %config.name,
bind = %route_config.bind,
pattern = %route_config.pattern,
backend = %backend,
ua_filter = ?route_config.ua_filter,
"Registered docked bay route"
);
}
Err(e) => {
error!(
bay = %config.name,
bind = %route_config.bind,
pattern = %route_config.pattern,
error = %e,
"Invalid route pattern"
);
}
}
}
}
}
}
let mut static_dirs = manifest.mothership.static_dirs.clone();
static_dirs.sort_by(|a, b| b.prefix.len().cmp(&a.prefix.len()));
let cors_cache = if manifest.mothership.cors_cache.is_enabled() {
let ttl = Duration::from_secs(manifest.mothership.cors_cache.default_ttl());
let max_entries = manifest.mothership.cors_cache.max_entries();
info!(
ttl_secs = ttl.as_secs(),
max_entries = max_entries,
"CORS preflight cache enabled"
);
Some(Arc::new(CorsCache::new(ttl, max_entries)))
} else {
None
};
let rate_limiter = manifest.mothership.rate_limiting.as_ref().map(|config| {
let limiter = Arc::new(RateLimiter::new(config.global_rps, config.per_ip_rpm));
match (config.global_rps, config.per_ip_rpm) {
(Some(global), Some(per_ip)) => {
info!(
global_limit = format!("{} req/s", global),
per_ip_limit = format!("{} req/min", per_ip),
"Rate limiting enabled"
);
}
(Some(global), None) => {
info!(
global_limit = format!("{} req/s", global),
"Rate limiting enabled (global only)"
);
}
(None, Some(per_ip)) => {
info!(
per_ip_limit = format!("{} req/min", per_ip),
"Rate limiting enabled (per-IP only)"
);
}
(None, None) => {
info!("Rate limiting configured but unlimited");
}
}
limiter
});
if rate_limiter.is_none() {
info!("Rate limiting disabled (unlimited)");
}
Some(Self {
binds,
routes,
module_runtime: Arc::new(RwLock::new(None)),
module_configs: manifest.modules.clone(),
static_dirs,
compression: manifest.mothership.compression,
bay_connectors: Arc::new(RwLock::new(HashMap::new())),
cors_cache,
rate_limiter,
})
}
pub async fn register_bay_connector(&self, bay_name: String, connector: Arc<DockingConnector>) {
self.bay_connectors
.write()
.await
.insert(bay_name, connector);
}
pub fn bay_connectors(&self) -> Arc<RwLock<HashMap<String, Arc<DockingConnector>>>> {
self.bay_connectors.clone()
}
pub async fn run(self, shutdown: tokio::sync::watch::Receiver<bool>) -> anyhow::Result<()> {
info!(binds = self.binds.len(), "Starting HTTP exposure layer");
if !self.module_configs.is_empty() {
let runtime = ModuleRuntime::new().map_err(|e| {
anyhow::anyhow!(
"failed to create module runtime with modules configured: {}",
e
)
})?;
runtime
.load_modules(&self.module_configs)
.await
.map_err(|e| anyhow::anyhow!("failed to load configured modules: {}", e))?;
let names = runtime.module_names().await;
if !names.is_empty() {
info!(modules = ?names, "Loaded WASM modules");
}
*self.module_runtime.write().await = Some(runtime);
}
let graceful = rama::graceful::Shutdown::default();
let module_runtime = self.module_runtime.clone();
for static_cfg in &self.static_dirs {
info!(
path = %static_cfg.path,
prefix = %static_cfg.prefix,
bind = ?static_cfg.bind,
"Static file serving enabled"
);
}
if self.compression {
info!("Response compression enabled");
}
let bay_connectors = self.bay_connectors.clone();
let metrics_registry = super::sensors::global_metrics_registry();
let cors_cache = self.cors_cache.clone();
let rate_limiter = self.rate_limiter.clone();
for (bind_name, bind_addr) in &self.binds {
let (addr, use_proxy_protocol) = match bind_addr {
Bind::Tcp {
host,
port,
proxy_protocol,
} => (format!("{}:{}", host, port), *proxy_protocol),
Bind::Unix { path } => {
warn!(bind = %bind_name, path = %path, "Unix sockets not yet supported");
continue;
}
};
let routes = self.routes.get(bind_name).cloned().unwrap_or_default();
let bind_name_clone = bind_name.clone();
let module_runtime = module_runtime.clone();
let bay_connectors = bay_connectors.clone();
let metrics_registry = metrics_registry.clone();
let cors_cache = cors_cache.clone();
let rate_limiter_clone = rate_limiter.clone();
let static_configs: Vec<StaticDirConfig> = self
.static_dirs
.iter()
.filter(|cfg| cfg.bind.as_ref().is_none_or(|b| b == bind_name))
.cloned()
.collect();
let compression_enabled = self.compression;
let proxy_protocol_enabled = use_proxy_protocol;
info!(bind = %bind_name, addr = %addr, routes = routes.routes.len(), proxy_protocol = proxy_protocol_enabled, "Starting listener");
graceful.spawn_task_fn(async move |guard| {
let bind_name = bind_name_clone;
let tcp_listener = match TcpListener::build().bind(&addr).await {
Ok(l) => {
info!(bind = %bind_name, addr = %addr, "Listener bound successfully");
l
}
Err(e) => {
error!(bind = %bind_name, addr = %addr, error = %e, "Failed to bind - check permissions for port < 1024");
return;
}
};
let exec = Executor::graceful(guard.clone());
let ctx = RequestContext {
bind_name: bind_name.clone(),
routes: Arc::new(routes),
module_runtime: module_runtime.clone(),
static_configs: Arc::new(static_configs),
bay_connectors: bay_connectors.clone(),
metrics_registry: metrics_registry.clone(),
cors_cache: cors_cache.clone(),
rate_limiter: rate_limiter_clone,
};
let core_service = service_fn(move |req: Request| {
let ctx = ctx.clone();
async move { handle_request(req, &ctx).await }
});
info!(bind = %bind_name, addr = %addr, "Serving HTTP requests");
if compression_enabled {
let http_service = HttpServer::auto(exec).service(
(
UserAgentClassifierLayer::new(),
TraceLayer::new_for_http(),
CompressionLayer::new(),
).into_layer(core_service),
);
let body_limited = BodyLimitLayer::symmetric(10 * 1024 * 1024).into_layer(http_service);
if proxy_protocol_enabled {
tcp_listener
.serve_graceful(
guard,
HaProxyLayer::new().with_peek(true).into_layer(body_limited),
)
.await;
} else {
tcp_listener
.serve_graceful(guard, body_limited)
.await;
}
} else {
let http_service = HttpServer::auto(exec).service(
(
UserAgentClassifierLayer::new(),
TraceLayer::new_for_http(),
).into_layer(core_service),
);
let body_limited = BodyLimitLayer::symmetric(10 * 1024 * 1024).into_layer(http_service);
if proxy_protocol_enabled {
tcp_listener
.serve_graceful(
guard,
HaProxyLayer::new().with_peek(true).into_layer(body_limited),
)
.await;
} else {
tcp_listener
.serve_graceful(guard, body_limited)
.await;
}
}
info!(bind = %bind_name, "Listener stopped");
});
}
let mut shutdown = shutdown;
let guard = graceful.guard();
tokio::select! {
_ = async {
loop {
shutdown.changed().await.ok();
if *shutdown.borrow() {
break;
}
}
} => {
info!("Received external shutdown signal");
}
_ = guard.cancelled() => {
info!("Received SIGINT");
}
}
drop(guard);
if let Err(e) = graceful.shutdown_with_limit(Duration::from_secs(30)).await {
warn!(error = %e, "Graceful shutdown timed out");
}
info!("HTTP exposure layer stopped");
Ok(())
}
}
#[derive(Clone)]
struct RequestContext {
bind_name: String,
routes: Arc<BindRoutes>,
module_runtime: Arc<RwLock<Option<ModuleRuntime>>>,
static_configs: Arc<Vec<StaticDirConfig>>,
bay_connectors: Arc<RwLock<HashMap<String, Arc<DockingConnector>>>>,
metrics_registry: Option<Arc<MetricsRegistry>>,
cors_cache: Option<Arc<CorsCache>>,
rate_limiter: Option<Arc<RateLimiter>>,
}
async fn handle_request(req: Request, ctx: &RequestContext) -> Result<Response, Infallible> {
use rama::http::Method;
let rate_limiter = &ctx.rate_limiter;
let static_configs = &ctx.static_configs;
if let Some(limiter) = &rate_limiter {
let client_ip = req
.extensions()
.get::<SocketInfo>()
.map(|s| s.peer_addr().ip().to_string())
.unwrap_or_else(|| "unknown".to_string());
if let Err(retry_after) = limiter.check_request(&client_ip).await {
let retry_after_secs = retry_after.ceil() as u64;
warn!(
client_ip = %client_ip,
retry_after = retry_after_secs,
"Rate limit exceeded, returning 429"
);
return Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header("Retry-After", retry_after_secs.to_string())
.header("X-RateLimit-Limit", "1000") .header("X-RateLimit-Remaining", "0")
.body(Body::from("Rate limit exceeded. Please try again later."))
.unwrap());
}
}
let path = req.uri().path().to_string();
let method = req.method();
let try_static = matches!(method, &Method::GET | &Method::HEAD);
if try_static && let Some(cfg) = static_configs.iter().find(|c| path.starts_with(&c.prefix)) {
let file_path = path.strip_prefix(&cfg.prefix).unwrap_or(&path);
let file_path = file_path.trim_start_matches('/');
let serve_dir = ServeDir::new(&cfg.path);
let (parts, body) = req.into_parts();
let mut static_parts = parts.clone();
let new_path = format!("/{}", file_path);
if let Ok(uri) = new_path.parse::<Uri>() {
static_parts.uri = uri;
}
let static_req = Request::from_parts(static_parts, Body::empty());
match serve_dir.serve(static_req).await {
Ok(resp) => {
if resp.status() != StatusCode::NOT_FOUND {
debug!(path = %path, "Served static file");
return Ok(resp);
}
debug!(path = %path, "Static file not found, falling through to routes");
}
Err(e) => {
warn!(path = %path, error = %e, "Static file error, falling through to routes");
}
}
let req = Request::from_parts(parts, body);
return reverse_proxy(req, ctx).await;
}
reverse_proxy(req, ctx).await
}
async fn reverse_proxy(mut req: Request, ctx: &RequestContext) -> Result<Response, Infallible> {
use rama::http::Method;
let bind_name = &ctx.bind_name;
let routes = &*ctx.routes;
let module_runtime = &ctx.module_runtime;
let bay_connectors = &ctx.bay_connectors;
let metrics_registry = &ctx.metrics_registry;
let cors_cache = &ctx.cors_cache;
let mut path = req.uri().path().to_string();
let method = req.method().to_string();
let query = req.uri().query().map(|s| s.to_string());
let is_options = req.method() == Method::OPTIONS;
let ua_info = req.extensions().get::<UserAgent>().map(|ua| {
let kind = ua.info().map(|i| i.kind.to_string()).unwrap_or_default();
let platform = ua.platform().map(|p| p.to_string()).unwrap_or_default();
(kind, platform)
});
let shields = Ja4H::compute(&req).ok().map(|fp| format!("{fp}"));
if let Some((kind, platform)) = &ua_info {
debug!(
method = %method,
path = %path,
ua_kind = %kind,
ua_platform = %platform,
shields = ?shields,
"Request"
);
}
let cors_cache_key = if is_options {
if let Some(cache) = cors_cache {
if let Some(key) = CorsCacheKey::from_request(&path, bind_name, req.headers()) {
if let Some(cached) = cache.get(&key) {
return Ok(cached);
}
Some(key)
} else {
None
}
} else {
None
}
} else {
None
};
let headers: std::collections::HashMap<String, String> = req
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let request_info = RequestInfo {
method: method.clone(),
path: path.clone(),
query: query.clone(),
headers: headers.clone(),
};
let modules = module_runtime.read().await;
if let Some(ref runtime) = *modules {
match runtime.process_request(&path, &request_info).await {
ModuleAction::Block {
status,
body,
headers,
} => {
return Ok(build_block_response(status, body, headers));
}
ModuleAction::Modify {
path: new_path,
headers: new_headers,
} => {
let has_mods = new_path.is_some() || new_headers.is_some();
if has_mods {
info!(new_path = ?new_path, "Module modified request");
}
if let Some(new_path) = new_path {
let (path_part, query_part) = match new_path.split_once('?') {
Some((p, q)) => (p, Some(q)),
None => (new_path.as_str(), None),
};
let normalized_path = if path_part.starts_with('/') {
path_part.to_string()
} else {
format!("/{}", path_part)
};
let path_and_query = match query_part {
Some(q) => format!("{}?{}", normalized_path, q),
None => normalized_path,
};
match path_and_query.parse::<rama::http::uri::PathAndQuery>() {
Ok(pq) => {
let mut parts = req.uri().clone().into_parts();
parts.path_and_query = Some(pq);
match Uri::from_parts(parts) {
Ok(uri) => {
*req.uri_mut() = uri;
}
Err(e) => {
warn!(error = %e, path = %path_and_query, "Invalid module-modified URI");
}
}
}
Err(e) => {
warn!(error = %e, path = %path_and_query, "Invalid module-modified path");
}
}
}
if let Some(headers_map) = new_headers {
for (key, value) in headers_map {
match (key.parse::<HeaderName>(), value.parse::<HeaderValue>()) {
(Ok(name), Ok(val)) => {
req.headers_mut().insert(name, val);
}
(Err(e), _) => {
warn!(error = %e, header = %key, "Invalid module-modified header name");
}
(_, Err(e)) => {
warn!(error = %e, header = %key, "Invalid module-modified header value");
}
}
}
}
}
ModuleAction::Continue => {}
}
}
drop(modules);
path = req.uri().path().to_string();
let ua_header = req
.headers()
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let ua_kind = req
.extensions()
.get::<UserAgent>()
.and_then(|ua| ua.info().map(|i| i.kind.to_string()));
let route = routes.routes.iter().find(|r| {
if !r.pattern.is_match(&path) {
return false;
}
if let Some(ref filter) = r.ua_filter {
filter.matches(ua_header, ua_kind.as_deref())
} else {
true
}
});
let route = match route {
Some(r) => r,
None => {
return Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("No route matched"))
.unwrap());
}
};
let backend = &route.backend;
let ship_name = &route.ship_name;
let route_key = route.route_key.clone();
let circuit_breaker = &route.circuit_breaker;
let strip_prefix = &route.strip_prefix;
if WebSocketMatcher::new().matches(None, &req) {
debug!(ship = %ship_name, path = %path, backend = %backend, "WebSocket upgrade request detected");
let resp = if let BackendTarget::Docked { bay_name } = backend {
proxy_websocket_docked(
req,
bay_name,
ship_name,
circuit_breaker,
bay_connectors.clone(),
strip_prefix,
)
.await
} else {
proxy_websocket(req, backend, ship_name, circuit_breaker, strip_prefix).await
};
if let Some(registry) = metrics_registry {
registry.record_websocket(&route_key, ship_name);
}
return resp;
}
let start = std::time::Instant::now();
let record_metrics = |elapsed: std::time::Duration| {
if let Some(registry) = metrics_registry {
registry.record_request(&route_key, ship_name, elapsed.as_millis() as u64);
}
};
let respond = |status: StatusCode, body: &'static str| -> Result<Response, Infallible> {
record_metrics(start.elapsed());
Ok(Response::builder()
.status(status)
.body(Body::from(body))
.unwrap())
};
if matches!(backend, BackendTarget::Docked { .. }) {
return respond(
StatusCode::NOT_FOUND,
"Docked backends only support WebSocket connections",
);
}
{
let breaker = circuit_breaker.lock().await;
if breaker.is_open() {
warn!(ship = %ship_name, backend = %backend, "Circuit breaker open, rejecting request");
return respond(
StatusCode::SERVICE_UNAVAILABLE,
"Service temporarily unavailable",
);
}
}
debug!(ship = %ship_name, path = %path, backend = %backend, "Routing request");
let uri = req.uri();
let original_path = uri.path();
let stripped_path = match strip_prefix {
Some(prefix) => original_path
.strip_prefix(prefix.as_str())
.unwrap_or(original_path),
None => original_path,
};
let final_path = if stripped_path.is_empty() || !stripped_path.starts_with('/') {
format!("/{}", stripped_path.trim_start_matches('/'))
} else {
stripped_path.to_string()
};
let path_and_query = format!(
"{}{}",
final_path,
uri.query().map(|q| format!("?{}", q)).unwrap_or_default()
);
let timeout = Duration::from_secs(30);
let client_ip = req
.extensions()
.get::<SocketInfo>()
.map(|s: &SocketInfo| s.peer_addr().ip().to_string());
let original_host = req
.headers()
.get(HOST)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let resp = match backend {
BackendTarget::Tcp { base_uri } => {
let new_uri = format!("{}{}", base_uri, path_and_query);
let new_uri: Uri = match new_uri.parse() {
Ok(u) => u,
Err(e) => {
error!(error = %e, uri = %new_uri, "Failed to parse backend URI");
return respond(StatusCode::INTERNAL_SERVER_ERROR, "Invalid backend URI");
}
};
let (mut parts, body) = req.into_parts();
parts.uri = new_uri;
if let Some(ref ip) = client_ip
&& let Ok(val) = ip.parse::<rama::http::HeaderValue>()
{
parts.headers.insert("x-forwarded-for", val);
if let Ok(val2) = ip.parse::<rama::http::HeaderValue>() {
parts.headers.insert("x-real-ip", val2);
}
}
if let Some(ref host) = original_host
&& let Ok(val) = host.parse::<rama::http::HeaderValue>()
{
parts.headers.insert("x-forwarded-host", val);
}
if let Ok(proto) = "http".parse::<rama::http::HeaderValue>() {
parts.headers.insert("x-forwarded-proto", proto);
}
let req = Request::from_parts(parts, body);
let client = EasyHttpWebClient::default();
match tokio::time::timeout(timeout, client.serve(req)).await {
Ok(Ok(resp)) => {
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
let breaker = circuit_breaker.lock().await;
breaker.record_success(duration_ms);
resp
}
Ok(Err(e)) => {
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
let breaker = circuit_breaker.lock().await;
breaker.record_failure(duration_ms);
drop(breaker);
error!(error = %e, backend = %backend, "Backend request failed");
return respond(StatusCode::BAD_GATEWAY, "Backend unavailable");
}
Err(_) => {
let breaker = circuit_breaker.lock().await;
breaker.record_failure(30_000.0);
drop(breaker);
error!(backend = %backend, timeout_secs = 30, "Backend request timed out");
return respond(StatusCode::GATEWAY_TIMEOUT, "Backend timeout");
}
}
}
BackendTarget::Unix { socket_path } => {
let new_uri: Uri = match format!("http://localhost{}", path_and_query).parse() {
Ok(u) => u,
Err(e) => {
error!(error = %e, "Failed to parse Unix socket URI");
return respond(StatusCode::INTERNAL_SERVER_ERROR, "Invalid URI");
}
};
let (mut parts, body) = req.into_parts();
parts.uri = new_uri;
if let Some(ref ip) = client_ip
&& let Ok(val) = ip.parse::<rama::http::HeaderValue>()
{
parts.headers.insert("x-forwarded-for", val);
if let Ok(val2) = ip.parse::<rama::http::HeaderValue>() {
parts.headers.insert("x-real-ip", val2);
}
}
if let Some(ref host) = original_host
&& let Ok(val) = host.parse::<rama::http::HeaderValue>()
{
parts.headers.insert("x-forwarded-host", val);
}
if let Ok(proto) = "http".parse::<rama::http::HeaderValue>() {
parts.headers.insert("x-forwarded-proto", proto);
}
let request = Request::from_parts(parts, body);
let connector = HttpConnector::new(UnixConnector::fixed(socket_path));
let connect_result = tokio::time::timeout(timeout, connector.connect(request)).await;
match connect_result {
Ok(Ok(EstablishedClientConnection { conn, input, .. })) => {
match tokio::time::timeout(timeout, conn.serve(input)).await {
Ok(Ok(resp)) => {
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
let breaker = circuit_breaker.lock().await;
breaker.record_success(duration_ms);
resp
}
Ok(Err(e)) => {
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
let breaker = circuit_breaker.lock().await;
breaker.record_failure(duration_ms);
drop(breaker);
error!(error = %e, backend = %backend, "Unix socket request failed");
return respond(StatusCode::BAD_GATEWAY, "Backend unavailable");
}
Err(_) => {
let breaker = circuit_breaker.lock().await;
breaker.record_failure(30_000.0);
drop(breaker);
error!(backend = %backend, timeout_secs = 30, "Unix socket request timed out");
return respond(StatusCode::GATEWAY_TIMEOUT, "Backend timeout");
}
}
}
Ok(Err(e)) => {
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
let breaker = circuit_breaker.lock().await;
breaker.record_failure(duration_ms);
drop(breaker);
error!(error = %e, backend = %backend, "Failed to connect to Unix socket");
return respond(StatusCode::BAD_GATEWAY, "Backend unavailable");
}
Err(_) => {
let breaker = circuit_breaker.lock().await;
breaker.record_failure(30_000.0);
drop(breaker);
error!(backend = %backend, timeout_secs = 30, "Unix socket connect timed out");
return respond(StatusCode::GATEWAY_TIMEOUT, "Backend timeout");
}
}
}
BackendTarget::Docked { bay_name } => {
error!(bay = %bay_name, path = %path_and_query, "HTTP request to docked bay - only WebSocket is supported");
return respond(
StatusCode::NOT_IMPLEMENTED,
"This endpoint only supports WebSocket connections",
);
}
};
let response_headers: std::collections::HashMap<String, String> = resp
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let response_info = ResponseInfo {
status: resp.status().as_u16(),
headers: response_headers,
};
let modules = module_runtime.read().await;
let resp = if let Some(ref runtime) = *modules {
match runtime.process_response(&path, &response_info).await {
ModuleAction::Block {
status,
body,
headers,
} => {
record_metrics(start.elapsed());
return Ok(build_block_response(status, body, headers));
}
ModuleAction::Modify {
headers: new_headers,
..
} => {
if let Some(headers_map) = new_headers {
let (mut parts, body) = resp.into_parts();
for (key, value) in headers_map {
if let (Ok(name), Ok(val)) = (key.parse::<HeaderName>(), value.parse()) {
parts.headers.insert(name, val);
}
}
Response::from_parts(parts, body)
} else {
resp
}
}
ModuleAction::Continue => resp,
}
} else {
resp
};
if let Some(key) = cors_cache_key
&& resp.status().is_success()
&& let Some(cache) = cors_cache
{
cache.insert(key, &resp);
}
record_metrics(start.elapsed());
Ok(resp)
}
async fn proxy_websocket(
req: Request,
backend: &BackendTarget,
ship_name: &str,
circuit_breaker: &Arc<Mutex<CircuitBreaker>>,
strip_prefix: &Option<String>,
) -> Result<Response, Infallible> {
let path = req.uri().path().to_string();
let (parts, body) = req.into_parts();
let parts_clone = parts.clone();
let mut req = Request::from_parts(parts, body);
let executor = req
.extensions()
.get::<Executor>()
.cloned()
.unwrap_or_default();
let egress_socket = match backend {
BackendTarget::Tcp { base_uri } => {
let original_path = req.uri().path();
let stripped_path = match strip_prefix {
Some(prefix) => original_path
.strip_prefix(prefix.as_str())
.unwrap_or(original_path),
None => original_path,
};
let final_path = if stripped_path.is_empty() || !stripped_path.starts_with('/') {
format!("/{}", stripped_path.trim_start_matches('/'))
} else {
stripped_path.to_string()
};
let path_and_query = format!(
"{}{}",
final_path,
req.uri()
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default()
);
let backend_uri: Uri = match format!("{}{}", base_uri, path_and_query).parse() {
Ok(uri) => uri,
Err(e) => {
error!(error = %e, "Failed to parse WebSocket backend URI");
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Invalid backend URI"))
.unwrap());
}
};
*req.uri_mut() = backend_uri;
let client = EasyHttpWebClient::default();
let handshake = match client
.websocket_with_request(req)
.initiate_handshake(rama::extensions::Extensions::default())
.await
{
Ok(hs) => hs,
Err(e) => {
error!(
ship = %ship_name,
path = %path,
error = %e,
"Failed to initiate WebSocket handshake with TCP backend"
);
let breaker = circuit_breaker.lock().await;
breaker.record_failure(0.0);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("Backend WebSocket handshake failed"))
.unwrap());
}
};
match handshake.complete().await {
Ok(socket) => socket,
Err(e) => {
error!(
ship = %ship_name,
path = %path,
error = %e,
"Failed to complete WebSocket handshake with TCP backend"
);
let breaker = circuit_breaker.lock().await;
breaker.record_failure(0.0);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("Backend WebSocket connection failed"))
.unwrap());
}
}
}
BackendTarget::Unix { socket_path } => {
let original_path = req.uri().path();
let stripped_path = match strip_prefix {
Some(prefix) => original_path
.strip_prefix(prefix.as_str())
.unwrap_or(original_path),
None => original_path,
};
let final_path = if stripped_path.is_empty() || !stripped_path.starts_with('/') {
format!("/{}", stripped_path.trim_start_matches('/'))
} else {
stripped_path.to_string()
};
let path_and_query = format!(
"{}{}",
final_path,
req.uri()
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default()
);
let backend_uri: Uri = match format!("http://localhost{}", path_and_query).parse() {
Ok(uri) => uri,
Err(e) => {
error!(error = %e, "Failed to parse Unix WebSocket URI");
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Invalid backend URI"))
.unwrap());
}
};
*req.uri_mut() = backend_uri;
let connector = HttpConnector::new(UnixConnector::fixed(socket_path));
let EstablishedClientConnection { conn: client, .. } = match connector.serve(req).await
{
Ok(established) => established,
Err(e) => {
error!(
ship = %ship_name,
path = %path,
socket = %socket_path,
error = %e,
"Failed to connect to Unix socket for WebSocket"
);
let breaker = circuit_breaker.lock().await;
breaker.record_failure(0.0);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("Backend connection failed"))
.unwrap());
}
};
let ws_req = Request::builder()
.uri(format!("http://localhost{}", path_and_query))
.body(Body::empty())
.unwrap();
let handshake = match client
.websocket_with_request(ws_req)
.initiate_handshake(rama::extensions::Extensions::default())
.await
{
Ok(hs) => hs,
Err(e) => {
error!(
ship = %ship_name,
path = %path,
socket = %socket_path,
error = %e,
"Failed to initiate WebSocket handshake with Unix backend"
);
let breaker = circuit_breaker.lock().await;
breaker.record_failure(0.0);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("Backend WebSocket handshake failed"))
.unwrap());
}
};
match handshake.complete().await {
Ok(socket) => socket,
Err(e) => {
error!(
ship = %ship_name,
path = %path,
socket = %socket_path,
error = %e,
"Failed to complete WebSocket handshake with Unix backend"
);
let breaker = circuit_breaker.lock().await;
breaker.record_failure(0.0);
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("Backend WebSocket connection failed"))
.unwrap());
}
}
}
BackendTarget::Docked { .. } => {
error!(ship = %ship_name, path = %path, "Unexpected Docked backend in proxy_websocket");
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Internal routing error"))
.unwrap());
}
};
{
let breaker = circuit_breaker.lock().await;
breaker.record_success(0.0);
}
let (egress_socket, response_parts, _) = egress_socket.into_parts();
let response = Response::from_parts(response_parts, Body::empty());
let ship_name_clone = ship_name.to_string();
let path_clone = path.clone();
executor.spawn_task(async move {
let ship_name = ship_name_clone;
debug!(
ship = %ship_name,
path = %path_clone,
"WebSocket backend connected, waiting for client upgrade"
);
let request = Request::from_parts(parts_clone, Body::empty());
let ingress_socket = match upgrade::handle_upgrade(&request).await {
Ok(upgraded) => AsyncWebSocket::from_raw_socket(upgraded, Role::Server, None).await,
Err(e) => {
error!(
ship = %ship_name,
path = %path_clone,
error = %e,
"Failed to upgrade client WebSocket connection"
);
return;
}
};
debug!(
ship = %ship_name,
path = %path_clone,
"WebSocket relay started between client and backend"
);
relay_websockets(ingress_socket, egress_socket, &ship_name, &path_clone).await;
debug!(
ship = %ship_name,
path = %path_clone,
"WebSocket connection closed"
);
});
info!(ship = %ship_name, path = %path, "WebSocket upgrade initiated");
Ok(response)
}
async fn relay_websockets(
mut ingress: AsyncWebSocket,
mut egress: AsyncWebSocket,
ship_name: &str,
path: &str,
) {
use rama::futures::SinkExt;
loop {
tokio::select! {
result = ingress.recv_message() => {
match result {
Ok(msg) => {
if let Err(e) = egress.send(msg).await {
if e.is_connection_error() {
debug!(ship = %ship_name, path = %path, "Backend disconnected");
return;
}
error!(ship = %ship_name, path = %path, error = %e, "Failed to send to backend");
}
}
Err(e) => {
if e.is_connection_error() || matches!(e, rama::http::ws::ProtocolError::ResetWithoutClosingHandshake) {
debug!(ship = %ship_name, path = %path, "Client disconnected");
} else {
error!(ship = %ship_name, path = %path, error = %e, "Client WebSocket error");
}
return;
}
}
}
result = egress.recv_message() => {
match result {
Ok(msg) => {
if let Err(e) = ingress.send(msg).await {
if e.is_connection_error() {
debug!(ship = %ship_name, path = %path, "Client disconnected");
return;
}
error!(ship = %ship_name, path = %path, error = %e, "Failed to send to client");
}
}
Err(e) => {
if e.is_connection_error() || matches!(e, rama::http::ws::ProtocolError::ResetWithoutClosingHandshake) {
debug!(ship = %ship_name, path = %path, "Backend disconnected");
} else {
error!(ship = %ship_name, path = %path, error = %e, "Backend WebSocket error");
}
return;
}
}
}
}
}
}
async fn proxy_websocket_docked(
req: Request,
bay_name: &str,
ship_name: &str,
circuit_breaker: &Arc<Mutex<CircuitBreaker>>,
bay_connectors: Arc<RwLock<HashMap<String, Arc<DockingConnector>>>>,
strip_prefix: &Option<String>,
) -> Result<Response, Infallible> {
let path = req.uri().path().to_string();
let connector = {
let connectors = bay_connectors.read().await;
connectors.get(bay_name).cloned()
};
let connector = match connector {
Some(c) => c,
None => {
error!(bay = %bay_name, "Bay connector not found");
let breaker = circuit_breaker.lock().await;
breaker.record_failure(0.0);
return Ok(Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Body::from("Bay not docked"))
.unwrap());
}
};
let conn_id = next_conn_id();
let mut cargo_rx = connector.register_connection(conn_id).await;
let original_path = req.uri().path();
let stripped_path = match strip_prefix {
Some(prefix) => original_path
.strip_prefix(prefix.as_str())
.unwrap_or(original_path),
None => original_path,
};
let final_path = if stripped_path.is_empty() || !stripped_path.starts_with('/') {
format!("/{}", stripped_path.trim_start_matches('/'))
} else {
stripped_path.to_string()
};
let path_with_query = format!(
"{}{}",
final_path,
req.uri()
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default()
);
let remote_addr = req
.extensions()
.get::<SocketInfo>()
.map(|s| s.peer_addr().to_string())
.unwrap_or_else(|| "unknown".to_string());
let headers: HashMap<String, String> = req
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let (parts, body) = req.into_parts();
let parts_clone = parts.clone();
let req = Request::from_parts(parts, body);
let executor = req
.extensions()
.get::<Executor>()
.cloned()
.unwrap_or_default();
let response = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header(
"Sec-WebSocket-Accept",
compute_websocket_accept_key(parts_clone.headers.get("Sec-WebSocket-Key")),
)
.body(Body::empty())
.unwrap();
let bay_name_clone = bay_name.to_string();
let ship_name_clone = ship_name.to_string();
let path_clone = path_with_query.clone();
let connector_clone = connector.clone();
let circuit_breaker_clone = circuit_breaker.clone();
executor.spawn_task(async move {
let bay_name = bay_name_clone;
let ship_name = ship_name_clone;
let connector = connector_clone;
let circuit_breaker = circuit_breaker_clone;
debug!(
bay = %bay_name,
conn_id = conn_id,
path = %path_clone,
"Waiting for client WebSocket upgrade"
);
let request = Request::from_parts(parts_clone, Body::empty());
let ingress_socket = match upgrade::handle_upgrade(&request).await {
Ok(upgraded) => AsyncWebSocket::from_raw_socket(upgraded, Role::Server, None).await,
Err(e) => {
error!(
bay = %bay_name,
conn_id = conn_id,
error = %e,
"Failed to upgrade client WebSocket connection"
);
connector.unregister_connection(conn_id).await;
return;
}
};
let boarding = Boarding {
conn_id,
path: path_clone.clone(),
remote_addr,
headers,
};
if let Err(e) = connector.send_boarding(boarding).await {
error!(
bay = %bay_name,
conn_id = conn_id,
error = %e,
"Failed to send boarding message"
);
connector.unregister_connection(conn_id).await;
return;
}
info!(
bay = %bay_name,
ship = %ship_name,
conn_id = conn_id,
path = %path_clone,
"WebSocket connection boarded via docking protocol"
);
{
let breaker = circuit_breaker.lock().await;
breaker.record_success(0.0);
}
relay_websocket_docked(
ingress_socket,
&connector,
conn_id,
&mut cargo_rx,
&bay_name,
&path_clone,
)
.await;
let disembark = Disembark {
conn_id,
code: 1000,
reason: "normal".to_string(),
};
let _ = connector.send_disembark(disembark).await;
connector.unregister_connection(conn_id).await;
debug!(
bay = %bay_name,
conn_id = conn_id,
path = %path_clone,
"WebSocket connection closed"
);
});
info!(bay = %bay_name, ship = %ship_name, path = %path, conn_id = conn_id, "WebSocket upgrade initiated via docking");
Ok(response)
}
fn compute_websocket_accept_key(key: Option<&rama::http::HeaderValue>) -> String {
use sha1::{Digest, Sha1};
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let key = key.and_then(|v| v.to_str().ok()).unwrap_or("");
let mut hasher = Sha1::new();
hasher.update(key.as_bytes());
hasher.update(WEBSOCKET_GUID.as_bytes());
let result = hasher.finalize();
use base64::Engine;
base64::engine::general_purpose::STANDARD.encode(result)
}
async fn relay_websocket_docked(
mut ingress: AsyncWebSocket,
connector: &Arc<DockingConnector>,
conn_id: u32,
cargo_rx: &mut tokio::sync::mpsc::Receiver<Vec<u8>>,
bay_name: &str,
path: &str,
) {
use rama::futures::SinkExt;
use rama::http::ws::Message;
loop {
tokio::select! {
result = ingress.recv_message() => {
match result {
Ok(msg) => {
let data = match &msg {
Message::Text(text) => text.as_bytes().to_vec(),
Message::Binary(data) => data.to_vec(),
Message::Ping(data) => {
let _ = ingress.send(Message::Pong(data.clone())).await;
continue;
}
Message::Pong(_) => continue,
Message::Close(_) => {
debug!(bay = %bay_name, conn_id = conn_id, "Client sent close frame");
return;
}
Message::Frame(_) => {
continue;
}
};
let cargo = Cargo { conn_id, data };
if let Err(e) = connector.send_cargo(cargo).await {
error!(bay = %bay_name, conn_id = conn_id, error = %e, "Failed to send cargo to bay");
return;
}
}
Err(e) => {
if e.is_connection_error() || matches!(e, rama::http::ws::ProtocolError::ResetWithoutClosingHandshake) {
debug!(bay = %bay_name, conn_id = conn_id, path = %path, "Client disconnected");
} else {
error!(bay = %bay_name, conn_id = conn_id, path = %path, error = %e, "Client WebSocket error");
}
return;
}
}
}
result = cargo_rx.recv() => {
match result {
Some(data) => {
let msg = if data.iter().all(|&b| b.is_ascii()) {
Message::Text(String::from_utf8_lossy(&data).to_string().into())
} else {
Message::Binary(data.into())
};
if let Err(e) = ingress.send(msg).await {
if e.is_connection_error() {
debug!(bay = %bay_name, conn_id = conn_id, path = %path, "Client disconnected");
return;
}
error!(bay = %bay_name, conn_id = conn_id, path = %path, error = %e, "Failed to send to client");
}
}
None => {
debug!(bay = %bay_name, conn_id = conn_id, path = %path, "Bay cargo channel closed");
return;
}
}
}
}
}
}