use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::extract::State;
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::axum_edge::EdgeState;
use super::http_handler_metrics::HttpTransport;
use crate::wire::redwire::listener::handle_session_consume_magic;
use reddb_wire::redwire::{
build_auth_response_bearer_payload, build_auth_response_frame, build_client_hello_frame,
decode_frame, encode_frame, read_frame_async, write_frame_async, Frame, MessageKind,
FRAME_HEADER_SIZE, REDWIRE_MAGIC,
};
pub(super) use reddb_wire::redwire::{REDWIRE_WS_PATH, REDWIRE_WS_SUBPROTOCOL};
const WS_BRIDGE_BUF: usize = 64 * 1024;
const WS_READ_CHUNK: usize = 16 * 1024;
#[derive(Debug, PartialEq, Eq)]
pub(super) enum WsRejection {
NotTls,
OriginMissing,
OriginRejected,
}
impl WsRejection {
fn status_and_msg(&self) -> (StatusCode, &'static str) {
match self {
WsRejection::NotTls => (
StatusCode::FORBIDDEN,
"redwire websocket requires TLS (wss://)",
),
WsRejection::OriginMissing => (
StatusCode::FORBIDDEN,
"redwire websocket upgrade requires an Origin header",
),
WsRejection::OriginRejected => (
StatusCode::FORBIDDEN,
"origin not allowed for redwire websocket",
),
}
}
}
impl IntoResponse for WsRejection {
fn into_response(self) -> Response {
let (status, msg) = self.status_and_msg();
(status, msg).into_response()
}
}
impl From<reddb_wire::redwire::WsUpgradeRefusal> for WsRejection {
fn from(refusal: reddb_wire::redwire::WsUpgradeRefusal) -> Self {
use reddb_wire::redwire::WsUpgradeRefusal as R;
match refusal {
R::NotTls => WsRejection::NotTls,
R::OriginMissing => WsRejection::OriginMissing,
R::OriginRejected => WsRejection::OriginRejected,
}
}
}
pub(super) fn ws_upgrade_decision(
transport: HttpTransport,
origin: Option<&str>,
allowlist: &[String],
) -> Result<(), WsRejection> {
reddb_wire::redwire::evaluate_ws_upgrade(transport == HttpTransport::Https, origin, allowlist)
.map_err(WsRejection::from)
}
pub(super) async fn redwire_ws_upgrade(
State(state): State<EdgeState>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Response {
let origin = headers
.get(header::ORIGIN)
.and_then(|value| value.to_str().ok());
if let Err(rejection) = ws_upgrade_decision(
state.transport,
origin,
state.server.websocket_allowed_origins(),
) {
return rejection.into_response();
}
let server = state.server.clone();
ws.protocols([REDWIRE_WS_SUBPROTOCOL])
.on_upgrade(move |socket| async move {
run_ws_session(socket, server).await;
})
}
enum WsInbound {
Data(Bytes),
Ignore,
Eof,
}
fn classify_inbound(inbound: Option<Result<Message, axum::Error>>) -> WsInbound {
match inbound {
Some(Ok(Message::Binary(bytes))) => WsInbound::Data(bytes),
Some(Ok(Message::Close(_))) | Some(Err(_)) | None => WsInbound::Eof,
Some(Ok(_)) => WsInbound::Ignore,
}
}
pub(crate) async fn run_ws_session(socket: WebSocket, server: super::RedDBServer) {
let runtime = Arc::new(server.runtime().clone());
let auth_store = runtime.auth_store();
let oauth = runtime.oauth_validator();
let (session_io, net_io) = tokio::io::duplex(WS_BRIDGE_BUF);
let session = tokio::spawn(async move {
let _ = handle_session_consume_magic(session_io, runtime, auth_store, oauth).await;
});
pump_ws_stream(socket, net_io).await;
session.abort();
}
pub(crate) async fn run_injected_ws_session(
socket: WebSocket,
server: super::RedDBServer,
token: &str,
) {
let runtime = Arc::new(server.runtime().clone());
let auth_store = runtime.auth_store();
let oauth = runtime.oauth_validator();
let (session_io, net_io) = tokio::io::duplex(WS_BRIDGE_BUF);
let session = tokio::spawn(async move {
let _ = handle_session_consume_magic(session_io, runtime, auth_store, oauth).await;
});
inject_bearer_handshake(socket, net_io, token).await;
session.abort();
}
pub(crate) async fn inject_bearer_handshake<S>(mut socket: WebSocket, mut backend: S, token: &str)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send,
{
let mut reader = WsFrameReader::new();
let preamble = match reader.read_exact_bytes(&mut socket, 2).await {
Some(bytes) => bytes,
None => return close_ws(socket).await,
};
if preamble[0] != REDWIRE_MAGIC {
return close_ws(socket).await;
}
let minor = preamble[1];
let ui_hello = match reader.read_frame(&mut socket).await {
Some(frame) if frame.kind == MessageKind::Hello => frame,
_ => return close_ws(socket).await,
};
if backend.write_all(&[REDWIRE_MAGIC, minor]).await.is_err() {
return close_ws(socket).await;
}
let our_hello = match build_client_hello_frame(
ui_hello.correlation_id,
["bearer", "anonymous"],
0,
Some("red-ui-bridge"),
) {
Ok(frame) => frame,
Err(_) => return close_ws(socket).await,
};
if write_frame_async(&mut backend, &our_hello).await.is_err() {
return close_ws(socket).await;
}
let ack = match read_frame_async(&mut backend).await {
Ok(frame) => frame,
Err(_) => return close_ws(socket).await,
};
if send_frame(&mut socket, &ack).await.is_err() {
return;
}
let ui_auth = match reader.read_frame(&mut socket).await {
Some(frame) => frame,
None => return close_ws(socket).await,
};
let bearer = match build_auth_response_frame(
ui_auth.correlation_id,
build_auth_response_bearer_payload(token),
) {
Ok(frame) => frame,
Err(_) => return close_ws(socket).await,
};
if write_frame_async(&mut backend, &bearer).await.is_err() {
return close_ws(socket).await;
}
let auth_reply = match read_frame_async(&mut backend).await {
Ok(frame) => frame,
Err(_) => return close_ws(socket).await,
};
if send_frame(&mut socket, &auth_reply).await.is_err() {
return;
}
if !reader.buffered().is_empty() {
let pending = reader.take_buffered();
if backend.write_all(&pending).await.is_err() {
return close_ws(socket).await;
}
}
pump_ws_stream(socket, backend).await;
}
async fn send_frame(socket: &mut WebSocket, frame: &Frame) -> Result<(), axum::Error> {
socket
.send(Message::Binary(Bytes::from(encode_frame(frame))))
.await
}
async fn close_ws(mut socket: WebSocket) {
let _ = socket.send(Message::Close(None)).await;
}
struct WsFrameReader {
buf: Vec<u8>,
}
impl WsFrameReader {
fn new() -> Self {
Self { buf: Vec::new() }
}
fn buffered(&self) -> &[u8] {
&self.buf
}
fn take_buffered(&mut self) -> Vec<u8> {
std::mem::take(&mut self.buf)
}
async fn read_exact_bytes(&mut self, socket: &mut WebSocket, n: usize) -> Option<Vec<u8>> {
while self.buf.len() < n {
if !self.fill(socket).await {
return None;
}
}
let tail = self.buf.split_off(n);
Some(std::mem::replace(&mut self.buf, tail))
}
async fn read_frame(&mut self, socket: &mut WebSocket) -> Option<Frame> {
loop {
if self.buf.len() >= FRAME_HEADER_SIZE {
if let Ok((frame, consumed)) = decode_frame(&self.buf) {
self.buf.drain(..consumed);
return Some(frame);
}
}
if !self.fill(socket).await {
return None;
}
}
}
async fn fill(&mut self, socket: &mut WebSocket) -> bool {
loop {
match socket.recv().await {
Some(Ok(Message::Binary(bytes))) => {
self.buf.extend_from_slice(&bytes);
return true;
}
Some(Ok(Message::Close(_))) | Some(Err(_)) | None => return false,
Some(Ok(_)) => continue,
}
}
}
}
pub(crate) async fn pump_ws_stream<S>(mut socket: WebSocket, stream: S)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite,
{
let (mut net_read, mut net_write) = tokio::io::split(stream);
let mut out_buf = vec![0u8; WS_READ_CHUNK];
loop {
tokio::select! {
inbound = socket.recv() => {
match classify_inbound(inbound) {
WsInbound::Data(bytes) => {
if net_write.write_all(&bytes).await.is_err() {
break;
}
}
WsInbound::Ignore => {}
WsInbound::Eof => break,
}
}
outbound = net_read.read(&mut out_buf) => {
match outbound {
Ok(0) => break,
Ok(n) => {
let msg = Message::Binary(Bytes::copy_from_slice(&out_buf[..n]));
if socket.send(msg).await.is_err() {
break;
}
}
Err(_) => break,
}
}
}
}
drop(net_write);
drop(net_read);
let _ = socket.send(Message::Close(None)).await;
}
#[cfg(test)]
mod tests {
use super::*;
fn allowlist() -> Vec<String> {
vec![
"https://app.example.com".to_string(),
"https://admin.example.com".to_string(),
]
}
#[test]
fn allowed_origin_over_tls_is_accepted() {
let result = ws_upgrade_decision(
HttpTransport::Https,
Some("https://app.example.com"),
&allowlist(),
);
assert_eq!(result, Ok(()));
}
#[test]
fn non_allowlisted_origin_is_rejected() {
let result = ws_upgrade_decision(
HttpTransport::Https,
Some("https://evil.example.com"),
&allowlist(),
);
assert_eq!(result, Err(WsRejection::OriginRejected));
}
#[test]
fn missing_origin_is_rejected() {
let result = ws_upgrade_decision(HttpTransport::Https, None, &allowlist());
assert_eq!(result, Err(WsRejection::OriginMissing));
}
#[test]
fn plain_ws_over_http_edge_is_rejected_even_when_origin_allowed() {
let result = ws_upgrade_decision(
HttpTransport::Http,
Some("https://app.example.com"),
&allowlist(),
);
assert_eq!(result, Err(WsRejection::NotTls));
}
#[test]
fn empty_allowlist_rejects_every_origin() {
let result =
ws_upgrade_decision(HttpTransport::Https, Some("https://app.example.com"), &[]);
assert_eq!(result, Err(WsRejection::OriginRejected));
}
#[test]
fn origin_match_is_exact_not_prefix() {
let result = ws_upgrade_decision(
HttpTransport::Https,
Some("https://app.example.com.evil.com"),
&allowlist(),
);
assert_eq!(result, Err(WsRejection::OriginRejected));
}
fn assert_data(got: WsInbound, expected: &[u8]) {
match got {
WsInbound::Data(b) => assert_eq!(&b[..], expected),
WsInbound::Ignore => panic!("expected Data, got Ignore"),
WsInbound::Eof => panic!("expected Data, got Eof"),
}
}
#[test]
fn binary_message_passes_through_byte_for_byte() {
let frame = vec![0xFE, 0x01, 0x10, 0x00, 0x00, 0x00];
let got = classify_inbound(Some(Ok(Message::Binary(Bytes::from(frame.clone())))));
assert_data(got, &frame);
}
#[test]
fn binary_messages_reassemble_in_order() {
let first = classify_inbound(Some(Ok(Message::Binary(Bytes::from_static(&[0xFE, 0x01])))));
let second = classify_inbound(Some(Ok(Message::Binary(Bytes::from_static(&[0x10, 0x00])))));
assert_data(first, &[0xFE, 0x01]);
assert_data(second, &[0x10, 0x00]);
}
#[test]
fn close_and_stream_end_map_to_eof() {
assert!(matches!(
classify_inbound(Some(Ok(Message::Close(None)))),
WsInbound::Eof
));
assert!(matches!(classify_inbound(None), WsInbound::Eof));
}
#[test]
fn control_frames_are_ignored_not_forwarded() {
assert!(matches!(
classify_inbound(Some(Ok(Message::Ping(Bytes::new())))),
WsInbound::Ignore
));
assert!(matches!(
classify_inbound(Some(Ok(Message::Pong(Bytes::new())))),
WsInbound::Ignore
));
}
}