athena_rs 3.26.2

Hyper performant polyglot Database driver
Documentation
use actix_web::dev::ServiceRequest;
use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, web};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::Message as UpstreamWebsocketMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
use tracing::warn;

use crate::api::response::{bad_gateway, not_found};

use super::proxy::{
    PROXY_HOP_HEADER, PROXY_ORIGIN_HOST_HEADER, ROUTE_KEY_HEADER, SERVICE_KEY_HEADER,
    build_realtime_websocket_target_url, current_proxy_hop, header_is_hop_by_hop,
};

pub(super) const INTERNAL_REALTIME_WEBSOCKET_PROXY_PATH: &str =
    "/__athena_internal/service-router/realtime-websocket-proxy";

#[derive(Debug, Clone)]
struct PreparedRealtimeWebsocketProxy {
    target_url: String,
    request_tail_path: String,
    route_key: Option<String>,
    client_name: Option<String>,
}

pub(super) fn prepare_realtime_websocket_proxy(
    req: &mut ServiceRequest,
    target_url: &str,
    request_tail_path: &str,
    route_key: Option<&str>,
    client_name: Option<&str>,
) -> Result<(), Error> {
    req.extensions_mut().insert(PreparedRealtimeWebsocketProxy {
        target_url: target_url.to_string(),
        request_tail_path: request_tail_path.to_string(),
        route_key: route_key.map(str::to_string),
        client_name: client_name.map(str::to_string),
    });
    Ok(())
}

async fn proxy_realtime_websocket(
    req: HttpRequest,
    body: web::Payload,
) -> Result<HttpResponse, Error> {
    let Some(prepared) = req
        .extensions()
        .get::<PreparedRealtimeWebsocketProxy>()
        .cloned()
    else {
        return Ok(not_found(
            "Prepared realtime proxy missing",
            "This internal realtime websocket proxy route may only be used through the Athena service router.",
        ));
    };

    let hop = current_proxy_hop(&req);
    if hop >= 1 {
        return Ok(bad_gateway(
            "Service route loop detected",
            "Refusing to proxy a realtime websocket request that already passed through an Athena service route target.",
        ));
    }

    let upstream_request = match build_upstream_websocket_request(&req, &prepared, hop + 1) {
        Ok(request) => request,
        Err(err) => return Ok(bad_gateway("Invalid realtime websocket target", err)),
    };

    let (upstream_socket, _) = match connect_async(upstream_request).await {
        Ok(connection) => connection,
        Err(err) => {
            warn!(
                target_url = %prepared.target_url,
                request_tail_path = %prepared.request_tail_path,
                route_key = ?prepared.route_key,
                client_name = ?prepared.client_name,
                error = %err,
                "Failed to connect realtime websocket upstream"
            );
            return Ok(bad_gateway(
                "Realtime websocket upstream unavailable",
                err.to_string(),
            ));
        }
    };

    let (response, session, message_stream) = actix_ws::handle(&req, body)?;
    actix_web::rt::spawn(async move {
        bridge_websocket_streams(session, message_stream, upstream_socket).await;
    });

    Ok(response)
}

fn build_upstream_websocket_request(
    req: &HttpRequest,
    prepared: &PreparedRealtimeWebsocketProxy,
    next_hop: u8,
) -> Result<tokio_tungstenite::tungstenite::http::Request<()>, String> {
    let upstream_url = build_realtime_websocket_target_url(
        &prepared.target_url,
        &prepared.request_tail_path,
        req.query_string(),
    )?;
    let mut request = upstream_url
        .as_str()
        .into_client_request()
        .map_err(|err| err.to_string())?;

    for (name, value) in req.headers() {
        let name_str = name.as_str();
        if header_is_hop_by_hop(name_str)
            || name_str.eq_ignore_ascii_case(PROXY_HOP_HEADER)
            || name_str.eq_ignore_ascii_case(ROUTE_KEY_HEADER)
            || name_str.eq_ignore_ascii_case(SERVICE_KEY_HEADER)
            || name_str.eq_ignore_ascii_case("sec-websocket-key")
            || name_str.eq_ignore_ascii_case("sec-websocket-version")
            || name_str.eq_ignore_ascii_case("sec-websocket-extensions")
            || name_str.eq_ignore_ascii_case("upgrade")
        {
            continue;
        }

        let upstream_name =
            tokio_tungstenite::tungstenite::http::HeaderName::from_bytes(name_str.as_bytes())
                .map_err(|err| err.to_string())?;
        let upstream_value =
            tokio_tungstenite::tungstenite::http::HeaderValue::from_bytes(value.as_bytes())
                .map_err(|err| err.to_string())?;
        request.headers_mut().append(upstream_name, upstream_value);
    }

    request.headers_mut().insert(
        tokio_tungstenite::tungstenite::http::HeaderName::from_static(PROXY_HOP_HEADER),
        tokio_tungstenite::tungstenite::http::HeaderValue::from_str(&next_hop.to_string())
            .map_err(|err| err.to_string())?,
    );
    request.headers_mut().insert(
        tokio_tungstenite::tungstenite::http::HeaderName::from_static(SERVICE_KEY_HEADER),
        tokio_tungstenite::tungstenite::http::HeaderValue::from_static("realtime"),
    );

    if let Some(route_key) = prepared
        .route_key
        .as_deref()
        .map(str::trim)
        .filter(|value| !value.is_empty())
    {
        request.headers_mut().insert(
            tokio_tungstenite::tungstenite::http::HeaderName::from_static(ROUTE_KEY_HEADER),
            tokio_tungstenite::tungstenite::http::HeaderValue::from_str(route_key)
                .map_err(|err| err.to_string())?,
        );
    }
    if let Some(client_name) = prepared
        .client_name
        .as_deref()
        .map(str::trim)
        .filter(|value| !value.is_empty())
    {
        request.headers_mut().insert(
            tokio_tungstenite::tungstenite::http::HeaderName::from_static("x-athena-client"),
            tokio_tungstenite::tungstenite::http::HeaderValue::from_str(client_name)
                .map_err(|err| err.to_string())?,
        );
    }
    if let Some(host) = req
        .headers()
        .get(actix_web::http::header::HOST)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| !value.is_empty())
    {
        request.headers_mut().insert(
            tokio_tungstenite::tungstenite::http::HeaderName::from_static(PROXY_ORIGIN_HOST_HEADER),
            tokio_tungstenite::tungstenite::http::HeaderValue::from_str(host)
                .map_err(|err| err.to_string())?,
        );
    }

    Ok(request)
}

type UpstreamSocket = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;

async fn bridge_websocket_streams(
    mut session: actix_ws::Session,
    mut message_stream: actix_ws::MessageStream,
    mut upstream_socket: UpstreamSocket,
) {
    loop {
        tokio::select! {
            maybe_client_message = message_stream.next() => {
                match maybe_client_message {
                    Some(Ok(actix_ws::Message::Text(text))) => {
                        if upstream_socket.send(UpstreamWebsocketMessage::Text(text.to_string())).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(actix_ws::Message::Binary(bytes))) => {
                        if upstream_socket.send(UpstreamWebsocketMessage::Binary(bytes.to_vec())).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(actix_ws::Message::Ping(bytes))) => {
                        if upstream_socket.send(UpstreamWebsocketMessage::Ping(bytes.to_vec())).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(actix_ws::Message::Pong(bytes))) => {
                        if upstream_socket.send(UpstreamWebsocketMessage::Pong(bytes.to_vec())).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(actix_ws::Message::Close(_))) => {
                        let _ = upstream_socket.send(UpstreamWebsocketMessage::Close(None)).await;
                        break;
                    }
                    Some(Ok(actix_ws::Message::Continuation(_))) => {
                        warn!("Unsupported websocket continuation frame on realtime proxy; closing bridge");
                        let _ = upstream_socket.send(UpstreamWebsocketMessage::Close(None)).await;
                        break;
                    }
                    Some(Ok(actix_ws::Message::Nop)) => {}
                    Some(Err(err)) => {
                        warn!(error = %err, "Realtime websocket client stream failed");
                        break;
                    }
                    None => {
                        let _ = upstream_socket.send(UpstreamWebsocketMessage::Close(None)).await;
                        break;
                    }
                }
            }
            maybe_upstream_message = upstream_socket.next() => {
                match maybe_upstream_message {
                    Some(Ok(UpstreamWebsocketMessage::Text(text))) => {
                        if session.text(text).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(UpstreamWebsocketMessage::Binary(bytes))) => {
                        if session.binary(bytes).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(UpstreamWebsocketMessage::Ping(bytes))) => {
                        if session.ping(&bytes).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(UpstreamWebsocketMessage::Pong(bytes))) => {
                        if session.pong(&bytes).await.is_err() {
                            break;
                        }
                    }
                    Some(Ok(UpstreamWebsocketMessage::Close(_))) => {
                        let _ = session.clone().close(None).await;
                        return;
                    }
                    Some(Ok(UpstreamWebsocketMessage::Frame(_))) => {}
                    Some(Err(err)) => {
                        warn!(error = %err, "Realtime websocket upstream stream failed");
                        break;
                    }
                    None => {
                        break;
                    }
                }
            }
        }
    }

    let _ = session.close(None).await;
}

pub(super) fn services(cfg: &mut web::ServiceConfig) {
    cfg.service(
        web::resource(INTERNAL_REALTIME_WEBSOCKET_PROXY_PATH)
            .route(web::get().to(proxy_realtime_websocket)),
    );
}