use crate::serve::{ServerError, ServerResult};
use anyhow::Context;
use axum::{
body::Body,
extract::{
ws::{Message as MsgAxm, WebSocket, WebSocketUpgrade},
Request, State,
},
http::{Response, Uri},
routing::{any, get, Router},
RequestExt,
};
use bytes::BytesMut;
use futures_util::{sink::SinkExt, stream::StreamExt, TryStreamExt};
use http::{header::HOST, HeaderMap};
use std::sync::Arc;
use tokio_tungstenite::{
connect_async,
tungstenite::{protocol::CloseFrame, Message as MsgTng},
};
use tower_http::trace::TraceLayer;
const X_FORWARDED_HOST: &str = "x-forwarded-host";
const X_FORWARDED_PROTO: &str = "x-forwarded-proto";
pub(crate) struct ProxyHandlerHttp {
proto: String,
client: reqwest::Client,
backend: Uri,
request_headers: HeaderMap,
rewrite: Option<String>,
}
fn make_outbound_uri(backend: &Uri, request: &Uri) -> anyhow::Result<Uri> {
let mut segments = ["/", "", "", "", "", ""];
segments[1] = backend.path().trim_start_matches('/');
segments[3] = request.path().trim_start_matches('/');
if !segments[1].is_empty() && !segments[3].is_empty() && !segments[1].ends_with('/') {
segments[2] = "/";
}
if let Some(query) = request.query() {
segments[4] = "?";
segments[5] = query;
}
let path_and_query = segments.join("");
Uri::builder()
.scheme(backend.scheme_str().unwrap_or_default())
.authority(
backend
.authority()
.map(|val| val.as_str())
.unwrap_or_default(),
)
.path_and_query(path_and_query)
.build()
.context("error building proxy request to backend")
}
fn make_outbound_request(
inbound_proto: &str,
outbound_uri: &Uri,
method: http::Method,
original_headers: HeaderMap,
override_headers: HeaderMap,
) -> anyhow::Result<http::request::Builder> {
let mut request = http::Request::builder()
.uri(outbound_uri.to_string())
.method(method);
let Some(outbound_host) = outbound_uri.authority().map(|authority| authority.host()) else {
anyhow::bail!("No host found in outbound URI");
};
for key in original_headers.keys() {
let values = original_headers
.get_all(key)
.iter()
.cloned()
.collect::<Vec<_>>();
for value in values {
if key == HOST {
request = request.header(HOST, outbound_host);
request = request.header(X_FORWARDED_HOST, value);
request = request.header(X_FORWARDED_PROTO, inbound_proto);
} else {
request = request.header(key, value);
}
}
}
if let Some(headers) = request.headers_mut() {
for (key, value) in override_headers {
let Some(key) = key else { continue };
if value.is_empty() {
headers.remove(key);
} else {
headers.insert(key, value);
}
}
}
Ok(request)
}
impl ProxyHandlerHttp {
pub fn new(
proto: String,
client: reqwest::Client,
backend: Uri,
request_headers: HeaderMap,
rewrite: Option<String>,
) -> Arc<Self> {
Arc::new(Self {
proto,
client,
backend,
request_headers,
rewrite,
})
}
pub fn register(self: Arc<Self>, router: Router) -> Router {
router.nest_service(
self.path(),
any(Self::proxy_http_request)
.layer(TraceLayer::new_for_http())
.with_state(self.clone()),
)
}
pub fn path(&self) -> &str {
self.rewrite
.as_deref()
.unwrap_or_else(|| self.backend.path())
}
#[tracing::instrument(level = "debug", skip(state, req))]
async fn proxy_http_request(
State(state): State<Arc<Self>>,
req: Request,
) -> ServerResult<Response<Body>> {
let outbound_uri = make_outbound_uri(&state.backend, req.uri())?;
let outbound_req = make_outbound_request(
&state.proto,
&outbound_uri,
req.method().clone(),
req.headers().clone(),
state.request_headers.clone(),
)?;
let outbound_req = outbound_req
.body(reqwest::Body::from(
req.into_body()
.into_data_stream()
.try_collect::<BytesMut>()
.await
.map_err(|err| ServerError(err.into()))?
.freeze(),
))
.context("error building outbound request to proxy backend")?;
let outbound_req = outbound_req
.try_into()
.context("error translating outbound request")?;
let backend_res = state
.client
.execute(outbound_req)
.await
.context("error proxying request to proxy backend")?;
let mut res = Response::builder().status(backend_res.status());
for (key, val) in backend_res.headers() {
res = res.header(key, val);
}
Ok(res
.body(Body::from_stream(backend_res.bytes_stream()))
.context("error building proxy response")?)
}
}
pub struct ProxyHandlerWebSocket {
proto: String,
backend: Uri,
rewrite: Option<String>,
request_headers: HeaderMap,
}
impl ProxyHandlerWebSocket {
pub fn new(
proto: String,
backend: Uri,
headers: HeaderMap,
rewrite: Option<String>,
) -> Arc<Self> {
Arc::new(Self {
proto,
backend,
rewrite,
request_headers: headers,
})
}
pub fn register(self: Arc<Self>, router: Router) -> Router {
let proxy = self.clone();
let override_headers = self.request_headers.clone();
let proto = self.proto.clone();
router.nest_service(
self.path(),
get(|req: Request<Body>| async move {
let req_headers = req.headers().to_owned();
let uri = req.uri().clone();
let ws = req.extract::<WebSocketUpgrade, _>().await;
ws.map(|e| {
e.on_upgrade(|socket| async move {
proxy
.clone()
.proxy_ws_request(&proto, socket, uri, req_headers, override_headers)
.await
})
})
}),
)
}
pub fn path(&self) -> &str {
self.rewrite
.as_deref()
.unwrap_or_else(|| self.backend.path())
}
#[tracing::instrument(level = "debug", skip(self, ws))]
async fn proxy_ws_request(
self: Arc<Self>,
inbound_proto: &str,
ws: WebSocket,
request_uri: Uri,
req_headers: HeaderMap,
override_headers: HeaderMap,
) {
tracing::debug!("new websocket connection");
let outbound_uri = match make_outbound_uri(&self.backend, &request_uri) {
Ok(outbound_uri) => outbound_uri,
Err(err) => {
tracing::error!(error = ?err, "failed to build proxy uri from {:?}", &request_uri);
return;
}
};
let outbound_request = match make_outbound_request(
inbound_proto,
&outbound_uri,
http::Method::GET,
req_headers,
override_headers,
) {
Ok(outbound_uri) => outbound_uri,
Err(err) => {
tracing::error!(error = ?err, "failed to create outbound request");
return;
}
};
let outbound_request = match outbound_request
.body(())
.context("Failed to build outbound request")
{
Ok(outbound_uri) => outbound_uri,
Err(err) => {
tracing::error!(error = ?err, "failed to build outbound request");
return;
}
};
let (backend, _res) = match connect_async(outbound_request).await {
Ok(backend) => backend,
Err(err) => {
tracing::error!(error = ?err, "error establishing WebSocket connection to backend {:?} for proxy", &outbound_uri);
return;
}
};
let (mut backend_sink, mut backend_stream) = backend.split();
let (mut frontend_sink, mut frontend_stream) = ws.split();
let stream_to_backend = async move {
while let Some(Ok(msg_axm)) = frontend_stream.next().await {
let msg_tng = match msg_axm {
MsgAxm::Text(msg) => MsgTng::Text(msg.as_str().into()),
MsgAxm::Binary(msg) => MsgTng::Binary(msg),
MsgAxm::Ping(msg) => MsgTng::Ping(msg),
MsgAxm::Pong(msg) => MsgTng::Pong(msg),
MsgAxm::Close(Some(close_frame)) => MsgTng::Close(Some(CloseFrame {
code: close_frame.code.into(),
reason: close_frame.reason.as_str().into(),
})),
MsgAxm::Close(None) => MsgTng::Close(None),
};
if let Err(err) = backend_sink.send(msg_tng).await {
tracing::error!(error = ?err, "error forwarding frontend WebSocket message to backend");
return;
}
}
};
let stream_to_frontend = async move {
while let Some(Ok(msg)) = backend_stream.next().await {
let msg_axm = match msg {
MsgTng::Binary(val) => MsgAxm::Binary(val),
MsgTng::Text(val) => MsgAxm::Text(val.as_str().into()),
MsgTng::Ping(val) => MsgAxm::Ping(val),
MsgTng::Pong(val) => MsgAxm::Pong(val),
MsgTng::Close(Some(frame)) => {
MsgAxm::Close(Some(axum::extract::ws::CloseFrame {
code: frame.code.into(),
reason: frame.reason.as_str().into(),
}))
}
MsgTng::Close(None) => MsgAxm::Close(None),
MsgTng::Frame(_) => continue,
};
if let Err(err) = frontend_sink.send(msg_axm).await {
tracing::error!(error = ?err, "error forwarding backend WebSocket message to frontend");
return;
}
}
};
tokio::select! {
_ = stream_to_backend => (),
_ = stream_to_frontend => ()
};
tracing::debug!("websocket connection closed");
}
}
#[cfg(test)]
mod tests {
use crate::proxy::make_outbound_uri;
use axum::http::{HeaderValue, Uri};
use http::{
header::{
ACCEPT, ACCEPT_ENCODING, CONNECTION, CONTENT_LENGTH, CONTENT_TYPE, COOKIE, DATE,
EXPECT, HOST, USER_AGENT,
},
HeaderMap,
};
use super::{make_outbound_request, X_FORWARDED_HOST};
#[test]
fn make_outbound_uri_two_base_paths() {
let backend = Uri::from_static("https://backend/");
let request = Uri::from_static("http://localhost/");
assert_eq!(
make_outbound_uri(&backend, &request).expect("Unexpected error"),
Uri::from_static("https://backend/")
)
}
#[test]
fn make_outbound_uri_two_empty_paths() {
let backend = Uri::from_static("https://backend");
let request = Uri::from_static("http://localhost");
assert_eq!(
make_outbound_uri(&backend, &request).expect("Unexpected error"),
Uri::from_static("https://backend/")
)
}
#[test]
fn make_outbound_uri_two_with_query() {
let backend = Uri::from_static("https://backend/");
let request = Uri::from_static("http://localhost/auth?user=user&pwd=secret");
assert_eq!(
make_outbound_uri(&backend, &request).expect("Unexpected error"),
Uri::from_static("https://backend/auth?user=user&pwd=secret")
)
}
#[test]
fn make_outbound_uri_two_slash_at_end() {
let backend = Uri::from_static("https://backend/");
let request = Uri::from_static("http://localhost/auth/");
assert_eq!(
make_outbound_uri(&backend, &request).expect("Unexpected error"),
Uri::from_static("https://backend/auth/")
)
}
#[test]
fn make_outbound_uri_request_with_path() {
let backend = Uri::from_static("https://backend/");
let request = Uri::from_static("http://localhost/auth");
assert_eq!(
make_outbound_uri(&backend, &request).expect("Unexpected error"),
Uri::from_static("https://backend/auth")
)
}
#[test]
fn make_outbound_uri_request_with_sub_paths() {
let backend = Uri::from_static("https://backend/sub");
let request = Uri::from_static("http://localhost/auth");
assert_eq!(
make_outbound_uri(&backend, &request).expect("Unexpected error"),
Uri::from_static("https://backend/sub/auth")
)
}
#[test]
fn make_outbound_request_from_uri_and_headers() {
let backend_uri = Uri::from_static("https://backend/sub");
let inbound_uri = Uri::from_static("http://localhost/auth");
let inbound_headers = vec![
(
HOST,
HeaderValue::from_str("localhost").expect("Failed to create Header Value"),
),
(
USER_AGENT,
HeaderValue::from_str("curl/7.64.1").expect("Failed to create Header Value"),
),
(
ACCEPT,
HeaderValue::from_str("*/*").expect("Failed to create Header Value"),
),
(
ACCEPT_ENCODING,
HeaderValue::from_str("deflate, gzip").expect("Failed to create Header Value"),
),
(
CONNECTION,
HeaderValue::from_str("keep-alive").expect("Failed to create Header Value"),
),
(
CONTENT_LENGTH,
HeaderValue::from_str("0").expect("Failed to create Header Value"),
),
(
CONTENT_TYPE,
HeaderValue::from_str("application/json").expect("Failed to create Header Value"),
),
(
DATE,
HeaderValue::from_str("Tue, 01 Dec 2020 00:00:00 GMT")
.expect("Failed to create Header Value"),
),
(
EXPECT,
HeaderValue::from_str("").expect("Failed to create Header Value"),
),
(
COOKIE,
HeaderValue::from_str("cookie1=value1; cookie2=value2")
.expect("Failed to create Header Value"),
),
(
COOKIE,
HeaderValue::from_str("cookie3=value1; cookie4=value2")
.expect("Failed to create Header Value"),
),
];
let mut want_headers = HeaderMap::new();
for (key, val) in inbound_headers {
want_headers.append(key, val);
}
let have_outbound_uri = make_outbound_uri(&backend_uri, &inbound_uri)
.expect("Failed to create Uri instance from inbound");
let have_outbound_req = make_outbound_request(
"http",
&have_outbound_uri,
http::Method::GET,
want_headers.clone(),
Default::default(),
)
.expect("Failed to create Request instance from inbound")
.body(())
.expect("Failed to create Request from builder");
assert_eq!(have_outbound_req.uri(), &have_outbound_uri);
assert_eq!(have_outbound_req.method(), &http::Method::GET);
assert_eq!(
have_outbound_req
.headers()
.get(HOST)
.expect("Expected HOST header"),
&HeaderValue::from_static("backend")
);
for key in want_headers.keys() {
if key == HOST {
continue;
}
if key == X_FORWARDED_HOST {
assert_eq!(
have_outbound_req
.headers()
.get(key.clone())
.unwrap_or_else(|| panic!("Expected header value for {}", key)),
&HeaderValue::from_static("localhost")
);
continue;
}
let val = want_headers.get_all(key).iter().collect::<Vec<_>>();
assert_eq!(
have_outbound_req
.headers()
.get_all(key.clone())
.iter()
.collect::<Vec<_>>(),
val
);
}
}
}