use actix_web::http::{StatusCode, header as actix_header};
use actix_web::{HttpRequest, HttpResponse};
use reqwest::Url;
use serde_json::Value;
use tracing::warn;
use crate::AppState;
use crate::api::gateway::contracts::GatewayOperationKind;
use crate::api::gateway::response::{
GATEWAY_ERROR_CODE_INVALID_ROUTE_TARGET, GATEWAY_ERROR_CODE_ROUTE_TARGET_LOOP_DETECTED,
GATEWAY_ERROR_CODE_ROUTE_TARGET_RESPONSE_UNAVAILABLE,
GATEWAY_ERROR_CODE_ROUTE_TARGET_UNAVAILABLE, gateway_bad_request_with_code,
gateway_service_unavailable_with_code,
};
const PROXY_HOP_HEADER: &str = "x-athena-proxy-hop";
const PROXY_ORIGIN_HOST_HEADER: &str = "x-athena-origin-host";
pub(super) fn build_route_target_url(
target_url: &str,
request_path: &str,
request_query: &str,
) -> Result<Url, String> {
let trimmed = target_url.trim();
if trimmed.is_empty() {
return Err("route target URL is empty".to_string());
}
let mut parsed =
Url::parse(trimmed).map_err(|err| format!("route target URL is invalid: {err}"))?;
if !matches!(parsed.scheme(), "http" | "https") || parsed.host_str().is_none() {
return Err("route target URL must include an http or https scheme and host".to_string());
}
let base_path = parsed.path().trim_end_matches('/');
let request_path = if request_path.starts_with('/') {
request_path.to_string()
} else {
format!("/{request_path}")
};
let next_path = if base_path.is_empty() {
request_path
} else {
format!("{}/{}", base_path, request_path.trim_start_matches('/'))
};
parsed.set_path(&next_path);
let mut query_parts = Vec::new();
if let Some(existing_query) = parsed.query().filter(|value| !value.is_empty()) {
query_parts.push(existing_query.to_string());
}
if !request_query.trim().is_empty() {
query_parts.push(request_query.to_string());
}
if query_parts.is_empty() {
parsed.set_query(None);
} else {
parsed.set_query(Some(&query_parts.join("&")));
}
Ok(parsed)
}
fn header_is_hop_by_hop(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"connection"
| "content-length"
| "host"
| "keep-alive"
| "proxy-authenticate"
| "proxy-authorization"
| "te"
| "trailer"
| "transfer-encoding"
| "upgrade"
| "accept-encoding"
)
}
fn current_proxy_hop(req: &HttpRequest) -> u8 {
req.headers()
.get(PROXY_HOP_HEADER)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.trim().parse::<u8>().ok())
.unwrap_or(0)
}
pub(crate) async fn proxy_fetch_to_route_target(
req: &HttpRequest,
app_state: &AppState,
target_url: &str,
client_name: &str,
body: Option<&Value>,
) -> HttpResponse {
let hop = current_proxy_hop(req);
if hop >= 1 {
return gateway_bad_request_with_code(
GATEWAY_ERROR_CODE_ROUTE_TARGET_LOOP_DETECTED,
GatewayOperationKind::Fetch,
"Route target loop detected",
"Refusing to proxy a request that already passed through an Athena route target.",
);
}
let upstream_url = match build_route_target_url(target_url, req.path(), req.query_string()) {
Ok(url) => url,
Err(err) => {
return gateway_service_unavailable_with_code(
GATEWAY_ERROR_CODE_INVALID_ROUTE_TARGET,
GatewayOperationKind::Fetch,
"Invalid route target",
err,
);
}
};
let method = req
.method()
.as_str()
.parse::<reqwest::Method>()
.unwrap_or(reqwest::Method::POST);
let mut upstream = app_state.client.request(method, upstream_url.clone());
for (name, value) in req.headers() {
let name_str = name.as_str();
if header_is_hop_by_hop(name_str)
|| name_str.eq_ignore_ascii_case(PROXY_HOP_HEADER)
|| name_str.eq_ignore_ascii_case("x-athena-client")
{
continue;
}
let Ok(reqwest_name) = reqwest::header::HeaderName::from_bytes(name_str.as_bytes()) else {
continue;
};
let Ok(reqwest_value) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) else {
continue;
};
upstream = upstream.header(reqwest_name, reqwest_value);
}
upstream = upstream
.header("X-Athena-Client", client_name)
.header(PROXY_HOP_HEADER, (hop + 1).to_string());
if let Some(host) = req
.headers()
.get(actix_header::HOST)
.and_then(|value| value.to_str().ok())
.filter(|value| !value.trim().is_empty())
{
upstream = upstream.header(PROXY_ORIGIN_HOST_HEADER, host);
}
if let Some(value) = body {
upstream = upstream.json(value);
}
let upstream_response = match upstream.send().await {
Ok(response) => response,
Err(err) => {
warn!(
target_url = %target_url,
upstream_url = %upstream_url,
client_name = %client_name,
error = %err,
"Failed to proxy gateway fetch to route target"
);
return gateway_service_unavailable_with_code(
GATEWAY_ERROR_CODE_ROUTE_TARGET_UNAVAILABLE,
GatewayOperationKind::Fetch,
"Route target unavailable",
err.to_string(),
);
}
};
let status = StatusCode::from_u16(upstream_response.status().as_u16())
.unwrap_or(StatusCode::BAD_GATEWAY);
let response_headers = upstream_response
.headers()
.iter()
.filter_map(|(name, value)| {
let name_str = name.as_str();
if header_is_hop_by_hop(name_str) {
return None;
}
let actix_name = actix_header::HeaderName::from_bytes(name_str.as_bytes()).ok()?;
let actix_value = actix_header::HeaderValue::from_bytes(value.as_bytes()).ok()?;
Some((actix_name, actix_value))
})
.collect::<Vec<_>>();
let response_body = match upstream_response.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
return gateway_service_unavailable_with_code(
GATEWAY_ERROR_CODE_ROUTE_TARGET_RESPONSE_UNAVAILABLE,
GatewayOperationKind::Fetch,
"Route target response unavailable",
err.to_string(),
);
}
};
let mut builder = HttpResponse::build(status);
for (name, value) in response_headers {
builder.append_header((name, value));
}
builder.body(response_body)
}
#[cfg(test)]
mod tests {
use super::build_route_target_url;
#[test]
fn builds_target_url_from_host_root() {
let url = build_route_target_url(
"https://mirror1.athena-cluster.com",
"/gateway/fetch",
"limit=10",
)
.expect("url");
assert_eq!(
url.as_str(),
"https://mirror1.athena-cluster.com/gateway/fetch?limit=10"
);
}
#[test]
fn preserves_target_base_path_and_query() {
let url = build_route_target_url(
"https://mirror1.athena-cluster.com/base?token=1",
"/gateway/data",
"page=2",
)
.expect("url");
assert_eq!(
url.as_str(),
"https://mirror1.athena-cluster.com/base/gateway/data?token=1&page=2"
);
}
#[test]
fn rejects_non_http_targets() {
let err = build_route_target_url("postgres://mirror1/db", "/gateway/fetch", "")
.expect_err("expected error");
assert!(err.contains("http or https"));
}
}