use fetchkit::{FetchError, FetchRequest, Tool};
use std::env;
use std::sync::{Mutex, OnceLock};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::time::timeout;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn proxy_env_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
struct ProxyEnvGuard {
http_proxy: Option<String>,
https_proxy: Option<String>,
no_proxy: Option<String>,
}
impl ProxyEnvGuard {
fn set(proxy_url: &str) -> Self {
let guard = Self {
http_proxy: env::var("HTTP_PROXY").ok(),
https_proxy: env::var("HTTPS_PROXY").ok(),
no_proxy: env::var("NO_PROXY").ok(),
};
env::set_var("HTTP_PROXY", proxy_url);
env::set_var("HTTPS_PROXY", proxy_url);
env::remove_var("NO_PROXY");
guard
}
}
impl Drop for ProxyEnvGuard {
fn drop(&mut self) {
restore_env_var("HTTP_PROXY", self.http_proxy.as_deref());
restore_env_var("HTTPS_PROXY", self.https_proxy.as_deref());
restore_env_var("NO_PROXY", self.no_proxy.as_deref());
}
}
fn restore_env_var(key: &str, value: Option<&str>) {
if let Some(value) = value {
env::set_var(key, value);
} else {
env::remove_var(key);
}
}
async fn spawn_test_proxy() -> (String, oneshot::Receiver<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
if let Ok(Ok((mut stream, _))) = timeout(Duration::from_secs(2), listener.accept()).await {
let mut buf = [0_u8; 1024];
let _ = stream.read(&mut buf).await;
let _ = stream
.write_all(b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\n\r\n")
.await;
let _ = tx.send(());
}
});
(format!("http://{}", addr), rx)
}
#[tokio::test]
async fn test_ssrf_001_loopback_ipv4_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://127.0.0.1/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_001_loopback_ipv4_alt_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://127.0.0.2/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_001_private_10_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://10.0.0.1/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_001_private_172_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://172.16.0.1/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_001_private_192_168_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://192.168.1.1/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_002_localhost_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://localhost/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_003_cloud_metadata_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://169.254.169.254/latest/meta-data/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_003_link_local_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://169.254.0.1/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_006_ipv6_loopback_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://[::1]/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_input_001_file_scheme_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("file:///etc/passwd");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::InvalidUrlScheme)));
}
#[tokio::test]
async fn test_input_001_ftp_scheme_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("ftp://internal-server/files");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::InvalidUrlScheme)));
}
#[tokio::test]
async fn test_input_001_data_scheme_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("data:text/html,<h1>XSS</h1>");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::InvalidUrlScheme)));
}
#[tokio::test]
async fn test_input_001_gopher_scheme_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("gopher://internal:70/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::InvalidUrlScheme)));
}
#[tokio::test]
async fn test_default_blocks_loopback_mock_server() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("Hello from loopback")
.insert_header("content-type", "text/plain"),
)
.mount(&mock_server)
.await;
let tool = Tool::default();
let req = FetchRequest::new(format!("{}/", mock_server.uri()));
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_explicit_opt_out_allows_loopback() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("Hello")
.insert_header("content-type", "text/plain"),
)
.mount(&mock_server)
.await;
let tool = Tool::builder().block_private_ips(false).build();
let req = FetchRequest::new(format!("{}/", mock_server.uri()));
let result = tool.execute(req).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status_code, 200);
}
#[tokio::test]
async fn test_prefix_block_and_dns_policy_combined() {
let tool = Tool::builder()
.block_prefix("https://blocked.example.com")
.build();
let req = FetchRequest::new("https://blocked.example.com/secret");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
let req = FetchRequest::new("http://10.0.0.1/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_004_zero_ip_blocked() {
let tool = Tool::default();
let req = FetchRequest::new("http://0.0.0.0/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_conv_001_script_stripped_in_markdown() {
let mock_server = MockServer::start().await;
let html = r#"<html><body>
<p>Hello</p>
<script>alert('xss')</script>
<p>World</p>
</body></html>"#;
Mock::given(method("GET"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).set_body_raw(html, "text/html"))
.mount(&mock_server)
.await;
let tool = Tool::builder().block_private_ips(false).build();
let req = FetchRequest::new(format!("{}/", mock_server.uri())).as_markdown();
let resp = tool.execute(req).await.unwrap();
let content = resp.content.unwrap();
assert!(!content.contains("alert"));
assert!(!content.contains("<script>"));
assert!(content.contains("Hello"));
assert!(content.contains("World"));
}
#[tokio::test]
async fn test_conv_001_script_stripped_in_text() {
let mock_server = MockServer::start().await;
let html = r#"<html><body>
<p>Safe content</p>
<script>document.cookie</script>
<style>.hidden{display:none}</style>
</body></html>"#;
Mock::given(method("GET"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200).set_body_raw(html, "text/html"))
.mount(&mock_server)
.await;
let tool = Tool::builder().block_private_ips(false).build();
let req = FetchRequest::new(format!("{}/", mock_server.uri())).as_text();
let resp = tool.execute(req).await.unwrap();
let content = resp.content.unwrap();
assert!(!content.contains("document.cookie"));
assert!(!content.contains("display:none"));
assert!(content.contains("Safe content"));
}
#[tokio::test]
async fn test_ssrf_010_redirect_to_loopback_blocked() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(
ResponseTemplate::new(302).insert_header("Location", "http://127.0.0.1:9999/secret"),
)
.mount(&mock_server)
.await;
let tool = Tool::default();
let req = FetchRequest::new(format!("{}/redirect", mock_server.uri()));
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_ssrf_010_redirect_to_private_ip_blocked() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(
ResponseTemplate::new(302).insert_header("Location", "http://10.0.0.1/internal-data"),
)
.mount(&mock_server)
.await;
let tool = Tool::builder().block_private_ips(false).build();
let req = FetchRequest::new(format!("{}/redirect", mock_server.uri()));
let result = tool.execute(req).await;
assert!(result.is_ok() || result.is_err());
}
#[tokio::test]
async fn test_ssrf_010_redirect_followed_when_safe() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/start"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("Location", format!("{}/final", mock_server.uri())),
)
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/final"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("Redirected content")
.insert_header("content-type", "text/plain"),
)
.mount(&mock_server)
.await;
let tool = Tool::builder().block_private_ips(false).build();
let req = FetchRequest::new(format!("{}/start", mock_server.uri()));
let resp = tool.execute(req).await.unwrap();
assert_eq!(resp.status_code, 200);
assert!(resp.content.unwrap().contains("Redirected content"));
}
#[tokio::test]
async fn test_ssrf_010_redirect_scheme_validation() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(ResponseTemplate::new(302).insert_header("Location", "file:///etc/passwd"))
.mount(&mock_server)
.await;
let tool = Tool::builder().block_private_ips(false).build();
let req = FetchRequest::new(format!("{}/redirect", mock_server.uri()));
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::InvalidUrlScheme)));
}
#[tokio::test]
async fn test_ssrf_010_same_host_redirect_policy_blocks_cross_host_redirect() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redirect"))
.respond_with(
ResponseTemplate::new(302).insert_header("Location", "https://other.example/final"),
)
.mount(&mock_server)
.await;
let tool = Tool::builder()
.block_private_ips(false)
.same_host_redirects_only(true)
.build();
let req = FetchRequest::new(format!("{}/redirect", mock_server.uri()));
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_net_004_env_proxy_ignored_by_default() {
let _lock = proxy_env_lock().lock().unwrap_or_else(|e| e.into_inner());
let (proxy_url, proxy_hit) = spawn_test_proxy().await;
let _env = ProxyEnvGuard::set(&proxy_url);
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("direct")
.insert_header("content-type", "text/plain"),
)
.mount(&mock_server)
.await;
let tool = Tool::builder().block_private_ips(false).build();
let req = FetchRequest::new(format!("{}/", mock_server.uri()));
let response = tool.execute(req).await.unwrap();
assert_eq!(response.status_code, 200);
assert_eq!(response.content.as_deref(), Some("direct"));
assert!(timeout(Duration::from_millis(300), proxy_hit)
.await
.is_err());
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_net_004_env_proxy_can_be_opted_in() {
let _lock = proxy_env_lock().lock().unwrap_or_else(|e| e.into_inner());
let (proxy_url, proxy_hit) = spawn_test_proxy().await;
let _env = ProxyEnvGuard::set(&proxy_url);
let tool = Tool::builder()
.block_private_ips(false)
.use_env_proxy(true)
.build();
let req = FetchRequest::new("http://proxy-test-target.invalid/");
let response = tool.execute(req).await.unwrap();
assert_eq!(response.status_code, 502);
assert!(timeout(Duration::from_secs(1), proxy_hit).await.is_ok());
}
#[tokio::test]
async fn test_hardened_profile_blocks_internal_hostname_suffixes() {
let tool = Tool::builder().hardened().build();
let req = FetchRequest::new("https://api.default.svc/status");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_hardened_profile_blocks_non_standard_ports() {
let tool = Tool::builder().hardened().build();
let req = FetchRequest::new("https://example.com:8443/");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}
#[tokio::test]
async fn test_dos_001_body_size_limit() {
let mock_server = MockServer::start().await;
let large_body = "x".repeat(2000);
Mock::given(method("GET"))
.and(path("/large"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(&large_body)
.insert_header("content-type", "text/plain"),
)
.mount(&mock_server)
.await;
let tool = Tool::builder()
.block_private_ips(false)
.max_body_size(1000)
.build();
let req = FetchRequest::new(format!("{}/large", mock_server.uri()));
let resp = tool.execute(req).await.unwrap();
assert_eq!(resp.truncated, Some(true));
assert!(resp.size.unwrap() <= 1000);
assert!(resp.content.unwrap().contains("[..content truncated...]"));
}
#[tokio::test]
async fn test_dos_001_body_within_limit_not_truncated() {
let mock_server = MockServer::start().await;
let body = "small body";
Mock::given(method("GET"))
.and(path("/small"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string(body)
.insert_header("content-type", "text/plain"),
)
.mount(&mock_server)
.await;
let tool = Tool::builder()
.block_private_ips(false)
.max_body_size(1_000_000)
.build();
let req = FetchRequest::new(format!("{}/small", mock_server.uri()));
let resp = tool.execute(req).await.unwrap();
assert!(resp.truncated.is_none());
assert!(resp.content.unwrap().contains("small body"));
}
#[tokio::test]
async fn test_input_007_subdomain_not_matched_by_host_prefix() {
let tool = Tool::builder()
.block_private_ips(false)
.block_prefix("http://internal.example.com")
.build();
let req = FetchRequest::new("http://internal.example.com/secret");
let result = tool.execute(req).await;
assert!(matches!(result, Err(FetchError::BlockedUrl)));
let req = FetchRequest::new("http://internal.example.com.evil.com/secret");
let result = tool.execute(req).await;
assert!(!matches!(result, Err(FetchError::BlockedUrl)));
}
#[test]
fn test_leak_001_request_error_variants_are_generic() {
let generic_messages = [
"redirect error",
"error reading response body",
"error decoding response",
"request failed",
];
for msg in &generic_messages {
let err = FetchError::RequestError(msg.to_string());
let display = err.to_string();
assert!(
!display.contains("127.0.0.1"),
"Error should not contain IPs: {display}"
);
assert!(
!display.contains("internal"),
"Error should not contain hostnames: {display}"
);
}
}
#[test]
fn test_leak_001_timeout_and_connect_display_are_generic() {
let timeout_msg = FetchError::FirstByteTimeout.to_string();
assert_eq!(
timeout_msg,
"Request timed out: server did not respond within 1 second"
);
let blocked_msg = FetchError::BlockedUrl.to_string();
assert_eq!(blocked_msg, "Blocked URL: not allowed by policy");
}