use std::sync::{Arc, Mutex};
use axum::{
body::Body,
extract::{ws::Message as AxumMessage, Path, State, WebSocketUpgrade},
http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri},
response::{IntoResponse, Response},
};
use futures_util::{SinkExt, StreamExt};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::crypto::aws_lc_rs;
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme};
use sha2::{Digest, Sha256};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message as TungMessage;
use crate::AppState;
pub const UPSTREAM_AUTH_HEADER: &str = "x-mobux-upstream-authorization";
pub const HOP_HEADER: &str = "x-mobux-relay-hop";
const MAX_HOPS: u32 = 4;
fn default_peer_port() -> u16 {
std::env::var("PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(8080)
}
pub fn canonical_peer(raw: &str) -> Result<String, String> {
let raw = raw.trim();
if raw.is_empty() {
return Err("empty peer".into());
}
if raw.contains('/') || raw.contains(' ') {
return Err("invalid peer (host or host:port only)".into());
}
match raw.rsplit_once(':') {
Some((host, port)) => {
if host.is_empty() {
return Err("empty peer host".into());
}
let port: u16 = port
.parse()
.map_err(|_| format!("invalid peer port: {port}"))?;
Ok(format!("{host}:{port}"))
}
None => Ok(format!("{raw}:{}", default_peer_port())),
}
}
fn peer_allowed(peer: &str) -> bool {
match std::env::var("MOBUX_PEERS") {
Ok(list) if !list.trim().is_empty() => list
.split(',')
.filter_map(|e| canonical_peer(e).ok())
.any(|allowed| allowed == peer),
_ => true,
}
}
pub fn cert_fingerprint(der: &[u8]) -> String {
let digest = Sha256::digest(der);
hex::encode(digest)
}
#[derive(Debug)]
pub enum RelayError {
PinMismatch { expected: String, actual: String },
Upstream(String),
BadRequest(String),
}
impl RelayError {
fn into_response(self) -> Response {
match self {
RelayError::PinMismatch { expected, actual } => structured_error(
StatusCode::CONFLICT,
"pin_mismatch",
&format!(
"peer certificate fingerprint changed (pinned {expected}, got {actual}); \
delete the pin to re-trust this peer"
),
),
RelayError::Upstream(msg) => {
structured_error(StatusCode::BAD_GATEWAY, "upstream_error", &msg)
}
RelayError::BadRequest(msg) => {
structured_error(StatusCode::BAD_REQUEST, "bad_request", &msg)
}
}
}
}
fn structured_error(status: StatusCode, kind: &str, message: &str) -> Response {
let body = serde_json::json!({ "error": kind, "message": message });
(status, axum::Json(body)).into_response()
}
#[derive(Debug)]
struct TofuVerifier {
expected: Option<String>,
observed: Mutex<Option<String>>,
mismatch: Mutex<Option<(String, String)>>,
}
impl TofuVerifier {
fn new(expected: Option<String>) -> Self {
Self {
expected,
observed: Mutex::new(None),
mismatch: Mutex::new(None),
}
}
}
impl ServerCertVerifier for TofuVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let fp = cert_fingerprint(end_entity.as_ref());
*self.observed.lock().unwrap() = Some(fp.clone());
match &self.expected {
Some(pin) if pin != &fp => {
*self.mismatch.lock().unwrap() = Some((pin.clone(), fp));
Err(rustls::Error::General(
"mobux: peer cert pin mismatch".into(),
))
}
_ => Ok(ServerCertVerified::assertion()),
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&aws_lc_rs::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&aws_lc_rs::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
aws_lc_rs::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
fn pinned_client_config(expected: Option<String>) -> (ClientConfig, Arc<TofuVerifier>) {
let verifier = Arc::new(TofuVerifier::new(expected));
let config = ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier.clone())
.with_no_client_auth();
(config, verifier)
}
fn resolve_pin_outcome(
state: &AppState,
peer: &str,
had_pin: bool,
verifier: &TofuVerifier,
transport_err: Option<String>,
) -> Result<(), RelayError> {
if let Some((expected, actual)) = verifier.mismatch.lock().unwrap().clone() {
return Err(RelayError::PinMismatch { expected, actual });
}
if let Some(err) = transport_err {
return Err(RelayError::Upstream(err));
}
if !had_pin {
if let Some(fp) = verifier.observed.lock().unwrap().clone() {
if let Err(e) = state.db.insert_peer_pin(peer, &fp) {
eprintln!("[relay] WARN: failed to persist pin for {peer}: {e:#}");
}
}
}
Ok(())
}
fn is_stripped_request_header(name: &HeaderName) -> bool {
let n = name.as_str();
n == header::HOST.as_str()
|| n == header::AUTHORIZATION.as_str()
|| n == header::CONNECTION.as_str()
|| n == header::COOKIE.as_str()
|| n == UPSTREAM_AUTH_HEADER
|| n == "keep-alive"
|| n == "proxy-authenticate"
|| n == "proxy-authorization"
|| n == "te"
|| n == "trailer"
|| n == "transfer-encoding"
|| n == "upgrade"
}
pub fn build_forward_headers(incoming: &HeaderMap) -> HeaderMap {
let mut out = HeaderMap::new();
for (name, value) in incoming.iter() {
if is_stripped_request_header(name) {
continue;
}
out.append(name.clone(), value.clone());
}
if let Some(upstream) = incoming.get(HeaderName::from_static(UPSTREAM_AUTH_HEADER)) {
out.insert(header::AUTHORIZATION, upstream.clone());
}
out
}
fn is_stripped_response_header(name: &HeaderName) -> bool {
let n = name.as_str();
n == header::CONNECTION.as_str()
|| n == header::TRANSFER_ENCODING.as_str()
|| n == header::WWW_AUTHENTICATE.as_str()
|| n == "keep-alive"
|| n == "proxy-authenticate"
|| n == "te"
|| n == "trailer"
|| n == "upgrade"
}
fn check_loop_guard(headers: &HeaderMap, forward_path: &str) -> Result<u32, RelayError> {
if forward_path.starts_with("/r/") {
return Err(RelayError::BadRequest(
"refusing to relay a path that is itself a relay path (loop guard)".into(),
));
}
let hop = headers
.get(HeaderName::from_static(HOP_HEADER))
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(0);
if hop >= MAX_HOPS {
return Err(RelayError::BadRequest(format!(
"relay hop limit ({MAX_HOPS}) exceeded (loop guard)"
)));
}
Ok(hop + 1)
}
pub async fn relay_http(
State(state): State<AppState>,
Path((peer, rest)): Path<(String, String)>,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Response {
match relay_http_inner(state, peer, rest, method, uri, headers, body).await {
Ok(resp) => resp,
Err(e) => e.into_response(),
}
}
#[allow(clippy::too_many_arguments)]
async fn relay_http_inner(
state: AppState,
peer: String,
rest: String,
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<Response, RelayError> {
let peer = canonical_peer(&peer).map_err(RelayError::BadRequest)?;
if !peer_allowed(&peer) {
return Err(RelayError::BadRequest(format!(
"peer {peer} not in MOBUX_PEERS allowlist"
)));
}
let forward_path = format!("/{}", rest.trim_start_matches('/'));
let next_hop = check_loop_guard(&headers, &forward_path)?;
let query = uri.query().map(|q| format!("?{q}")).unwrap_or_default();
let url_str = format!("https://{peer}{forward_path}{query}");
let url = reqwest::Url::parse(&url_str)
.map_err(|e| RelayError::BadRequest(format!("invalid peer URL ({url_str}): {e}")))?;
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.map_err(|e| RelayError::BadRequest(format!("reading request body: {e}")))?;
let mut fwd_headers = build_forward_headers(&headers);
fwd_headers.insert(
HeaderName::from_static(HOP_HEADER),
HeaderValue::from_str(&next_hop.to_string()).unwrap(),
);
let pin = state
.db
.peer_pin(&peer)
.map_err(|e| RelayError::Upstream(format!("reading pin: {e}")))?;
let had_pin = pin.is_some();
let (tls, verifier) = pinned_client_config(pin);
let client = reqwest::Client::builder()
.use_preconfigured_tls(tls)
.connect_timeout(std::time::Duration::from_secs(5))
.build()
.map_err(|e| RelayError::Upstream(format!("building client: {e}")))?;
let upstream = client
.request(method, url)
.headers(fwd_headers)
.body(body_bytes)
.send()
.await;
let resp = match upstream {
Ok(r) => {
resolve_pin_outcome(&state, &peer, had_pin, &verifier, None)?;
r
}
Err(e) => {
resolve_pin_outcome(&state, &peer, had_pin, &verifier, Some(e.to_string()))?;
return Err(RelayError::Upstream(e.to_string()));
}
};
let status = resp.status();
let resp_headers = resp.headers().clone();
let bytes = resp
.bytes()
.await
.map_err(|e| RelayError::Upstream(format!("reading upstream body: {e}")))?;
let mut out = Response::builder().status(status);
for (name, value) in resp_headers.iter() {
if is_stripped_response_header(name) {
continue;
}
out = out.header(name, value);
}
out.body(Body::from(bytes))
.map_err(|e| RelayError::Upstream(format!("building response: {e}")))
}
pub async fn relay_ws(
State(state): State<AppState>,
Path((peer, rest)): Path<(String, String)>,
uri: Uri,
ws: WebSocketUpgrade,
) -> Response {
let peer = match canonical_peer(&peer) {
Ok(p) => p,
Err(e) => return RelayError::BadRequest(e).into_response(),
};
if !peer_allowed(&peer) {
return RelayError::BadRequest(format!("peer {peer} not in MOBUX_PEERS allowlist"))
.into_response();
}
let forward_path = format!("/ws/{}", rest.trim_start_matches('/'));
if forward_path.starts_with("/r/") {
return RelayError::BadRequest("loop guard: relay path".into()).into_response();
}
let (upstream_auth, fwd_query) = split_ws_query(uri.query());
let target = format!("wss://{peer}{forward_path}{fwd_query}");
let pin = match state.db.peer_pin(&peer) {
Ok(p) => p,
Err(e) => return RelayError::Upstream(format!("reading pin: {e}")).into_response(),
};
ws.on_upgrade(move |client_socket| async move {
if let Err(e) = pump_ws(state, peer, target, upstream_auth, pin, client_socket).await {
eprintln!("[relay] ws error: {e}");
}
})
}
fn split_ws_query(query: Option<&str>) -> (Option<String>, String) {
let Some(q) = query else {
return (None, String::new());
};
let mut auth = None;
let mut kept: Vec<&str> = Vec::new();
for pair in q.split('&') {
if let Some(v) = pair.strip_prefix("upstream_auth=") {
auth = Some(v.to_string());
} else if !pair.is_empty() {
kept.push(pair);
}
}
let fwd = if kept.is_empty() {
String::new()
} else {
format!("?{}", kept.join("&"))
};
(auth, fwd)
}
async fn pump_ws(
state: AppState,
peer: String,
target: String,
upstream_auth: Option<String>,
pin: Option<String>,
client_socket: axum::extract::ws::WebSocket,
) -> Result<(), String> {
let had_pin = pin.is_some();
let (tls, verifier) = pinned_client_config(pin);
let connector = tokio_tungstenite::Connector::Rustls(Arc::new(tls));
let mut request = target
.into_client_request()
.map_err(|e| format!("building ws request: {e}"))?;
if let Some(auth) = upstream_auth {
let decoded = percent_decode(&auth);
request.headers_mut().insert(
header::AUTHORIZATION,
HeaderValue::from_str(&format!("Basic {decoded}"))
.map_err(|e| format!("bad upstream auth: {e}"))?,
);
}
let dial =
tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
.await;
let (peer_ws, _resp) = match dial {
Ok(ok) => {
resolve_pin_outcome(&state, &peer, had_pin, &verifier, None)
.map_err(|e| format!("{e:?}"))?;
ok
}
Err(e) => {
resolve_pin_outcome(&state, &peer, had_pin, &verifier, Some(e.to_string()))
.map_err(|e| format!("{e:?}"))?;
return Err(format!("ws dial: {e}"));
}
};
let (mut peer_tx, mut peer_rx) = peer_ws.split();
let (mut client_tx, mut client_rx) = client_socket.split();
loop {
tokio::select! {
msg = client_rx.next() => match msg {
Some(Ok(m)) => {
if let Some(tm) = axum_to_tung(m) {
if peer_tx.send(tm).await.is_err() { break; }
}
}
_ => break,
},
msg = peer_rx.next() => match msg {
Some(Ok(m)) => {
if let Some(am) = tung_to_axum(m) {
if client_tx.send(am).await.is_err() { break; }
}
}
_ => break,
},
}
}
Ok(())
}
fn axum_to_tung(m: AxumMessage) -> Option<TungMessage> {
match m {
AxumMessage::Text(t) => Some(TungMessage::Text(t.as_str().into())),
AxumMessage::Binary(b) => Some(TungMessage::Binary(b)),
AxumMessage::Ping(p) => Some(TungMessage::Ping(p)),
AxumMessage::Pong(p) => Some(TungMessage::Pong(p)),
AxumMessage::Close(_) => Some(TungMessage::Close(None)),
}
}
fn tung_to_axum(m: TungMessage) -> Option<AxumMessage> {
match m {
TungMessage::Text(t) => Some(AxumMessage::Text(t.as_str().into())),
TungMessage::Binary(b) => Some(AxumMessage::Binary(b)),
TungMessage::Ping(p) => Some(AxumMessage::Ping(p)),
TungMessage::Pong(p) => Some(AxumMessage::Pong(p)),
TungMessage::Close(_) => Some(AxumMessage::Close(None)),
TungMessage::Frame(_) => None,
}
}
fn percent_decode(s: &str) -> String {
let bytes = s.as_bytes();
let mut out = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'%' if i + 2 < bytes.len() => {
let hi = (bytes[i + 1] as char).to_digit(16);
let lo = (bytes[i + 2] as char).to_digit(16);
if let (Some(h), Some(l)) = (hi, lo) {
out.push((h * 16 + l) as u8);
i += 3;
continue;
}
out.push(bytes[i]);
i += 1;
}
b'+' => {
out.push(b' ');
i += 1;
}
b => {
out.push(b);
i += 1;
}
}
}
String::from_utf8_lossy(&out).into_owned()
}
pub async fn delete_peer_pin(State(state): State<AppState>, Path(peer): Path<String>) -> Response {
let peer = match canonical_peer(&peer) {
Ok(p) => p,
Err(e) => return RelayError::BadRequest(e).into_response(),
};
match state.db.delete_peer_pin(&peer) {
Ok(deleted) => (
StatusCode::OK,
axum::Json(serde_json::json!({ "deleted": deleted })),
)
.into_response(),
Err(e) => RelayError::Upstream(format!("deleting pin: {e}")).into_response(),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn hv(s: &str) -> HeaderValue {
HeaderValue::from_str(s).unwrap()
}
#[test]
fn fingerprint_is_stable_lowercase_hex_sha256() {
let fp = cert_fingerprint(b"hello");
assert_eq!(
fp,
"2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
);
assert_eq!(fp, cert_fingerprint(b"hello"), "deterministic");
assert_ne!(fp, cert_fingerprint(b"hella"), "input-sensitive");
}
#[test]
fn verifier_records_observed_on_first_contact() {
let v = TofuVerifier::new(None);
let der = CertificateDer::from(b"fake-cert-der".to_vec());
let name = ServerName::try_from("peer.example").unwrap();
let ok = v.verify_server_cert(&der, &[], &name, &[], UnixTime::now());
assert!(ok.is_ok(), "first contact accepts");
assert_eq!(
v.observed.lock().unwrap().clone(),
Some(cert_fingerprint(b"fake-cert-der"))
);
assert!(v.mismatch.lock().unwrap().is_none());
}
#[test]
fn verifier_accepts_matching_pin() {
let der = CertificateDer::from(b"cert".to_vec());
let v = TofuVerifier::new(Some(cert_fingerprint(b"cert")));
let name = ServerName::try_from("peer.example").unwrap();
assert!(v
.verify_server_cert(&der, &[], &name, &[], UnixTime::now())
.is_ok());
assert!(v.mismatch.lock().unwrap().is_none());
}
#[test]
fn verifier_flags_pin_mismatch() {
let der = CertificateDer::from(b"new-cert".to_vec());
let v = TofuVerifier::new(Some("deadbeef".to_string()));
let name = ServerName::try_from("peer.example").unwrap();
let res = v.verify_server_cert(&der, &[], &name, &[], UnixTime::now());
assert!(res.is_err(), "mismatch aborts handshake");
let (expected, actual) = v.mismatch.lock().unwrap().clone().unwrap();
assert_eq!(expected, "deadbeef");
assert_eq!(actual, cert_fingerprint(b"new-cert"));
}
#[test]
fn forward_headers_swap_upstream_auth_into_authorization() {
let mut h = HeaderMap::new();
h.insert(header::AUTHORIZATION, hv("Basic relay-creds"));
h.insert(
HeaderName::from_static(UPSTREAM_AUTH_HEADER),
hv("Basic peer-creds"),
);
h.insert(header::CONTENT_TYPE, hv("application/json"));
let out = build_forward_headers(&h);
assert_eq!(
out.get(header::AUTHORIZATION).unwrap(),
"Basic peer-creds",
"upstream auth becomes Authorization"
);
assert!(
out.get(HeaderName::from_static(UPSTREAM_AUTH_HEADER))
.is_none(),
"upstream header is stripped, never forwarded"
);
assert_eq!(out.get(header::CONTENT_TYPE).unwrap(), "application/json");
}
#[test]
fn forward_headers_drop_relay_auth_when_no_upstream() {
let mut h = HeaderMap::new();
h.insert(header::AUTHORIZATION, hv("Basic relay-creds"));
h.insert(header::COOKIE, hv("mobux_session=secret"));
let out = build_forward_headers(&h);
assert!(
out.get(header::AUTHORIZATION).is_none(),
"relay's own creds never reach the peer"
);
assert!(
out.get(header::COOKIE).is_none(),
"relay cookie not forwarded"
);
}
#[test]
fn response_strips_www_authenticate_so_browser_never_prompts() {
assert!(
is_stripped_response_header(&header::WWW_AUTHENTICATE),
"WWW-Authenticate must never be forwarded to the browser"
);
assert!(!is_stripped_response_header(&header::CONTENT_TYPE));
let mut peer_headers = HeaderMap::new();
peer_headers.insert(
header::WWW_AUTHENTICATE,
hv("Basic realm=\"mobux\", charset=\"UTF-8\""),
);
peer_headers.insert(header::CONTENT_TYPE, hv("application/json"));
let status = StatusCode::UNAUTHORIZED;
let mut out = Response::builder().status(status);
for (name, value) in peer_headers.iter() {
if is_stripped_response_header(name) {
continue;
}
out = out.header(name, value);
}
let resp = out.body(Body::empty()).unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED, "401 preserved");
assert!(
resp.headers().get(header::WWW_AUTHENTICATE).is_none(),
"relay must not forward the peer's WWW-Authenticate challenge"
);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
"application/json",
"non-stripped headers still pass through"
);
}
#[test]
fn canonical_peer_adds_default_port_and_validates() {
std::env::remove_var("PORT");
assert_eq!(canonical_peer("host-b").unwrap(), "host-b:8080");
assert_eq!(canonical_peer("host-b:5151").unwrap(), "host-b:5151");
assert!(canonical_peer("").is_err());
assert!(canonical_peer("a/b").is_err());
assert!(canonical_peer("host:notaport").is_err());
}
#[test]
fn loop_guard_rejects_relay_paths_and_hop_cap() {
let h = HeaderMap::new();
assert!(check_loop_guard(&h, "/r/other/api/x").is_err());
assert_eq!(check_loop_guard(&h, "/api/sessions").unwrap(), 1);
let mut h2 = HeaderMap::new();
h2.insert(HeaderName::from_static(HOP_HEADER), hv("2"));
assert_eq!(check_loop_guard(&h2, "/api/sessions").unwrap(), 3);
let mut h3 = HeaderMap::new();
h3.insert(
HeaderName::from_static(HOP_HEADER),
hv(&MAX_HOPS.to_string()),
);
assert!(check_loop_guard(&h3, "/api/sessions").is_err());
}
#[test]
fn split_ws_query_extracts_and_strips_upstream_auth() {
let (auth, fwd) = split_ws_query(Some("token=abc&upstream_auth=QmFzaWM&x=1"));
assert_eq!(auth.as_deref(), Some("QmFzaWM"));
assert_eq!(fwd, "?token=abc&x=1");
let (auth2, fwd2) = split_ws_query(Some("upstream_auth=only"));
assert_eq!(auth2.as_deref(), Some("only"));
assert_eq!(fwd2, "");
let (auth3, fwd3) = split_ws_query(None);
assert!(auth3.is_none());
assert_eq!(fwd3, "");
}
#[test]
fn forwarded_url_parses_for_normal_paths() {
let url = reqwest::Url::parse("https://devbox:5151/api/sessions?x=1");
assert!(url.is_ok());
assert!(reqwest::Url::parse("https://[bad:5151/api/sessions").is_err());
}
#[test]
fn percent_decode_handles_basic_creds() {
assert_eq!(percent_decode("dXNlcjpwYXNz"), "dXNlcjpwYXNz");
assert_eq!(percent_decode("a%2Bb"), "a+b");
assert_eq!(percent_decode("a%2Fb"), "a/b");
}
}