use actix_web::dev::ServiceRequest;
use actix_web::{Error, HttpMessage, HttpRequest, HttpResponse, web};
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::Message as UpstreamWebsocketMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
use tracing::warn;
use crate::api::response::{bad_gateway, not_found};
use super::proxy::{
PROXY_HOP_HEADER, PROXY_ORIGIN_HOST_HEADER, ROUTE_KEY_HEADER, SERVICE_KEY_HEADER,
build_realtime_websocket_target_url, current_proxy_hop, header_is_hop_by_hop,
};
pub(super) const INTERNAL_REALTIME_WEBSOCKET_PROXY_PATH: &str =
"/__athena_internal/service-router/realtime-websocket-proxy";
#[derive(Debug, Clone)]
struct PreparedRealtimeWebsocketProxy {
target_url: String,
request_tail_path: String,
route_key: Option<String>,
client_name: Option<String>,
}
pub(super) fn prepare_realtime_websocket_proxy(
req: &mut ServiceRequest,
target_url: &str,
request_tail_path: &str,
route_key: Option<&str>,
client_name: Option<&str>,
) -> Result<(), Error> {
req.extensions_mut().insert(PreparedRealtimeWebsocketProxy {
target_url: target_url.to_string(),
request_tail_path: request_tail_path.to_string(),
route_key: route_key.map(str::to_string),
client_name: client_name.map(str::to_string),
});
Ok(())
}
async fn proxy_realtime_websocket(
req: HttpRequest,
body: web::Payload,
) -> Result<HttpResponse, Error> {
let Some(prepared) = req
.extensions()
.get::<PreparedRealtimeWebsocketProxy>()
.cloned()
else {
return Ok(not_found(
"Prepared realtime proxy missing",
"This internal realtime websocket proxy route may only be used through the Athena service router.",
));
};
let hop = current_proxy_hop(&req);
if hop >= 1 {
return Ok(bad_gateway(
"Service route loop detected",
"Refusing to proxy a realtime websocket request that already passed through an Athena service route target.",
));
}
let upstream_request = match build_upstream_websocket_request(&req, &prepared, hop + 1) {
Ok(request) => request,
Err(err) => return Ok(bad_gateway("Invalid realtime websocket target", err)),
};
let (upstream_socket, _) = match connect_async(upstream_request).await {
Ok(connection) => connection,
Err(err) => {
warn!(
target_url = %prepared.target_url,
request_tail_path = %prepared.request_tail_path,
route_key = ?prepared.route_key,
client_name = ?prepared.client_name,
error = %err,
"Failed to connect realtime websocket upstream"
);
return Ok(bad_gateway(
"Realtime websocket upstream unavailable",
err.to_string(),
));
}
};
let (response, session, message_stream) = actix_ws::handle(&req, body)?;
actix_web::rt::spawn(async move {
bridge_websocket_streams(session, message_stream, upstream_socket).await;
});
Ok(response)
}
fn build_upstream_websocket_request(
req: &HttpRequest,
prepared: &PreparedRealtimeWebsocketProxy,
next_hop: u8,
) -> Result<tokio_tungstenite::tungstenite::http::Request<()>, String> {
let upstream_url = build_realtime_websocket_target_url(
&prepared.target_url,
&prepared.request_tail_path,
req.query_string(),
)?;
let mut request = upstream_url
.as_str()
.into_client_request()
.map_err(|err| err.to_string())?;
for (name, value) in req.headers() {
let name_str = name.as_str();
if header_is_hop_by_hop(name_str)
|| name_str.eq_ignore_ascii_case(PROXY_HOP_HEADER)
|| name_str.eq_ignore_ascii_case(ROUTE_KEY_HEADER)
|| name_str.eq_ignore_ascii_case(SERVICE_KEY_HEADER)
|| name_str.eq_ignore_ascii_case("sec-websocket-key")
|| name_str.eq_ignore_ascii_case("sec-websocket-version")
|| name_str.eq_ignore_ascii_case("sec-websocket-extensions")
|| name_str.eq_ignore_ascii_case("upgrade")
{
continue;
}
let upstream_name =
tokio_tungstenite::tungstenite::http::HeaderName::from_bytes(name_str.as_bytes())
.map_err(|err| err.to_string())?;
let upstream_value =
tokio_tungstenite::tungstenite::http::HeaderValue::from_bytes(value.as_bytes())
.map_err(|err| err.to_string())?;
request.headers_mut().append(upstream_name, upstream_value);
}
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::HeaderName::from_static(PROXY_HOP_HEADER),
tokio_tungstenite::tungstenite::http::HeaderValue::from_str(&next_hop.to_string())
.map_err(|err| err.to_string())?,
);
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::HeaderName::from_static(SERVICE_KEY_HEADER),
tokio_tungstenite::tungstenite::http::HeaderValue::from_static("realtime"),
);
if let Some(route_key) = prepared
.route_key
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
{
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::HeaderName::from_static(ROUTE_KEY_HEADER),
tokio_tungstenite::tungstenite::http::HeaderValue::from_str(route_key)
.map_err(|err| err.to_string())?,
);
}
if let Some(client_name) = prepared
.client_name
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
{
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::HeaderName::from_static("x-athena-client"),
tokio_tungstenite::tungstenite::http::HeaderValue::from_str(client_name)
.map_err(|err| err.to_string())?,
);
}
if let Some(host) = req
.headers()
.get(actix_web::http::header::HOST)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
{
request.headers_mut().insert(
tokio_tungstenite::tungstenite::http::HeaderName::from_static(PROXY_ORIGIN_HOST_HEADER),
tokio_tungstenite::tungstenite::http::HeaderValue::from_str(host)
.map_err(|err| err.to_string())?,
);
}
Ok(request)
}
type UpstreamSocket = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
async fn bridge_websocket_streams(
mut session: actix_ws::Session,
mut message_stream: actix_ws::MessageStream,
mut upstream_socket: UpstreamSocket,
) {
loop {
tokio::select! {
maybe_client_message = message_stream.next() => {
match maybe_client_message {
Some(Ok(actix_ws::Message::Text(text))) => {
if upstream_socket.send(UpstreamWebsocketMessage::Text(text.to_string())).await.is_err() {
break;
}
}
Some(Ok(actix_ws::Message::Binary(bytes))) => {
if upstream_socket.send(UpstreamWebsocketMessage::Binary(bytes.to_vec())).await.is_err() {
break;
}
}
Some(Ok(actix_ws::Message::Ping(bytes))) => {
if upstream_socket.send(UpstreamWebsocketMessage::Ping(bytes.to_vec())).await.is_err() {
break;
}
}
Some(Ok(actix_ws::Message::Pong(bytes))) => {
if upstream_socket.send(UpstreamWebsocketMessage::Pong(bytes.to_vec())).await.is_err() {
break;
}
}
Some(Ok(actix_ws::Message::Close(_))) => {
let _ = upstream_socket.send(UpstreamWebsocketMessage::Close(None)).await;
break;
}
Some(Ok(actix_ws::Message::Continuation(_))) => {
warn!("Unsupported websocket continuation frame on realtime proxy; closing bridge");
let _ = upstream_socket.send(UpstreamWebsocketMessage::Close(None)).await;
break;
}
Some(Ok(actix_ws::Message::Nop)) => {}
Some(Err(err)) => {
warn!(error = %err, "Realtime websocket client stream failed");
break;
}
None => {
let _ = upstream_socket.send(UpstreamWebsocketMessage::Close(None)).await;
break;
}
}
}
maybe_upstream_message = upstream_socket.next() => {
match maybe_upstream_message {
Some(Ok(UpstreamWebsocketMessage::Text(text))) => {
if session.text(text).await.is_err() {
break;
}
}
Some(Ok(UpstreamWebsocketMessage::Binary(bytes))) => {
if session.binary(bytes).await.is_err() {
break;
}
}
Some(Ok(UpstreamWebsocketMessage::Ping(bytes))) => {
if session.ping(&bytes).await.is_err() {
break;
}
}
Some(Ok(UpstreamWebsocketMessage::Pong(bytes))) => {
if session.pong(&bytes).await.is_err() {
break;
}
}
Some(Ok(UpstreamWebsocketMessage::Close(_))) => {
let _ = session.clone().close(None).await;
return;
}
Some(Ok(UpstreamWebsocketMessage::Frame(_))) => {}
Some(Err(err)) => {
warn!(error = %err, "Realtime websocket upstream stream failed");
break;
}
None => {
break;
}
}
}
}
}
let _ = session.close(None).await;
}
pub(super) fn services(cfg: &mut web::ServiceConfig) {
cfg.service(
web::resource(INTERNAL_REALTIME_WEBSOCKET_PROXY_PATH)
.route(web::get().to(proxy_realtime_websocket)),
);
}