#![cfg(feature = "rustls")]
use std::io::{Read, Write};
use std::net::{Shutdown, SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
#[cfg(feature = "h2")]
use std::sync::Once;
use std::sync::{Arc, Mutex};
use std::thread;
use async_net::TcpListener;
use async_tls::TlsAcceptor;
use futures_lite::future::block_on;
use futures_lite::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
#[cfg(feature = "h2")]
use h2::server;
#[cfg(feature = "h2")]
use http::Response as HttpResponse;
use rcgen::generate_simple_self_signed;
use rustls::ServerConfig;
#[cfg(feature = "h2")]
use tokio::runtime::Builder as TokioRuntimeBuilder;
#[cfg(feature = "h2")]
use tokio_rustls::TlsAcceptor as TokioTlsAcceptor;
#[cfg(feature = "h2")]
use tokio_rustls::rustls::ServerConfig as TokioRustlsServerConfig;
#[cfg(feature = "h2")]
use tokio_rustls::rustls::crypto::aws_lc_rs;
#[cfg(feature = "h2")]
use tokio_rustls::rustls::pki_types::{
CertificateDer as TokioCertificateDer, PrivateKeyDer as TokioPrivateKeyDer, PrivatePkcs8KeyDer,
};
use ugi::{Client, Proxy};
#[cfg(feature = "h2")]
static TOKIO_RUSTLS_PROVIDER: Once = Once::new();
fn run<T>(value: T) -> T::Output
where
T: std::future::IntoFuture,
{
block_on(async move { value.await })
}
#[test]
fn http_connect_proxy_tunnels_https_without_auth() {
let (base, target_addr, requests, https_handle) =
block_on(spawn_https_request_server()).unwrap();
let (proxy, connect_request, proxy_handle) =
spawn_http_connect_proxy_server(target_addr, None).unwrap();
let client = Client::builder()
.proxy(proxy)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get(format!("{base}/http-connect"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
block_on(client.close()).unwrap();
proxy_handle.join().unwrap();
https_handle.join().unwrap();
let connect_request = connect_request.lock().unwrap().clone().unwrap();
assert!(connect_request.starts_with("CONNECT localhost:"));
assert!(
!connect_request
.to_ascii_lowercase()
.contains("proxy-authorization:")
);
let request = requests.lock().unwrap().clone().unwrap();
assert!(request.starts_with("GET /http-connect HTTP/1.1\r\n"));
assert!(request.contains("\r\nhost: localhost:"));
}
#[test]
fn http_connect_proxy_tunnels_https_with_auth() {
let (base, target_addr, requests, https_handle) =
block_on(spawn_https_request_server()).unwrap();
let (proxy, connect_request, proxy_handle) =
spawn_http_connect_proxy_server(target_addr, Some("Basic dXNlcjpwYXNz")).unwrap();
let client = Client::builder()
.proxy(Proxy::http_with_auth(proxy.addr(), "user", "pass"))
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get(format!("{base}/http-connect-auth"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
block_on(client.close()).unwrap();
proxy_handle.join().unwrap();
https_handle.join().unwrap();
let connect_request = connect_request.lock().unwrap().clone().unwrap();
let connect_request = connect_request.to_ascii_lowercase();
assert!(connect_request.contains("\r\nproxy-authorization: basic dxnlcjpwyxnz\r\n"));
let request = requests.lock().unwrap().clone().unwrap();
assert!(request.starts_with("GET /http-connect-auth HTTP/1.1\r\n"));
}
#[test]
fn socks5_proxy_tunnels_https_without_auth() {
let (base, target_addr, requests, https_handle) =
block_on(spawn_https_request_server()).unwrap();
let (proxy, proxy_handle) = spawn_socks5_proxy_server(target_addr, None).unwrap();
let client = Client::builder()
.proxy(proxy)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get(format!("{base}/socks5"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
block_on(client.close()).unwrap();
proxy_handle.join().unwrap();
https_handle.join().unwrap();
let request = requests.lock().unwrap().clone().unwrap();
assert!(request.starts_with("GET /socks5 HTTP/1.1\r\n"));
assert!(request.contains("\r\nhost: localhost:"));
}
#[test]
fn socks5_proxy_tunnels_https_with_auth() {
let (base, target_addr, requests, https_handle) =
block_on(spawn_https_request_server()).unwrap();
let (proxy, proxy_handle) =
spawn_socks5_proxy_server(target_addr, Some(("user", "pass"))).unwrap();
let client = Client::builder()
.proxy(proxy)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client.get(format!("{base}/socks5-auth"))).unwrap();
assert_eq!(block_on(response.text()).unwrap(), "ok");
block_on(client.close()).unwrap();
proxy_handle.join().unwrap();
https_handle.join().unwrap();
let request = requests.lock().unwrap().clone().unwrap();
assert!(request.starts_with("GET /socks5-auth HTTP/1.1\r\n"));
}
#[cfg(feature = "h2")]
#[test]
fn http_connect_proxy_tunnels_https_over_http2() {
let (base, target_addr, request, https_handle) = spawn_https_h2_request_server().unwrap();
let (proxy, connect_request, proxy_handle) =
spawn_http_connect_proxy_server(target_addr, None).unwrap();
let client = Client::builder()
.proxy(proxy)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client
.get(format!("{base}/h2-http-connect"))
.http2_only()
.header("x-through-proxy", "1")
.unwrap())
.unwrap();
assert_eq!(response.version(), ugi::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "ok");
block_on(client.close()).unwrap();
proxy_handle.join().unwrap();
https_handle.join().unwrap();
let connect_request = connect_request.lock().unwrap().clone().unwrap();
assert!(connect_request.starts_with("CONNECT localhost:"));
let request = request.lock().unwrap().clone().unwrap();
assert_eq!(request.method, "GET");
assert_eq!(request.path, "/h2-http-connect");
assert_eq!(request.version, "HTTP/2.0");
assert_eq!(request.header_x_through_proxy.as_deref(), Some("1"));
}
#[cfg(feature = "h2")]
#[test]
fn socks5_proxy_tunnels_https_over_http2() {
let (base, target_addr, request, https_handle) = spawn_https_h2_request_server().unwrap();
let (proxy, proxy_handle) = spawn_socks5_proxy_server(target_addr, None).unwrap();
let client = Client::builder()
.proxy(proxy)
.danger_accept_invalid_certs(true)
.build()
.unwrap();
let response = run(client
.get(format!("{base}/h2-socks5"))
.http2_only()
.header("x-through-proxy", "1")
.unwrap())
.unwrap();
assert_eq!(response.version(), ugi::Version::Http2);
assert_eq!(block_on(response.text()).unwrap(), "ok");
block_on(client.close()).unwrap();
proxy_handle.join().unwrap();
https_handle.join().unwrap();
let request = request.lock().unwrap().clone().unwrap();
assert_eq!(request.method, "GET");
assert_eq!(request.path, "/h2-socks5");
assert_eq!(request.version, "HTTP/2.0");
assert_eq!(request.header_x_through_proxy.as_deref(), Some("1"));
}
async fn spawn_https_request_server() -> ugi::Result<(
String,
SocketAddr,
Arc<Mutex<Option<String>>>,
thread::JoinHandle<()>,
)> {
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to generate tls certificate",
err,
)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to build tls server config",
err,
)
})?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to bind https test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to inspect https test server",
err,
)
})?;
let requests = Arc::new(Mutex::new(None));
let task_requests = Arc::clone(&requests);
let handle = thread::spawn(move || {
block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let request = read_http_request_head(&mut stream).await.unwrap();
*task_requests.lock().unwrap() = Some(request);
stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok")
.await
.unwrap();
stream.flush().await.unwrap();
});
});
Ok((
format!("https://localhost:{}", addr.port()),
addr,
requests,
handle,
))
}
#[cfg(feature = "h2")]
#[derive(Clone, Debug, Eq, PartialEq)]
struct CapturedH2Request {
method: String,
path: String,
version: String,
header_x_through_proxy: Option<String>,
}
#[cfg(feature = "h2")]
fn spawn_https_h2_request_server() -> ugi::Result<(
String,
SocketAddr,
Arc<Mutex<Option<CapturedH2Request>>>,
thread::JoinHandle<()>,
)> {
TOKIO_RUSTLS_PROVIDER.call_once(|| {
let _ = aws_lc_rs::default_provider().install_default();
});
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to generate tls certificate",
err,
)
})?;
let cert_der = TokioCertificateDer::from(cert.cert.der().to_vec());
let key_der =
TokioPrivateKeyDer::from(PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der()));
let mut server_config = TokioRustlsServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to build h2 tls server config",
err,
)
})?;
server_config.alpn_protocols = vec![b"h2".to_vec()];
let acceptor = TokioTlsAcceptor::from(Arc::new(server_config));
let request = Arc::new(Mutex::new(None));
let task_request = Arc::clone(&request);
let (addr_tx, addr_rx) = std::sync::mpsc::sync_channel(1);
let handle = thread::spawn(move || {
let runtime = TokioRuntimeBuilder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
let listener = tokio::net::TcpListener::bind(("127.0.0.1", 0))
.await
.unwrap();
let addr = listener.local_addr().unwrap();
addr_tx.send(addr).unwrap();
let (stream, _) = listener.accept().await.unwrap();
let tls = acceptor.accept(stream).await.unwrap();
let mut connection = server::handshake(tls).await.unwrap();
if let Some(result) = connection.accept().await {
let (request_head, mut respond) = result.unwrap();
let captured = CapturedH2Request {
method: request_head.method().as_str().to_owned(),
path: request_head.uri().path().to_owned(),
version: format!("{:?}", request_head.version()),
header_x_through_proxy: request_head
.headers()
.get("x-through-proxy")
.and_then(|value| value.to_str().ok())
.map(str::to_owned),
};
*task_request.lock().unwrap() = Some(captured);
let response = HttpResponse::builder()
.status(200)
.header("content-length", "2")
.body(())
.unwrap();
let mut send = respond.send_response(response, false).unwrap();
send.send_data(bytes::Bytes::from_static(b"ok"), true)
.unwrap();
let _ =
tokio::time::timeout(std::time::Duration::from_millis(50), connection.accept())
.await;
}
});
});
let addr = addr_rx.recv().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to receive h2 test server address",
err,
)
})?;
Ok((
format!("https://localhost:{}", addr.port()),
addr,
request,
handle,
))
}
fn spawn_http_connect_proxy_server(
target_addr: SocketAddr,
expected_proxy_auth: Option<&'static str>,
) -> ugi::Result<(Proxy, Arc<Mutex<Option<String>>>, thread::JoinHandle<()>)> {
let listener = StdTcpListener::bind(("127.0.0.1", 0)).map_err(|err| {
ugi::Error::with_source(ugi::ErrorKind::Transport, "failed to bind http proxy", err)
})?;
let addr = listener.local_addr().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to inspect http proxy",
err,
)
})?;
let captured = Arc::new(Mutex::new(None));
let task_captured = Arc::clone(&captured);
let handle = thread::spawn(move || {
let (mut upstream, _) = listener.accept().unwrap();
let connect_request = read_http_request_head_blocking(&mut upstream).unwrap();
assert!(connect_request.starts_with(&format!(
"CONNECT localhost:{} HTTP/1.1\r\n",
target_addr.port()
)));
if let Some(expected_proxy_auth) = expected_proxy_auth {
assert!(connect_request.to_ascii_lowercase().contains(&format!(
"\r\nproxy-authorization: {}\r\n",
expected_proxy_auth.to_ascii_lowercase()
)));
}
*task_captured.lock().unwrap() = Some(connect_request);
upstream
.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n")
.unwrap();
upstream.flush().unwrap();
let downstream = StdTcpStream::connect(target_addr).unwrap();
relay_bidirectional(upstream, downstream).unwrap();
});
Ok((Proxy::http(addr), captured, handle))
}
fn spawn_socks5_proxy_server(
target_addr: SocketAddr,
expected_auth: Option<(&'static str, &'static str)>,
) -> ugi::Result<(Proxy, thread::JoinHandle<()>)> {
let listener = StdTcpListener::bind(("127.0.0.1", 0)).map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to bind socks5 proxy",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to inspect socks5 proxy",
err,
)
})?;
let handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut greeting = [0_u8; 2];
stream.read_exact(&mut greeting).unwrap();
assert_eq!(greeting[0], 0x05);
let mut methods = vec![0_u8; greeting[1] as usize];
stream.read_exact(&mut methods).unwrap();
match expected_auth {
Some((username, password)) => {
assert!(methods.contains(&0x02));
stream.write_all(&[0x05, 0x02]).unwrap();
let mut auth_header = [0_u8; 2];
stream.read_exact(&mut auth_header).unwrap();
assert_eq!(auth_header[0], 0x01);
let mut username_buf = vec![0_u8; auth_header[1] as usize];
stream.read_exact(&mut username_buf).unwrap();
let mut password_len = [0_u8; 1];
stream.read_exact(&mut password_len).unwrap();
let mut password_buf = vec![0_u8; password_len[0] as usize];
stream.read_exact(&mut password_buf).unwrap();
assert_eq!(String::from_utf8_lossy(&username_buf), username);
assert_eq!(String::from_utf8_lossy(&password_buf), password);
stream.write_all(&[0x01, 0x00]).unwrap();
}
None => {
assert!(methods.contains(&0x00));
stream.write_all(&[0x05, 0x00]).unwrap();
}
}
let mut request_header = [0_u8; 4];
stream.read_exact(&mut request_header).unwrap();
assert_eq!(request_header, [0x05, 0x01, 0x00, 0x03]);
let mut domain_len = [0_u8; 1];
stream.read_exact(&mut domain_len).unwrap();
let mut domain = vec![0_u8; domain_len[0] as usize];
stream.read_exact(&mut domain).unwrap();
let mut port = [0_u8; 2];
stream.read_exact(&mut port).unwrap();
assert_eq!(String::from_utf8_lossy(&domain), "localhost");
assert_eq!(u16::from_be_bytes(port), target_addr.port());
let mut reply = vec![0x05, 0x00, 0x00, 0x01];
reply.extend_from_slice(&[127, 0, 0, 1]);
reply.extend_from_slice(&target_addr.port().to_be_bytes());
stream.write_all(&reply).unwrap();
stream.flush().unwrap();
let downstream = StdTcpStream::connect(target_addr).unwrap();
relay_bidirectional(stream, downstream).unwrap();
});
let proxy = match expected_auth {
Some((username, password)) => Proxy::socks5_with_auth(addr, username, password),
None => Proxy::socks5(addr),
};
Ok((proxy, handle))
}
async fn read_http_request_head<S>(stream: &mut S) -> ugi::Result<String>
where
S: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
loop {
let mut scratch = [0_u8; 1024];
let read = stream.read(&mut scratch).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to read https request head",
err,
)
})?;
if read == 0 {
return Err(ugi::Error::new(
ugi::ErrorKind::Transport,
"unexpected eof while reading https request head",
));
}
buffer.extend_from_slice(&scratch[..read]);
if buffer.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
Ok(String::from_utf8_lossy(&buffer).to_string())
}
fn read_http_request_head_blocking(stream: &mut StdTcpStream) -> std::io::Result<String> {
let mut buffer = Vec::new();
loop {
let mut scratch = [0_u8; 1024];
let read = stream.read(&mut scratch)?;
if read == 0 {
break;
}
buffer.extend_from_slice(&scratch[..read]);
if buffer.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
Ok(String::from_utf8_lossy(&buffer).to_string())
}
fn relay_bidirectional(
mut upstream: StdTcpStream,
mut downstream: StdTcpStream,
) -> std::io::Result<()> {
upstream.set_nodelay(true)?;
downstream.set_nodelay(true)?;
let mut upstream_reader = upstream.try_clone()?;
let mut downstream_writer = downstream.try_clone()?;
let relay = thread::spawn(move || -> std::io::Result<()> {
if let Err(err) = std::io::copy(&mut upstream_reader, &mut downstream_writer) {
if !is_tunnel_close_error(&err) {
return Err(err);
}
}
let _ = downstream_writer.shutdown(Shutdown::Write);
Ok(())
});
if let Err(err) = std::io::copy(&mut downstream, &mut upstream) {
if !is_tunnel_close_error(&err) {
return Err(err);
}
}
let _ = upstream.shutdown(Shutdown::Write);
relay.join().unwrap()?;
Ok(())
}
fn is_tunnel_close_error(err: &std::io::Error) -> bool {
matches!(
err.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::UnexpectedEof
| std::io::ErrorKind::NotConnected
)
}