use crate::providers::ToolDefinition;
use anyhow::Result;
use serde_json::{Value, json};
const DEFAULT_TIMEOUT_SECS: u64 = 15;
pub(crate) const MAX_REDIRECTS: usize = 10;
const USER_AGENT: &str = "Koda/0.1 (AI coding agent)";
pub fn definitions() -> Vec<ToolDefinition> {
vec![ToolDefinition {
name: "WebFetch".to_string(),
description: "Fetch content from a URL. HTML is stripped to readable text by default; \
set raw=true for raw HTML. Only use URLs from tool results or user input — \
never guess or generate URLs from memory. \
For documentation lookup, prefer reading local files first."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch (must start with http:// or https://)"
},
"raw": {
"type": "boolean",
"description": "If true, return raw HTML instead of stripped text (default: false)"
}
},
"required": ["url"]
}),
}]
}
async fn validate_url_safety(url_str: &str) -> Result<url::Url> {
if !url_str.starts_with("http://") && !url_str.starts_with("https://") {
anyhow::bail!("URL must start with http:// or https://");
}
if !is_safe_url(url_str) {
anyhow::bail!(
"URL blocked: requests to internal/private networks are not allowed. \
This includes localhost, private IPs, and cloud metadata endpoints."
);
}
let parsed = url::Url::parse(url_str)
.map_err(|e| anyhow::anyhow!("Failed to parse URL '{url_str}': {e}"))?;
if let Some(host) = parsed.host_str()
&& parsed
.host()
.is_some_and(|h| matches!(h, url::Host::Domain(_)))
{
let port = parsed.port_or_known_default().unwrap_or(80);
match tokio::net::lookup_host(format!("{host}:{port}")).await {
Ok(addrs) => {
for addr in addrs {
if !is_safe_ip(addr.ip()) {
anyhow::bail!(
"URL blocked: domain '{host}' resolves to private/internal IP {}.",
addr.ip()
);
}
}
}
Err(e) => {
anyhow::bail!("DNS resolution failed for '{host}': {e}");
}
}
}
Ok(parsed)
}
pub(crate) async fn safely_follow_redirects<F, Fut>(
client: &reqwest::Client,
initial_url: url::Url,
max_hops: usize,
validator: F,
) -> Result<reqwest::Response>
where
F: Fn(String) -> Fut,
Fut: std::future::Future<Output = Result<url::Url>>,
{
let mut current_url = initial_url;
for hop in 0..=max_hops {
let response = client
.get(current_url.clone())
.header("User-Agent", USER_AGENT)
.send()
.await
.map_err(|e| anyhow::anyhow!("HTTP request failed: {e}"))?;
let status = response.status();
if !status.is_redirection() {
return Ok(response);
}
if !matches!(status.as_u16(), 301 | 302 | 303 | 307 | 308) {
return Ok(response);
}
if hop == max_hops {
anyhow::bail!(
"WebFetch exceeded max redirect hops ({max_hops}); last URL: {current_url}"
);
}
let location = response
.headers()
.get(reqwest::header::LOCATION)
.ok_or_else(|| {
anyhow::anyhow!(
"Redirect status {status} from {current_url} but no Location header"
)
})?
.to_str()
.map_err(|e| {
anyhow::anyhow!(
"Redirect Location header from {current_url} is not valid UTF-8: {e}"
)
})?
.to_string();
let next_url = current_url.join(&location).map_err(|e| {
anyhow::anyhow!(
"Failed to resolve redirect Location '{location}' against {current_url}: {e}"
)
})?;
current_url = validator(next_url.to_string()).await.map_err(|e| {
anyhow::anyhow!("Redirect from {current_url} to {next_url} blocked by SSRF policy: {e}")
})?;
}
unreachable!("safely_follow_redirects loop exited without returning")
}
fn web_fetch_client() -> &'static reqwest::Client {
static CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
CLIENT.get_or_init(|| {
crate::providers::build_http_client_with_redirect_policy(
None,
reqwest::redirect::Policy::none(),
)
})
}
pub async fn web_fetch(args: &Value, max_body_chars: usize) -> Result<String> {
let url_str = args["url"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing 'url' argument"))?;
let raw = args["raw"].as_bool().unwrap_or(false);
let initial_url = validate_url_safety(url_str).await?;
let client = web_fetch_client();
let response = tokio::time::timeout(
std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS),
safely_follow_redirects(client, initial_url, MAX_REDIRECTS, |u| async move {
validate_url_safety(&u).await
}),
)
.await
.map_err(|_| anyhow::anyhow!("Request timed out after {DEFAULT_TIMEOUT_SECS}s"))??;
let final_url = response.url().clone();
let status = response.status();
if !status.is_success() {
anyhow::bail!("HTTP {status} for {final_url}");
}
let body = response
.text()
.await
.map_err(|e| anyhow::anyhow!("Failed to read response body: {e}"))?;
let content = if raw { body } else { strip_html(&body) };
if content.len() > max_body_chars {
Ok(format!(
"{}\n\n[TRUNCATED: response was {} chars. \
Consider fetching a more specific URL.]",
&content[..max_body_chars],
content.len()
))
} else {
Ok(content)
}
}
pub(crate) fn is_safe_ip(ip: std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(ipv4) => {
let octets = ipv4.octets();
if octets[0] == 127
|| octets[0] == 10
|| (octets[0] == 172 && (16..=31).contains(&octets[1]))
|| (octets[0] == 192 && octets[1] == 168)
|| (octets[0] == 169 && octets[1] == 254)
|| ipv4.is_unspecified()
{
return false;
}
true
}
std::net::IpAddr::V6(ipv6) => {
if ipv6.is_loopback() || ipv6.is_unspecified() {
return false;
}
if let Some(ipv4) = ipv6.to_ipv4_mapped() {
return is_safe_ip(std::net::IpAddr::V4(ipv4));
}
true
}
}
}
pub(crate) fn is_safe_url(url_str: &str) -> bool {
let Ok(parsed) = url::Url::parse(url_str) else {
return false;
};
let Some(host) = parsed.host_str() else {
return false;
};
let blocked_hosts = [
"169.254.169.254",
"metadata.google.internal",
"metadata.internal",
"localhost",
"0.0.0.0",
];
if blocked_hosts.contains(&host) {
return false;
}
if host.ends_with(".internal") || host.ends_with(".local") {
return false;
}
match parsed.host() {
Some(url::Host::Ipv4(ip)) => {
if !is_safe_ip(std::net::IpAddr::V4(ip)) {
return false;
}
}
Some(url::Host::Ipv6(ip)) => {
if !is_safe_ip(std::net::IpAddr::V6(ip)) {
return false;
}
}
Some(url::Host::Domain(_)) => {
}
None => return false,
}
true
}
fn strip_html(html: &str) -> String {
let mut result = String::with_capacity(html.len());
let mut in_tag = false;
let mut in_script = false;
let mut in_style = false;
let mut last_was_space = false;
let lower = html.to_lowercase();
let chars: Vec<char> = html.chars().collect();
let lower_chars: Vec<char> = lower.chars().collect();
let mut i = 0;
while i < chars.len() {
if in_script {
if i + 9 <= lower_chars.len()
&& lower_chars[i..i + 9].iter().collect::<String>() == "</script>"
{
in_script = false;
i += 9;
} else {
i += 1;
}
continue;
}
if in_style {
if i + 8 <= lower_chars.len()
&& lower_chars[i..i + 8].iter().collect::<String>() == "</style>"
{
in_style = false;
i += 8;
} else {
i += 1;
}
continue;
}
if chars[i] == '<' {
if i + 7 <= lower_chars.len()
&& lower_chars[i..i + 7].iter().collect::<String>() == "<script"
{
in_script = true;
} else if i + 6 <= lower_chars.len()
&& lower_chars[i..i + 6].iter().collect::<String>() == "<style"
{
in_style = true;
}
in_tag = true;
let tag_start: String = lower_chars[i..std::cmp::min(i + 10, lower_chars.len())]
.iter()
.collect();
if tag_start.starts_with("<br")
|| tag_start.starts_with("<p")
|| tag_start.starts_with("<div")
|| tag_start.starts_with("<h")
|| tag_start.starts_with("<li")
|| tag_start.starts_with("<tr")
{
result.push('\n');
last_was_space = true;
}
i += 1;
continue;
}
if chars[i] == '>' {
in_tag = false;
i += 1;
continue;
}
if !in_tag {
let ch = chars[i];
if ch.is_whitespace() {
if !last_was_space {
result.push(' ');
last_was_space = true;
}
} else {
result.push(ch);
last_was_space = false;
}
}
i += 1;
}
result
.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace(""", "\"")
.replace("'", "'")
.replace(" ", " ")
}
use crate::tools::{Tool, ToolEffect, ToolExecCtx, ToolResult};
use async_trait::async_trait;
pub struct WebFetchTool;
#[async_trait]
impl Tool for WebFetchTool {
fn name(&self) -> &'static str {
"WebFetch"
}
fn definition(&self) -> ToolDefinition {
definitions()
.into_iter()
.find(|d| d.name == "WebFetch")
.expect("web_fetch::definitions() must contain WebFetch")
}
fn classify(&self, _args: &serde_json::Value) -> ToolEffect {
ToolEffect::ReadOnly
}
async fn execute(&self, ctx: &ToolExecCtx<'_>, args: &serde_json::Value) -> ToolResult {
let r = web_fetch(args, ctx.caps.web_body_chars).await;
crate::tools::wrap_result(r)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strip_html_basic() {
let html = "<h1>Hello</h1><p>World & friends</p>";
let result = strip_html(html);
assert!(result.contains("Hello"));
assert!(result.contains("World & friends"));
assert!(!result.contains("<h1>"));
}
#[test]
fn test_strip_html_script_removal() {
let html = "<p>Before</p><script>alert('xss')</script><p>After</p>";
let result = strip_html(html);
assert!(result.contains("Before"));
assert!(result.contains("After"));
assert!(!result.contains("alert"));
}
#[test]
fn test_strip_html_whitespace_collapse() {
let html = "<p> lots of spaces </p>";
let result = strip_html(html);
assert!(!result.contains(" ")); }
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_web_fetch_bad_url() {
let args = json!({ "url": "not-a-url" });
let result = web_fetch(&args, 15_000).await;
assert!(result.is_err());
}
#[test]
fn test_is_safe_url_blocks_metadata() {
assert!(!is_safe_url("http://169.254.169.254/latest/meta-data/"));
assert!(!is_safe_url("http://metadata.google.internal/"));
}
#[test]
fn test_is_safe_url_blocks_localhost() {
assert!(!is_safe_url("http://localhost:8080/admin"));
assert!(!is_safe_url("http://127.0.0.1/secret"));
assert!(!is_safe_url("http://0.0.0.0/"));
}
#[test]
fn test_is_safe_url_blocks_private_ips() {
assert!(!is_safe_url("http://10.0.0.1/internal"));
assert!(!is_safe_url("http://172.16.0.1/admin"));
assert!(!is_safe_url("http://192.168.1.1/config"));
}
#[test]
fn test_is_safe_url_blocks_userinfo_bypass() {
assert!(!is_safe_url(
"http://evil.com@169.254.169.254/latest/meta-data/"
));
assert!(!is_safe_url("http://user:pass@127.0.0.1/"));
}
#[test]
fn test_is_safe_url_blocks_ipv6_mapped() {
assert!(!is_safe_url("http://[::ffff:127.0.0.1]/"));
assert!(!is_safe_url("http://[::1]/"));
}
#[test]
fn test_is_safe_url_allows_public() {
assert!(is_safe_url("https://docs.rs/tokio/latest/tokio/"));
assert!(is_safe_url("https://api.github.com/repos"));
assert!(is_safe_url("https://example.com"));
}
#[test]
fn test_is_safe_ip_blocks_private() {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))));
assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))));
assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
assert!(!is_safe_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)));
assert!(!is_safe_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED)));
}
#[test]
fn test_is_safe_ip_allows_public() {
use std::net::{IpAddr, Ipv4Addr};
assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34))));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_web_fetch_blocks_ssrf() {
let args = json!({ "url": "http://169.254.169.254/latest/meta-data/" });
let result = web_fetch(&args, 15_000).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("blocked"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_web_fetch_missing_url() {
let args = json!({});
let result = web_fetch(&args, 15_000).await;
assert!(result.is_err());
}
use axum::{Router, extract::State, http::StatusCode, response::IntoResponse, routing::get};
use std::sync::{Arc, Mutex as StdMutex};
use tokio_util::sync::CancellationToken;
#[derive(Clone, Debug)]
enum Step {
Redirect(String),
Ok(String),
}
#[derive(Clone)]
struct ServerState {
steps: Arc<StdMutex<Vec<Step>>>,
}
async fn handler(
State(state): State<ServerState>,
uri: axum::http::Uri,
) -> axum::response::Response {
let step = state.steps.lock().expect("steps mutex poisoned").pop();
match step {
Some(Step::Redirect(loc)) => (
StatusCode::FOUND,
[(axum::http::header::LOCATION, loc)],
String::new(),
)
.into_response(),
Some(Step::Ok(body)) => (StatusCode::OK, body).into_response(),
None => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("unexpected request to {uri} — test script exhausted"),
)
.into_response(),
}
}
async fn spawn_test_server(steps: Vec<Step>) -> (String, CancellationToken) {
let mut reversed = steps;
reversed.reverse();
let state = ServerState {
steps: Arc::new(StdMutex::new(reversed)),
};
let app = Router::new()
.route("/{*path}", get(handler))
.route("/", get(handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = format!("http://{addr}");
let ct = CancellationToken::new();
let ct_server = ct.clone();
tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move { ct_server.cancelled_owned().await })
.await
.ok();
});
(url, ct)
}
async fn permissive_validator(url: String) -> Result<url::Url> {
url::Url::parse(&url).map_err(|e| anyhow::anyhow!("parse: {e}"))
}
fn test_client() -> reqwest::Client {
crate::providers::build_http_client_with_redirect_policy(
None,
reqwest::redirect::Policy::none(),
)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_redirect_to_loopback_is_blocked_by_production_validator() {
let (server_url, ct) = spawn_test_server(vec![Step::Redirect(
"http://127.0.0.1:1/secret".to_string(),
)])
.await;
let initial = url::Url::parse(&server_url).unwrap();
let client = test_client();
let result = safely_follow_redirects(&client, initial, MAX_REDIRECTS, |u| async move {
validate_url_safety(&u).await
})
.await;
ct.cancel();
let err = result.expect_err("redirect to loopback must be rejected");
let msg = err.to_string();
assert!(
msg.contains("blocked") || msg.contains("SSRF"),
"error should mention SSRF/blocked, got: {msg}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_redirect_to_cloud_metadata_is_blocked() {
let (server_url, ct) = spawn_test_server(vec![Step::Redirect(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/".to_string(),
)])
.await;
let initial = url::Url::parse(&server_url).unwrap();
let client = test_client();
let result = safely_follow_redirects(&client, initial, MAX_REDIRECTS, |u| async move {
validate_url_safety(&u).await
})
.await;
ct.cancel();
assert!(
result.is_err(),
"redirect to 169.254.169.254 must be rejected"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_max_redirect_hops_enforced() {
let mut steps: Vec<Step> = (0..11)
.map(|i| Step::Redirect(format!("/hop{i}")))
.collect();
steps.push(Step::Ok("never reached".to_string()));
let (server_url, ct) = spawn_test_server(steps).await;
let initial = url::Url::parse(&server_url).unwrap();
let client = test_client();
let result = safely_follow_redirects(&client, initial, 3, permissive_validator).await;
ct.cancel();
let err = result.expect_err("hop limit must be enforced");
let msg = err.to_string();
assert!(
msg.contains("max redirect hops"),
"error should mention hop cap, got: {msg}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_relative_redirect_resolves_against_current_url() {
let (server_url, ct) = spawn_test_server(vec![
Step::Redirect("/b".to_string()),
Step::Ok("final body".to_string()),
])
.await;
let initial = url::Url::parse(&format!("{server_url}/a")).unwrap();
let client = test_client();
let response =
safely_follow_redirects(&client, initial, MAX_REDIRECTS, permissive_validator)
.await
.expect("relative redirect should succeed");
let final_url = response.url().clone();
let body = response.text().await.unwrap();
ct.cancel();
assert_eq!(body, "final body");
assert!(
final_url.path().ends_with("/b"),
"final URL should be the relative-resolved /b, got: {final_url}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_scheme_relative_redirect_revalidated() {
let (server_url, ct) =
spawn_test_server(vec![Step::Redirect("//127.0.0.1:1/x".to_string())]).await;
let initial = url::Url::parse(&server_url).unwrap();
let client = test_client();
let result = safely_follow_redirects(&client, initial, MAX_REDIRECTS, |u| async move {
validate_url_safety(&u).await
})
.await;
ct.cancel();
assert!(
result.is_err(),
"scheme-relative redirect to loopback must be rejected"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_happy_path_two_hop_redirect_chain() {
let (server_url, ct) = spawn_test_server(vec![
Step::Redirect("/step2".to_string()),
Step::Redirect("/final".to_string()),
Step::Ok("hello world".to_string()),
])
.await;
let initial = url::Url::parse(&server_url).unwrap();
let client = test_client();
let response =
safely_follow_redirects(&client, initial, MAX_REDIRECTS, permissive_validator)
.await
.expect("happy path should succeed");
let body = response.text().await.unwrap();
ct.cancel();
assert_eq!(body, "hello world");
}
}