athena_rs 3.23.0

Hyper performant polyglot Database driver
Documentation
use actix_web::http::{StatusCode, header as actix_header};
use actix_web::{HttpRequest, HttpResponse};
use reqwest::Url;
use serde_json::Value;
use tracing::warn;

use crate::AppState;
use crate::api::gateway::contracts::GatewayOperationKind;
use crate::api::gateway::response::{
    GATEWAY_ERROR_CODE_INVALID_ROUTE_TARGET, GATEWAY_ERROR_CODE_ROUTE_TARGET_LOOP_DETECTED,
    GATEWAY_ERROR_CODE_ROUTE_TARGET_RESPONSE_UNAVAILABLE,
    GATEWAY_ERROR_CODE_ROUTE_TARGET_UNAVAILABLE, gateway_bad_request_with_code,
    gateway_service_unavailable_with_code,
};

const PROXY_HOP_HEADER: &str = "x-athena-proxy-hop";
const PROXY_ORIGIN_HOST_HEADER: &str = "x-athena-origin-host";

pub(super) fn build_route_target_url(
    target_url: &str,
    request_path: &str,
    request_query: &str,
) -> Result<Url, String> {
    let trimmed = target_url.trim();
    if trimmed.is_empty() {
        return Err("route target URL is empty".to_string());
    }

    let mut parsed =
        Url::parse(trimmed).map_err(|err| format!("route target URL is invalid: {err}"))?;
    if !matches!(parsed.scheme(), "http" | "https") || parsed.host_str().is_none() {
        return Err("route target URL must include an http or https scheme and host".to_string());
    }

    let base_path = parsed.path().trim_end_matches('/');
    let request_path = if request_path.starts_with('/') {
        request_path.to_string()
    } else {
        format!("/{request_path}")
    };
    let next_path = if base_path.is_empty() {
        request_path
    } else {
        format!("{}/{}", base_path, request_path.trim_start_matches('/'))
    };
    parsed.set_path(&next_path);

    let mut query_parts = Vec::new();
    if let Some(existing_query) = parsed.query().filter(|value| !value.is_empty()) {
        query_parts.push(existing_query.to_string());
    }
    if !request_query.trim().is_empty() {
        query_parts.push(request_query.to_string());
    }
    if query_parts.is_empty() {
        parsed.set_query(None);
    } else {
        parsed.set_query(Some(&query_parts.join("&")));
    }

    Ok(parsed)
}

fn header_is_hop_by_hop(name: &str) -> bool {
    matches!(
        name.to_ascii_lowercase().as_str(),
        "connection"
            | "content-length"
            | "host"
            | "keep-alive"
            | "proxy-authenticate"
            | "proxy-authorization"
            | "te"
            | "trailer"
            | "transfer-encoding"
            | "upgrade"
            | "accept-encoding"
    )
}

fn current_proxy_hop(req: &HttpRequest) -> u8 {
    req.headers()
        .get(PROXY_HOP_HEADER)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.trim().parse::<u8>().ok())
        .unwrap_or(0)
}

pub(crate) async fn proxy_fetch_to_route_target(
    req: &HttpRequest,
    app_state: &AppState,
    target_url: &str,
    client_name: &str,
    body: Option<&Value>,
) -> HttpResponse {
    let hop = current_proxy_hop(req);
    if hop >= 1 {
        return gateway_bad_request_with_code(
            GATEWAY_ERROR_CODE_ROUTE_TARGET_LOOP_DETECTED,
            GatewayOperationKind::Fetch,
            "Route target loop detected",
            "Refusing to proxy a request that already passed through an Athena route target.",
        );
    }

    let upstream_url = match build_route_target_url(target_url, req.path(), req.query_string()) {
        Ok(url) => url,
        Err(err) => {
            return gateway_service_unavailable_with_code(
                GATEWAY_ERROR_CODE_INVALID_ROUTE_TARGET,
                GatewayOperationKind::Fetch,
                "Invalid route target",
                err,
            );
        }
    };
    let method = req
        .method()
        .as_str()
        .parse::<reqwest::Method>()
        .unwrap_or(reqwest::Method::POST);

    let mut upstream = app_state.client.request(method, upstream_url.clone());
    for (name, value) in req.headers() {
        let name_str = name.as_str();
        if header_is_hop_by_hop(name_str)
            || name_str.eq_ignore_ascii_case(PROXY_HOP_HEADER)
            || name_str.eq_ignore_ascii_case("x-athena-client")
        {
            continue;
        }
        let Ok(reqwest_name) = reqwest::header::HeaderName::from_bytes(name_str.as_bytes()) else {
            continue;
        };
        let Ok(reqwest_value) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) else {
            continue;
        };
        upstream = upstream.header(reqwest_name, reqwest_value);
    }

    upstream = upstream
        .header("X-Athena-Client", client_name)
        .header(PROXY_HOP_HEADER, (hop + 1).to_string());
    if let Some(host) = req
        .headers()
        .get(actix_header::HOST)
        .and_then(|value| value.to_str().ok())
        .filter(|value| !value.trim().is_empty())
    {
        upstream = upstream.header(PROXY_ORIGIN_HOST_HEADER, host);
    }
    if let Some(value) = body {
        upstream = upstream.json(value);
    }

    let upstream_response = match upstream.send().await {
        Ok(response) => response,
        Err(err) => {
            warn!(
                target_url = %target_url,
                upstream_url = %upstream_url,
                client_name = %client_name,
                error = %err,
                "Failed to proxy gateway fetch to route target"
            );
            return gateway_service_unavailable_with_code(
                GATEWAY_ERROR_CODE_ROUTE_TARGET_UNAVAILABLE,
                GatewayOperationKind::Fetch,
                "Route target unavailable",
                err.to_string(),
            );
        }
    };

    let status = StatusCode::from_u16(upstream_response.status().as_u16())
        .unwrap_or(StatusCode::BAD_GATEWAY);
    let response_headers = upstream_response
        .headers()
        .iter()
        .filter_map(|(name, value)| {
            let name_str = name.as_str();
            if header_is_hop_by_hop(name_str) {
                return None;
            }
            let actix_name = actix_header::HeaderName::from_bytes(name_str.as_bytes()).ok()?;
            let actix_value = actix_header::HeaderValue::from_bytes(value.as_bytes()).ok()?;
            Some((actix_name, actix_value))
        })
        .collect::<Vec<_>>();
    let response_body = match upstream_response.bytes().await {
        Ok(bytes) => bytes,
        Err(err) => {
            return gateway_service_unavailable_with_code(
                GATEWAY_ERROR_CODE_ROUTE_TARGET_RESPONSE_UNAVAILABLE,
                GatewayOperationKind::Fetch,
                "Route target response unavailable",
                err.to_string(),
            );
        }
    };

    let mut builder = HttpResponse::build(status);
    for (name, value) in response_headers {
        builder.append_header((name, value));
    }
    builder.body(response_body)
}

#[cfg(test)]
mod tests {
    use super::build_route_target_url;

    #[test]
    fn builds_target_url_from_host_root() {
        let url = build_route_target_url(
            "https://mirror1.athena-cluster.com",
            "/gateway/fetch",
            "limit=10",
        )
        .expect("url");

        assert_eq!(
            url.as_str(),
            "https://mirror1.athena-cluster.com/gateway/fetch?limit=10"
        );
    }

    #[test]
    fn preserves_target_base_path_and_query() {
        let url = build_route_target_url(
            "https://mirror1.athena-cluster.com/base?token=1",
            "/gateway/data",
            "page=2",
        )
        .expect("url");

        assert_eq!(
            url.as_str(),
            "https://mirror1.athena-cluster.com/base/gateway/data?token=1&page=2"
        );
    }

    #[test]
    fn rejects_non_http_targets() {
        let err = build_route_target_url("postgres://mirror1/db", "/gateway/fetch", "")
            .expect_err("expected error");

        assert!(err.contains("http or https"));
    }
}