mod common;
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use axum::Router;
use axum::response::sse::{Event, Sse};
use axum::routing::get;
use axum_server::tls_rustls::RustlsConfig;
use common::*;
use futures_util::stream::StreamExt;
use http_body_util::BodyExt;
use noxy::http::{Body, BoxError, HttpService, full_body};
use noxy::{CertificateAuthority, Proxy};
use rcgen::{CertificateParams, KeyPair};
use tokio::net::TcpListener;
struct AddResponseHeader {
inner: HttpService,
name: http::HeaderName,
value: http::HeaderValue,
}
impl tower::Service<http::Request<Body>> for AddResponseHeader {
type Response = http::Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<http::Response<Body>, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
let fut = self.inner.call(req);
let name = self.name.clone();
let value = self.value.clone();
Box::pin(async move {
let mut resp = fut.await?;
resp.headers_mut().insert(name, value);
Ok(resp)
})
}
}
#[tokio::test]
async fn proxy_relays_data() {
let upstream_addr = start_upstream("hello world").await;
let proxy_addr = start_proxy(vec![]).await;
let client = http_client(proxy_addr);
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.text().await.unwrap(), "hello world");
}
#[tokio::test]
async fn proxy_applies_layer() {
let upstream_addr = start_upstream("hello").await;
let proxy_addr = start_proxy(vec![Box::new(|inner: HttpService| {
tower::util::BoxService::new(AddResponseHeader {
inner,
name: http::HeaderName::from_static("x-proxy"),
value: http::HeaderValue::from_static("noxy"),
})
})])
.await;
let client = http_client(proxy_addr);
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.headers().get("x-proxy").unwrap(), "noxy");
assert_eq!(resp.text().await.unwrap(), "hello");
}
#[tokio::test]
async fn proxy_chains_layers() {
let upstream_addr = start_upstream("hello").await;
let proxy_addr = start_proxy(vec![
Box::new(|inner: HttpService| {
tower::util::BoxService::new(AddResponseHeader {
inner,
name: http::HeaderName::from_static("x-first"),
value: http::HeaderValue::from_static("1"),
})
}),
Box::new(|inner: HttpService| {
tower::util::BoxService::new(AddResponseHeader {
inner,
name: http::HeaderName::from_static("x-second"),
value: http::HeaderValue::from_static("2"),
})
}),
])
.await;
let client = http_client(proxy_addr);
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.headers().get("x-first").unwrap(), "1");
assert_eq!(resp.headers().get("x-second").unwrap(), "2");
}
#[tokio::test]
async fn proxy_rejects_non_connect() {
let proxy_addr = start_proxy(vec![]).await;
let resp = reqwest::Client::builder()
.no_proxy()
.build()
.unwrap()
.get(format!("http://{proxy_addr}/"))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 400);
}
struct AddRequestHeader {
inner: HttpService,
name: http::HeaderName,
value: http::HeaderValue,
}
impl tower::Service<http::Request<Body>> for AddRequestHeader {
type Response = http::Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<http::Response<Body>, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
req.headers_mut()
.insert(self.name.clone(), self.value.clone());
self.inner.call(req)
}
}
#[tokio::test]
async fn proxy_injects_request_header() {
install_crypto_provider();
let key_pair = KeyPair::generate().unwrap();
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
let cert = params.self_signed(&key_pair).unwrap();
let cert_der = cert.der().to_vec();
let key_der = key_pair.serialized_der().to_vec();
let config = RustlsConfig::from_der(vec![cert_der], key_der)
.await
.unwrap();
let app = Router::new().route(
"/",
get(|headers: axum::http::HeaderMap| async move {
headers
.get("x-injected")
.map(|v| v.to_str().unwrap().to_string())
.unwrap_or_default()
}),
);
let handle = axum_server::Handle::new();
let listener_handle = handle.clone();
tokio::spawn(async move {
axum_server::bind_rustls("127.0.0.1:0".parse().unwrap(), config)
.handle(handle)
.serve(app.into_make_service())
.await
.unwrap();
});
let upstream_addr = listener_handle.listening().await.unwrap();
let proxy_addr = start_proxy(vec![Box::new(|inner: HttpService| {
tower::util::BoxService::new(AddRequestHeader {
inner,
name: http::HeaderName::from_static("x-injected"),
value: http::HeaderValue::from_static("from-proxy"),
})
})])
.await;
let client = http_client(proxy_addr);
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.text().await.unwrap(), "from-proxy");
}
struct AddBodyChecksum {
inner: HttpService,
}
impl tower::Service<http::Request<Body>> for AddBodyChecksum {
type Response = http::Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<http::Response<Body>, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
let fut = self.inner.call(req);
Box::pin(async move {
let (parts, body) = fut.await?.into_parts();
let bytes = body.collect().await?.to_bytes();
let checksum: u8 = bytes.iter().fold(0u8, |acc, &b| acc.wrapping_add(b));
let mut resp = http::Response::from_parts(parts, full_body(bytes));
resp.headers_mut().insert(
http::HeaderName::from_static("x-body-checksum"),
http::HeaderValue::from_str(&checksum.to_string()).unwrap(),
);
Ok(resp)
})
}
}
#[tokio::test]
async fn proxy_layer_buffers_body_for_checksum() {
let body = "hello world";
let expected_checksum: u8 = body.bytes().fold(0u8, |acc, b| acc.wrapping_add(b));
let upstream_addr = start_upstream(body).await;
let proxy_addr = start_proxy(vec![Box::new(|inner: HttpService| {
tower::util::BoxService::new(AddBodyChecksum { inner })
})])
.await;
let client = http_client(proxy_addr);
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(
resp.headers().get("x-body-checksum").unwrap(),
&expected_checksum.to_string()
);
assert_eq!(resp.text().await.unwrap(), body);
}
#[tokio::test]
async fn proxy_streams_sse_incrementally() {
const EVENT_COUNT: usize = 5;
const EVENT_DELAY: Duration = Duration::from_millis(100);
install_crypto_provider();
let key_pair = KeyPair::generate().unwrap();
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
let cert = params.self_signed(&key_pair).unwrap();
let cert_der = cert.der().to_vec();
let key_der = key_pair.serialized_der().to_vec();
let config = RustlsConfig::from_der(vec![cert_der], key_der)
.await
.unwrap();
let app = Router::new().route(
"/sse",
get(|| async {
Sse::new(futures_util::stream::unfold(0usize, |i| async move {
if i >= EVENT_COUNT {
return None;
}
if i > 0 {
tokio::time::sleep(EVENT_DELAY).await;
}
Some((
Ok::<_, Infallible>(Event::default().data(format!("event-{i}"))),
i + 1,
))
}))
}),
);
let handle = axum_server::Handle::new();
let listener_handle = handle.clone();
tokio::spawn(async move {
axum_server::bind_rustls("127.0.0.1:0".parse().unwrap(), config)
.handle(handle)
.serve(app.into_make_service())
.await
.unwrap();
});
let upstream_addr = listener_handle.listening().await.unwrap();
let proxy_addr = start_proxy(vec![]).await;
let client = http_client(proxy_addr);
let start = Instant::now();
let resp = client
.get(format!("https://localhost:{}/sse", upstream_addr.port()))
.send()
.await
.unwrap();
let mut stream = resp.bytes_stream();
let mut events = Vec::new();
let mut buf = String::new();
while let Some(chunk) = stream.next().await {
buf.push_str(&String::from_utf8_lossy(&chunk.unwrap()));
while let Some(pos) = buf.find("\n\n") {
let event_text = buf[..pos].to_string();
buf.drain(..pos + 2);
if event_text.contains("data:") {
events.push((event_text, start.elapsed()));
}
}
}
assert_eq!(events.len(), EVENT_COUNT, "should receive all SSE events");
let total_stream_time = EVENT_DELAY * (EVENT_COUNT as u32 - 1);
assert!(
events[0].1 < total_stream_time / 2,
"first event arrived at {:?}, expected well before {:?} (total stream time) — \
proxy may be buffering instead of streaming",
events[0].1,
total_stream_time,
);
for (i, (event_text, _)) in events.iter().enumerate() {
assert!(
event_text.contains(&format!("event-{i}")),
"event {i} should contain 'event-{i}', got: {event_text}"
);
}
}
#[test]
fn certificate_authority_generates_valid_cert() {
let ca = CertificateAuthority::from_pem_files("tests/dummy-cert.pem", "tests/dummy-key.pem")
.unwrap();
let (cert_der, key_der) = ca.generate_cert("example.com").unwrap();
assert!(!cert_der.is_empty());
assert!(!key_der.secret_der().is_empty());
}
#[tokio::test]
async fn handshake_timeout_drops_slow_connection() {
let upstream_addr = start_upstream("hello").await;
let proxy = Proxy::builder()
.ca_pem_files("tests/dummy-cert.pem", "tests/dummy-key.pem")
.unwrap()
.danger_accept_invalid_upstream_certs()
.handshake_timeout(Duration::from_millis(200))
.build()
.unwrap();
let proxy_addr = spawn_proxy(proxy).await;
use tokio::io::AsyncWriteExt;
let mut stream = tokio::net::TcpStream::connect(proxy_addr).await.unwrap();
stream
.write_all(format!("CONNECT localhost:{}", upstream_addr.port()).as_bytes())
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
let result = stream.write_all(b" HTTP/1.1\r\n\r\n").await;
if result.is_ok() {
let mut buf = [0u8; 128];
let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf)
.await
.unwrap_or(0);
assert_eq!(n, 0, "expected proxy to have closed the connection");
}
}
#[tokio::test]
async fn handshake_timeout_allows_fast_connection() {
let upstream_addr = start_upstream("hello").await;
let proxy = Proxy::builder()
.ca_pem_files("tests/dummy-cert.pem", "tests/dummy-key.pem")
.unwrap()
.danger_accept_invalid_upstream_certs()
.handshake_timeout(Duration::from_secs(10))
.build()
.unwrap();
let proxy_addr = spawn_proxy(proxy).await;
let client = http_client(proxy_addr);
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.text().await.unwrap(), "hello");
}
#[tokio::test]
async fn max_connections_applies_backpressure() {
install_crypto_provider();
let key_pair = KeyPair::generate().unwrap();
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
let cert = params.self_signed(&key_pair).unwrap();
let cert_der = cert.der().to_vec();
let key_der = key_pair.serialized_der().to_vec();
let config = RustlsConfig::from_der(vec![cert_der], key_der)
.await
.unwrap();
let app = Router::new().route(
"/",
get(|| async {
tokio::time::sleep(Duration::from_millis(300)).await;
"ok"
}),
);
let handle = axum_server::Handle::new();
let listener_handle = handle.clone();
tokio::spawn(async move {
axum_server::bind_rustls("127.0.0.1:0".parse().unwrap(), config)
.handle(handle)
.serve(app.into_make_service())
.await
.unwrap();
});
let upstream_addr = listener_handle.listening().await.unwrap();
let proxy = Proxy::builder()
.ca_pem_files("tests/dummy-cert.pem", "tests/dummy-key.pem")
.unwrap()
.danger_accept_invalid_upstream_certs()
.max_connections(2)
.build()
.unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
proxy.listen_on(listener).await.unwrap();
});
let client = http_client(proxy_addr);
let start = Instant::now();
let url = format!("https://localhost:{}/", upstream_addr.port());
let (r1, r2, r3) = tokio::join!(
client.get(&url).send(),
client.get(&url).send(),
client.get(&url).send(),
);
let elapsed = start.elapsed();
assert_eq!(r1.unwrap().text().await.unwrap(), "ok");
assert_eq!(r2.unwrap().text().await.unwrap(), "ok");
assert_eq!(r3.unwrap().text().await.unwrap(), "ok");
assert!(
elapsed >= Duration::from_millis(500),
"3 requests with max_connections=2 and 300ms upstream should take ~600ms, took {elapsed:?}"
);
}
fn start_authenticated_proxy() -> noxy::ProxyBuilder {
Proxy::builder()
.ca_pem_files("tests/dummy-cert.pem", "tests/dummy-key.pem")
.unwrap()
.danger_accept_invalid_upstream_certs()
.credential("admin", "secret")
.credential("user2", "pass2")
}
fn http_client_with_auth(
proxy_addr: SocketAddr,
username: &str,
password: &str,
) -> reqwest::Client {
let ca_pem = std::fs::read("tests/dummy-cert.pem").unwrap();
let ca_cert = reqwest::tls::Certificate::from_pem(&ca_pem).unwrap();
reqwest::Client::builder()
.proxy(
reqwest::Proxy::https(format!("http://{proxy_addr}"))
.unwrap()
.basic_auth(username, password),
)
.add_root_certificate(ca_cert)
.build()
.unwrap()
}
#[tokio::test]
async fn proxy_auth_rejects_missing_credentials() {
let upstream_addr = start_upstream("hello").await;
let proxy = start_authenticated_proxy().build().unwrap();
let proxy_addr = spawn_proxy(proxy).await;
let client = http_client(proxy_addr);
let result = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await;
assert!(
result.is_err(),
"expected request without credentials to fail"
);
}
#[tokio::test]
async fn proxy_auth_rejects_wrong_credentials() {
let upstream_addr = start_upstream("hello").await;
let proxy = start_authenticated_proxy().build().unwrap();
let proxy_addr = spawn_proxy(proxy).await;
let client = http_client_with_auth(proxy_addr, "admin", "wrong");
let result = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await;
assert!(
result.is_err(),
"expected request with wrong credentials to fail"
);
}
#[tokio::test]
async fn proxy_auth_accepts_valid_credentials() {
let upstream_addr = start_upstream("hello").await;
let proxy = start_authenticated_proxy().build().unwrap();
let proxy_addr = spawn_proxy(proxy).await;
let client = http_client_with_auth(proxy_addr, "admin", "secret");
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.text().await.unwrap(), "hello");
}
#[tokio::test]
async fn proxy_auth_accepts_second_credential() {
let upstream_addr = start_upstream("hello").await;
let proxy = start_authenticated_proxy().build().unwrap();
let proxy_addr = spawn_proxy(proxy).await;
let client = http_client_with_auth(proxy_addr, "user2", "pass2");
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.text().await.unwrap(), "hello");
}
#[tokio::test]
async fn proxy_relays_websocket() {
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use futures_util::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
async fn echo_ws(mut socket: WebSocket) {
while let Some(Ok(msg)) = socket.recv().await {
if matches!(msg, Message::Text(_) | Message::Binary(_))
&& socket.send(msg).await.is_err()
{
break;
}
}
}
install_crypto_provider();
let key_pair = KeyPair::generate().unwrap();
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
let cert = params.self_signed(&key_pair).unwrap();
let cert_der = cert.der().clone();
let key_der =
rustls::pki_types::PrivateKeyDer::Pkcs8(key_pair.serialized_der().to_vec().into());
let mut server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.unwrap();
server_config.alpn_protocols = vec![b"http/1.1".to_vec()];
let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config));
let app = Router::new().route(
"/ws",
get(|ws: WebSocketUpgrade| async { ws.on_upgrade(echo_ws) }),
);
let upstream_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let upstream_addr = upstream_listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let Ok((stream, _)) = upstream_listener.accept().await else {
break;
};
let acceptor = acceptor.clone();
let app = app.clone();
tokio::spawn(async move {
let Ok(tls_stream) = acceptor.accept(stream).await else {
return;
};
let io = hyper_util::rt::TokioIo::new(tls_stream);
hyper::server::conn::http1::Builder::new()
.serve_connection(
io,
hyper::service::service_fn(
move |req: hyper::Request<hyper::body::Incoming>| {
let app = app.clone();
async move {
use tower::Service;
let mut app = app;
let req = req.map(axum::body::Body::new);
Ok::<_, Infallible>(app.call(req).await.unwrap())
}
},
),
)
.with_upgrades()
.await
.ok();
});
}
});
let proxy_addr = start_proxy(vec![]).await;
let port = upstream_addr.port();
let mut stream = tokio::net::TcpStream::connect(proxy_addr).await.unwrap();
stream
.write_all(
format!("CONNECT localhost:{port} HTTP/1.1\r\nHost: localhost:{port}\r\n\r\n")
.as_bytes(),
)
.await
.unwrap();
let mut buf = [0u8; 256];
let mut total = 0;
loop {
let n = stream.read(&mut buf[total..]).await.unwrap();
assert!(n > 0, "proxy closed connection before 200 response");
total += n;
if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
let resp = std::str::from_utf8(&buf[..total]).unwrap();
assert!(
resp.starts_with("HTTP/1.1 200"),
"expected 200, got: {resp}"
);
let ca_pem = std::fs::read("tests/dummy-cert.pem").unwrap();
let ca_certs: Vec<_> = rustls_pemfile::certs(&mut &*ca_pem)
.collect::<Result<_, _>>()
.unwrap();
let mut root_store = rustls::RootCertStore::empty();
for cert in ca_certs {
root_store.add(cert).unwrap();
}
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config));
let server_name = rustls::pki_types::ServerName::try_from("localhost").unwrap();
let tls_stream = connector.connect(server_name, stream).await.unwrap();
let (mut ws, _) =
tokio_tungstenite::client_async(format!("ws://localhost:{port}/ws"), tls_stream)
.await
.unwrap();
ws.send(tokio_tungstenite::tungstenite::Message::Text(
"hello".into(),
))
.await
.unwrap();
let msg = ws.next().await.unwrap().unwrap();
assert_eq!(msg.into_text().unwrap(), "hello");
ws.close(None).await.unwrap();
}
#[tokio::test]
async fn forward_proxy_with_tls_listener() {
install_crypto_provider();
let upstream_addr = start_upstream("hello tls forward").await;
let key_pair = KeyPair::generate().unwrap();
let params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
let cert = params.self_signed(&key_pair).unwrap();
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
let cert_path = std::env::temp_dir().join("noxy-test-forward-tls-cert.pem");
let key_path = std::env::temp_dir().join("noxy-test-forward-tls-key.pem");
std::fs::write(&cert_path, &cert_pem).unwrap();
std::fs::write(&key_path, &key_pem).unwrap();
let proxy = Proxy::builder()
.ca_pem_files("tests/dummy-cert.pem", "tests/dummy-key.pem")
.unwrap()
.danger_accept_invalid_upstream_certs()
.tls_identity(&cert_path, &key_path)
.unwrap()
.build()
.unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let proxy_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
proxy.listen_on(listener).await.unwrap();
});
let listener_ca = reqwest::tls::Certificate::from_pem(cert_pem.as_bytes()).unwrap();
let ca_pem = std::fs::read("tests/dummy-cert.pem").unwrap();
let mitm_ca = reqwest::tls::Certificate::from_pem(&ca_pem).unwrap();
let client = reqwest::Client::builder()
.proxy(reqwest::Proxy::all(format!("https://localhost:{}", proxy_addr.port())).unwrap())
.add_root_certificate(listener_ca)
.add_root_certificate(mitm_ca)
.build()
.unwrap();
let resp = client
.get(format!("https://localhost:{}/", upstream_addr.port()))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.text().await.unwrap(), "hello tls forward");
std::fs::remove_file(&cert_path).ok();
std::fs::remove_file(&key_path).ok();
}