use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use axum::extract::ws::{WebSocket, WebSocketUpgrade};
use axum::extract::State;
use axum::http::{header, HeaderMap, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use tokio::sync::oneshot;
use super::ws_edge::{
inject_bearer_handshake, pump_ws_stream, run_injected_ws_session, run_ws_session,
REDWIRE_WS_PATH, REDWIRE_WS_SUBPROTOCOL,
};
use super::RedDBServer;
const FIXTURE_INDEX: &str = include_str!("ui_bridge_fixture/index.html");
#[derive(Debug, Clone, Default)]
pub struct UiBridgeConfig {
pub ui_dir: Option<PathBuf>,
pub port: u16,
pub injected_token: Option<String>,
pub auth_mode: super::ui_auth::UiAuthMode,
}
#[derive(Debug, Clone)]
pub struct RemoteRedwireTarget {
pub host: String,
pub port: u16,
pub tls: bool,
pub ca_pem: Option<Vec<u8>>,
}
enum BridgeBackend {
Embedded(Box<RedDBServer>),
Remote(RemoteRedwireTarget),
Direct,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UiTarget {
File,
Remote(RemoteRedwireTargetSpec),
Direct { ws_url: String },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RemoteRedwireTargetSpec {
pub host: String,
pub port: u16,
pub tls: bool,
}
pub fn classify_ui_target(uri: &str) -> Result<UiTarget, String> {
if !uri.contains("://") {
return Ok(UiTarget::File);
}
match reddb_wire::parse(uri) {
Ok(reddb_wire::ConnectionTarget::File { .. }) => Ok(UiTarget::File),
Ok(reddb_wire::ConnectionTarget::RedWire { host, port, tls }) => {
Ok(UiTarget::Remote(RemoteRedwireTargetSpec {
host,
port,
tls,
}))
}
Ok(reddb_wire::ConnectionTarget::WsNative { host, port, tls }) => {
let scheme = if tls { "wss" } else { "ws" };
let ws_url = format!("{scheme}://{host}:{port}/redwire");
Ok(UiTarget::Direct { ws_url })
}
Ok(_) | Err(_) => Err(format!(
"unsupported target for red ui; supported schemes: \
file://, red://, reds://, red+ws://, red+wss://; got: {uri}"
)),
}
}
#[derive(Clone)]
struct BridgeState {
backend: Arc<BridgeBackend>,
allowed_origins: Arc<Vec<String>>,
ui_dir: Option<Arc<PathBuf>>,
injected_token: Option<Arc<String>>,
auth_mode: super::ui_auth::UiAuthMode,
direct_ws_url: Option<Arc<String>>,
}
pub struct UiBridge {
local_addr: SocketAddr,
shutdown_tx: Option<oneshot::Sender<()>>,
join: tokio::task::JoinHandle<()>,
direct_ws_url: Option<String>,
}
impl UiBridge {
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn ui_url(&self) -> String {
format!("http://{}/", self.local_addr)
}
pub fn ws_url(&self) -> String {
self.direct_ws_url
.clone()
.unwrap_or_else(|| format!("ws://{}{}", self.local_addr, REDWIRE_WS_PATH))
}
pub async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
let _ = self.join.await;
}
}
pub fn loopback_ws_origin_allowed(origin: Option<&str>, allowed: &[String]) -> bool {
match origin {
None => false,
Some(o) => allowed.iter().any(|a| a == o),
}
}
fn seed_loopback_origins(port: u16) -> Vec<String> {
vec![
format!("http://127.0.0.1:{port}"),
format!("http://localhost:{port}"),
]
}
pub async fn spawn_ui_bridge(
server: RedDBServer,
config: UiBridgeConfig,
) -> std::io::Result<UiBridge> {
spawn_ui_bridge_backend(BridgeBackend::Embedded(Box::new(server)), config).await
}
pub async fn spawn_ui_bridge_remote(
target: RemoteRedwireTarget,
config: UiBridgeConfig,
) -> std::io::Result<UiBridge> {
spawn_ui_bridge_backend(BridgeBackend::Remote(target), config).await
}
async fn spawn_ui_bridge_backend(
backend: BridgeBackend,
config: UiBridgeConfig,
) -> std::io::Result<UiBridge> {
let listener = tokio::net::TcpListener::bind(("127.0.0.1", config.port)).await?;
let local_addr = listener.local_addr()?;
let state = BridgeState {
backend: Arc::new(backend),
allowed_origins: Arc::new(seed_loopback_origins(local_addr.port())),
ui_dir: config.ui_dir.map(Arc::new),
injected_token: config.injected_token.map(Arc::new),
auth_mode: config.auth_mode,
direct_ws_url: None,
};
let router = axum::Router::new()
.route(REDWIRE_WS_PATH, get(loopback_redwire_upgrade))
.fallback(serve_ui)
.with_state(state);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let join = tokio::spawn(async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await;
});
Ok(UiBridge {
local_addr,
shutdown_tx: Some(shutdown_tx),
join,
direct_ws_url: None,
})
}
pub async fn spawn_direct_ui_server(
ws_url: String,
config: UiBridgeConfig,
) -> std::io::Result<UiBridge> {
let listener = tokio::net::TcpListener::bind(("127.0.0.1", config.port)).await?;
let local_addr = listener.local_addr()?;
let state = BridgeState {
backend: Arc::new(BridgeBackend::Direct),
allowed_origins: Arc::new(vec![]),
ui_dir: config.ui_dir.map(Arc::new),
injected_token: config.injected_token.map(Arc::new),
auth_mode: config.auth_mode,
direct_ws_url: Some(Arc::new(ws_url.clone())),
};
let router = axum::Router::new().fallback(serve_ui).with_state(state);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let join = tokio::spawn(async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await;
});
Ok(UiBridge {
local_addr,
shutdown_tx: Some(shutdown_tx),
join,
direct_ws_url: Some(ws_url),
})
}
async fn loopback_redwire_upgrade(
State(state): State<BridgeState>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Response {
let origin = headers
.get(header::ORIGIN)
.and_then(|value| value.to_str().ok());
if !loopback_ws_origin_allowed(origin, &state.allowed_origins) {
return (
StatusCode::FORBIDDEN,
"origin not allowed for loopback redwire websocket",
)
.into_response();
}
let backend = Arc::clone(&state.backend);
let injected_token = state.injected_token.clone();
ws.protocols([REDWIRE_WS_SUBPROTOCOL])
.on_upgrade(move |socket| async move {
match &*backend {
BridgeBackend::Embedded(server) => {
if let Some(token) = injected_token.as_deref().map(String::as_str) {
run_injected_ws_session(socket, (**server).clone(), token).await;
} else {
run_ws_session(socket, (**server).clone()).await;
}
}
BridgeBackend::Remote(target) => {
run_remote_ws_session(
socket,
target,
injected_token.as_deref().map(String::as_str),
)
.await;
}
BridgeBackend::Direct => {
close_ws(socket).await;
}
}
})
}
async fn run_remote_ws_session(
socket: WebSocket,
target: &RemoteRedwireTarget,
injected_token: Option<&str>,
) {
let addr = (target.host.as_str(), target.port);
let tcp = match tokio::net::TcpStream::connect(addr).await {
Ok(tcp) => tcp,
Err(err) => {
tracing::warn!(
host = %target.host,
port = target.port,
err = %err,
"ui bridge: connect to remote redwire target failed"
);
close_ws(socket).await;
return;
}
};
if !target.tls {
if let Some(token) = injected_token {
inject_bearer_handshake(socket, tcp, token).await;
} else {
pump_ws_stream(socket, tcp).await;
}
return;
}
match wrap_remote_tls(tcp, target).await {
Ok(tls) => {
if let Some(token) = injected_token {
inject_bearer_handshake(socket, tls, token).await;
} else {
pump_ws_stream(socket, tls).await;
}
}
Err(err) => {
tracing::warn!(
host = %target.host,
port = target.port,
err = %err,
"ui bridge: TLS handshake to remote redwire target failed"
);
close_ws(socket).await;
}
}
}
async fn close_ws(mut socket: WebSocket) {
let _ = socket.send(axum::extract::ws::Message::Close(None)).await;
}
async fn wrap_remote_tls(
tcp: tokio::net::TcpStream,
target: &RemoteRedwireTarget,
) -> std::io::Result<tokio_rustls::client::TlsStream<tokio::net::TcpStream>> {
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, RootCertStore};
let _ = rustls::crypto::ring::default_provider().install_default();
let mut roots = RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(pem) = &target.ca_pem {
let mut reader = std::io::BufReader::new(&pem[..]);
for cert in rustls_pemfile::certs(&mut reader) {
let cert = cert.map_err(std::io::Error::other)?;
roots
.add(cert)
.map_err(|e| std::io::Error::other(format!("add CA cert: {e}")))?;
}
}
let config = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(config));
let server_name = ServerName::try_from(target.host.clone())
.map_err(|e| std::io::Error::other(format!("invalid TLS server name: {e}")))?;
connector.connect(server_name, tcp).await
}
async fn serve_ui(State(state): State<BridgeState>, uri: Uri) -> Response {
let raw = uri.path();
let rel = raw.trim_start_matches('/');
let rel = if rel.is_empty() { "index.html" } else { rel };
let (content_type, mut body) = match &state.ui_dir {
None => {
if rel == "index.html" {
(
"text/html; charset=utf-8",
FIXTURE_INDEX.as_bytes().to_vec(),
)
} else {
return not_found();
}
}
Some(dir) => {
if rel
.split('/')
.any(|seg| seg == ".." || seg == "." || seg.is_empty())
{
return not_found();
}
let full = dir.join(rel);
match tokio::task::spawn_blocking(move || std::fs::read(&full)).await {
Ok(Ok(bytes)) => (content_type_for(rel), bytes),
_ => return not_found(),
}
}
};
if content_type.starts_with("text/html") {
if let Some(ws_url) = &state.direct_ws_url {
body = inject_ws_url_config(body, ws_url);
}
body = super::ui_auth::inject_auth_mode_config(body, state.auth_mode);
}
(StatusCode::OK, [(header::CONTENT_TYPE, content_type)], body).into_response()
}
fn inject_ws_url_config(html: Vec<u8>, ws_url: &str) -> Vec<u8> {
let snippet = format!("<script>window.REDDB_WS_URL=\"{ws_url}\";</script>");
let marker = b"</head>";
match html.windows(marker.len()).position(|w| w == marker) {
Some(pos) => {
let mut out = Vec::with_capacity(html.len() + snippet.len());
out.extend_from_slice(&html[..pos]);
out.extend_from_slice(snippet.as_bytes());
out.extend_from_slice(&html[pos..]);
out
}
None => html,
}
}
pub(crate) fn content_type_for(path: &str) -> &'static str {
let ext = path.rsplit('.').next().unwrap_or("");
match ext {
"html" | "htm" => "text/html; charset=utf-8",
"js" | "mjs" => "text/javascript; charset=utf-8",
"css" => "text/css; charset=utf-8",
"json" => "application/json; charset=utf-8",
"svg" => "image/svg+xml",
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
"gif" => "image/gif",
"ico" => "image/x-icon",
"wasm" => "application/wasm",
"map" => "application/json; charset=utf-8",
"txt" => "text/plain; charset=utf-8",
_ => "application/octet-stream",
}
}
fn content_type_response(path: &str, body: Vec<u8>) -> Response {
(
StatusCode::OK,
[(header::CONTENT_TYPE, content_type_for(path))],
body,
)
.into_response()
}
fn html_response(body: Vec<u8>) -> Response {
(
StatusCode::OK,
[(header::CONTENT_TYPE, "text/html; charset=utf-8")],
body,
)
.into_response()
}
fn not_found() -> Response {
(StatusCode::NOT_FOUND, "not found").into_response()
}
#[cfg(test)]
mod tests {
use super::*;
fn origins() -> Vec<String> {
seed_loopback_origins(7777)
}
#[test]
fn served_origin_is_allowed() {
assert!(loopback_ws_origin_allowed(
Some("http://127.0.0.1:7777"),
&origins()
));
assert!(loopback_ws_origin_allowed(
Some("http://localhost:7777"),
&origins()
));
}
#[test]
fn missing_origin_is_rejected() {
assert!(!loopback_ws_origin_allowed(None, &origins()));
}
#[test]
fn cross_site_origin_is_rejected() {
assert!(!loopback_ws_origin_allowed(
Some("http://evil.example.com"),
&origins()
));
assert!(!loopback_ws_origin_allowed(
Some("http://127.0.0.1:9999"),
&origins()
));
}
#[test]
fn empty_allowlist_denies_every_origin() {
assert!(!loopback_ws_origin_allowed(
Some("http://127.0.0.1:7777"),
&[]
));
}
#[test]
fn red_scheme_classifies_as_remote_plaintext_default_port() {
assert_eq!(
classify_ui_target("red://db.internal").unwrap(),
UiTarget::Remote(RemoteRedwireTargetSpec {
host: "db.internal".to_string(),
port: reddb_wire::DEFAULT_PORT_RED,
tls: false,
})
);
assert_eq!(reddb_wire::DEFAULT_PORT_RED, 5050);
}
#[test]
fn reds_scheme_classifies_as_remote_tls_default_port() {
assert_eq!(
classify_ui_target("reds://db.internal").unwrap(),
UiTarget::Remote(RemoteRedwireTargetSpec {
host: "db.internal".to_string(),
port: reddb_wire::DEFAULT_PORT_RED,
tls: true,
})
);
}
#[test]
fn red_scheme_honours_explicit_port() {
assert_eq!(
classify_ui_target("red://127.0.0.1:6000").unwrap(),
UiTarget::Remote(RemoteRedwireTargetSpec {
host: "127.0.0.1".to_string(),
port: 6000,
tls: false,
})
);
assert_eq!(
classify_ui_target("reds://host:7001").unwrap(),
UiTarget::Remote(RemoteRedwireTargetSpec {
host: "host".to_string(),
port: 7001,
tls: true,
})
);
}
#[test]
fn file_and_bare_path_classify_as_local() {
assert_eq!(
classify_ui_target("file:///var/lib/db.rdb").unwrap(),
UiTarget::File
);
assert_eq!(classify_ui_target("./data.rdb").unwrap(), UiTarget::File);
assert_eq!(classify_ui_target("data.rdb").unwrap(), UiTarget::File);
}
#[test]
fn unsupported_scheme_is_rejected() {
assert!(classify_ui_target("grpc://host:5055").is_err());
assert!(classify_ui_target("http://host").is_err());
assert!(classify_ui_target("red://a,b").is_err());
}
#[test]
fn red_plus_wss_classifies_as_direct_default_port() {
assert_eq!(
classify_ui_target("red+wss://mydb.db.reddb.io").unwrap(),
UiTarget::Direct {
ws_url: "wss://mydb.db.reddb.io:443/redwire".to_string(),
}
);
}
#[test]
fn red_plus_wss_with_explicit_port_classifies_as_direct() {
assert_eq!(
classify_ui_target("red+wss://host:5055").unwrap(),
UiTarget::Direct {
ws_url: "wss://host:5055/redwire".to_string(),
}
);
}
#[test]
fn red_plus_ws_classifies_as_direct_plaintext() {
assert_eq!(
classify_ui_target("red+ws://host:8080").unwrap(),
UiTarget::Direct {
ws_url: "ws://host:8080/redwire".to_string(),
}
);
}
#[test]
fn unsupported_scheme_error_names_supported_set() {
let err = classify_ui_target("mongodb://host").unwrap_err();
for scheme in ["file://", "red://", "reds://", "red+ws://", "red+wss://"] {
assert!(
err.contains(scheme),
"error must mention {scheme}: got: {err}"
);
}
}
#[test]
fn inject_ws_url_inserts_before_head_close() {
let html = b"<html><head></head><body></body></html>".to_vec();
let out = inject_ws_url_config(html, "wss://host:443/redwire");
let s = String::from_utf8(out).unwrap();
assert!(
s.contains("<script>window.REDDB_WS_URL=\"wss://host:443/redwire\";</script></head>"),
"snippet must appear before </head>: {s}"
);
}
#[test]
fn inject_ws_url_noop_when_no_head_close() {
let html = b"<html><body>no head close</body></html>".to_vec();
let orig = html.clone();
let out = inject_ws_url_config(html, "wss://host/redwire");
assert_eq!(out, orig, "html without </head> must be returned unchanged");
}
#[test]
fn content_types_cover_bundle_assets() {
assert_eq!(content_type_for("index.html"), "text/html; charset=utf-8");
assert_eq!(content_type_for("app.js"), "text/javascript; charset=utf-8");
assert_eq!(content_type_for("style.css"), "text/css; charset=utf-8");
assert_eq!(content_type_for("data.bin"), "application/octet-stream");
}
}