use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use std::pin::Pin;
use std::future::Future;
use relay_core_lib::capture::source::{CaptureSource, IncomingConnection};
use relay_core_lib::start_proxy;
use relay_core_lib::interceptor::NoOpInterceptor;
use relay_core_lib::tls::CertificateAuthority;
use hyper_util::rt::TokioIo;
use http_body_util::BodyExt;
use relay_core_api::flow::FlowUpdate;
use relay_core_api::policy::ProxyPolicy;
struct MockTransparentSource {
listener: TcpListener,
target_addr: SocketAddr,
}
impl MockTransparentSource {
async fn new(addr: SocketAddr, target_addr: SocketAddr) -> Self {
let listener = TcpListener::bind(addr).await.expect("Failed to bind mock source");
Self { listener, target_addr }
}
}
impl CaptureSource for MockTransparentSource {
type IO = TcpStream;
fn accept(&mut self) -> Pin<Box<dyn Future<Output = relay_core_lib::error::Result<IncomingConnection<Self::IO>>> + Send + '_>> {
Box::pin(async move {
let (stream, client_addr) = self.listener.accept().await?;
Ok(IncomingConnection {
stream,
client_addr,
target_addr: Some(self.target_addr),
})
})
}
fn listen_addrs(&self) -> Vec<SocketAddr> {
if let Ok(addr) = self.listener.local_addr() {
vec![addr]
} else {
vec![]
}
}
}
#[tokio::test]
async fn test_transparent_proxy_routing() {
let _ = rustls::crypto::ring::default_provider().install_default();
let echo_addr = SocketAddr::from(([127, 0, 0, 1], 0));
let echo_listener = TcpListener::bind(echo_addr).await.expect("Failed to bind echo server");
let echo_port = echo_listener.local_addr().unwrap().port();
let echo_socket_addr = echo_listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
if let Ok((mut socket, _)) = echo_listener.accept().await {
tokio::spawn(async move {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = [0; 4096];
let _ = socket.read(&mut buf).await;
let response = "HTTP/1.1 200 OK\r\nContent-Length: 12\r\nConnection: close\r\n\r\nHello World!";
let _ = socket.write_all(response.as_bytes()).await;
});
}
}
});
let proxy_addr = SocketAddr::from(([127, 0, 0, 1], 0));
let source = MockTransparentSource::new(proxy_addr, echo_socket_addr).await;
let proxy_port = source.listener.local_addr().unwrap().port();
let interceptor = Arc::new(NoOpInterceptor {});
let ca = Arc::new(CertificateAuthority::new().expect("Failed to create CA"));
let (tx, _rx) = tokio::sync::mpsc::channel::<FlowUpdate>(10);
let on_flow = tx.clone();
tokio::spawn(async move {
let mut policy = ProxyPolicy::default();
policy.transparent_enabled = true;
let (_policy_tx, policy_rx) = tokio::sync::watch::channel(policy);
start_proxy(source, on_flow, interceptor, ca, policy_rx, None, None).await.unwrap();
});
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await.expect("Failed to connect to proxy");
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.expect("Handshake failed");
tokio::spawn(async move {
if let Err(e) = conn.await {
eprintln!("Connection failed: {:?}", e);
}
});
let req = hyper::Request::builder()
.uri("/")
.header("Host", format!("127.0.0.1:{}", echo_port))
.body(http_body_util::Full::new(bytes::Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.expect("Request failed");
assert_eq!(res.status(), 200);
let body = res.collect().await.unwrap().to_bytes();
assert_eq!(body, "Hello World!");
}
#[tokio::test]
async fn test_transparent_proxy_loop_detection() {
let _ = rustls::crypto::ring::default_provider().install_default();
let proxy_addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(proxy_addr).await.unwrap();
let proxy_socket_addr = listener.local_addr().unwrap();
let source = MockTransparentSource { listener, target_addr: proxy_socket_addr };
let proxy_port = proxy_socket_addr.port();
let interceptor = Arc::new(NoOpInterceptor {});
let ca = Arc::new(CertificateAuthority::new().expect("Failed to create CA"));
let (tx, _rx) = tokio::sync::mpsc::channel::<FlowUpdate>(10);
let on_flow = tx.clone();
tokio::spawn(async move {
let mut policy = ProxyPolicy::default();
policy.transparent_enabled = true;
let (_policy_tx, policy_rx) = tokio::sync::watch::channel(policy);
start_proxy(source, on_flow, interceptor, ca, policy_rx, None, None).await.unwrap();
});
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await.expect("Failed to connect to proxy");
let io = TokioIo::new(stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.expect("Handshake failed");
tokio::spawn(async move {
if let Err(e) = conn.await {
eprintln!("Connection failed: {:?}", e);
}
});
let req = hyper::Request::builder()
.uri("/")
.header("Host", "example.com")
.body(http_body_util::Full::new(bytes::Bytes::new()))
.unwrap();
let res = sender.send_request(req).await.expect("Request failed");
assert_eq!(res.status(), 508);
}