use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::State;
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::axum_edge::EdgeState;
use super::http_handler_metrics::HttpTransport;
use crate::wire::redwire::listener::handle_session_consume_magic;
pub(super) const REDWIRE_WS_PATH: &str = "/redwire";
pub(super) const REDWIRE_WS_SUBPROTOCOL: &str = "reddb.redwire.v1";
const WS_BRIDGE_BUF: usize = 64 * 1024;
const WS_READ_CHUNK: usize = 16 * 1024;
#[derive(Debug, PartialEq, Eq)]
pub(super) enum WsRejection {
NotTls,
OriginMissing,
OriginRejected,
}
impl WsRejection {
fn status_and_msg(&self) -> (StatusCode, &'static str) {
match self {
WsRejection::NotTls => (
StatusCode::FORBIDDEN,
"redwire websocket requires TLS (wss://)",
),
WsRejection::OriginMissing => (
StatusCode::FORBIDDEN,
"redwire websocket upgrade requires an Origin header",
),
WsRejection::OriginRejected => {
(StatusCode::FORBIDDEN, "origin not allowed for redwire websocket")
}
}
}
}
impl IntoResponse for WsRejection {
fn into_response(self) -> Response {
let (status, msg) = self.status_and_msg();
(status, msg).into_response()
}
}
pub(super) fn ws_upgrade_decision(
transport: HttpTransport,
origin: Option<&str>,
allowlist: &[String],
) -> Result<(), WsRejection> {
if transport != HttpTransport::Https {
return Err(WsRejection::NotTls);
}
match origin {
None => Err(WsRejection::OriginMissing),
Some(o) if allowlist.iter().any(|allowed| allowed == o) => Ok(()),
Some(_) => Err(WsRejection::OriginRejected),
}
}
pub(super) async fn redwire_ws_upgrade(
State(state): State<EdgeState>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Response {
let origin = headers
.get(header::ORIGIN)
.and_then(|value| value.to_str().ok());
if let Err(rejection) =
ws_upgrade_decision(state.transport, origin, state.server.websocket_allowed_origins())
{
return rejection.into_response();
}
let server = state.server.clone();
ws.protocols([REDWIRE_WS_SUBPROTOCOL])
.on_upgrade(move |socket| async move {
run_ws_session(socket, server).await;
})
}
enum WsInbound {
Data(Bytes),
Ignore,
Eof,
}
fn classify_inbound(inbound: Option<Result<Message, axum::Error>>) -> WsInbound {
match inbound {
Some(Ok(Message::Binary(bytes))) => WsInbound::Data(bytes),
Some(Ok(Message::Close(_))) | Some(Err(_)) | None => WsInbound::Eof,
Some(Ok(_)) => WsInbound::Ignore,
}
}
async fn run_ws_session(mut socket: WebSocket, server: super::RedDBServer) {
let runtime = Arc::new(server.runtime().clone());
let auth_store = runtime.auth_store();
let oauth = runtime.oauth_validator();
let (session_io, net_io) = tokio::io::duplex(WS_BRIDGE_BUF);
let session = tokio::spawn(async move {
let _ = handle_session_consume_magic(session_io, runtime, auth_store, oauth).await;
});
let (mut net_read, mut net_write) = tokio::io::split(net_io);
let mut out_buf = vec![0u8; WS_READ_CHUNK];
loop {
tokio::select! {
inbound = socket.recv() => {
match classify_inbound(inbound) {
WsInbound::Data(bytes) => {
if net_write.write_all(&bytes).await.is_err() {
break;
}
}
WsInbound::Ignore => {}
WsInbound::Eof => break,
}
}
outbound = net_read.read(&mut out_buf) => {
match outbound {
Ok(0) => break,
Ok(n) => {
let msg = Message::Binary(Bytes::copy_from_slice(&out_buf[..n]));
if socket.send(msg).await.is_err() {
break;
}
}
Err(_) => break,
}
}
}
}
drop(net_write);
drop(net_read);
let _ = socket.send(Message::Close(None)).await;
session.abort();
}
#[cfg(test)]
mod tests {
use super::*;
fn allowlist() -> Vec<String> {
vec![
"https://app.example.com".to_string(),
"https://admin.example.com".to_string(),
]
}
#[test]
fn allowed_origin_over_tls_is_accepted() {
let result = ws_upgrade_decision(
HttpTransport::Https,
Some("https://app.example.com"),
&allowlist(),
);
assert_eq!(result, Ok(()));
}
#[test]
fn non_allowlisted_origin_is_rejected() {
let result = ws_upgrade_decision(
HttpTransport::Https,
Some("https://evil.example.com"),
&allowlist(),
);
assert_eq!(result, Err(WsRejection::OriginRejected));
}
#[test]
fn missing_origin_is_rejected() {
let result = ws_upgrade_decision(HttpTransport::Https, None, &allowlist());
assert_eq!(result, Err(WsRejection::OriginMissing));
}
#[test]
fn plain_ws_over_http_edge_is_rejected_even_when_origin_allowed() {
let result = ws_upgrade_decision(
HttpTransport::Http,
Some("https://app.example.com"),
&allowlist(),
);
assert_eq!(result, Err(WsRejection::NotTls));
}
#[test]
fn empty_allowlist_rejects_every_origin() {
let result =
ws_upgrade_decision(HttpTransport::Https, Some("https://app.example.com"), &[]);
assert_eq!(result, Err(WsRejection::OriginRejected));
}
#[test]
fn origin_match_is_exact_not_prefix() {
let result = ws_upgrade_decision(
HttpTransport::Https,
Some("https://app.example.com.evil.com"),
&allowlist(),
);
assert_eq!(result, Err(WsRejection::OriginRejected));
}
fn assert_data(got: WsInbound, expected: &[u8]) {
match got {
WsInbound::Data(b) => assert_eq!(&b[..], expected),
WsInbound::Ignore => panic!("expected Data, got Ignore"),
WsInbound::Eof => panic!("expected Data, got Eof"),
}
}
#[test]
fn binary_message_passes_through_byte_for_byte() {
let frame = vec![0xFE, 0x01, 0x10, 0x00, 0x00, 0x00];
let got = classify_inbound(Some(Ok(Message::Binary(Bytes::from(frame.clone())))));
assert_data(got, &frame);
}
#[test]
fn binary_messages_reassemble_in_order() {
let first = classify_inbound(Some(Ok(Message::Binary(Bytes::from_static(&[0xFE, 0x01])))));
let second =
classify_inbound(Some(Ok(Message::Binary(Bytes::from_static(&[0x10, 0x00])))));
assert_data(first, &[0xFE, 0x01]);
assert_data(second, &[0x10, 0x00]);
}
#[test]
fn close_and_stream_end_map_to_eof() {
assert!(matches!(
classify_inbound(Some(Ok(Message::Close(None)))),
WsInbound::Eof
));
assert!(matches!(classify_inbound(None), WsInbound::Eof));
}
#[test]
fn control_frames_are_ignored_not_forwarded() {
assert!(matches!(
classify_inbound(Some(Ok(Message::Ping(Bytes::new())))),
WsInbound::Ignore
));
assert!(matches!(
classify_inbound(Some(Ok(Message::Pong(Bytes::new())))),
WsInbound::Ignore
));
}
}