mod proxy;
mod websocket;
use actix_web::dev::ServiceRequest;
use actix_web::error::ErrorInternalServerError;
use actix_web::http::{Uri, header, uri::PathAndQuery};
use actix_web::{Error, HttpMessage, HttpResponse, web};
use tracing::{debug, warn};
use crate::AppState;
use crate::api::headers::request_context::{
ResolvedAthenaClient, ResolvedAthenaClientSource, set_disallow_jdbc_routing,
};
use crate::api::response::{forbidden, service_unavailable};
use crate::data::service_routes::{PublicServiceRouteRecord, get_public_service_route_by_key};
const DEFAULT_DB_TARGET_URL_ENV: &str = "ATHENA_ROUTER_DEFAULT_DB_TARGET_URL";
const DEFAULT_AUTH_TARGET_URL_ENV: &str = "ATHENA_ROUTER_DEFAULT_AUTH_TARGET_URL";
const DEFAULT_STORAGE_TARGET_URL_ENV: &str = "ATHENA_ROUTER_DEFAULT_STORAGE_TARGET_URL";
const DEFAULT_REALTIME_TARGET_URL_ENV: &str = "ATHENA_ROUTER_DEFAULT_REALTIME_TARGET_URL";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum ServiceKey {
Auth,
Storage,
Db,
Realtime,
}
impl ServiceKey {
fn parse(segment: &str) -> Option<Self> {
match segment.trim().to_ascii_lowercase().as_str() {
"auth" => Some(Self::Auth),
"storage" => Some(Self::Storage),
"db" => Some(Self::Db),
"realtime" | "real-time" => Some(Self::Realtime),
_ => None,
}
}
fn as_str(self) -> &'static str {
match self {
Self::Auth => "auth",
Self::Storage => "storage",
Self::Db => "db",
Self::Realtime => "realtime",
}
}
fn default_local_path_prefix(self) -> Option<&'static str> {
match self {
Self::Storage => Some("/storage"),
Self::Db => Some("/gateway"),
Self::Realtime => Some("/wss"),
Self::Auth => None,
}
}
fn default_upstream_env_var(self) -> Option<&'static str> {
match self {
Self::Db => Some(DEFAULT_DB_TARGET_URL_ENV),
Self::Auth => Some(DEFAULT_AUTH_TARGET_URL_ENV),
Self::Storage => Some(DEFAULT_STORAGE_TARGET_URL_ENV),
Self::Realtime => Some(DEFAULT_REALTIME_TARGET_URL_ENV),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ServiceRouteTargetKind {
LocalPath,
HttpUpstream,
}
impl ServiceRouteTargetKind {
fn parse(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"local_path" => Some(Self::LocalPath),
"http_upstream" => Some(Self::HttpUpstream),
_ => None,
}
}
}
enum ServiceRouteAction {
Disabled,
LocalPath { prefix: String },
HttpUpstream { target_url: String },
}
pub fn services(cfg: &mut web::ServiceConfig) {
websocket::services(cfg);
}
pub async fn maybe_route_service_request(
req: &mut ServiceRequest,
) -> Result<Option<HttpResponse>, Error> {
if req.path() == websocket::INTERNAL_REALTIME_WEBSOCKET_PROXY_PATH {
return Ok(None);
}
let Some((service_key, tail_path)) =
resolve_service_request_path(req.path(), request_host_service_key(req).as_deref())
else {
return Ok(None);
};
let Some(action) = (match resolve_service_route_action(req, service_key).await {
Ok(action) => action,
Err(response) => return Ok(Some(response)),
}) else {
return Ok(Some(unconfigured_service_response(req, service_key)));
};
set_disallow_jdbc_routing(req.request());
if service_key == ServiceKey::Realtime && websocket_upgrade_requested(req) {
return match action {
ServiceRouteAction::Disabled => Ok(Some(forbidden(
"Service route inactive",
format!(
"The '{}' service route is explicitly disabled for this tenant.",
service_key.as_str()
),
))),
ServiceRouteAction::LocalPath { .. } => Ok(Some(service_unavailable(
"Realtime websocket upstream unavailable",
"Realtime HTTP discovery can use the local /wss scaffold, but websocket upgrades under /realtime require ATHENA_ROUTER_DEFAULT_REALTIME_TARGET_URL or an active public_service_routes target_url.",
))),
ServiceRouteAction::HttpUpstream { target_url } => {
websocket::prepare_realtime_websocket_proxy(
req,
&target_url,
&tail_path,
resolved_route_key(req).as_deref(),
resolved_client_name(req).as_deref(),
)?;
rewrite_request_path(req, websocket::INTERNAL_REALTIME_WEBSOCKET_PROXY_PATH)?;
Ok(None)
}
};
}
match action {
ServiceRouteAction::Disabled => Ok(Some(forbidden(
"Service route inactive",
format!(
"The '{}' service route is explicitly disabled for this tenant.",
service_key.as_str()
),
))),
ServiceRouteAction::LocalPath { prefix } => {
let rewritten_path = rewrite_local_service_path(service_key, &prefix, &tail_path);
if rewritten_path != req.path() {
debug!(
service_key = %service_key.as_str(),
original_path = %req.path(),
rewritten_path = %rewritten_path,
"Rewriting service route to local Athena path"
);
rewrite_request_path(req, &rewritten_path)?;
}
Ok(None)
}
ServiceRouteAction::HttpUpstream { target_url } => {
let Some(app_state) = req.app_data::<web::Data<AppState>>().cloned() else {
return Ok(Some(service_unavailable(
"Service router unavailable",
"AppState is not registered on this request.",
)));
};
let body = req.extract::<web::Bytes>().await?;
let route_key = resolved_route_key(req);
let client_name = resolved_client_name(req);
let response = proxy::proxy_service_request(
req.request(),
app_state.get_ref(),
&target_url,
&normalized_service_tail(service_key, &tail_path),
service_key.as_str(),
route_key.as_deref(),
client_name.as_deref(),
body,
)
.await;
Ok(Some(response))
}
}
}
fn split_service_path(path: &str) -> Option<(ServiceKey, String)> {
let trimmed = path.trim_start_matches('/');
let mut parts = trimmed.splitn(2, '/');
let service_segment = parts.next()?;
let service_key = ServiceKey::parse(service_segment)?;
let tail_path = parts.next().unwrap_or_default().to_string();
Some((service_key, tail_path))
}
fn resolve_service_request_path(
path: &str,
host_service_key: Option<&str>,
) -> Option<(ServiceKey, String)> {
if let Some(path_route) = split_service_path(path) {
return Some(path_route);
}
let service_key = host_service_key.and_then(ServiceKey::parse)?;
Some((service_key, path.trim_start_matches('/').to_string()))
}
fn request_host_service_key(req: &ServiceRequest) -> Option<String> {
crate::api::host_routing::resolve_public_host_route(req.headers())?.service_key
}
fn join_rewritten_path(prefix: &str, tail_path: &str) -> String {
let trimmed_prefix = prefix.trim_end_matches('/');
let trimmed_tail = tail_path.trim_matches('/');
if trimmed_tail.is_empty() {
return if trimmed_prefix.is_empty() {
"/".to_string()
} else {
trimmed_prefix.to_string()
};
}
if trimmed_prefix.is_empty() {
format!("/{trimmed_tail}")
} else {
format!("{trimmed_prefix}/{trimmed_tail}")
}
}
fn join_rewritten_path_without_duplicate_prefix_segment(prefix: &str, tail_path: &str) -> String {
let trimmed_tail = tail_path.trim_matches('/');
let prefix_segment = prefix
.trim_matches('/')
.rsplit('/')
.next()
.unwrap_or_default()
.trim();
if let Some((segment, rest)) = first_path_segment(trimmed_tail)
&& !prefix_segment.is_empty()
&& segment == prefix_segment
{
return join_rewritten_path(prefix, rest);
}
join_rewritten_path(prefix, trimmed_tail)
}
fn first_path_segment(path: &str) -> Option<(&str, &str)> {
let trimmed = path.trim_matches('/');
if trimmed.is_empty() {
return None;
}
match trimmed.split_once('/') {
Some((segment, rest)) => Some((segment, rest)),
None => Some((trimmed, "")),
}
}
fn rewrite_local_service_path(service_key: ServiceKey, prefix: &str, tail_path: &str) -> String {
match service_key {
ServiceKey::Db => rewrite_db_local_path(prefix, tail_path),
ServiceKey::Realtime => rewrite_realtime_local_path(prefix, tail_path),
ServiceKey::Auth | ServiceKey::Storage => join_rewritten_path(prefix, tail_path),
}
}
fn normalized_service_tail(service_key: ServiceKey, tail_path: &str) -> String {
let trimmed_tail = tail_path.trim_matches('/');
match service_key {
ServiceKey::Db => normalize_db_service_tail(trimmed_tail),
ServiceKey::Realtime => normalize_realtime_service_tail(trimmed_tail),
ServiceKey::Auth | ServiceKey::Storage => trimmed_tail.to_string(),
}
}
fn normalize_db_service_tail(trimmed_tail: &str) -> String {
if trimmed_tail.is_empty() {
return "gateway".to_string();
}
if let Some((segment, _)) = first_path_segment(trimmed_tail)
&& matches!(segment, "gateway" | "query" | "rest")
{
return trimmed_tail.to_string();
}
format!("gateway/{trimmed_tail}")
}
fn rewrite_db_local_path(prefix: &str, tail_path: &str) -> String {
let normalized_tail = normalize_db_service_tail(tail_path.trim_matches('/'));
if let Some((segment, _)) = first_path_segment(&normalized_tail)
&& matches!(segment, "query" | "rest")
{
return format!("/{normalized_tail}");
}
join_rewritten_path_without_duplicate_prefix_segment(prefix, &normalized_tail)
}
fn normalize_realtime_service_tail(trimmed_tail: &str) -> String {
match trimmed_tail {
"" => "info".to_string(),
"socket" => "gateway".to_string(),
_ => trimmed_tail.to_string(),
}
}
fn rewrite_realtime_local_path(prefix: &str, tail_path: &str) -> String {
join_rewritten_path(
prefix,
&normalize_realtime_service_tail(tail_path.trim_matches('/')),
)
}
fn env_target_url(env_key: &str) -> Option<String> {
std::env::var(env_key)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
async fn resolve_service_route_action(
req: &ServiceRequest,
service_key: ServiceKey,
) -> Result<Option<ServiceRouteAction>, HttpResponse> {
if let Some(route_key) = resolved_route_key(req) {
match lookup_service_route_record(req, &route_key, service_key).await {
Some(Ok(record)) => match route_record_to_action(record, service_key) {
Ok(action) => return Ok(Some(action)),
Err(err) => return Err(service_unavailable("Invalid service route", err)),
},
Some(Err(err)) => {
warn!(
route_key = %route_key,
service_key = %service_key.as_str(),
error = %err,
"Public service route lookup failed; falling back to built-in routing"
);
}
None => {}
}
}
if let Some(env_key) = service_key.default_upstream_env_var() {
if let Some(target_url) = env_target_url(env_key) {
return Ok(Some(ServiceRouteAction::HttpUpstream { target_url }));
}
}
Ok(service_key
.default_local_path_prefix()
.map(|prefix| ServiceRouteAction::LocalPath {
prefix: prefix.to_string(),
}))
}
fn unconfigured_service_response(req: &ServiceRequest, service_key: ServiceKey) -> HttpResponse {
let route_key = resolved_route_key(req).unwrap_or_else(|| "default".to_string());
if let Some(env_key) = service_key.default_upstream_env_var() {
return service_unavailable(
"Service route unconfigured",
format!(
"No active route exists for service '{}' on route '{}' and {} is unset.",
service_key.as_str(),
route_key,
env_key
),
);
}
service_unavailable(
"Service route unconfigured",
format!(
"No active route exists for service '{}' on route '{}'.",
service_key.as_str(),
route_key
),
)
}
async fn lookup_service_route_record(
req: &ServiceRequest,
route_key: &str,
service_key: ServiceKey,
) -> Option<Result<PublicServiceRouteRecord, String>> {
let app_state = req.app_data::<web::Data<AppState>>()?;
let logging_client_name = app_state.logging_client_name.as_ref()?;
let pool = app_state.pg_registry.get_pool(logging_client_name)?;
match get_public_service_route_by_key(&pool, route_key, service_key.as_str()).await {
Ok(Some(record)) => Some(Ok(record)),
Ok(None) => None,
Err(err) => Some(Err(err.to_string())),
}
}
fn route_record_to_action(
record: PublicServiceRouteRecord,
service_key: ServiceKey,
) -> Result<ServiceRouteAction, String> {
if !record.is_active {
return Ok(ServiceRouteAction::Disabled);
}
let Some(target_kind) = ServiceRouteTargetKind::parse(&record.target_kind) else {
return Err(format!(
"service route '{}' for '{}' has unsupported target_kind '{}'",
record.route_key,
service_key.as_str(),
record.target_kind
));
};
match target_kind {
ServiceRouteTargetKind::LocalPath => {
let prefix = record
.local_path_prefix
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
format!(
"service route '{}' for '{}' is missing local_path_prefix",
record.route_key,
service_key.as_str()
)
})?;
Ok(ServiceRouteAction::LocalPath {
prefix: prefix.to_string(),
})
}
ServiceRouteTargetKind::HttpUpstream => {
let target_url = record
.target_url
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
format!(
"service route '{}' for '{}' is missing target_url",
record.route_key,
service_key.as_str()
)
})?;
Ok(ServiceRouteAction::HttpUpstream {
target_url: target_url.to_string(),
})
}
}
}
fn resolved_route_key(req: &ServiceRequest) -> Option<String> {
req.extensions()
.get::<ResolvedAthenaClientSource>()
.and_then(|source| source.route_key.clone())
.or_else(|| resolved_client_name(req))
}
fn resolved_client_name(req: &ServiceRequest) -> Option<String> {
req.extensions()
.get::<ResolvedAthenaClient>()
.map(|value| value.0.clone())
.or_else(|| {
req.headers()
.get("X-Athena-Client")
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string)
})
}
fn websocket_upgrade_requested(req: &ServiceRequest) -> bool {
let has_upgrade = req
.headers()
.get(header::UPGRADE)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.eq_ignore_ascii_case("websocket"));
let connection_has_upgrade = req
.headers()
.get(header::CONNECTION)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| {
value
.split(',')
.any(|segment| segment.trim().eq_ignore_ascii_case("upgrade"))
});
has_upgrade || connection_has_upgrade
}
fn rewrite_request_path(req: &mut ServiceRequest, new_path: &str) -> Result<(), Error> {
let mut parts = req.head().uri.clone().into_parts();
let query = parts
.path_and_query
.as_ref()
.and_then(|value| value.query())
.map(str::to_string);
let path_and_query = match query {
Some(query) => format!("{new_path}?{query}"),
None => new_path.to_string(),
};
parts.path_and_query = Some(
path_and_query
.parse::<PathAndQuery>()
.map_err(ErrorInternalServerError)?,
);
let uri = Uri::from_parts(parts).map_err(ErrorInternalServerError)?;
req.match_info_mut().get_mut().update(&uri);
req.head_mut().uri = uri;
Ok(())
}
#[cfg(test)]
mod tests {
use super::{
ServiceKey, join_rewritten_path, normalized_service_tail, resolve_service_request_path,
rewrite_db_local_path, rewrite_local_service_path, rewrite_realtime_local_path,
split_service_path,
};
#[test]
fn parses_db_service_path() {
let (service_key, tail_path) = split_service_path("/db/query").expect("service route");
assert_eq!(service_key, ServiceKey::Db);
assert_eq!(tail_path, "query");
}
#[test]
fn normalizes_realtime_alias() {
let (service_key, tail_path) =
split_service_path("/real-time/info").expect("service route");
assert_eq!(service_key, ServiceKey::Realtime);
assert_eq!(tail_path, "info");
}
#[test]
fn joins_rewritten_path_for_db_alias() {
assert_eq!(join_rewritten_path("/gateway", "query"), "/gateway/query");
assert_eq!(join_rewritten_path("/gateway", ""), "/gateway");
}
#[test]
fn keeps_canonical_db_gateway_suffix_stable() {
assert_eq!(
rewrite_db_local_path("/gateway", "gateway/fetch"),
"/gateway/fetch"
);
}
#[test]
fn maps_short_db_fetch_suffix_to_gateway_surface() {
assert_eq!(rewrite_db_local_path("/gateway", "fetch"), "/gateway/fetch");
}
#[test]
fn preserves_existing_rest_surface_under_db_prefix() {
assert_eq!(
rewrite_db_local_path("/gateway", "rest/v1/users"),
"/rest/v1/users"
);
}
#[test]
fn rewrites_realtime_info_to_local_wss_info() {
assert_eq!(rewrite_realtime_local_path("/wss", "info"), "/wss/info");
assert_eq!(rewrite_realtime_local_path("/wss", ""), "/wss/info");
}
#[test]
fn rewrites_realtime_socket_to_gateway_alias() {
assert_eq!(
rewrite_local_service_path(ServiceKey::Realtime, "/wss", "socket"),
"/wss/gateway"
);
}
#[test]
fn normalizes_db_tail_for_upstream_routing() {
assert_eq!(
normalized_service_tail(ServiceKey::Db, "gateway/fetch"),
"gateway/fetch"
);
assert_eq!(
normalized_service_tail(ServiceKey::Db, "fetch"),
"gateway/fetch"
);
assert_eq!(
normalized_service_tail(ServiceKey::Db, "rest/v1/users"),
"rest/v1/users"
);
}
#[test]
fn normalizes_realtime_tail_for_upstream_routing() {
assert_eq!(normalized_service_tail(ServiceKey::Realtime, ""), "info");
assert_eq!(
normalized_service_tail(ServiceKey::Realtime, "socket"),
"gateway"
);
}
#[test]
fn host_service_alias_uses_full_request_path() {
let (service_key, tail_path) =
resolve_service_request_path("/gateway/fetch", Some("db")).expect("service route");
assert_eq!(service_key, ServiceKey::Db);
assert_eq!(tail_path, "gateway/fetch");
}
}