athena_rs 3.22.1

Hyper performant polyglot Database driver
Documentation
use actix_web::http::{StatusCode, header as actix_header};
use actix_web::{HttpRequest, HttpResponse, web};
use reqwest::Url;
use tracing::warn;

use crate::AppState;
use crate::api::response::{bad_gateway, service_unavailable};

pub(super) const PROXY_HOP_HEADER: &str = "x-athena-proxy-hop";
pub(super) const PROXY_ORIGIN_HOST_HEADER: &str = "x-athena-origin-host";
pub(super) const ROUTE_KEY_HEADER: &str = "x-athena-route-key";
pub(super) const SERVICE_KEY_HEADER: &str = "x-athena-service-key";

pub(super) fn build_service_target_url(
    target_url: &str,
    request_tail_path: &str,
    request_query: &str,
) -> Result<Url, String> {
    let trimmed = target_url.trim();
    if trimmed.is_empty() {
        return Err("service route target URL is empty".to_string());
    }

    let mut parsed =
        Url::parse(trimmed).map_err(|err| format!("service route target URL is invalid: {err}"))?;
    if !matches!(parsed.scheme(), "http" | "https") || parsed.host_str().is_none() {
        return Err(
            "service route target URL must include an http or https scheme and host".to_string(),
        );
    }

    let base_path = parsed.path().trim_end_matches('/');
    let request_tail_path = request_tail_path.trim_matches('/');
    let next_path = if request_tail_path.is_empty() {
        if base_path.is_empty() {
            "/".to_string()
        } else {
            base_path.to_string()
        }
    } else if base_path.is_empty() {
        format!("/{request_tail_path}")
    } else {
        format!("{base_path}/{request_tail_path}")
    };
    parsed.set_path(&next_path);

    let mut query_parts = Vec::new();
    if let Some(existing_query) = parsed.query().filter(|value| !value.is_empty()) {
        query_parts.push(existing_query.to_string());
    }
    if !request_query.trim().is_empty() {
        query_parts.push(request_query.to_string());
    }
    if query_parts.is_empty() {
        parsed.set_query(None);
    } else {
        parsed.set_query(Some(&query_parts.join("&")));
    }

    Ok(parsed)
}

pub(super) fn build_realtime_websocket_target_url(
    target_url: &str,
    request_tail_path: &str,
    request_query: &str,
) -> Result<Url, String> {
    let trimmed = target_url.trim();
    if trimmed.is_empty() {
        return Err("service route target URL is empty".to_string());
    }

    let mut parsed =
        Url::parse(trimmed).map_err(|err| format!("service route target URL is invalid: {err}"))?;
    match parsed.scheme() {
        "http" => parsed
            .set_scheme("ws")
            .map_err(|_| "failed to switch realtime target URL to ws".to_string())?,
        "https" => parsed
            .set_scheme("wss")
            .map_err(|_| "failed to switch realtime target URL to wss".to_string())?,
        "ws" | "wss" => {}
        _ => {
            return Err(
                "realtime websocket target URL must use http(s) or ws(s) with a host".to_string(),
            );
        }
    }

    if parsed.host_str().is_none() {
        return Err("realtime websocket target URL must include a host".to_string());
    }

    let base_path = parsed.path().trim_end_matches('/');
    let request_tail_path = request_tail_path.trim_matches('/');
    let next_path = if request_tail_path.is_empty() {
        if base_path.is_empty() {
            "/".to_string()
        } else {
            base_path.to_string()
        }
    } else if base_path.is_empty() {
        format!("/{request_tail_path}")
    } else {
        format!("{base_path}/{request_tail_path}")
    };
    parsed.set_path(&next_path);

    let mut query_parts = Vec::new();
    if let Some(existing_query) = parsed.query().filter(|value| !value.is_empty()) {
        query_parts.push(existing_query.to_string());
    }
    if !request_query.trim().is_empty() {
        query_parts.push(request_query.to_string());
    }
    if query_parts.is_empty() {
        parsed.set_query(None);
    } else {
        parsed.set_query(Some(&query_parts.join("&")));
    }

    Ok(parsed)
}

pub(super) fn header_is_hop_by_hop(name: &str) -> bool {
    matches!(
        name.to_ascii_lowercase().as_str(),
        "connection"
            | "content-length"
            | "host"
            | "keep-alive"
            | "proxy-authenticate"
            | "proxy-authorization"
            | "te"
            | "trailer"
            | "transfer-encoding"
            | "upgrade"
            | "accept-encoding"
    )
}

pub(super) fn current_proxy_hop(req: &HttpRequest) -> u8 {
    req.headers()
        .get(PROXY_HOP_HEADER)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.trim().parse::<u8>().ok())
        .unwrap_or(0)
}

pub(super) async fn proxy_service_request(
    req: &HttpRequest,
    app_state: &AppState,
    target_url: &str,
    request_tail_path: &str,
    service_key: &str,
    route_key: Option<&str>,
    client_name: Option<&str>,
    body: web::Bytes,
) -> HttpResponse {
    let hop = current_proxy_hop(req);
    if hop >= 1 {
        return bad_gateway(
            "Service route loop detected",
            "Refusing to proxy a request that already passed through an Athena service route target.",
        );
    }

    let upstream_url =
        match build_service_target_url(target_url, request_tail_path, req.query_string()) {
            Ok(url) => url,
            Err(err) => return service_unavailable("Invalid service route target", err),
        };
    let method = req
        .method()
        .as_str()
        .parse::<reqwest::Method>()
        .unwrap_or(reqwest::Method::GET);

    let mut upstream = app_state.client.request(method, upstream_url.clone());
    for (name, value) in req.headers() {
        let name_str = name.as_str();
        if header_is_hop_by_hop(name_str)
            || name_str.eq_ignore_ascii_case(PROXY_HOP_HEADER)
            || name_str.eq_ignore_ascii_case(ROUTE_KEY_HEADER)
            || name_str.eq_ignore_ascii_case(SERVICE_KEY_HEADER)
        {
            continue;
        }
        let Ok(reqwest_name) = reqwest::header::HeaderName::from_bytes(name_str.as_bytes()) else {
            continue;
        };
        let Ok(reqwest_value) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) else {
            continue;
        };
        upstream = upstream.header(reqwest_name, reqwest_value);
    }

    upstream = upstream
        .header(PROXY_HOP_HEADER, (hop + 1).to_string())
        .header(SERVICE_KEY_HEADER, service_key);
    if let Some(route_key) = route_key.filter(|value| !value.trim().is_empty()) {
        upstream = upstream.header(ROUTE_KEY_HEADER, route_key);
    }
    if let Some(client_name) = client_name.filter(|value| !value.trim().is_empty()) {
        upstream = upstream.header("X-Athena-Client", client_name);
    }
    if let Some(host) = req
        .headers()
        .get(actix_header::HOST)
        .and_then(|value| value.to_str().ok())
        .filter(|value| !value.trim().is_empty())
    {
        upstream = upstream.header(PROXY_ORIGIN_HOST_HEADER, host);
    }
    if !body.is_empty() {
        upstream = upstream.body(body.clone());
    }

    let upstream_response = match upstream.send().await {
        Ok(response) => response,
        Err(err) => {
            warn!(
                target_url = %target_url,
                upstream_url = %upstream_url,
                service_key = %service_key,
                route_key = ?route_key,
                client_name = ?client_name,
                error = %err,
                "Failed to proxy service route request"
            );
            return bad_gateway("Service route unavailable", err.to_string());
        }
    };

    let status = StatusCode::from_u16(upstream_response.status().as_u16())
        .unwrap_or(StatusCode::BAD_GATEWAY);
    let response_headers = upstream_response
        .headers()
        .iter()
        .filter_map(|(name, value)| {
            let name_str = name.as_str();
            if header_is_hop_by_hop(name_str) {
                return None;
            }
            let actix_name = actix_header::HeaderName::from_bytes(name_str.as_bytes()).ok()?;
            let actix_value = actix_header::HeaderValue::from_bytes(value.as_bytes()).ok()?;
            Some((actix_name, actix_value))
        })
        .collect::<Vec<_>>();
    let response_body = match upstream_response.bytes().await {
        Ok(bytes) => bytes,
        Err(err) => {
            return bad_gateway("Service route response unavailable", err.to_string());
        }
    };

    let mut builder = HttpResponse::build(status);
    for (name, value) in response_headers {
        builder.append_header((name, value));
    }
    builder.body(response_body)
}

#[cfg(test)]
mod tests {
    use super::{build_realtime_websocket_target_url, build_service_target_url};

    #[test]
    fn builds_service_target_url_from_host_root() {
        let url =
            build_service_target_url("https://auth.example.com", "sessions/verify", "token=1")
                .expect("url");

        assert_eq!(
            url.as_str(),
            "https://auth.example.com/sessions/verify?token=1"
        );
    }

    #[test]
    fn preserves_base_path_and_query() {
        let url = build_service_target_url(
            "https://auth.example.com/api/auth?mode=tenant",
            "callback/google",
            "code=abc",
        )
        .expect("url");

        assert_eq!(
            url.as_str(),
            "https://auth.example.com/api/auth/callback/google?mode=tenant&code=abc"
        );
    }

    #[test]
    fn derives_realtime_websocket_url_from_https_target() {
        let url = build_realtime_websocket_target_url(
            "https://realtime.example.com/base?mode=tenant",
            "socket",
            "token=1",
        )
        .expect("url");

        assert_eq!(
            url.as_str(),
            "wss://realtime.example.com/base/socket?mode=tenant&token=1"
        );
    }
}