#![cfg(feature = "default-client")]
use aws_smithy_async::time::SystemTimeSource;
use aws_smithy_http_client::{proxy::ProxyConfig, tls, Connector};
use aws_smithy_runtime_api::client::http::{
http_client_fn, HttpClient, HttpConnector, HttpConnectorSettings, SharedHttpConnector,
};
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use base64::Engine;
use http_1x::{Request, Response, StatusCode};
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
#[derive(Debug)]
struct MockProxyServer {
conn_count: Arc<()>,
addr: SocketAddr,
shutdown_tx: Option<oneshot::Sender<()>>,
request_log: Arc<Mutex<Vec<RecordedRequest>>>,
}
#[derive(Debug, Clone)]
struct RecordedRequest {
method: String,
uri: String,
headers: HashMap<String, String>,
}
impl MockProxyServer {
async fn new<F>(handler: F) -> Self
where
F: Fn(RecordedRequest) -> Response<String> + Send + Sync + 'static,
{
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let request_log = Arc::new(Mutex::new(Vec::new()));
let request_log_clone = request_log.clone();
let conn_count = Arc::new(());
let server_conn_count = conn_count.clone();
let handler = Arc::new(handler);
tokio::spawn(async move {
let mut shutdown_rx = shutdown_rx;
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, _)) => {
let io = TokioIo::new(stream);
let handler = handler.clone();
let request_log = request_log_clone.clone();
let stream_conn_count = server_conn_count.clone();
tokio::spawn(async move {
let _stream_conn_count = stream_conn_count;
let service = service_fn(move |req: Request<Incoming>| {
let handler = handler.clone();
let request_log = request_log.clone();
async move {
let recorded = RecordedRequest {
method: req.method().to_string(),
uri: req.uri().to_string(),
headers: req.headers().iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect(),
};
request_log.lock().unwrap().push(recorded.clone());
let response = handler(recorded);
let (parts, body) = response.into_parts();
let hyper_response = Response::from_parts(parts, body);
Ok::<_, Infallible>(hyper_response)
}
});
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.await
{
eprintln!("Mock proxy server connection error: {}", err);
}
});
}
Err(_) => break,
}
}
_ = &mut shutdown_rx => {
break;
}
}
}
});
Self {
addr,
shutdown_tx: Some(shutdown_tx),
request_log,
conn_count,
}
}
fn conn_count(&self) -> usize {
Arc::strong_count(&self.conn_count)
.checked_sub(2)
.expect("de-count 2 refs")
}
async fn with_response(status: StatusCode, body: &str) -> Self {
let body = body.to_string();
Self::new(move |_req| {
Response::builder()
.status(status)
.body(body.clone())
.unwrap()
})
.await
}
async fn with_auth_validation(expected_user: &str, expected_pass: &str) -> Self {
let expected_auth = format!(
"Basic {}",
base64::prelude::BASE64_STANDARD.encode(format!("{}:{}", expected_user, expected_pass))
);
Self::new(move |req| {
if let Some(auth_header) = req.headers.get("proxy-authorization") {
if auth_header == &expected_auth {
Response::builder()
.status(StatusCode::OK)
.body("authenticated".to_string())
.unwrap()
} else {
Response::builder()
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
.body("invalid credentials".to_string())
.unwrap()
}
} else {
Response::builder()
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
.header("proxy-authenticate", "Basic realm=\"proxy\"")
.body("authentication required".to_string())
.unwrap()
}
})
.await
}
fn addr(&self) -> SocketAddr {
self.addr
}
fn requests(&self) -> Vec<RecordedRequest> {
self.request_log.lock().unwrap().clone()
}
}
impl Drop for MockProxyServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
#[allow(clippy::await_holding_lock)]
async fn with_env_vars<F, Fut, R>(vars: &[(&str, &str)], test: F) -> R
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = R>,
{
static ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
let _guard = ENV_MUTEX.lock().unwrap();
let original_vars: Vec<_> = vars
.iter()
.map(|(key, _)| (*key, std::env::var(key)))
.collect();
for (key, value) in vars {
std::env::set_var(key, value);
}
let result = test().await;
for (key, original_value) in original_vars {
match original_value {
Ok(val) => std::env::set_var(key, val),
Err(_) => std::env::remove_var(key),
}
}
result
}
async fn make_http_request_through_proxy(
proxy_config: ProxyConfig,
target_url: &str,
) -> Result<(StatusCode, String), Box<dyn std::error::Error + Send + Sync>> {
make_http_request_through_proxy_with_pool_timeout(
proxy_config,
Some(Duration::from_secs(90)),
target_url,
)
.await
.map(|(status, res, _client)| (status, res))
}
async fn make_http_request_through_proxy_with_pool_timeout(
proxy_config: ProxyConfig,
pool_idle_timeout: Option<Duration>,
target_url: &str,
) -> Result<(StatusCode, String, SharedHttpConnector), Box<dyn std::error::Error + Send + Sync>> {
let http_client = http_client_fn(move |settings, _components| {
let connector = Connector::builder()
.proxy_config(proxy_config.clone())
.pool_idle_timeout(pool_idle_timeout)
.connector_settings(settings.clone())
.build_http();
aws_smithy_runtime_api::client::http::SharedHttpConnector::new(connector)
});
let connector_settings = HttpConnectorSettings::builder().build();
let runtime_components = RuntimeComponentsBuilder::for_tests()
.with_time_source(Some(SystemTimeSource::new()))
.build()
.unwrap();
let http_connector = http_client.http_connector(&connector_settings, &runtime_components);
let request = HttpRequest::get(target_url)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let response = http_connector.call(request).await?;
let status = response.status();
let body_bytes = response.into_body().collect().await?.to_bytes();
let body_string = String::from_utf8(body_bytes.to_vec())?;
Ok((status.into(), body_string, http_connector))
}
#[tokio::test(start_paused = false)]
async fn test_http_proxy_connection_pool_timeout() {
const TIMEOUT: Duration = Duration::from_secs(10);
let mock_proxy = MockProxyServer::new(|req| {
assert_eq!(req.method, "GET");
assert_eq!(req.uri, "http://aws.amazon.com/api/data");
Response::builder()
.status(StatusCode::OK)
.body("proxied response from mock server".to_string())
.unwrap()
})
.await;
assert_eq!(mock_proxy.conn_count(), 0);
tracing::info!("Start!");
let proxy_config = ProxyConfig::http(format!("http://{}", mock_proxy.addr())).unwrap();
let target_url = "http://aws.amazon.com/api/data";
let start = tokio::time::Instant::now();
let result =
make_http_request_through_proxy_with_pool_timeout(proxy_config, Some(TIMEOUT), target_url)
.await;
let (status, body, _connector) = result.expect("HTTP request through proxy should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "proxied response from mock server");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].method, "GET");
assert_eq!(requests[0].uri, target_url);
assert_eq!(mock_proxy.conn_count(), 1);
tokio::time::sleep_until(start + TIMEOUT - Duration::from_secs(1)).await;
assert_eq!(mock_proxy.conn_count(), 1);
tokio::time::sleep(Duration::from_secs(3)).await;
assert_eq!(mock_proxy.conn_count(), 0);
}
#[tokio::test]
async fn test_http_proxy_basic_request() {
let mock_proxy = MockProxyServer::new(|req| {
assert_eq!(req.method, "GET");
assert_eq!(req.uri, "http://aws.amazon.com/api/data");
Response::builder()
.status(StatusCode::OK)
.body("proxied response from mock server".to_string())
.unwrap()
})
.await;
let proxy_config = ProxyConfig::http(format!("http://{}", mock_proxy.addr())).unwrap();
let target_url = "http://aws.amazon.com/api/data";
let result = make_http_request_through_proxy(proxy_config, target_url).await;
let (status, body) = result.expect("HTTP request through proxy should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "proxied response from mock server");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].method, "GET");
assert_eq!(requests[0].uri, target_url);
}
#[tokio::test]
async fn test_proxy_authentication() {
let mock_proxy = MockProxyServer::with_auth_validation("testuser", "testpass").await;
let proxy_config = ProxyConfig::http(format!("http://{}", mock_proxy.addr()))
.unwrap()
.with_basic_auth("testuser", "testpass");
let target_url = "http://aws.amazon.com/protected/resource";
let result = make_http_request_through_proxy(proxy_config, target_url).await;
let (status, body) = result.expect("Authenticated proxy request should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "authenticated");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
let expected_auth = format!(
"Basic {}",
base64::prelude::BASE64_STANDARD.encode("testuser:testpass")
);
assert_eq!(
requests[0].headers.get("proxy-authorization"),
Some(&expected_auth)
);
}
#[tokio::test]
async fn test_proxy_url_embedded_auth() {
let mock_proxy = MockProxyServer::with_auth_validation("urluser", "urlpass").await;
let proxy_url = format!("http://urluser:urlpass@{}", mock_proxy.addr());
let proxy_config = ProxyConfig::http(proxy_url).unwrap();
let target_url = "http://aws.amazon.com/api/test";
let result = make_http_request_through_proxy(proxy_config, target_url).await;
let (status, body) = result.expect("URL-embedded auth proxy request should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "authenticated");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
let expected_auth = format!(
"Basic {}",
base64::prelude::BASE64_STANDARD.encode("urluser:urlpass")
);
assert_eq!(
requests[0].headers.get("proxy-authorization"),
Some(&expected_auth)
);
}
#[tokio::test]
async fn test_proxy_auth_precedence() {
let mock_proxy = MockProxyServer::with_auth_validation("urluser", "urlpass").await;
let proxy_url = format!("http://urluser:urlpass@{}", mock_proxy.addr());
let proxy_config = ProxyConfig::http(proxy_url)
.unwrap()
.with_basic_auth("programmatic", "auth");
let target_url = "http://aws.amazon.com/precedence/test";
let result = make_http_request_through_proxy(proxy_config, target_url).await;
let (status, body) = result.expect("Auth precedence test should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "authenticated");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
let expected_auth = format!(
"Basic {}",
base64::prelude::BASE64_STANDARD.encode("urluser:urlpass")
);
assert_eq!(
requests[0].headers.get("proxy-authorization"),
Some(&expected_auth)
);
}
#[tokio::test]
async fn test_proxy_from_environment_variables() {
let mock_proxy = MockProxyServer::with_response(StatusCode::OK, "env proxy response").await;
with_env_vars(
&[
("HTTP_PROXY", &format!("http://{}", mock_proxy.addr())),
("NO_PROXY", "localhost,127.0.0.1"),
],
|| async {
let proxy_config = ProxyConfig::from_env();
let target_url = "http://aws.amazon.com/v1/data";
let result = make_http_request_through_proxy(proxy_config, target_url).await;
let (status, body) = result.expect("Environment proxy request should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "env proxy response");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].uri, target_url);
},
)
.await;
}
#[tokio::test]
async fn test_no_proxy_bypass_rules() {
let mock_proxy = MockProxyServer::with_response(StatusCode::OK, "should not reach here").await;
let direct_server = MockProxyServer::with_response(StatusCode::OK, "direct connection").await;
let direct_ip = "127.0.0.1";
let proxy_config = ProxyConfig::http(format!("http://{}", mock_proxy.addr()))
.unwrap()
.no_proxy(direct_ip);
let result = make_http_request_through_proxy(
proxy_config,
&format!("http://{}/test", direct_server.addr()),
)
.await;
let (status, body) = result.expect("Direct connection should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "direct connection");
let proxy_requests = mock_proxy.requests();
assert_eq!(
proxy_requests.len(),
0,
"Proxy should not have received any requests due to NO_PROXY bypass"
);
let direct_requests = direct_server.requests();
assert_eq!(
direct_requests.len(),
1,
"Direct server should have received the request"
);
}
#[tokio::test]
async fn test_proxy_disabled() {
let direct_server = MockProxyServer::with_response(StatusCode::OK, "direct connection").await;
let proxy_config = ProxyConfig::disabled();
let result = make_http_request_through_proxy(
proxy_config,
&format!("http://{}/get", direct_server.addr()),
)
.await;
let (status, body) = result.expect("Direct connection should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "direct connection");
let requests = direct_server.requests();
assert_eq!(
requests.len(),
1,
"Direct server should have received the request"
);
assert_eq!(requests[0].method, "GET");
assert!(
requests[0].uri == format!("http://{}/get", direct_server.addr())
|| requests[0].uri == "/get",
"URI should be either full URL or path, got: {}",
requests[0].uri
);
}
#[tokio::test]
async fn test_https_proxy_configuration() {
let mock_proxy = MockProxyServer::with_response(StatusCode::OK, "https proxy response").await;
let direct_server =
MockProxyServer::with_response(StatusCode::OK, "direct http connection").await;
let proxy_config = ProxyConfig::https(format!("http://{}", mock_proxy.addr())).unwrap();
let target_url = format!("http://{}/api", direct_server.addr());
let result = make_http_request_through_proxy(proxy_config.clone(), &target_url).await;
let (status, body) = result.expect("HTTP request should succeed via direct connection");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "direct http connection");
let proxy_requests = mock_proxy.requests();
assert_eq!(
proxy_requests.len(),
0,
"HTTP request should not go through HTTPS-only proxy"
);
let direct_requests = direct_server.requests();
assert_eq!(
direct_requests.len(),
1,
"Direct server should have received the HTTP request"
);
}
#[tokio::test]
async fn test_all_traffic_proxy() {
let mock_proxy = MockProxyServer::with_response(StatusCode::OK, "all traffic proxy").await;
let proxy_config = ProxyConfig::all(format!("http://{}", mock_proxy.addr())).unwrap();
let target_url = "http://aws.amazon.com/api/endpoint";
let result = make_http_request_through_proxy(proxy_config.clone(), target_url).await;
let (status, body) = result.expect("HTTP request through all-traffic proxy should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "all traffic proxy");
let requests = mock_proxy.requests();
assert_eq!(
requests.len(),
1,
"Proxy should have received exactly one request"
);
assert_eq!(requests[0].method, "GET");
assert_eq!(requests[0].uri, target_url);
}
#[tokio::test]
async fn test_proxy_connection_failure() {
let proxy_config = ProxyConfig::http("http://127.0.0.1:1").unwrap();
let target_url = "http://aws.amazon.com/api/test";
let result = make_http_request_through_proxy(proxy_config, target_url).await;
assert!(
result.is_err(),
"Request should fail when proxy is unreachable"
);
let error = result.unwrap_err();
let error_msg = error.to_string().to_lowercase();
assert!(
error_msg.contains("connection")
|| error_msg.contains("refused")
|| error_msg.contains("unreachable")
|| error_msg.contains("timeout")
|| error_msg.contains("connect")
|| error_msg.contains("io error"), "Error should be connection-related, got: {}",
error
);
}
#[tokio::test]
async fn test_proxy_authentication_failure() {
let mock_proxy = MockProxyServer::with_auth_validation("correct", "password").await;
let proxy_config = ProxyConfig::http(format!("http://{}", mock_proxy.addr()))
.unwrap()
.with_basic_auth("wrong", "credentials");
let target_url = "http://aws.amazon.com/secure/api";
let result = make_http_request_through_proxy(proxy_config, target_url).await;
let (status, _body) = result.expect("Request should complete (even with auth failure)");
assert_eq!(status, StatusCode::PROXY_AUTHENTICATION_REQUIRED);
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1, "Proxy should have received the request");
let expected_wrong_auth = format!(
"Basic {}",
base64::prelude::BASE64_STANDARD.encode("wrong:credentials")
);
assert_eq!(
requests[0].headers.get("proxy-authorization"),
Some(&expected_wrong_auth)
);
}
#[tokio::test]
async fn test_explicit_proxy_disable_overrides_environment() {
let mock_proxy = MockProxyServer::new(|_req| {
panic!("Request should not reach proxy when explicitly disabled");
})
.await;
let direct_server = MockProxyServer::with_response(StatusCode::OK, "direct connection").await;
with_env_vars(
&[("HTTP_PROXY", &format!("http://{}", mock_proxy.addr()))],
|| async {
let proxy_config = ProxyConfig::disabled();
let target_url = format!("http://{}/test", direct_server.addr());
let result = make_http_request_through_proxy(proxy_config, &target_url).await;
let (status, body) = result.expect("Direct connection should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "direct connection");
let proxy_requests = mock_proxy.requests();
assert_eq!(
proxy_requests.len(),
0,
"Proxy should not receive requests when explicitly disabled"
);
let direct_requests = direct_server.requests();
assert_eq!(
direct_requests.len(),
1,
"Direct server should have received the request"
);
},
)
.await;
}
async fn make_https_request_through_proxy(
proxy_config: ProxyConfig,
target_url: &str,
tls_provider: tls::Provider,
) -> Result<(StatusCode, String), Box<dyn std::error::Error + Send + Sync>> {
let http_client = http_client_fn(move |settings, _components| {
let connector = Connector::builder()
.proxy_config(proxy_config.clone())
.connector_settings(settings.clone())
.tls_provider(tls_provider.clone())
.build();
aws_smithy_runtime_api::client::http::SharedHttpConnector::new(connector)
});
let connector_settings = HttpConnectorSettings::builder().build();
let runtime_components = RuntimeComponentsBuilder::for_tests()
.with_time_source(Some(SystemTimeSource::new()))
.build()
.unwrap();
let http_connector = http_client.http_connector(&connector_settings, &runtime_components);
let request = HttpRequest::get(target_url)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
let response = http_connector.call(request).await?;
let status = response.status();
let body_bytes = response.into_body().collect().await?.to_bytes();
let body_string = String::from_utf8(body_bytes.to_vec())?;
Ok((status.into(), body_string))
}
async fn run_https_connect_with_auth_test(tls_provider: tls::Provider, provider_name: &str) {
let mock_proxy = MockProxyServer::new(|req| {
assert_eq!(req.method, "CONNECT");
assert_eq!(req.uri, "secure.aws.amazon.com:443");
let expected_auth = format!(
"Basic {}",
base64::prelude::BASE64_STANDARD.encode("connectuser:connectpass")
);
assert_eq!(req.headers.get("proxy-authorization"), Some(&expected_auth));
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("CONNECT tunnel setup failed".to_string())
.unwrap()
})
.await;
let proxy_config = ProxyConfig::all(format!("http://{}", mock_proxy.addr()))
.unwrap()
.with_basic_auth("connectuser", "connectpass");
let target_url = "https://secure.aws.amazon.com/api/secure";
let result = make_https_request_through_proxy(proxy_config, target_url, tls_provider).await;
assert!(
result.is_err(),
"CONNECT tunnel should fail with 400 response for {}",
provider_name
);
let requests = mock_proxy.requests();
assert_eq!(
requests.len(),
1,
"Proxy should have received exactly one CONNECT request for {}",
provider_name
);
}
async fn run_https_connect_auth_required_test(tls_provider: tls::Provider, provider_name: &str) {
let mock_proxy = MockProxyServer::new(|req| {
assert_eq!(req.method, "CONNECT");
assert_eq!(req.uri, "secure.aws.amazon.com:443");
assert!(!req.headers.contains_key("proxy-authorization"));
Response::builder()
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
.body("Proxy authentication required for CONNECT".to_string())
.unwrap()
})
.await;
let proxy_config = ProxyConfig::all(format!("http://{}", mock_proxy.addr())).unwrap();
let target_url = "https://secure.aws.amazon.com/api/secure";
let result = make_https_request_through_proxy(proxy_config, target_url, tls_provider).await;
assert!(
result.is_err(),
"CONNECT tunnel should fail with 407 response for {}",
provider_name
);
let error_msg = result.unwrap_err().to_string();
let error_msg_lower = error_msg.to_lowercase();
assert!(
error_msg_lower.contains("407")
|| error_msg_lower.contains("proxy")
|| error_msg_lower.contains("auth")
|| error_msg_lower.contains("io error")
|| error_msg_lower.contains("connection"),
"Error should be connection-related (indicating CONNECT was attempted) for {}, got: {}",
provider_name,
error_msg
);
let requests = mock_proxy.requests();
assert_eq!(
requests.len(),
1,
"Proxy should have received exactly one CONNECT request for {}",
provider_name
);
}
#[cfg(feature = "rustls-ring")]
#[tokio::test]
async fn test_https_connect_with_auth_rustls() {
run_https_connect_with_auth_test(
tls::Provider::rustls(tls::rustls_provider::CryptoMode::Ring),
"rustls",
)
.await;
}
#[cfg(feature = "rustls-ring")]
#[tokio::test]
async fn test_https_connect_auth_required_rustls() {
run_https_connect_auth_required_test(
tls::Provider::rustls(tls::rustls_provider::CryptoMode::Ring),
"rustls",
)
.await;
}
#[cfg(feature = "s2n-tls")]
#[tokio::test]
async fn test_https_connect_with_auth_s2n_tls() {
run_https_connect_with_auth_test(tls::Provider::S2nTls, "s2n-tls").await;
}
#[cfg(feature = "s2n-tls")]
#[tokio::test]
async fn test_https_connect_auth_required_s2n_tls() {
run_https_connect_auth_required_test(tls::Provider::S2nTls, "s2n-tls").await;
}
#[tokio::test]
async fn test_http_proxy_absolute_uri_form() {
let target_host = "api.example.com";
let target_path = "/v1/data";
let expected_absolute_uri = format!("http://{}{}", target_host, target_path);
let expected_uri_clone = expected_absolute_uri.clone();
let target_host_clone = target_host.to_string();
let mock_proxy = MockProxyServer::new(move |req| {
assert_eq!(req.method, "GET");
assert_eq!(req.uri, expected_uri_clone);
assert_eq!(req.headers.get("host"), Some(&target_host_clone));
Response::builder()
.status(StatusCode::OK)
.body("proxied response".to_string())
.unwrap()
})
.await;
let proxy_config = ProxyConfig::http(format!("http://{}", mock_proxy.addr())).unwrap();
let result = make_http_request_through_proxy(proxy_config, &expected_absolute_uri).await;
let (status, body) = result.expect("HTTP request through proxy should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "proxied response");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].uri, expected_absolute_uri);
}
#[tokio::test]
async fn test_direct_http_origin_uri_form() {
let target_path = "/v1/data";
let direct_server = MockProxyServer::new(move |req| {
assert_eq!(req.method, "GET");
assert!(
req.uri == target_path || req.uri.ends_with(target_path),
"Expected origin form URI ending with '{}', got '{}'",
target_path,
req.uri
);
Response::builder()
.status(StatusCode::OK)
.body("direct response".to_string())
.unwrap()
})
.await;
let proxy_config = ProxyConfig::disabled();
let target_url = format!("http://{}{}", direct_server.addr(), target_path);
let result = make_http_request_through_proxy(proxy_config, &target_url).await;
let (status, body) = result.expect("Direct HTTP request should succeed");
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "direct response");
let requests = direct_server.requests();
assert_eq!(requests.len(), 1);
}
#[tokio::test]
async fn test_uri_form_proxy_vs_direct() {
let target_host = "test.example.com";
let target_path = "/api/test";
let full_url = format!("http://{}{}", target_host, target_path);
{
let target_host_clone = target_host.to_string();
let target_path_clone = target_path.to_string();
let mock_proxy = MockProxyServer::new(move |req| {
assert!(req.uri.starts_with("http://"));
assert!(req.uri.contains(&target_host_clone));
assert!(req.uri.contains(&target_path_clone));
Response::builder()
.status(StatusCode::OK)
.body("proxy response".to_string())
.unwrap()
})
.await;
let proxy_config = ProxyConfig::http(format!("http://{}", mock_proxy.addr())).unwrap();
let result = make_http_request_through_proxy(proxy_config, &full_url).await;
assert!(result.is_ok(), "Proxy request should succeed");
let requests = mock_proxy.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].uri, full_url);
}
{
let target_path_clone = target_path.to_string();
let direct_server = MockProxyServer::new(move |req| {
assert!(!req.uri.starts_with("http://"));
assert!(req.uri == target_path_clone || req.uri.ends_with(&target_path_clone));
Response::builder()
.status(StatusCode::OK)
.body("direct response".to_string())
.unwrap()
})
.await;
let proxy_config = ProxyConfig::disabled();
let direct_url = format!("http://{}{}", direct_server.addr(), target_path);
let result = make_http_request_through_proxy(proxy_config, &direct_url).await;
assert!(result.is_ok(), "Direct request should succeed");
let requests = direct_server.requests();
assert_eq!(requests.len(), 1);
}
}
async fn run_connect_uri_form_test(tls_provider: tls::Provider, provider_name: &str) {
let target_host = "secure.example.com";
let target_port = 443;
let expected_connect_uri = format!("{}:{}", target_host, target_port);
let expected_uri_clone = expected_connect_uri.clone();
let mock_proxy = MockProxyServer::new(move |req| {
if req.method == "CONNECT" {
assert_eq!(req.uri, expected_uri_clone);
Response::builder()
.status(StatusCode::OK)
.body("Connection established".to_string())
.unwrap()
} else {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("Unexpected non-CONNECT request".to_string())
.unwrap()
}
})
.await;
let proxy_config = ProxyConfig::all(format!("http://{}", mock_proxy.addr())).unwrap();
let target_url = format!("https://{}/api/secure", target_host);
let _result = make_https_request_through_proxy(proxy_config, &target_url, tls_provider).await;
let requests = mock_proxy.requests();
assert_eq!(
requests.len(),
1,
"Should have received exactly one CONNECT request for {}",
provider_name
);
assert_eq!(requests[0].method, "CONNECT");
assert_eq!(requests[0].uri, expected_connect_uri);
}
#[cfg(feature = "rustls-ring")]
#[tokio::test]
async fn test_connect_uri_form_rustls() {
run_connect_uri_form_test(
tls::Provider::rustls(tls::rustls_provider::CryptoMode::Ring),
"rustls",
)
.await;
}
#[cfg(feature = "s2n-tls")]
#[tokio::test]
async fn test_connect_uri_form_s2n_tls() {
run_connect_uri_form_test(tls::Provider::S2nTls, "s2n-tls").await;
}