koda_core/tools/
web_fetch.rs1use crate::providers::ToolDefinition;
19use anyhow::Result;
20use serde_json::{Value, json};
21
22const DEFAULT_TIMEOUT_SECS: u64 = 15;
23
24pub fn definitions() -> Vec<ToolDefinition> {
26 vec![ToolDefinition {
27 name: "WebFetch".to_string(),
28 description: "Fetch content from a URL. HTML is stripped to readable text by default; \
29 set raw=true for raw HTML. Only use URLs from tool results or user input — \
30 never guess or generate URLs from memory. \
31 For documentation lookup, prefer reading local files first."
32 .to_string(),
33 parameters: json!({
34 "type": "object",
35 "properties": {
36 "url": {
37 "type": "string",
38 "description": "The URL to fetch (must start with http:// or https://)"
39 },
40 "raw": {
41 "type": "boolean",
42 "description": "If true, return raw HTML instead of stripped text (default: false)"
43 }
44 },
45 "required": ["url"]
46 }),
47 }]
48}
49
50pub async fn web_fetch(args: &Value, max_body_chars: usize) -> Result<String> {
52 let url = args["url"]
53 .as_str()
54 .ok_or_else(|| anyhow::anyhow!("Missing 'url' argument"))?;
55 let raw = args["raw"].as_bool().unwrap_or(false);
56
57 if !url.starts_with("http://") && !url.starts_with("https://") {
58 anyhow::bail!("URL must start with http:// or https://");
59 }
60
61 if !is_safe_url(url) {
63 anyhow::bail!(
64 "URL blocked: requests to internal/private networks are not allowed. \
65 This includes localhost, private IPs, and cloud metadata endpoints."
66 );
67 }
68
69 if let Ok(parsed) = url::Url::parse(url)
73 && let Some(host) = parsed.host_str()
74 {
75 if parsed
77 .host()
78 .is_some_and(|h| matches!(h, url::Host::Domain(_)))
79 {
80 match tokio::net::lookup_host(format!(
81 "{}:{}",
82 host,
83 parsed.port_or_known_default().unwrap_or(80)
84 ))
85 .await
86 {
87 Ok(addrs) => {
88 for addr in addrs {
89 if !is_safe_ip(addr.ip()) {
90 anyhow::bail!(
91 "URL blocked: domain '{host}' resolves to private/internal IP {}.",
92 addr.ip()
93 );
94 }
95 }
96 }
97 Err(e) => {
98 anyhow::bail!("DNS resolution failed for '{host}': {e}");
99 }
100 }
101 }
102 }
103
104 static HTTP_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
105 let client = HTTP_CLIENT
106 .get_or_init(|| crate::providers::build_http_client(None))
107 .clone();
108 let response = tokio::time::timeout(
109 std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS),
110 client
111 .get(url)
112 .header("User-Agent", "Koda/0.1 (AI coding agent)")
113 .send(),
114 )
115 .await
116 .map_err(|_| anyhow::anyhow!("Request timed out after {DEFAULT_TIMEOUT_SECS}s"))?
117 .map_err(|e| anyhow::anyhow!("HTTP request failed: {e}"))?;
118
119 let status = response.status();
120 if !status.is_success() {
121 anyhow::bail!("HTTP {status} for {url}");
122 }
123
124 let body = response
125 .text()
126 .await
127 .map_err(|e| anyhow::anyhow!("Failed to read response body: {e}"))?;
128
129 let content = if raw { body } else { strip_html(&body) };
130
131 if content.len() > max_body_chars {
132 Ok(format!(
133 "{}\n\n[TRUNCATED: response was {} chars. \
134 Consider fetching a more specific URL.]",
135 &content[..max_body_chars],
136 content.len()
137 ))
138 } else {
139 Ok(content)
140 }
141}
142
143pub(crate) fn is_safe_ip(ip: std::net::IpAddr) -> bool {
145 match ip {
146 std::net::IpAddr::V4(ipv4) => {
147 let octets = ipv4.octets();
148 if octets[0] == 127
150 || octets[0] == 10
151 || (octets[0] == 172 && (16..=31).contains(&octets[1]))
152 || (octets[0] == 192 && octets[1] == 168)
153 || (octets[0] == 169 && octets[1] == 254)
154 || ipv4.is_unspecified()
155 {
156 return false;
157 }
158 true
159 }
160 std::net::IpAddr::V6(ipv6) => {
161 if ipv6.is_loopback() || ipv6.is_unspecified() {
162 return false;
163 }
164 if let Some(ipv4) = ipv6.to_ipv4_mapped() {
165 return is_safe_ip(std::net::IpAddr::V4(ipv4));
166 }
167 true
168 }
169 }
170}
171
172pub(crate) fn is_safe_url(url_str: &str) -> bool {
175 let Ok(parsed) = url::Url::parse(url_str) else {
176 return false;
177 };
178 let Some(host) = parsed.host_str() else {
179 return false;
180 };
181
182 let blocked_hosts = [
184 "169.254.169.254",
185 "metadata.google.internal",
186 "metadata.internal",
187 "localhost",
188 "0.0.0.0",
189 ];
190 if blocked_hosts.contains(&host) {
191 return false;
192 }
193
194 if host.ends_with(".internal") || host.ends_with(".local") {
196 return false;
197 }
198
199 match parsed.host() {
201 Some(url::Host::Ipv4(ip)) => {
202 if !is_safe_ip(std::net::IpAddr::V4(ip)) {
203 return false;
204 }
205 }
206 Some(url::Host::Ipv6(ip)) => {
207 if !is_safe_ip(std::net::IpAddr::V6(ip)) {
208 return false;
209 }
210 }
211 Some(url::Host::Domain(_)) => {
212 }
215 None => return false,
216 }
217
218 true
219}
220
221fn strip_html(html: &str) -> String {
223 let mut result = String::with_capacity(html.len());
224 let mut in_tag = false;
225 let mut in_script = false;
226 let mut in_style = false;
227 let mut last_was_space = false;
228
229 let lower = html.to_lowercase();
230 let chars: Vec<char> = html.chars().collect();
231 let lower_chars: Vec<char> = lower.chars().collect();
232
233 let mut i = 0;
234 while i < chars.len() {
235 if in_script {
236 if i + 9 <= lower_chars.len()
238 && lower_chars[i..i + 9].iter().collect::<String>() == "</script>"
239 {
240 in_script = false;
241 i += 9;
242 } else {
243 i += 1;
244 }
245 continue;
246 }
247 if in_style {
248 if i + 8 <= lower_chars.len()
249 && lower_chars[i..i + 8].iter().collect::<String>() == "</style>"
250 {
251 in_style = false;
252 i += 8;
253 } else {
254 i += 1;
255 }
256 continue;
257 }
258
259 if chars[i] == '<' {
260 if i + 7 <= lower_chars.len()
262 && lower_chars[i..i + 7].iter().collect::<String>() == "<script"
263 {
264 in_script = true;
265 } else if i + 6 <= lower_chars.len()
266 && lower_chars[i..i + 6].iter().collect::<String>() == "<style"
267 {
268 in_style = true;
269 }
270 in_tag = true;
271 let tag_start: String = lower_chars[i..std::cmp::min(i + 10, lower_chars.len())]
273 .iter()
274 .collect();
275 if tag_start.starts_with("<br")
276 || tag_start.starts_with("<p")
277 || tag_start.starts_with("<div")
278 || tag_start.starts_with("<h")
279 || tag_start.starts_with("<li")
280 || tag_start.starts_with("<tr")
281 {
282 result.push('\n');
283 last_was_space = true;
284 }
285 i += 1;
286 continue;
287 }
288
289 if chars[i] == '>' {
290 in_tag = false;
291 i += 1;
292 continue;
293 }
294
295 if !in_tag {
296 let ch = chars[i];
297 if ch.is_whitespace() {
298 if !last_was_space {
299 result.push(' ');
300 last_was_space = true;
301 }
302 } else {
303 result.push(ch);
304 last_was_space = false;
305 }
306 }
307 i += 1;
308 }
309
310 result
312 .replace("&", "&")
313 .replace("<", "<")
314 .replace(">", ">")
315 .replace(""", "\"")
316 .replace("'", "'")
317 .replace(" ", " ")
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_strip_html_basic() {
326 let html = "<h1>Hello</h1><p>World & friends</p>";
327 let result = strip_html(html);
328 assert!(result.contains("Hello"));
329 assert!(result.contains("World & friends"));
330 assert!(!result.contains("<h1>"));
331 }
332
333 #[test]
334 fn test_strip_html_script_removal() {
335 let html = "<p>Before</p><script>alert('xss')</script><p>After</p>";
336 let result = strip_html(html);
337 assert!(result.contains("Before"));
338 assert!(result.contains("After"));
339 assert!(!result.contains("alert"));
340 }
341
342 #[test]
343 fn test_strip_html_whitespace_collapse() {
344 let html = "<p> lots of spaces </p>";
345 let result = strip_html(html);
346 assert!(!result.contains(" ")); }
348
349 #[tokio::test]
350 async fn test_web_fetch_bad_url() {
351 let args = json!({ "url": "not-a-url" });
352 let result = web_fetch(&args, 15_000).await;
353 assert!(result.is_err());
354 }
355
356 #[test]
357 fn test_is_safe_url_blocks_metadata() {
358 assert!(!is_safe_url("http://169.254.169.254/latest/meta-data/"));
359 assert!(!is_safe_url("http://metadata.google.internal/"));
360 }
361
362 #[test]
363 fn test_is_safe_url_blocks_localhost() {
364 assert!(!is_safe_url("http://localhost:8080/admin"));
365 assert!(!is_safe_url("http://127.0.0.1/secret"));
366 assert!(!is_safe_url("http://0.0.0.0/"));
367 }
368
369 #[test]
370 fn test_is_safe_url_blocks_private_ips() {
371 assert!(!is_safe_url("http://10.0.0.1/internal"));
372 assert!(!is_safe_url("http://172.16.0.1/admin"));
373 assert!(!is_safe_url("http://192.168.1.1/config"));
374 }
375
376 #[test]
377 fn test_is_safe_url_blocks_userinfo_bypass() {
378 assert!(!is_safe_url(
380 "http://evil.com@169.254.169.254/latest/meta-data/"
381 ));
382 assert!(!is_safe_url("http://user:pass@127.0.0.1/"));
383 }
384
385 #[test]
386 fn test_is_safe_url_blocks_ipv6_mapped() {
387 assert!(!is_safe_url("http://[::ffff:127.0.0.1]/"));
388 assert!(!is_safe_url("http://[::1]/"));
389 }
390
391 #[test]
392 fn test_is_safe_url_allows_public() {
393 assert!(is_safe_url("https://docs.rs/tokio/latest/tokio/"));
394 assert!(is_safe_url("https://api.github.com/repos"));
395 assert!(is_safe_url("https://example.com"));
396 }
397
398 #[test]
401 fn test_is_safe_ip_blocks_private() {
402 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
403 assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
404 assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
405 assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))));
406 assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
407 assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))));
408 assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
409 assert!(!is_safe_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)));
410 assert!(!is_safe_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED)));
411 }
412
413 #[test]
414 fn test_is_safe_ip_allows_public() {
415 use std::net::{IpAddr, Ipv4Addr};
416 assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
417 assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
418 assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34))));
419 }
420
421 #[tokio::test]
422 async fn test_web_fetch_blocks_ssrf() {
423 let args = json!({ "url": "http://169.254.169.254/latest/meta-data/" });
424 let result = web_fetch(&args, 15_000).await;
425 assert!(result.is_err());
426 assert!(result.unwrap_err().to_string().contains("blocked"));
427 }
428
429 #[tokio::test]
430 async fn test_web_fetch_missing_url() {
431 let args = json!({});
432 let result = web_fetch(&args, 15_000).await;
433 assert!(result.is_err());
434 }
435}