use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::server::conn::http1 as server_http1;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use rustgate::cert::CertificateAuthority;
use rustgate::handler::LoggingHandler;
use rustgate::proxy::{handle_connection, ProxyState};
use std::sync::Arc;
use tempfile::TempDir;
use tokio::net::TcpListener;
type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
fn full_body(data: &'static str) -> BoxBody {
Full::new(Bytes::from(data))
.map_err(|_| unreachable!())
.boxed()
}
fn empty_body() -> BoxBody {
http_body_util::Empty::<Bytes>::new()
.map_err(|_| unreachable!())
.boxed()
}
async fn spawn_upstream(body: &'static str) -> (String, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);
let svc = service_fn(move |_req: Request<hyper::body::Incoming>| {
let body = body;
async move {
let resp = Response::builder()
.status(200)
.body(full_body(body))
.unwrap();
Ok::<_, hyper::Error>(resp)
}
});
let _ = server_http1::Builder::new()
.serve_connection(io, svc)
.await;
});
(format!("127.0.0.1:{}", addr.port()), handle)
}
async fn spawn_proxy(mitm: bool) -> (String, tokio::task::JoinHandle<()>, TempDir) {
let tmp = TempDir::new().unwrap();
let ca = Arc::new(CertificateAuthority::with_dir(tmp.path().to_path_buf()).await.unwrap());
let state = Arc::new(ProxyState {
ca,
mitm,
intercept: false,
log_traffic: false,
handler: Arc::new(LoggingHandler),
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
let (stream, peer) = listener.accept().await.unwrap();
handle_connection(stream, peer, state).await;
});
(format!("127.0.0.1:{}", addr.port()), handle, tmp)
}
#[tokio::test]
async fn test_http_forward() {
let (upstream_addr, _upstream) = spawn_upstream("hello from upstream").await;
let (proxy_addr, _proxy, _tmp) = spawn_proxy(false).await;
let stream = tokio::net::TcpStream::connect(&proxy_addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) =
hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::builder()
.method("GET")
.uri(format!("http://{upstream_addr}/test"))
.body(empty_body())
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), 200);
let body = res.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"hello from upstream");
}
#[tokio::test]
async fn test_connect_tunnel_returns_200() {
let (proxy_addr, _proxy, _tmp) = spawn_proxy(false).await;
let stream = tokio::net::TcpStream::connect(&proxy_addr).await.unwrap();
let io = TokioIo::new(stream);
let (mut sender, conn) =
hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::builder()
.method("CONNECT")
.uri("example.com:443")
.body(empty_body())
.unwrap();
let res = sender.send_request(req).await.unwrap();
assert_eq!(res.status(), 200);
}
#[tokio::test]
async fn test_ca_reload_preserves_identity() {
let tmp = TempDir::new().unwrap();
let dir = tmp.path().to_path_buf();
let ca1 = CertificateAuthority::with_dir(dir.clone()).await.unwrap();
let pem_after_create = std::fs::read_to_string(dir.join("ca.pem")).unwrap();
let ca2 = CertificateAuthority::with_dir(dir.clone()).await.unwrap();
let pem_after_reload = std::fs::read_to_string(dir.join("ca.pem")).unwrap();
assert_eq!(pem_after_create, pem_after_reload);
let leaf1 = ca1.get_or_create_cert("test.example.com").await.unwrap();
let leaf2 = ca2.get_or_create_cert("test.example.com").await.unwrap();
let issuer1 = extract_issuer_der(&leaf1.cert_der);
let issuer2 = extract_issuer_der(&leaf2.cert_der);
assert_eq!(issuer1, issuer2, "Leaf certs should have the same issuer");
}
fn extract_issuer_der(der: &[u8]) -> Vec<u8> {
let tbs_start = asn1_header_len(der);
let tbs_inner_start = tbs_start + asn1_header_len(&der[tbs_start..]);
let mut pos = tbs_inner_start;
for _ in 0..3 {
pos += asn1_element_len(&der[pos..]);
}
let issuer_len = asn1_element_len(&der[pos..]);
der[pos..pos + issuer_len].to_vec()
}
fn asn1_header_len(data: &[u8]) -> usize {
let len_byte = data[1];
if len_byte & 0x80 == 0 {
2
} else {
2 + (len_byte & 0x7f) as usize
}
}
fn asn1_element_len(data: &[u8]) -> usize {
let len_byte = data[1];
if len_byte & 0x80 == 0 {
2 + len_byte as usize
} else {
let num_bytes = (len_byte & 0x7f) as usize;
let mut len = 0usize;
for &b in &data[2..2 + num_bytes] {
len = (len << 8) | b as usize;
}
2 + num_bytes + len
}
}
#[test]
fn test_parse_host_port_ipv4() {
let (host, port) = rustgate::proxy::parse_host_port("example.com:8080");
assert_eq!(host, "example.com");
assert_eq!(port, 8080);
}
#[test]
fn test_parse_host_port_default() {
let (host, port) = rustgate::proxy::parse_host_port("example.com");
assert_eq!(host, "example.com");
assert_eq!(port, 443);
}
#[test]
fn test_parse_host_port_ipv6() {
let (host, port) = rustgate::proxy::parse_host_port("[::1]:443");
assert_eq!(host, "::1");
assert_eq!(port, 443);
}
#[test]
fn test_parse_host_port_ipv6_custom_port() {
let (host, port) = rustgate::proxy::parse_host_port("[2001:db8::1]:8080");
assert_eq!(host, "2001:db8::1");
assert_eq!(port, 8080);
}