use actix_web::http::{StatusCode, header as actix_header};
use actix_web::{HttpRequest, HttpResponse, web};
use reqwest::Url;
use tracing::warn;
use crate::AppState;
use crate::api::response::{bad_gateway, service_unavailable};
pub(super) const PROXY_HOP_HEADER: &str = "x-athena-proxy-hop";
pub(super) const PROXY_ORIGIN_HOST_HEADER: &str = "x-athena-origin-host";
pub(super) const ROUTE_KEY_HEADER: &str = "x-athena-route-key";
pub(super) const SERVICE_KEY_HEADER: &str = "x-athena-service-key";
pub(super) fn build_service_target_url(
target_url: &str,
request_tail_path: &str,
request_query: &str,
) -> Result<Url, String> {
let trimmed = target_url.trim();
if trimmed.is_empty() {
return Err("service route target URL is empty".to_string());
}
let mut parsed =
Url::parse(trimmed).map_err(|err| format!("service route target URL is invalid: {err}"))?;
if !matches!(parsed.scheme(), "http" | "https") || parsed.host_str().is_none() {
return Err(
"service route target URL must include an http or https scheme and host".to_string(),
);
}
let base_path = parsed.path().trim_end_matches('/');
let request_tail_path = request_tail_path.trim_matches('/');
let next_path = if request_tail_path.is_empty() {
if base_path.is_empty() {
"/".to_string()
} else {
base_path.to_string()
}
} else if base_path.is_empty() {
format!("/{request_tail_path}")
} else {
format!("{base_path}/{request_tail_path}")
};
parsed.set_path(&next_path);
let mut query_parts = Vec::new();
if let Some(existing_query) = parsed.query().filter(|value| !value.is_empty()) {
query_parts.push(existing_query.to_string());
}
if !request_query.trim().is_empty() {
query_parts.push(request_query.to_string());
}
if query_parts.is_empty() {
parsed.set_query(None);
} else {
parsed.set_query(Some(&query_parts.join("&")));
}
Ok(parsed)
}
pub(super) fn build_realtime_websocket_target_url(
target_url: &str,
request_tail_path: &str,
request_query: &str,
) -> Result<Url, String> {
let trimmed = target_url.trim();
if trimmed.is_empty() {
return Err("service route target URL is empty".to_string());
}
let mut parsed =
Url::parse(trimmed).map_err(|err| format!("service route target URL is invalid: {err}"))?;
match parsed.scheme() {
"http" => parsed
.set_scheme("ws")
.map_err(|_| "failed to switch realtime target URL to ws".to_string())?,
"https" => parsed
.set_scheme("wss")
.map_err(|_| "failed to switch realtime target URL to wss".to_string())?,
"ws" | "wss" => {}
_ => {
return Err(
"realtime websocket target URL must use http(s) or ws(s) with a host".to_string(),
);
}
}
if parsed.host_str().is_none() {
return Err("realtime websocket target URL must include a host".to_string());
}
let base_path = parsed.path().trim_end_matches('/');
let request_tail_path = request_tail_path.trim_matches('/');
let next_path = if request_tail_path.is_empty() {
if base_path.is_empty() {
"/".to_string()
} else {
base_path.to_string()
}
} else if base_path.is_empty() {
format!("/{request_tail_path}")
} else {
format!("{base_path}/{request_tail_path}")
};
parsed.set_path(&next_path);
let mut query_parts = Vec::new();
if let Some(existing_query) = parsed.query().filter(|value| !value.is_empty()) {
query_parts.push(existing_query.to_string());
}
if !request_query.trim().is_empty() {
query_parts.push(request_query.to_string());
}
if query_parts.is_empty() {
parsed.set_query(None);
} else {
parsed.set_query(Some(&query_parts.join("&")));
}
Ok(parsed)
}
pub(super) fn header_is_hop_by_hop(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"connection"
| "content-length"
| "host"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "te"
| "trailer"
| "transfer-encoding"
| "upgrade"
| "accept-encoding"
)
}
pub(super) fn current_proxy_hop(req: &HttpRequest) -> u8 {
req.headers()
.get(PROXY_HOP_HEADER)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.trim().parse::<u8>().ok())
.unwrap_or(0)
}
pub(super) async fn proxy_service_request(
req: &HttpRequest,
app_state: &AppState,
target_url: &str,
request_tail_path: &str,
service_key: &str,
route_key: Option<&str>,
client_name: Option<&str>,
body: web::Bytes,
) -> HttpResponse {
let hop = current_proxy_hop(req);
if hop >= 1 {
return bad_gateway(
"Service route loop detected",
"Refusing to proxy a request that already passed through an Athena service route target.",
);
}
let upstream_url =
match build_service_target_url(target_url, request_tail_path, req.query_string()) {
Ok(url) => url,
Err(err) => return service_unavailable("Invalid service route target", err),
};
let method = req
.method()
.as_str()
.parse::<reqwest::Method>()
.unwrap_or(reqwest::Method::GET);
let mut upstream = app_state.client.request(method, upstream_url.clone());
for (name, value) in req.headers() {
let name_str = name.as_str();
if header_is_hop_by_hop(name_str)
|| name_str.eq_ignore_ascii_case(PROXY_HOP_HEADER)
|| name_str.eq_ignore_ascii_case(ROUTE_KEY_HEADER)
|| name_str.eq_ignore_ascii_case(SERVICE_KEY_HEADER)
{
continue;
}
let Ok(reqwest_name) = reqwest::header::HeaderName::from_bytes(name_str.as_bytes()) else {
continue;
};
let Ok(reqwest_value) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) else {
continue;
};
upstream = upstream.header(reqwest_name, reqwest_value);
}
upstream = upstream
.header(PROXY_HOP_HEADER, (hop + 1).to_string())
.header(SERVICE_KEY_HEADER, service_key);
if let Some(route_key) = route_key.filter(|value| !value.trim().is_empty()) {
upstream = upstream.header(ROUTE_KEY_HEADER, route_key);
}
if let Some(client_name) = client_name.filter(|value| !value.trim().is_empty()) {
upstream = upstream.header("X-Athena-Client", client_name);
}
if let Some(host) = req
.headers()
.get(actix_header::HOST)
.and_then(|value| value.to_str().ok())
.filter(|value| !value.trim().is_empty())
{
upstream = upstream.header(PROXY_ORIGIN_HOST_HEADER, host);
}
if !body.is_empty() {
upstream = upstream.body(body.clone());
}
let upstream_response = match upstream.send().await {
Ok(response) => response,
Err(err) => {
warn!(
target_url = %target_url,
upstream_url = %upstream_url,
service_key = %service_key,
route_key = ?route_key,
client_name = ?client_name,
error = %err,
"Failed to proxy service route request"
);
return bad_gateway("Service route unavailable", err.to_string());
}
};
let status = StatusCode::from_u16(upstream_response.status().as_u16())
.unwrap_or(StatusCode::BAD_GATEWAY);
let response_headers = upstream_response
.headers()
.iter()
.filter_map(|(name, value)| {
let name_str = name.as_str();
if header_is_hop_by_hop(name_str) {
return None;
}
let actix_name = actix_header::HeaderName::from_bytes(name_str.as_bytes()).ok()?;
let actix_value = actix_header::HeaderValue::from_bytes(value.as_bytes()).ok()?;
Some((actix_name, actix_value))
})
.collect::<Vec<_>>();
let response_body = match upstream_response.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
return bad_gateway("Service route response unavailable", err.to_string());
}
};
let mut builder = HttpResponse::build(status);
for (name, value) in response_headers {
builder.append_header((name, value));
}
builder.body(response_body)
}
#[cfg(test)]
mod tests {
use super::{build_realtime_websocket_target_url, build_service_target_url};
#[test]
fn builds_service_target_url_from_host_root() {
let url =
build_service_target_url("https://auth.example.com", "sessions/verify", "token=1")
.expect("url");
assert_eq!(
url.as_str(),
"https://auth.example.com/sessions/verify?token=1"
);
}
#[test]
fn preserves_base_path_and_query() {
let url = build_service_target_url(
"https://auth.example.com/api/auth?mode=tenant",
"callback/google",
"code=abc",
)
.expect("url");
assert_eq!(
url.as_str(),
"https://auth.example.com/api/auth/callback/google?mode=tenant&code=abc"
);
}
#[test]
fn derives_realtime_websocket_url_from_https_target() {
let url = build_realtime_websocket_target_url(
"https://realtime.example.com/base?mode=tenant",
"socket",
"token=1",
)
.expect("url");
assert_eq!(
url.as_str(),
"wss://realtime.example.com/base/socket?mode=tenant&token=1"
);
}
}