use std::io::Cursor;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
use std::time::Duration;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::{json, Value};
use crate::traits::{
Tool, ToolCallSemantics, ToolCapabilities, ToolTargetHintKind, ToolVerificationMode,
};
const DEFAULT_MAX_CHARS: usize = 20_000;
const MAX_MAX_CHARS: usize = 50_000;
pub fn validate_url_for_ssrf(url: &str) -> Result<(), String> {
let parsed = reqwest::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
match parsed.scheme() {
"http" | "https" => {}
scheme => {
return Err(format!(
"Blocked scheme '{}': only http/https allowed",
scheme
))
}
}
let host = parsed
.host_str()
.ok_or_else(|| "URL must have a host".to_string())?;
let host_lower = host.to_lowercase();
const BLOCKED_HOSTS: &[&str] = &[
"localhost",
"127.0.0.1",
"::1",
"[::1]",
"0.0.0.0",
"metadata.google.internal",
"metadata.goog",
"169.254.169.254",
];
for blocked in BLOCKED_HOSTS {
if host_lower == *blocked {
return Err(format!("Blocked host: {}", host));
}
}
if host_lower.ends_with(".internal")
|| host_lower.ends_with(".local")
|| host_lower.ends_with(".localhost")
{
return Err(format!("Blocked internal hostname: {}", host));
}
let port = parsed.port().unwrap_or(match parsed.scheme() {
"https" => 443,
_ => 80,
});
let socket_addr = format!("{}:{}", host, port);
match socket_addr.to_socket_addrs() {
Ok(addrs) => {
for addr in addrs {
if is_blocked_ip(addr.ip()) {
return Err(format!(
"Blocked IP address {} (resolved from {})",
addr.ip(),
host
));
}
}
}
Err(_) => {
if let Ok(ip) = host.parse::<IpAddr>() {
if is_blocked_ip(ip) {
return Err(format!("Blocked IP address: {}", ip));
}
}
}
}
Ok(())
}
fn is_blocked_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(ipv4) => is_blocked_ipv4(ipv4),
IpAddr::V6(ipv6) => is_blocked_ipv6(ipv6),
}
}
fn is_blocked_ipv4(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
if octets[0] == 127 {
return true;
}
if octets[0] == 10 {
return true;
}
if octets[0] == 172 && (16..=31).contains(&octets[1]) {
return true;
}
if octets[0] == 192 && octets[1] == 168 {
return true;
}
if octets[0] == 169 && octets[1] == 254 {
return true;
}
if ip == Ipv4Addr::BROADCAST {
return true;
}
if ip == Ipv4Addr::UNSPECIFIED {
return true;
}
if (octets[0] == 192 && octets[1] == 0 && octets[2] == 2)
|| (octets[0] == 198 && octets[1] == 51 && octets[2] == 100)
|| (octets[0] == 203 && octets[1] == 0 && octets[2] == 113)
{
return true;
}
if octets[0] == 100 && (64..=127).contains(&octets[1]) {
return true;
}
false
}
fn is_blocked_ipv6(ip: Ipv6Addr) -> bool {
if ip.is_loopback() {
return true;
}
if ip.is_unspecified() {
return true;
}
if let Some(ipv4) = ip.to_ipv4_mapped() {
return is_blocked_ipv4(ipv4);
}
let segments = ip.segments();
if (segments[0] & 0xffc0) == 0xfe80 {
return true;
}
if (segments[0] & 0xfe00) == 0xfc00 {
return true;
}
false
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockedHostClass {
Loopback,
PrivateNetwork,
LinkLocalMetadata,
DisallowedScheme,
Malformed,
OtherReserved,
}
impl BlockedHostClass {
pub fn label(self) -> &'static str {
match self {
BlockedHostClass::Loopback => "loopback address",
BlockedHostClass::PrivateNetwork => "private network",
BlockedHostClass::LinkLocalMetadata => "link-local/metadata address",
BlockedHostClass::DisallowedScheme => "disallowed scheme",
BlockedHostClass::Malformed => "malformed URL",
BlockedHostClass::OtherReserved => "reserved/blocked address",
}
}
}
fn classify_blocked_ipv4(ip: Ipv4Addr) -> Option<BlockedHostClass> {
let octets = ip.octets();
if octets[0] == 127 {
return Some(BlockedHostClass::Loopback);
}
if ip == Ipv4Addr::UNSPECIFIED {
return Some(BlockedHostClass::Loopback);
}
if octets[0] == 10 {
return Some(BlockedHostClass::PrivateNetwork);
}
if octets[0] == 172 && (16..=31).contains(&octets[1]) {
return Some(BlockedHostClass::PrivateNetwork);
}
if octets[0] == 192 && octets[1] == 168 {
return Some(BlockedHostClass::PrivateNetwork);
}
if octets[0] == 100 && (64..=127).contains(&octets[1]) {
return Some(BlockedHostClass::PrivateNetwork);
}
if octets[0] == 169 && octets[1] == 254 {
return Some(BlockedHostClass::LinkLocalMetadata);
}
if ip == Ipv4Addr::BROADCAST
|| (octets[0] == 192 && octets[1] == 0 && octets[2] == 2)
|| (octets[0] == 198 && octets[1] == 51 && octets[2] == 100)
|| (octets[0] == 203 && octets[1] == 0 && octets[2] == 113)
{
return Some(BlockedHostClass::OtherReserved);
}
None
}
fn classify_blocked_ipv6(ip: Ipv6Addr) -> Option<BlockedHostClass> {
if ip.is_loopback() || ip.is_unspecified() {
return Some(BlockedHostClass::Loopback);
}
if let Some(ipv4) = ip.to_ipv4_mapped() {
return classify_blocked_ipv4(ipv4);
}
let segments = ip.segments();
if (segments[0] & 0xffc0) == 0xfe80 {
return Some(BlockedHostClass::LinkLocalMetadata);
}
if (segments[0] & 0xfe00) == 0xfc00 {
return Some(BlockedHostClass::PrivateNetwork);
}
None
}
fn classify_blocked_ip(ip: IpAddr) -> Option<BlockedHostClass> {
match ip {
IpAddr::V4(v4) => classify_blocked_ipv4(v4),
IpAddr::V6(v6) => classify_blocked_ipv6(v6),
}
}
pub fn classify_blocked_host(url: &str) -> Option<BlockedHostClass> {
let parsed = match reqwest::Url::parse(url) {
Ok(u) => u,
Err(_) => return Some(BlockedHostClass::Malformed),
};
match parsed.scheme() {
"http" | "https" => {}
_ => return Some(BlockedHostClass::DisallowedScheme),
}
let host = match parsed.host_str() {
Some(h) => h,
None => return Some(BlockedHostClass::Malformed),
};
let host_lower = host.to_lowercase();
match host_lower.as_str() {
"localhost" | "127.0.0.1" | "::1" | "[::1]" | "0.0.0.0" => {
return Some(BlockedHostClass::Loopback)
}
"metadata.google.internal" | "metadata.goog" | "169.254.169.254" => {
return Some(BlockedHostClass::LinkLocalMetadata)
}
_ => {}
}
if host_lower.ends_with(".internal") {
return Some(BlockedHostClass::LinkLocalMetadata);
}
if host_lower.ends_with(".local") || host_lower.ends_with(".localhost") {
if host_lower.ends_with(".localhost") {
return Some(BlockedHostClass::Loopback);
}
return Some(BlockedHostClass::OtherReserved);
}
let port = parsed.port().unwrap_or(match parsed.scheme() {
"https" => 443,
_ => 80,
});
let socket_addr = format!("{}:{}", host, port);
match socket_addr.to_socket_addrs() {
Ok(addrs) => {
for addr in addrs {
if let Some(class) = classify_blocked_ip(addr.ip()) {
return Some(class);
}
}
}
Err(_) => {
if let Ok(ip) = host.parse::<IpAddr>() {
if let Some(class) = classify_blocked_ip(ip) {
return Some(class);
}
}
}
}
None
}
pub fn build_browser_client() -> Client {
Client::builder()
.timeout(Duration::from_secs(30))
.redirect(reqwest::redirect::Policy::custom(|attempt| {
let url = attempt.url().to_string();
if let Err(_reason) = validate_url_for_ssrf(&url) {
attempt.stop()
} else if attempt.previous().len() >= 10 {
attempt.stop()
} else {
attempt.follow()
}
}))
.user_agent(
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:142.0) Gecko/20100101 Firefox/142.0",
)
.default_headers({
let mut h = reqwest::header::HeaderMap::new();
h.insert(
"Accept",
"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"
.parse()
.unwrap(),
);
h.insert("Accept-Language", "en-US,en;q=0.5".parse().unwrap());
h.insert("Accept-Encoding", "gzip, deflate, br".parse().unwrap());
h.insert("DNT", "1".parse().unwrap());
h.insert("Upgrade-Insecure-Requests", "1".parse().unwrap());
h.insert("Sec-Fetch-Dest", "document".parse().unwrap());
h.insert("Sec-Fetch-Mode", "navigate".parse().unwrap());
h.insert("Sec-Fetch-Site", "none".parse().unwrap());
h.insert("Sec-Fetch-User", "?1".parse().unwrap());
h.insert("Sec-GPC", "1".parse().unwrap());
h
})
.build()
.expect("failed to build browser HTTP client")
}
pub struct WebFetchTool {
client: Client,
}
impl WebFetchTool {
pub fn new() -> Self {
Self {
client: build_browser_client(),
}
}
}
#[async_trait]
impl Tool for WebFetchTool {
fn name(&self) -> &str {
"web_fetch"
}
fn description(&self) -> &str {
"Fetch a readable web page and extract its content; not for REST/JSON API endpoints"
}
fn schema(&self) -> Value {
json!({
"name": "web_fetch",
"description": "Fetch a readable web page and extract its content. Strips ads/navigation. Do NOT use for REST/JSON API endpoints or machine-readable responses; use http_request for APIs. For login-required sites, use browser instead.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch"
},
"max_chars": {
"type": "integer",
"description": "Maximum characters to return (default 20000, max 50000)"
}
},
"required": ["url"],
"additionalProperties": false
}
})
}
fn capabilities(&self) -> ToolCapabilities {
ToolCapabilities {
read_only: true,
external_side_effect: true,
needs_approval: false,
idempotent: true,
high_impact_write: false,
}
}
fn call_semantics(&self, arguments: &str) -> ToolCallSemantics {
let url = serde_json::from_str::<Value>(arguments)
.ok()
.and_then(|args| {
args.get("url")
.and_then(|value| value.as_str())
.map(str::to_string)
})
.unwrap_or_default();
ToolCallSemantics::observation()
.with_verification_mode(ToolVerificationMode::ResultContent)
.with_target_hint(ToolTargetHintKind::Url, url)
}
async fn call(&self, arguments: &str) -> anyhow::Result<String> {
let args: Value = serde_json::from_str(arguments)?;
let url = args["url"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing required parameter: url"))?;
let max_chars = args["max_chars"]
.as_u64()
.map(|n| n as usize)
.unwrap_or(DEFAULT_MAX_CHARS)
.clamp(1, MAX_MAX_CHARS);
if let Err(reason) = validate_url_for_ssrf(url) {
return Ok(format!("Request blocked: {}", reason));
}
let resp = self.client.get(url).send().await?;
if !resp.status().is_success() {
return Ok(format!("Error fetching {}: HTTP {}", url, resp.status()));
}
let html = resp.text().await?;
Ok(build_fetch_reply(url, &html, max_chars))
}
}
const MIN_EXTRACTED_CHARS: usize = 200;
const MIN_HTML_FOR_EXTRACTION_CHECK: usize = 10_000;
fn build_fetch_reply(url: &str, html: &str, max_chars: usize) -> String {
let text = extract_page_text(html, url);
let trimmed = text.trim();
if trimmed.chars().count() < MIN_EXTRACTED_CHARS && html.len() >= MIN_HTML_FOR_EXTRACTION_CHECK
{
return format!(
"Content from {}:\n\n{}\n\n[⚠EXTRACTION FAILED — this {} KB page yielded only {} \
characters of readable text; it is likely JavaScript-rendered or blocking \
non-browser clients. Do NOT treat this as the page's content, and do NOT re-fetch \
this same URL — the result will not change. Fetch a different source URL from your \
search results, or use a browser-based tool on this URL if one is available.]",
url,
trimmed,
html.len() / 1024,
trimmed.chars().count(),
);
}
format_fetch_result(url, &text, max_chars)
}
fn extract_page_text(html: &str, url: &str) -> String {
let cleaned = strip_nonvisible_blocks(html);
let parsed_url = reqwest::Url::parse(url)
.unwrap_or_else(|_| reqwest::Url::parse("http://example.com").unwrap());
let mut cursor = Cursor::new(cleaned.as_bytes());
match llm_readability::extractor::extract(&mut cursor, &parsed_url) {
Ok(product) if !product.text.trim().is_empty() => product.text,
_ => htmd::convert(&cleaned).unwrap_or_else(|_| cleaned.clone()),
}
}
fn strip_nonvisible_blocks(html: &str) -> String {
let mut out = html.to_string();
for tag in ["script", "style", "noscript"] {
out = strip_tag_blocks(&out, tag);
}
out
}
fn strip_tag_blocks(html: &str, tag: &str) -> String {
let open = format!("<{}", tag);
let close = format!("</{}>", tag);
let bytes = html.as_bytes();
let mut out = String::with_capacity(html.len());
let mut pos = 0;
while let Some(start) = find_ascii_ci(bytes, open.as_bytes(), pos) {
let after = bytes.get(start + open.len());
if !matches!(
after,
Some(b' ') | Some(b'>') | Some(b'\t') | Some(b'\n') | Some(b'/')
) {
out.push_str(&html[pos..start + open.len()]);
pos = start + open.len();
continue;
}
out.push_str(&html[pos..start]);
match find_ascii_ci(bytes, close.as_bytes(), start) {
Some(end) => pos = end + close.len(),
None => {
pos = html.len();
break;
}
}
}
out.push_str(&html[pos..]);
out
}
fn find_ascii_ci(haystack: &[u8], needle: &[u8], from: usize) -> Option<usize> {
if from >= haystack.len() || needle.is_empty() {
return None;
}
haystack[from..]
.windows(needle.len())
.position(|w| w.eq_ignore_ascii_case(needle))
.map(|i| i + from)
}
fn format_fetch_result(url: &str, text: &str, max_chars: usize) -> String {
let mut result = format!("Content from {}:\n\n", url);
if text.len() > max_chars {
let mut end = max_chars;
while end > 0 && !text.is_char_boundary(end) {
end -= 1;
}
result.push_str(&text[..end]);
result.push_str("\n\n");
result.push_str(&crate::utils::truncation_notice_with_hint(
text[..end].chars().count(),
text.chars().count(),
"If the answer may be in the omitted part, re-fetch with a larger max_chars \
(up to 50000) or fetch a more specific page; if the full content is still \
not visible, tell the user the page was longer than you could read.",
));
} else {
result.push_str(text);
}
result
}
#[cfg(test)]
mod ssrf_policy_tests {
use super::*;
fn assert_blocked(url: &str, expected: BlockedHostClass) {
let class = classify_blocked_host(url);
assert_eq!(
class,
Some(expected),
"expected {url} blocked as {expected:?}, got {class:?}"
);
assert!(
validate_url_for_ssrf(url).is_err(),
"validate_url_for_ssrf must also reject {url}"
);
let label = expected.label();
assert!(
!url.contains(label) || label.is_empty(),
"label must be a fixed class string, not derived from the url"
);
}
fn assert_allowed(url: &str) {
assert_eq!(
classify_blocked_host(url),
None,
"{url} should be allowed (public)"
);
}
#[test]
fn loopback_is_blocked() {
assert_blocked("http://127.0.0.1/", BlockedHostClass::Loopback);
assert_blocked("http://127.0.0.1:8080/admin", BlockedHostClass::Loopback);
assert_blocked("http://127.5.6.7/", BlockedHostClass::Loopback); assert_blocked("http://localhost/", BlockedHostClass::Loopback);
assert_blocked("http://[::1]/", BlockedHostClass::Loopback);
assert_blocked("http://0.0.0.0/", BlockedHostClass::Loopback);
}
#[test]
fn rfc1918_private_ranges_are_blocked() {
assert_blocked("http://10.0.0.1/", BlockedHostClass::PrivateNetwork);
assert_blocked("http://10.255.255.255/", BlockedHostClass::PrivateNetwork);
assert_blocked("http://172.16.0.1/", BlockedHostClass::PrivateNetwork);
assert_blocked("http://172.31.255.255/", BlockedHostClass::PrivateNetwork);
assert_blocked("http://192.168.1.1/", BlockedHostClass::PrivateNetwork);
assert_blocked("http://100.64.0.1/", BlockedHostClass::PrivateNetwork);
}
#[test]
fn link_local_and_cloud_metadata_are_blocked() {
assert_blocked(
"http://169.254.169.254/latest/meta-data/",
BlockedHostClass::LinkLocalMetadata,
);
assert_blocked("http://169.254.0.1/", BlockedHostClass::LinkLocalMetadata);
assert_blocked(
"http://metadata.google.internal/",
BlockedHostClass::LinkLocalMetadata,
);
assert_blocked(
"http://anything.internal/",
BlockedHostClass::LinkLocalMetadata,
);
}
#[test]
fn ipv4_mapped_ipv6_is_blocked_by_embedded_class() {
assert_blocked("http://[::ffff:127.0.0.1]/", BlockedHostClass::Loopback);
assert_blocked(
"http://[::ffff:10.0.0.1]/",
BlockedHostClass::PrivateNetwork,
);
assert_blocked(
"http://[::ffff:169.254.169.254]/",
BlockedHostClass::LinkLocalMetadata,
);
}
#[test]
fn ipv6_unique_local_and_link_local_are_blocked() {
assert_blocked("http://[fc00::1]/", BlockedHostClass::PrivateNetwork);
assert_blocked("http://[fd12:3456::1]/", BlockedHostClass::PrivateNetwork);
assert_blocked("http://[fe80::1]/", BlockedHostClass::LinkLocalMetadata);
}
#[test]
fn disallowed_schemes_are_blocked() {
assert_blocked("file:///etc/passwd", BlockedHostClass::DisallowedScheme);
assert_blocked("ftp://example.com/", BlockedHostClass::DisallowedScheme);
assert_blocked(
"data:text/html,<h1>hi</h1>",
BlockedHostClass::DisallowedScheme,
);
}
#[test]
fn malformed_urls_are_blocked() {
assert_eq!(
classify_blocked_host("not a url"),
Some(BlockedHostClass::Malformed)
);
assert_eq!(
classify_blocked_host("http://"),
Some(BlockedHostClass::Malformed)
);
}
#[test]
fn resolve_path_classifies_private_targets() {
assert_blocked("http://10.1.2.3:8443/x", BlockedHostClass::PrivateNetwork);
assert_blocked("http://127.0.0.1:9000/x", BlockedHostClass::Loopback);
assert_eq!(
classify_blocked_host("http://nonexistent-host.invalid/"),
None,
"unresolvable host is not pre-flight blocked (request-time only)"
);
}
#[test]
fn public_urls_are_allowed() {
assert_allowed("https://example.com/");
assert_allowed("https://www.rust-lang.org/learn");
assert_allowed("http://93.184.216.34/"); assert_allowed("https://8.8.8.8/");
}
#[test]
fn labels_are_fixed_and_leak_no_data() {
for class in [
BlockedHostClass::Loopback,
BlockedHostClass::PrivateNetwork,
BlockedHostClass::LinkLocalMetadata,
BlockedHostClass::DisallowedScheme,
BlockedHostClass::Malformed,
BlockedHostClass::OtherReserved,
] {
let label = class.label();
assert!(!label.is_empty());
for forbidden in ["://", "?", "=", "@", "127.", "169.254", "secret"] {
assert!(
!label.contains(forbidden),
"label {label:?} must not contain {forbidden:?}"
);
}
}
}
}
#[cfg(test)]
mod format_tests {
use super::*;
#[test]
fn truncated_fetch_carries_instructional_notice() {
let text = "a".repeat(120);
let out = format_fetch_result("https://example.com/page", &text, 50);
assert!(out.contains("Content from https://example.com/page"));
assert!(out.contains("OUTPUT TRUNCATED"));
assert!(out.contains("Do NOT enumerate"));
assert!(out.contains("max_chars"));
assert!(!out.contains("wc -l"));
assert!(!out.contains("[Truncated]"));
}
#[test]
fn untruncated_fetch_has_no_notice() {
let out = format_fetch_result("https://example.com", "short content", 1000);
assert!(out.contains("short content"));
assert!(!out.contains("OUTPUT TRUNCATED"));
}
#[test]
fn truncation_respects_char_boundaries() {
let text = format!("{}🦀🦀🦀", "x".repeat(49));
let out = format_fetch_result("https://example.com", &text, 51);
assert!(out.contains("OUTPUT TRUNCATED"));
}
}
#[cfg(test)]
mod extraction_tests {
use super::*;
fn script_heavy_page() -> String {
format!(
"<html><head><title>Ecuador national football team - Wikipedia</title>\
<script>(function(){{var className=\"client-js vector-feature-language-{}\";}})();</script>\
<style>.mw-parser{{display:none}}{}</style></head>\
<body><h1>Squad</h1><p>{}</p>\
<table><tr><td>Willian Pacho</td></tr><tr><td>Moises Caicedo</td></tr></table>\
</body></html>",
"x".repeat(8_000),
"y".repeat(8_000),
"The current squad was announced ahead of the tournament. ".repeat(20),
)
}
#[test]
fn script_and_style_blocks_never_reach_the_model() {
let reply = build_fetch_reply(
"https://en.wikipedia.org/wiki/X",
&script_heavy_page(),
2_000,
);
assert!(
!reply.contains("var className"),
"script body leaked: {}",
reply
);
assert!(
!reply.contains("display:none"),
"style body leaked: {}",
reply
);
assert!(reply.contains("Willian Pacho") || reply.contains("current squad"));
}
#[test]
fn near_empty_extraction_gets_instructional_notice() {
let html = format!(
"<html><head><script>{}</script></head><body><div id=\"root\"></div></body></html>",
"z".repeat(20_000)
);
let reply = build_fetch_reply("https://spa.example.com", &html, 20_000);
assert!(
reply.contains("EXTRACTION FAILED"),
"missing notice: {}",
reply
);
assert!(reply.contains("different source"));
assert!(!reply.contains("OUTPUT TRUNCATED"));
}
#[test]
fn normal_page_has_no_extraction_notice() {
let reply = build_fetch_reply("https://example.com", &script_heavy_page(), 20_000);
assert!(!reply.contains("EXTRACTION FAILED"));
}
#[test]
fn small_thin_pages_are_not_flagged() {
let reply = build_fetch_reply(
"https://example.com",
"<html><body><p>hi</p></body></html>",
20_000,
);
assert!(!reply.contains("EXTRACTION FAILED"));
}
}