use std::{
collections::HashMap,
net::SocketAddr,
str::FromStr,
sync::{Arc, OnceLock},
time::Duration,
};
use http::StatusCode;
use hyper_util::rt::TokioExecutor;
use iroh::{
Endpoint, EndpointId,
address_lookup::MemoryLookup,
endpoint::{BindError, presets},
protocol::Router,
};
use n0_error::{AnyError, Result, StdResultExt, stack_error};
use n0_future::task::AbortOnDropHandle;
use n0_tracing_test::traced_test;
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use tokio_util::time::FutureExt;
use tracing::debug;
use crate::{
ALPN, Authority, HttpProxyRequest, HttpProxyRequestKind, HttpRequest, HttpResponse,
IROH_DESTINATION_HEADER,
downstream::{
Deny, DownstreamMetrics, DownstreamProxy, EndpointAuthority, HttpProxyOpts, ProxyMode,
RequestHandler, SrcAddr,
opts::{RequestHandlerChain, StaticForwardProxy, StaticReverseProxy},
},
upstream::{AcceptAll, AuthError, AuthHandler, UpstreamMetrics, UpstreamProxy},
util::Prebuffered,
};
async fn bind_endpoint() -> Result<Endpoint, BindError> {
static ADDRESS_LOOKUP: OnceLock<MemoryLookup> = OnceLock::new();
let address_lookup = ADDRESS_LOOKUP.get_or_init(MemoryLookup::default);
let endpoint = Endpoint::builder(presets::Minimal)
.address_lookup(address_lookup.clone())
.bind()
.await?;
address_lookup.add_endpoint_info(endpoint.addr());
Ok(endpoint)
}
async fn spawn_upstream_proxy() -> Result<(Router, EndpointId)> {
let (router, id, _metrics) = spawn_upstream_proxy_with_auth(AcceptAll).await?;
Ok((router, id))
}
async fn spawn_upstream_proxy_with_metrics() -> Result<(Router, EndpointId, Arc<UpstreamMetrics>)> {
spawn_upstream_proxy_with_auth(AcceptAll).await
}
async fn spawn_upstream_proxy_with_auth(
auth: impl AuthHandler + 'static,
) -> Result<(Router, EndpointId, Arc<UpstreamMetrics>)> {
let endpoint = bind_endpoint().await?;
let upstream_proxy = UpstreamProxy::new(auth)?;
let metrics = upstream_proxy.metrics();
let router = Router::builder(endpoint)
.accept(ALPN, upstream_proxy)
.spawn();
let endpoint_id = router.endpoint().id();
debug!(endpoint_id=%endpoint_id.fmt_short(), "spawned upstream proxy");
Ok((router, endpoint_id, metrics))
}
async fn spawn_downstream_proxy(
mode: ProxyMode,
) -> Result<(SocketAddr, EndpointId, AbortOnDropHandle<Result>)> {
let endpoint = bind_endpoint().await?;
let endpoint_id = endpoint.id();
let proxy = DownstreamProxy::new(endpoint, Default::default());
let listener = TcpListener::bind("localhost:0").await?;
let tcp_addr = listener.local_addr()?;
debug!(endpoint_id=%endpoint_id.fmt_short(), %tcp_addr, "spawned downstream proxy");
let task = tokio::spawn(async move { proxy.forward_tcp_listener(listener, mode).await });
Ok((tcp_addr, endpoint_id, AbortOnDropHandle::new(task)))
}
async fn spawn_downstream_proxy_with_metrics(
mode: ProxyMode,
) -> Result<(
SocketAddr,
EndpointId,
Arc<DownstreamMetrics>,
AbortOnDropHandle<Result>,
)> {
let endpoint = bind_endpoint().await?;
let endpoint_id = endpoint.id();
let proxy = DownstreamProxy::new(endpoint, Default::default());
let metrics = proxy.metrics().clone();
let listener = TcpListener::bind("localhost:0").await?;
let tcp_addr = listener.local_addr()?;
debug!(endpoint_id=%endpoint_id.fmt_short(), %tcp_addr, "spawned downstream proxy");
let task = tokio::spawn(async move { proxy.forward_tcp_listener(listener, mode).await });
Ok((tcp_addr, endpoint_id, metrics, AbortOnDropHandle::new(task)))
}
async fn spawn_origin_server(label: &'static str) -> Result<(SocketAddr, AbortOnDropHandle<()>)> {
let listener = TcpListener::bind("localhost:0").await?;
let tcp_addr = listener.local_addr()?;
debug!(%label, %tcp_addr, "spawned origin server");
let task = tokio::spawn(async move { origin_server::run(listener, label).await });
Ok((tcp_addr, AbortOnDropHandle::new(task)))
}
async fn spawn_origin_server_echo_body(
label: &'static str,
) -> Result<(SocketAddr, AbortOnDropHandle<()>)> {
let listener = TcpListener::bind("localhost:0").await?;
let tcp_addr = listener.local_addr()?;
debug!(%label, %tcp_addr, "spawned origin server");
let task = tokio::spawn(async move { origin_server::run_echo_body(listener, label).await });
Ok((tcp_addr, AbortOnDropHandle::new(task)))
}
async fn spawn_echo_server() -> Result<(SocketAddr, AbortOnDropHandle<()>)> {
let listener = TcpListener::bind("localhost:0").await?;
let addr = listener.local_addr()?;
let task = tokio::spawn(async move {
loop {
let Ok((mut stream, _)) = listener.accept().await else {
break;
};
tokio::spawn(async move {
let (mut read, mut write) = stream.split();
let _ = tokio::io::copy(&mut read, &mut write).await;
});
}
});
Ok((addr, AbortOnDropHandle::new(task)))
}
async fn spawn_websocket_origin(
label: &'static str,
) -> Result<(SocketAddr, AbortOnDropHandle<()>)> {
let listener = TcpListener::bind("localhost:0").await?;
let tcp_addr = listener.local_addr()?;
debug!(%label, %tcp_addr, "spawned websocket origin server");
let task =
tokio::spawn(async move { origin_server::run_with_websocket(listener, label).await });
Ok((tcp_addr, AbortOnDropHandle::new(task)))
}
#[stack_error(derive, from_sources)]
enum ConnectError {
Io(#[error(source)] std::io::Error),
ReadResponse(#[error(source)] AnyError),
Status(StatusCode),
}
async fn create_http_connect_tunnel(
proxy_addr: SocketAddr,
origin_addr: impl std::fmt::Display,
destination_header: Option<EndpointId>,
) -> Result<tokio::io::Join<impl AsyncRead + Unpin, impl AsyncWrite + Unpin>, ConnectError> {
let stream = TcpStream::connect(proxy_addr).await?;
let (recv, mut send) = stream.into_split();
let request = {
let mut request = format!("CONNECT {origin_addr} HTTP/1.1\r\nHost: {origin_addr}\r\n");
if let Some(destination) = destination_header {
request.push_str(&format!("{IROH_DESTINATION_HEADER}: {destination}\r\n"));
}
request.push_str("\r\n");
request
};
send.write_all(request.as_bytes()).await?;
let mut recv = Prebuffered::new(recv, 8192);
let proxy_response = HttpResponse::read(&mut recv).await?;
if proxy_response.status != StatusCode::OK {
Err(ConnectError::Status(proxy_response.status))
} else {
Ok(tokio::io::join(recv, send))
}
}
async fn read_http_response(
stream: &mut (impl AsyncRead + Unpin + Send),
) -> Result<(u16, Vec<u8>)> {
let mut prebuf = Prebuffered::new(stream, 8192);
let response = HttpResponse::read(&mut prebuf)
.timeout(Duration::from_secs(3))
.await
.anyerr()??;
debug!("RESPONSE {:?}", response);
let content_length = response
.headers
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(0);
let mut body = vec![0u8; content_length];
if content_length > 0 {
prebuf
.read_exact(&mut body)
.timeout(Duration::from_secs(3))
.await
.anyerr()??;
}
Ok((response.status.as_u16(), body))
}
struct HeaderResolver;
impl RequestHandler for HeaderResolver {
async fn handle_request(
&self,
src_addr: SrcAddr,
req: &mut HttpRequest,
) -> Result<EndpointId, Deny> {
let header = req
.headers
.get(IROH_DESTINATION_HEADER)
.ok_or_else(|| Deny::bad_request("missing iroh-destination header"))?;
let header_str = header
.to_str()
.std_context("invalid iroh-destination header")
.map_err(Deny::bad_request)?;
let destination = EndpointId::from_str(header_str).map_err(Deny::bad_request);
req.set_forwarded_for_if_tcp(src_addr);
destination
}
}
struct SubdomainRouter {
routes: HashMap<String, EndpointAuthority>,
}
impl RequestHandler for SubdomainRouter {
async fn handle_request(
&self,
_src_addr: SrcAddr,
req: &mut HttpRequest,
) -> Result<EndpointId, Deny> {
let host = req
.host()
.ok_or_else(|| Deny::bad_request("missing host header"))?;
let subdomain = host
.split('.')
.next()
.ok_or_else(|| Deny::bad_request("invalid host header"))?;
let destination = self
.routes
.get(subdomain)
.cloned()
.ok_or_else(|| Deny::new(StatusCode::NOT_FOUND, "unknown subdomain"))?;
req.set_absolute_http_authority(destination.authority)
.map_err(|err| Deny::new(StatusCode::INTERNAL_SERVER_ERROR, err))?;
Ok(destination.endpoint_id)
}
}
struct AllowEndpoints(Vec<EndpointId>);
impl AuthHandler for AllowEndpoints {
async fn authorize(
&self,
remote_id: EndpointId,
_req: &HttpProxyRequest,
) -> Result<(), AuthError> {
if self.0.contains(&remote_id) {
Ok(())
} else {
Err(AuthError::Forbidden)
}
}
}
struct AllowAuthorities(Vec<String>);
impl AuthHandler for AllowAuthorities {
async fn authorize(
&self,
_remote_id: EndpointId,
req: &HttpProxyRequest,
) -> Result<(), AuthError> {
let target = match &req.kind {
HttpProxyRequestKind::Tunnel { target } => target.to_string(),
HttpProxyRequestKind::Absolute { target, .. } => Authority::from_absolute_uri(target)
.map(|a| a.to_string())
.unwrap_or_default(),
};
let allowed = self.0.contains(&target);
debug!(?allowed, ?target, list=?self.0, "AllowAuthorities::authorize");
if allowed {
Ok(())
} else {
Err(AuthError::Forbidden)
}
}
}
#[tokio::test]
#[traced_test]
async fn test_tcp_mode() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (echo_addr, _echo_task) = spawn_echo_server().await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&echo_addr.to_string())?,
);
let mode = ProxyMode::Tcp(destination);
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
stream.write_all(b"hello tcp").await?;
stream.shutdown().await?;
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await?;
assert_eq!(buf, b"hello tcp");
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_forward_absolute_form() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let res = client
.get(format!("http://{origin_addr}/test/path"))
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, "origin GET /test/path");
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_forward_connect() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, origin_task) = spawn_origin_server("origin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = create_http_connect_tunnel(proxy_addr, origin_addr, None).await?;
stream
.write_all(b"GET /tunnel/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
.await?;
let (status, body) = read_http_response(&mut stream).await?;
assert_eq!(status, 200);
assert_eq!(body, b"origin GET /tunnel/test");
proxy_task.abort();
upstream_router.shutdown().await.anyerr()?;
origin_task.abort();
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_reverse_simple() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{proxy_addr}/reverse/path"))
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, "origin GET /reverse/path");
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_forward_absolute_dynamic() -> Result {
let (upstream1_router, upstream1_id) = spawn_upstream_proxy().await?;
let (origin1_addr, _origin1_task) = spawn_origin_server("alpha").await?;
let (upstream2_router, upstream2_id) = spawn_upstream_proxy().await?;
let (origin2_addr, _origin2_task) = spawn_origin_server("beta").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(HeaderResolver));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream1 = TcpStream::connect(proxy_addr).await?;
let req1 = format!(
"GET http://{origin1_addr}/path1 HTTP/1.1\r\n\
Host: {origin1_addr}\r\n\
{IROH_DESTINATION_HEADER}: {upstream1_id}\r\n\
Connection: close\r\n\r\n"
);
stream1.write_all(req1.as_bytes()).await?;
let (status1, body1) = read_http_response(&mut stream1).await?;
assert_eq!(status1, 200);
assert_eq!(body1, b"alpha GET /path1");
let mut stream2 = TcpStream::connect(proxy_addr).await?;
let req2 = format!(
"GET http://{origin2_addr}/path2 HTTP/1.1\r\n\
Host: {origin2_addr}\r\n\
{IROH_DESTINATION_HEADER}: {upstream2_id}\r\n\
Connection: close\r\n\r\n"
);
stream2.write_all(req2.as_bytes()).await?;
let (status2, body2) = read_http_response(&mut stream2).await?;
assert_eq!(status2, 200);
assert_eq!(body2, b"beta GET /path2");
drop(proxy_task);
upstream1_router.shutdown().await.anyerr()?;
upstream2_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_forward_dynamic_missing_header() -> Result {
let mode = ProxyMode::Http(HttpProxyOpts::new(HeaderResolver));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let res = client
.get("http://example.com/path")
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_reverse_dynamic() -> Result {
let (upstream1_router, upstream1_id) = spawn_upstream_proxy().await?;
let (origin1_addr, _origin1_task) = spawn_origin_server("server1").await?;
let (upstream2_router, upstream2_id) = spawn_upstream_proxy().await?;
let (origin2_addr, _origin2_task) = spawn_origin_server("server2").await?;
let mut routes = HashMap::new();
routes.insert(
"proxy1".to_string(),
EndpointAuthority::new(
upstream1_id,
Authority::from_authority_str(&origin1_addr.to_string())?,
),
);
routes.insert(
"proxy2".to_string(),
EndpointAuthority::new(
upstream2_id,
Authority::from_authority_str(&origin2_addr.to_string())?,
),
);
let mode = ProxyMode::Http(HttpProxyOpts::new(SubdomainRouter { routes }));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream1 = TcpStream::connect(proxy_addr).await?;
stream1
.write_all(b"GET /path HTTP/1.1\r\nHost: proxy1.example.com\r\nConnection: close\r\n\r\n")
.await?;
let (status1, body1) = read_http_response(&mut stream1).await?;
assert_eq!(status1, 200);
assert_eq!(body1, b"server1 GET /path");
let mut stream2 = TcpStream::connect(proxy_addr).await?;
stream2
.write_all(b"GET /path HTTP/1.1\r\nHost: proxy2.example.com\r\nConnection: close\r\n\r\n")
.await?;
let (status2, body2) = read_http_response(&mut stream2).await?;
assert_eq!(status2, 200);
assert_eq!(body2, b"server2 GET /path");
upstream1_router.shutdown().await.anyerr()?;
upstream2_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_reverse_dynamic_unknown_subdomain() -> Result {
let routes = HashMap::new(); let mode = ProxyMode::Http(HttpProxyOpts::new(SubdomainRouter { routes }));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::new();
let res = client
.get(format!("http://{proxy_addr}/path"))
.header("Host", "unknown.example.com")
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_upstream_auth_endpoint() -> Result {
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let mode_placeholder = ProxyMode::Http(HttpProxyOpts::new(HeaderResolver));
let (proxy_addr, downstream_id, proxy_task) = spawn_downstream_proxy(mode_placeholder).await?;
let (upstream_router, upstream_id, _metrics) =
spawn_upstream_proxy_with_auth(AllowEndpoints(vec![downstream_id])).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
let req = format!(
"GET http://{origin_addr}/test HTTP/1.1\r\n\
Host: {origin_addr}\r\n\
{IROH_DESTINATION_HEADER}: {upstream_id}\r\n\
Connection: close\r\n\r\n"
);
stream.write_all(req.as_bytes()).await?;
let (status, body) = read_http_response(&mut stream).await?;
assert_eq!(status, 200);
assert_eq!(body, b"origin GET /test");
let (proxy_addr2, _, proxy_task2) = spawn_downstream_proxy(ProxyMode::Http(
HttpProxyOpts::new(StaticForwardProxy(upstream_id)),
))
.await?;
let mut stream2 = TcpStream::connect(proxy_addr2).await?;
let req2 = format!(
"GET http://{origin_addr}/fail HTTP/1.1\r\n\
Host: {origin_addr}\r\n\
{IROH_DESTINATION_HEADER}: {upstream_id}\r\n\
Connection: close\r\n\r\n"
);
stream2.write_all(req2.as_bytes()).await?;
let (status, body) = read_http_response(&mut stream2).await?;
assert_eq!(status, 403);
assert!(body.is_empty());
drop(proxy_task);
drop(proxy_task2);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_upstream_auth_authority() -> Result {
let (allowed_addr, _allowed_task) = spawn_origin_server("allowed").await?;
let (denied_addr, _denied_task) = spawn_origin_server("denied").await?;
let (upstream_router, upstream_id, _metrics) =
spawn_upstream_proxy_with_auth(AllowAuthorities(vec![allowed_addr.to_string()])).await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = create_http_connect_tunnel(proxy_addr, allowed_addr, None).await?;
stream
.write_all(b"GET /check HTTP/1.1\r\nHost: x\r\nConnection: close\r\n\r\n")
.await?;
let (upstream_status, body) = read_http_response(&mut stream).await?;
assert_eq!(upstream_status, 200);
assert_eq!(body, b"allowed GET /check");
let mut stream = TcpStream::connect(proxy_addr).await?;
let connect = format!("CONNECT {denied_addr} HTTP/1.1\r\nHost: {denied_addr}\r\n\r\n");
stream.write_all(connect.as_bytes()).await?;
let (status, body) = read_http_response(&mut stream).await?;
assert_eq!(status, StatusCode::FORBIDDEN);
assert!(body.is_empty());
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_forward_post_with_body() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server_echo_body("origin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let res = client
.post(format!("http://{origin_addr}/upload"))
.body("hello request body")
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, "origin POST /upload: hello request body");
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_http_reverse_post_with_body() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server_echo_body("origin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::new();
let res = client
.post(format!("http://{proxy_addr}/data"))
.body("post body content")
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, "origin POST /data: post body content");
let res = client
.post(format!("http://{proxy_addr}/data"))
.body("post body content 2")
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, "origin POST /data: post body content 2");
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_invalid_http_request() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
stream.write_all(b"NOT VALID HTTP\r\n\r\n").await?;
let (status, _) = read_http_response(&mut stream).await?;
assert_eq!(status, 400);
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_origin_form_to_forward_only_proxy() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
stream
.write_all(b"GET /path HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n")
.await?;
let (status, _) = read_http_response(&mut stream).await?;
assert_eq!(status, 400);
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_forward_request_to_reverse_only_proxy() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
let req = format!(
"GET http://{origin_addr}/path HTTP/1.1\r\nHost: {origin_addr}\r\nConnection: close\r\n\r\n"
);
stream.write_all(req.as_bytes()).await?;
let (status, _) = read_http_response(&mut stream).await?;
assert_eq!(status, 400);
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_connect_to_reverse_only_proxy() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
let req = "CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n\r\n".to_string();
stream.write_all(req.as_bytes()).await?;
let (status, _) = read_http_response(&mut stream).await?;
assert_eq!(status, 400);
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_connect_unreachable_origin() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
stream
.write_all(b"CONNECT 127.0.0.1:1 HTTP/1.1\r\nHost: 127.0.0.1:1\r\n\r\n")
.await?;
let (status, _) = read_http_response(&mut stream).await?;
assert_eq!(status, 502);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_concurrent_requests() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let mut handles = Vec::new();
for i in 0..10 {
let client = client.clone();
let url = format!("http://{origin_addr}/request/{i}");
handles.push(tokio::spawn(async move {
let res = client.get(&url).send().await?;
let text = res.text().await?;
Ok::<_, reqwest::Error>(text)
}));
}
for (i, handle) in handles.into_iter().enumerate() {
let text = handle.await.anyerr()?.anyerr()?;
assert_eq!(text, format!("origin GET /request/{i}"));
}
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_large_request_body() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server_echo_body("origin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let body = "x".repeat(1024 * 1024);
let res = client
.post(format!("http://{origin_addr}/large"))
.body(body.clone())
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, format!("origin POST /large: {body}"));
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_forward_and_reverse_combined() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (forward_origin_addr, _forward_origin_task) = spawn_origin_server("forward").await?;
let (reverse_origin_addr, _reverse_origin_task) = spawn_origin_server("reverse").await?;
let reverse_destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&reverse_origin_addr.to_string())?,
);
let handler = RequestHandlerChain::default()
.push(StaticForwardProxy(upstream_id))
.push(StaticReverseProxy(reverse_destination));
let mode = ProxyMode::Http(HttpProxyOpts::new(handler));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let res = client
.get(format!("http://{forward_origin_addr}/forward-path"))
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.anyerr()?, "forward GET /forward-path");
let client = reqwest::Client::new();
let res = client
.get(format!("http://{proxy_addr}/reverse-path"))
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await.anyerr()?, "reverse GET /reverse-path");
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn h2_multiple_connect_requests_single_connection() -> Result<()> {
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::Request;
use hyper_util::rt::{TokioExecutor, TokioIo};
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let opts = HttpProxyOpts::new(StaticForwardProxy(upstream_id));
let (proxy_addr, _, proxy_task) = spawn_downstream_proxy(ProxyMode::Http(opts)).await?;
let stream = TcpStream::connect(proxy_addr).await?;
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http2::handshake(TokioExecutor::new(), io)
.await
.anyerr()?;
let conn_task = tokio::spawn(async move { conn.await });
let req1 = Request::builder()
.method(http::Method::GET)
.uri(format!("http://{}/path1", origin_addr))
.body(Full::new(Bytes::new()))
.anyerr()?;
let res1 = sender.send_request(req1).await.anyerr()?;
assert_eq!(res1.status(), StatusCode::OK);
let body1 = res1.into_body().collect().await.anyerr()?.to_bytes();
assert_eq!(body1.as_ref(), b"origin GET /path1");
let req2 = Request::builder()
.method(http::Method::GET)
.uri(format!("http://{}/path2", origin_addr))
.body(Full::new(Bytes::new()))
.anyerr()?;
let res2 = sender.send_request(req2).await.anyerr()?;
assert_eq!(res2.status(), StatusCode::OK);
let body2 = res2.into_body().collect().await.anyerr()?.to_bytes();
assert_eq!(body2.as_ref(), b"origin GET /path2");
drop(proxy_task);
conn_task.abort();
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn h2_reqwest_reverse() -> Result<()> {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.http2_prior_knowledge()
.build()
.anyerr()?;
let res = client
.get(format!("http://{proxy_addr}/reverse/path"))
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, "origin GET /reverse/path");
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn h2_reqwest_forward() -> Result<()> {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let client = reqwest::Client::builder()
.http2_prior_knowledge()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let res = client
.get(format!("http://{origin_addr}/test/path"))
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let text = res.text().await.anyerr()?;
assert_eq!(text, "origin GET /test/path");
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: std::future::Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
tokio::spawn(fut);
}
}
#[tokio::test]
#[traced_test]
async fn test_websocket_http1_forward_connect() -> Result {
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, handshake};
use hyper::Request;
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_websocket_origin("wsorigin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
let connect_req = format!(
"CONNECT {origin_addr} HTTP/1.1\r\n\
Host: {origin_addr}\r\n\r\n"
);
stream.write_all(connect_req.as_bytes()).await?;
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await?;
let response = String::from_utf8_lossy(&buf[..n]);
assert!(
response.starts_with("HTTP/1.1 200"),
"Expected 200 OK, got: {response}"
);
let req = Request::builder()
.method("GET")
.uri("/ws")
.header("Host", origin_addr.to_string())
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", handshake::generate_key())
.header("Sec-WebSocket-Version", "13")
.body(http_body_util::Empty::<bytes::Bytes>::new())
.anyerr()?;
let (ws, _) = handshake::client(&SpawnExecutor, req, stream)
.await
.anyerr()?;
let mut ws = FragmentCollector::new(ws);
ws.write_frame(Frame::text(Payload::Borrowed(b"hello websocket")))
.await
.anyerr()?;
let frame = ws.read_frame().await.anyerr()?;
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(&*frame.payload, b"hello websocket");
ws.write_frame(Frame::text(Payload::Borrowed(b"second message")))
.await
.anyerr()?;
let frame = ws.read_frame().await.anyerr()?;
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(&*frame.payload, b"second message");
ws.write_frame(Frame::close_raw(vec![].into()))
.await
.anyerr()?;
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_websocket_http1_forward_absolute() -> Result {
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, handshake};
use hyper::Request;
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_websocket_origin("wsorigin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let stream = TcpStream::connect(proxy_addr).await?;
let req = Request::builder()
.method("GET")
.uri(format!("http://{}/ws", origin_addr))
.header("Host", origin_addr.to_string())
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", handshake::generate_key())
.header("Sec-WebSocket-Version", "13")
.body(http_body_util::Empty::<bytes::Bytes>::new())
.anyerr()?;
let (ws, _) = handshake::client(&SpawnExecutor, req, stream)
.await
.anyerr()?;
let mut ws = FragmentCollector::new(ws);
ws.write_frame(Frame::text(Payload::Borrowed(b"hello absolute-form ws")))
.await
.anyerr()?;
let frame = ws.read_frame().await.anyerr()?;
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(&*frame.payload, b"hello absolute-form ws");
ws.write_frame(Frame::text(Payload::Borrowed(b"second message")))
.await
.anyerr()?;
let frame = ws.read_frame().await.anyerr()?;
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(&*frame.payload, b"second message");
ws.write_frame(Frame::close_raw(vec![].into()))
.await
.anyerr()?;
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_websocket_http1_reverse() -> Result {
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, handshake};
use hyper::Request;
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_websocket_origin("wsorigin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let stream = TcpStream::connect(proxy_addr).await?;
let req = Request::builder()
.method("GET")
.uri("/ws")
.header("Host", proxy_addr.to_string())
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", handshake::generate_key())
.header("Sec-WebSocket-Version", "13")
.body(http_body_util::Empty::<bytes::Bytes>::new())
.anyerr()?;
let (ws, _) = handshake::client(&SpawnExecutor, req, stream)
.await
.anyerr()?;
let mut ws = FragmentCollector::new(ws);
ws.write_frame(Frame::text(Payload::Borrowed(b"reverse proxy ws")))
.await
.anyerr()?;
let frame = ws.read_frame().await.anyerr()?;
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(&*frame.payload, b"reverse proxy ws");
ws.write_frame(Frame::close_raw(vec![].into()))
.await
.anyerr()?;
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_websocket_h2_forward_connect() -> Result {
use bytes::Bytes;
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, handshake};
use http_body_util::Empty;
use hyper::Request;
use hyper_util::rt::TokioIo;
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_websocket_origin("wsorigin").await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let stream = TcpStream::connect(proxy_addr).await?;
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http2::handshake(TokioExecutor::new(), io)
.await
.anyerr()?;
let conn_task = tokio::spawn(async move { conn.await });
let connect_req = Request::builder()
.method(http::Method::CONNECT)
.uri(format!("{origin_addr}"))
.body(Empty::<Bytes>::new())
.anyerr()?;
let res = sender.send_request(connect_req).await.anyerr()?;
assert_eq!(res.status(), StatusCode::OK);
let upgraded = hyper::upgrade::on(res).await.anyerr()?;
let req = Request::builder()
.method("GET")
.uri("/ws")
.header("Host", origin_addr.to_string())
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", handshake::generate_key())
.header("Sec-WebSocket-Version", "13")
.body(Empty::<Bytes>::new())
.anyerr()?;
let (ws, _) = handshake::client(&SpawnExecutor, req, TokioIo::new(upgraded))
.await
.anyerr()?;
let mut ws = FragmentCollector::new(ws);
ws.write_frame(Frame::text(Payload::Borrowed(b"h2 websocket test")))
.await
.anyerr()?;
let frame = ws.read_frame().await.anyerr()?;
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(&*frame.payload, b"h2 websocket test");
for i in 0..3 {
let msg = format!("message {i}");
ws.write_frame(Frame::text(Payload::Owned(msg.as_bytes().to_vec())))
.await
.anyerr()?;
let frame = ws.read_frame().await.anyerr()?;
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(&*frame.payload, msg.as_bytes());
}
ws.write_frame(Frame::close_raw(vec![].into()))
.await
.anyerr()?;
conn_task.abort();
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_websocket_h2_reverse() -> Result {
use bytes::Bytes;
use fastwebsockets::{FragmentCollector, Frame, OpCode, handshake};
use http_body_util::Empty;
use hyper::Request;
use hyper_util::rt::TokioIo;
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_websocket_origin("wsorigin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let stream = TcpStream::connect(proxy_addr).await?;
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http2::handshake(TokioExecutor::new(), io)
.await
.anyerr()?;
let conn_task = tokio::spawn(async move { conn.await });
let req = Request::builder()
.method(http::Method::CONNECT)
.uri(format!("http://{}/ws", proxy_addr))
.header("Sec-WebSocket-Key", handshake::generate_key())
.header("Sec-WebSocket-Version", "13")
.extension(hyper::ext::Protocol::from_static("websocket"))
.body(Empty::<Bytes>::new())
.anyerr()?;
let res = sender.send_request(req).await.anyerr()?;
debug!("HTTP/2 extended CONNECT response status: {}", res.status());
assert_eq!(res.status(), http::StatusCode::OK);
let upgraded = hyper::upgrade::on(res).await.anyerr()?;
debug!("client upgraded");
let mut ws = FragmentCollector::new(fastwebsockets::WebSocket::after_handshake(
TokioIo::new(upgraded),
fastwebsockets::Role::Client,
));
ws.write_frame(Frame::text(fastwebsockets::Payload::Borrowed(
b"hello h2 extended connect",
)))
.await
.anyerr()?;
debug!("written");
let frame = ws.read_frame().await.anyerr()?;
debug!("read");
assert_eq!(frame.opcode, OpCode::Text);
assert_eq!(frame.payload.as_ref(), b"hello h2 extended connect");
ws.write_frame(Frame::close_raw(vec![].into()))
.await
.anyerr()?;
conn_task.abort();
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_hop_by_hop_headers_not_forwarded() -> Result {
let listener = TcpListener::bind("localhost:0").await?;
let origin_addr = listener.local_addr()?;
let origin_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
let request = String::from_utf8_lossy(&buf[..n]);
let body = request.to_string();
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
stream.write_all(response.as_bytes()).await.unwrap();
});
let _origin_task = AbortOnDropHandle::new(origin_task);
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
let req = format!(
"GET http://{origin_addr}/test HTTP/1.1\r\n\
Host: {origin_addr}\r\n\
Connection: keep-alive\r\n\
Proxy-Authorization: Basic secret\r\n\
Keep-Alive: timeout=5\r\n\
\r\n"
);
stream.write_all(req.as_bytes()).await?;
let (status, body) = read_http_response(&mut stream).await?;
assert_eq!(status, 200);
let body_str = String::from_utf8_lossy(&body);
debug!("Origin received: {}", body_str);
assert!(
!body_str.to_lowercase().contains("proxy-authorization"),
"Proxy-Authorization header was forwarded but shouldn't be"
);
assert!(
!body_str.to_lowercase().contains("keep-alive:"),
"Keep-Alive header was forwarded but shouldn't be"
);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[tokio::test]
#[traced_test]
async fn test_response_hop_by_hop_headers_filtered() -> Result {
let listener = TcpListener::bind("localhost:0").await?;
let origin_addr = listener.local_addr()?;
let origin_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 4096];
let _n = stream.read(&mut buf).await.unwrap();
let response = "HTTP/1.1 200 OK\r\n\
Content-Length: 2\r\n\
Connection: keep-alive\r\n\
Keep-Alive: timeout=5\r\n\
Proxy-Authenticate: Basic\r\n\
\r\nOK";
stream.write_all(response.as_bytes()).await.unwrap();
});
let _origin_task = AbortOnDropHandle::new(origin_task);
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, _proxy_task) = spawn_downstream_proxy(mode).await?;
let mut stream = TcpStream::connect(proxy_addr).await?;
let req = format!(
"GET http://{origin_addr}/test HTTP/1.1\r\n\
Host: {origin_addr}\r\n\
Connection: close\r\n\
\r\n"
);
stream.write_all(req.as_bytes()).await?;
let mut response_buf = vec![0u8; 4096];
let n = stream.read(&mut response_buf).await?;
let response_str = String::from_utf8_lossy(&response_buf[..n]);
debug!("Client received response: {}", response_str);
let response_lower = response_str.to_lowercase();
assert!(
!response_lower.contains("proxy-authenticate"),
"Proxy-Authenticate response header was forwarded but shouldn't be"
);
assert!(
!response_lower.contains("keep-alive:"),
"Keep-Alive response header was forwarded but shouldn't be"
);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[cfg(unix)]
async fn spawn_downstream_proxy_uds(
mode: ProxyMode,
) -> Result<(std::path::PathBuf, EndpointId, AbortOnDropHandle<Result>)> {
let endpoint = bind_endpoint().await?;
let endpoint_id = endpoint.id();
let proxy = DownstreamProxy::new(endpoint, Default::default());
let socket_path = std::env::temp_dir().join(format!("iroh-{}.sock", endpoint_id.fmt_short()));
let _ = std::fs::remove_file(&socket_path);
let listener = tokio::net::UnixListener::bind(&socket_path)?;
debug!(endpoint_id=%endpoint_id.fmt_short(), ?socket_path, "spawned downstream UDS proxy");
let task = tokio::spawn(async move { proxy.forward_uds_listener(listener, mode).await });
Ok((socket_path, endpoint_id, AbortOnDropHandle::new(task)))
}
#[cfg(unix)]
#[tokio::test]
#[traced_test]
async fn test_uds_http1_reverse() -> Result {
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (socket_path, _, proxy_task) = spawn_downstream_proxy_uds(mode).await?;
let mut stream = tokio::net::UnixStream::connect(&socket_path).await?;
stream
.write_all(b"GET /uds/path HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n")
.await?;
let (status, body) = read_http_response(&mut stream).await?;
assert_eq!(status, 200);
assert_eq!(body, b"origin GET /uds/path");
drop(proxy_task);
let _ = std::fs::remove_file(&socket_path);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
#[cfg(unix)]
#[tokio::test]
#[traced_test]
async fn test_uds_http2_reverse() -> Result<()> {
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::Request;
use hyper_util::rt::TokioIo;
let (upstream_router, upstream_id) = spawn_upstream_proxy().await?;
let (origin_addr, _origin_task) = spawn_origin_server("origin").await?;
let destination = EndpointAuthority::new(
upstream_id,
Authority::from_authority_str(&origin_addr.to_string())?,
);
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticReverseProxy(destination)));
let (socket_path, _, proxy_task) = spawn_downstream_proxy_uds(mode).await?;
let stream = tokio::net::UnixStream::connect(&socket_path).await?;
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http2::handshake(TokioExecutor::new(), io)
.await
.anyerr()?;
let conn_task = tokio::spawn(async move { conn.await });
let req1 = Request::builder()
.method(http::Method::GET)
.uri("/uds/h2/path1")
.body(Full::new(Bytes::new()))
.anyerr()?;
let res1 = sender.send_request(req1).await.anyerr()?;
assert_eq!(res1.status(), http::StatusCode::OK);
let body1 = res1.into_body().collect().await.anyerr()?.to_bytes();
assert_eq!(body1.as_ref(), b"origin GET /uds/h2/path1");
let req2 = Request::builder()
.method(http::Method::GET)
.uri("/uds/h2/path2")
.body(Full::new(Bytes::new()))
.anyerr()?;
let res2 = sender.send_request(req2).await.anyerr()?;
assert_eq!(res2.status(), http::StatusCode::OK);
let body2 = res2.into_body().collect().await.anyerr()?.to_bytes();
assert_eq!(body2.as_ref(), b"origin GET /uds/h2/path2");
conn_task.abort();
drop(proxy_task);
let _ = std::fs::remove_file(&socket_path);
upstream_router.shutdown().await.anyerr()?;
Ok(())
}
mod origin_server {
use std::{convert::Infallible, sync::Arc};
use fastwebsockets::{FragmentCollector, Frame, OpCode, Payload, WebSocketError};
use http::HeaderValue;
use http_body_util::{BodyExt, Empty, Full};
use hyper::{Request, Response, body::Bytes, server::conn::http1, service::service_fn};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tracing::debug;
pub(super) async fn run(listener: TcpListener, label: &'static str) {
let label = Arc::new(label);
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let io = TokioIo::new(stream);
let label = label.clone();
tokio::task::spawn(async move {
let handler = move |req: Request<hyper::body::Incoming>| {
let label = label.clone();
debug!("origin {label}: {req:?}");
async move {
let body = format!("{} {} {}", *label, req.method(), req.uri().path());
let len = body.len();
let mut res = Response::new(Full::new(Bytes::from(body)));
res.headers_mut().insert(
http::header::CONTENT_LENGTH,
HeaderValue::from_str(&len.to_string()).unwrap(),
);
Ok::<_, Infallible>(res)
}
};
let _ = http1::Builder::new()
.serve_connection(io, service_fn(handler))
.await;
});
}
}
pub(super) async fn run_echo_body(listener: TcpListener, label: &'static str) {
let label = Arc::new(label);
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let io = TokioIo::new(stream);
let label = label.clone();
tokio::task::spawn(async move {
let handler = move |req: Request<hyper::body::Incoming>| {
let label = label.clone();
async move {
let method = req.method().clone();
let path = req.uri().path().to_string();
let body_bytes = req.collect().await.unwrap().to_bytes();
let body_str = String::from_utf8_lossy(&body_bytes);
let response = format!("{} {} {}: {}", *label, method, path, body_str);
Ok::<_, Infallible>(Response::new(Full::new(Bytes::from(response))))
}
};
let _ = http1::Builder::new()
.serve_connection(io, service_fn(handler))
.await;
});
}
}
pub(super) async fn run_with_websocket(listener: TcpListener, label: &'static str) {
let label = Arc::new(label);
loop {
let Ok((stream, _)) = listener.accept().await else {
break;
};
let io = TokioIo::new(stream);
let label = label.clone();
tokio::task::spawn(async move {
let handler = move |mut req: Request<hyper::body::Incoming>| {
let label = label.clone();
debug!("origin {label}: {req:?}");
async move {
let path = req.uri().path().to_string();
if path == "/ws" && fastwebsockets::upgrade::is_upgrade_request(&req) {
let (response, fut) =
fastwebsockets::upgrade::upgrade(&mut req).unwrap();
tokio::spawn(async move {
if let Err(e) = handle_websocket(fut).await {
debug!("WebSocket error: {e:?}");
}
});
return Ok(
response.map(|_| Empty::new().map_err(|e| match e {}).boxed())
);
}
let body = format!("{} {} {}", *label, req.method(), path);
let len = body.len();
let mut res = Response::new(
Full::new(Bytes::from(body)).map_err(|e| match e {}).boxed(),
);
res.headers_mut().insert(
http::header::CONTENT_LENGTH,
HeaderValue::from_str(&len.to_string()).unwrap(),
);
Ok::<_, WebSocketError>(res)
}
};
let _ = http1::Builder::new()
.serve_connection(io, service_fn(handler))
.with_upgrades()
.await;
});
}
}
async fn handle_websocket(
fut: fastwebsockets::upgrade::UpgradeFut,
) -> Result<(), WebSocketError> {
let ws = fut.await?;
let mut ws = FragmentCollector::new(ws);
loop {
let frame = ws.read_frame().await?;
match frame.opcode {
OpCode::Close => break,
OpCode::Text => {
let payload = frame.payload.to_vec();
ws.write_frame(Frame::text(Payload::Owned(payload))).await?;
}
OpCode::Binary => {
let payload = frame.payload.to_vec();
ws.write_frame(Frame::binary(Payload::Owned(payload)))
.await?;
}
_ => {}
}
}
Ok(())
}
}
mod metrics {
use bytes::Bytes;
use super::*;
struct RejectAll;
impl RequestHandler for RejectAll {
async fn handle_request(
&self,
_src_addr: SrcAddr,
_req: &mut HttpRequest,
) -> Result<EndpointId, Deny> {
Err(Deny::bad_request("denied by RejectAll"))
}
}
#[tokio::test]
#[traced_test]
async fn http_metrics_track_requests() -> Result {
let (origin_addr, origin_task) = spawn_origin_server_echo_body("origin").await?;
let (upstream_router, upstream_id, upstream_metrics) =
spawn_upstream_proxy_with_metrics().await?;
let mode = ProxyMode::Http(HttpProxyOpts::new(StaticForwardProxy(upstream_id)));
let (proxy_addr, _, downstream_metrics, proxy_task) =
spawn_downstream_proxy_with_metrics(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let (tx, rx) = tokio::sync::mpsc::channel::<std::io::Result<Bytes>>(1);
let body = reqwest::Body::wrap_stream(tokio_stream::wrappers::ReceiverStream::new(rx));
let body_task = tokio::task::spawn({
let downstream_metrics = downstream_metrics.clone();
async move {
tokio::time::sleep(Duration::from_millis(100)).await;
tx.send(Ok(Bytes::from("hello"))).await.unwrap();
assert_eq!(downstream_metrics.active_requests(), 1);
drop(tx);
}
});
let res = client
.post(format!("http://{origin_addr}/upload"))
.body(body)
.send()
.await
.anyerr()?;
body_task.await.expect("task panicked");
assert_eq!(res.status(), StatusCode::OK);
let body = res.bytes().await.anyerr()?;
assert_eq!(body.as_ref(), b"origin POST /upload: hello");
assert_eq!(downstream_metrics.active_requests(), 0);
assert_eq!(downstream_metrics.requests_accepted.get(), 1);
assert_eq!(downstream_metrics.requests_accepted_h1.get(), 1);
assert_eq!(downstream_metrics.requests_completed.get(), 1);
assert_eq!(downstream_metrics.requests_denied.get(), 0);
assert!(downstream_metrics.bytes_to_upstream.get() > 0);
assert!(downstream_metrics.bytes_from_upstream.get() > 0);
assert_eq!(downstream_metrics.iroh_connections_opened.get(), 1);
assert_eq!(downstream_metrics.iroh_connections_closed_error.get(), 0);
assert_eq!(downstream_metrics.iroh_connections_closed_idle.get(), 0);
assert_eq!(downstream_metrics.active_iroh_connections(), 1);
drop(origin_task);
drop(proxy_task);
upstream_router.shutdown().await.anyerr()?;
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(downstream_metrics.iroh_connections_closed_error.get(), 1);
let origin_metrics = upstream_metrics
.get(&Authority::from_authority_str(&origin_addr.to_string()).unwrap())
.expect("exists");
assert_eq!(
origin_metrics.bytes_from_origin(),
downstream_metrics.bytes_from_upstream.get()
);
assert_eq!(
origin_metrics.bytes_to_origin(),
downstream_metrics.bytes_to_upstream.get()
);
debug!("downstream metrics: {downstream_metrics:#?}");
Ok(())
}
#[tokio::test]
#[traced_test]
async fn http_metrics_deny_behavior() -> Result {
let mode = ProxyMode::Http(HttpProxyOpts::new(RejectAll));
let (proxy_addr, _, metrics, proxy_task) =
spawn_downstream_proxy_with_metrics(mode).await?;
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::http(format!("http://{proxy_addr}")).anyerr()?)
.build()
.anyerr()?;
let res = client
.get("http://example.com/denied")
.send()
.await
.anyerr()?;
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(metrics.requests_denied.get(), 1);
assert_eq!(metrics.requests_completed.get(), 0);
assert_eq!(metrics.requests_failed.get(), 0);
assert_eq!(metrics.active_requests(), 0);
drop(proxy_task);
Ok(())
}
}