Skip to main content

koda_core/tools/
web_fetch.rs

1//! WebFetch tool — retrieve content from a URL.
2//!
3//! Fetches a web page and converts HTML to readable text.
4//! Body cap is set by `OutputCaps` (context-scaled).
5//!
6//! ## Parameters
7//!
8//! - **`url`** (required) — The URL to fetch
9//!
10//! ## Behavior
11//!
12//! - HTML pages are converted to clean text (strips tags, scripts, styles)
13//! - JSON and plain text are returned as-is
14//! - Output is truncated to context-scaled caps
15//! - Follows redirects (up to 10 hops)
16//! - Timeout: 15 seconds (see `DEFAULT_TIMEOUT_SECS`)
17
18use crate::providers::ToolDefinition;
19use anyhow::Result;
20use serde_json::{Value, json};
21
22const DEFAULT_TIMEOUT_SECS: u64 = 15;
23
24/// Return tool definitions for the LLM.
25pub fn definitions() -> Vec<ToolDefinition> {
26    vec![ToolDefinition {
27        name: "WebFetch".to_string(),
28        description: "Fetch content from a URL. HTML is stripped to readable text by default; \
29            set raw=true for raw HTML. Only use URLs from tool results or user input — \
30            never guess or generate URLs from memory. \
31            For documentation lookup, prefer reading local files first."
32            .to_string(),
33        parameters: json!({
34            "type": "object",
35            "properties": {
36                "url": {
37                    "type": "string",
38                    "description": "The URL to fetch (must start with http:// or https://)"
39                },
40                "raw": {
41                    "type": "boolean",
42                    "description": "If true, return raw HTML instead of stripped text (default: false)"
43                }
44            },
45            "required": ["url"]
46        }),
47    }]
48}
49
50/// Fetch a URL and return its content.
51pub async fn web_fetch(args: &Value, max_body_chars: usize) -> Result<String> {
52    let url = args["url"]
53        .as_str()
54        .ok_or_else(|| anyhow::anyhow!("Missing 'url' argument"))?;
55    let raw = args["raw"].as_bool().unwrap_or(false);
56
57    if !url.starts_with("http://") && !url.starts_with("https://") {
58        anyhow::bail!("URL must start with http:// or https://");
59    }
60
61    // SSRF protection: block requests to internal/private networks
62    if !is_safe_url(url) {
63        anyhow::bail!(
64            "URL blocked: requests to internal/private networks are not allowed. \
65             This includes localhost, private IPs, and cloud metadata endpoints."
66        );
67    }
68
69    // DNS rebinding protection: resolve the hostname and verify the IP
70    // is not private/internal before making the request. This prevents
71    // TOCTOU attacks where DNS re-resolves to a different IP.
72    if let Ok(parsed) = url::Url::parse(url)
73        && let Some(host) = parsed.host_str()
74    {
75        // Only resolve domain names (IPs are already checked above)
76        if parsed
77            .host()
78            .is_some_and(|h| matches!(h, url::Host::Domain(_)))
79        {
80            match tokio::net::lookup_host(format!(
81                "{}:{}",
82                host,
83                parsed.port_or_known_default().unwrap_or(80)
84            ))
85            .await
86            {
87                Ok(addrs) => {
88                    for addr in addrs {
89                        if !is_safe_ip(addr.ip()) {
90                            anyhow::bail!(
91                                "URL blocked: domain '{host}' resolves to private/internal IP {}.",
92                                addr.ip()
93                            );
94                        }
95                    }
96                }
97                Err(e) => {
98                    anyhow::bail!("DNS resolution failed for '{host}': {e}");
99                }
100            }
101        }
102    }
103
104    static HTTP_CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
105    let client = HTTP_CLIENT
106        .get_or_init(|| crate::providers::build_http_client(None))
107        .clone();
108    let response = tokio::time::timeout(
109        std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS),
110        client
111            .get(url)
112            .header("User-Agent", "Koda/0.1 (AI coding agent)")
113            .send(),
114    )
115    .await
116    .map_err(|_| anyhow::anyhow!("Request timed out after {DEFAULT_TIMEOUT_SECS}s"))?
117    .map_err(|e| anyhow::anyhow!("HTTP request failed: {e}"))?;
118
119    let status = response.status();
120    if !status.is_success() {
121        anyhow::bail!("HTTP {status} for {url}");
122    }
123
124    let body = response
125        .text()
126        .await
127        .map_err(|e| anyhow::anyhow!("Failed to read response body: {e}"))?;
128
129    let content = if raw { body } else { strip_html(&body) };
130
131    if content.len() > max_body_chars {
132        Ok(format!(
133            "{}\n\n[TRUNCATED: response was {} chars. \
134             Consider fetching a more specific URL.]",
135            &content[..max_body_chars],
136            content.len()
137        ))
138    } else {
139        Ok(content)
140    }
141}
142
143/// Check if an IP address is safe (not private/internal/loopback).
144pub(crate) fn is_safe_ip(ip: std::net::IpAddr) -> bool {
145    match ip {
146        std::net::IpAddr::V4(ipv4) => {
147            let octets = ipv4.octets();
148            // Loopback, private, link-local, unspecified
149            if octets[0] == 127
150                || octets[0] == 10
151                || (octets[0] == 172 && (16..=31).contains(&octets[1]))
152                || (octets[0] == 192 && octets[1] == 168)
153                || (octets[0] == 169 && octets[1] == 254)
154                || ipv4.is_unspecified()
155            {
156                return false;
157            }
158            true
159        }
160        std::net::IpAddr::V6(ipv6) => {
161            if ipv6.is_loopback() || ipv6.is_unspecified() {
162                return false;
163            }
164            if let Some(ipv4) = ipv6.to_ipv4_mapped() {
165                return is_safe_ip(std::net::IpAddr::V4(ipv4));
166            }
167            true
168        }
169    }
170}
171
172/// Check if a URL is safe to fetch (not internal/private network).
173/// Uses the `url` crate for robust parsing (handles userinfo@, IPv6, etc.).
174pub(crate) fn is_safe_url(url_str: &str) -> bool {
175    let Ok(parsed) = url::Url::parse(url_str) else {
176        return false;
177    };
178    let Some(host) = parsed.host_str() else {
179        return false;
180    };
181
182    // Block known metadata hostnames
183    let blocked_hosts = [
184        "169.254.169.254",
185        "metadata.google.internal",
186        "metadata.internal",
187        "localhost",
188        "0.0.0.0",
189    ];
190    if blocked_hosts.contains(&host) {
191        return false;
192    }
193
194    // Block .internal and .local TLDs
195    if host.ends_with(".internal") || host.ends_with(".local") {
196        return false;
197    }
198
199    // Block private/reserved IPs using the parsed host
200    match parsed.host() {
201        Some(url::Host::Ipv4(ip)) => {
202            if !is_safe_ip(std::net::IpAddr::V4(ip)) {
203                return false;
204            }
205        }
206        Some(url::Host::Ipv6(ip)) => {
207            if !is_safe_ip(std::net::IpAddr::V6(ip)) {
208                return false;
209            }
210        }
211        Some(url::Host::Domain(_)) => {
212            // Domain names — hostname checks above are sufficient
213            // (DNS resolution check happens separately in web_fetch)
214        }
215        None => return false,
216    }
217
218    true
219}
220
221/// Strip HTML tags and collapse whitespace for readability.
222fn strip_html(html: &str) -> String {
223    let mut result = String::with_capacity(html.len());
224    let mut in_tag = false;
225    let mut in_script = false;
226    let mut in_style = false;
227    let mut last_was_space = false;
228
229    let lower = html.to_lowercase();
230    let chars: Vec<char> = html.chars().collect();
231    let lower_chars: Vec<char> = lower.chars().collect();
232
233    let mut i = 0;
234    while i < chars.len() {
235        if in_script {
236            // Skip until </script>
237            if i + 9 <= lower_chars.len()
238                && lower_chars[i..i + 9].iter().collect::<String>() == "</script>"
239            {
240                in_script = false;
241                i += 9;
242            } else {
243                i += 1;
244            }
245            continue;
246        }
247        if in_style {
248            if i + 8 <= lower_chars.len()
249                && lower_chars[i..i + 8].iter().collect::<String>() == "</style>"
250            {
251                in_style = false;
252                i += 8;
253            } else {
254                i += 1;
255            }
256            continue;
257        }
258
259        if chars[i] == '<' {
260            // Check for <script or <style
261            if i + 7 <= lower_chars.len()
262                && lower_chars[i..i + 7].iter().collect::<String>() == "<script"
263            {
264                in_script = true;
265            } else if i + 6 <= lower_chars.len()
266                && lower_chars[i..i + 6].iter().collect::<String>() == "<style"
267            {
268                in_style = true;
269            }
270            in_tag = true;
271            // Block-level tags → newline
272            let tag_start: String = lower_chars[i..std::cmp::min(i + 10, lower_chars.len())]
273                .iter()
274                .collect();
275            if tag_start.starts_with("<br")
276                || tag_start.starts_with("<p")
277                || tag_start.starts_with("<div")
278                || tag_start.starts_with("<h")
279                || tag_start.starts_with("<li")
280                || tag_start.starts_with("<tr")
281            {
282                result.push('\n');
283                last_was_space = true;
284            }
285            i += 1;
286            continue;
287        }
288
289        if chars[i] == '>' {
290            in_tag = false;
291            i += 1;
292            continue;
293        }
294
295        if !in_tag {
296            let ch = chars[i];
297            if ch.is_whitespace() {
298                if !last_was_space {
299                    result.push(' ');
300                    last_was_space = true;
301                }
302            } else {
303                result.push(ch);
304                last_was_space = false;
305            }
306        }
307        i += 1;
308    }
309
310    // Decode common HTML entities
311    result
312        .replace("&amp;", "&")
313        .replace("&lt;", "<")
314        .replace("&gt;", ">")
315        .replace("&quot;", "\"")
316        .replace("&#39;", "'")
317        .replace("&nbsp;", " ")
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_strip_html_basic() {
326        let html = "<h1>Hello</h1><p>World &amp; friends</p>";
327        let result = strip_html(html);
328        assert!(result.contains("Hello"));
329        assert!(result.contains("World & friends"));
330        assert!(!result.contains("<h1>"));
331    }
332
333    #[test]
334    fn test_strip_html_script_removal() {
335        let html = "<p>Before</p><script>alert('xss')</script><p>After</p>";
336        let result = strip_html(html);
337        assert!(result.contains("Before"));
338        assert!(result.contains("After"));
339        assert!(!result.contains("alert"));
340    }
341
342    #[test]
343    fn test_strip_html_whitespace_collapse() {
344        let html = "<p>  lots   of    spaces  </p>";
345        let result = strip_html(html);
346        assert!(!result.contains("   ")); // No triple spaces
347    }
348
349    #[tokio::test]
350    async fn test_web_fetch_bad_url() {
351        let args = json!({ "url": "not-a-url" });
352        let result = web_fetch(&args, 15_000).await;
353        assert!(result.is_err());
354    }
355
356    #[test]
357    fn test_is_safe_url_blocks_metadata() {
358        assert!(!is_safe_url("http://169.254.169.254/latest/meta-data/"));
359        assert!(!is_safe_url("http://metadata.google.internal/"));
360    }
361
362    #[test]
363    fn test_is_safe_url_blocks_localhost() {
364        assert!(!is_safe_url("http://localhost:8080/admin"));
365        assert!(!is_safe_url("http://127.0.0.1/secret"));
366        assert!(!is_safe_url("http://0.0.0.0/"));
367    }
368
369    #[test]
370    fn test_is_safe_url_blocks_private_ips() {
371        assert!(!is_safe_url("http://10.0.0.1/internal"));
372        assert!(!is_safe_url("http://172.16.0.1/admin"));
373        assert!(!is_safe_url("http://192.168.1.1/config"));
374    }
375
376    #[test]
377    fn test_is_safe_url_blocks_userinfo_bypass() {
378        // RFC 3986 userinfo@ component should not fool the parser
379        assert!(!is_safe_url(
380            "http://evil.com@169.254.169.254/latest/meta-data/"
381        ));
382        assert!(!is_safe_url("http://user:pass@127.0.0.1/"));
383    }
384
385    #[test]
386    fn test_is_safe_url_blocks_ipv6_mapped() {
387        assert!(!is_safe_url("http://[::ffff:127.0.0.1]/"));
388        assert!(!is_safe_url("http://[::1]/"));
389    }
390
391    #[test]
392    fn test_is_safe_url_allows_public() {
393        assert!(is_safe_url("https://docs.rs/tokio/latest/tokio/"));
394        assert!(is_safe_url("https://api.github.com/repos"));
395        assert!(is_safe_url("https://example.com"));
396    }
397
398    // ── is_safe_ip tests (#526) ──
399
400    #[test]
401    fn test_is_safe_ip_blocks_private() {
402        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
403        assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
404        assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))));
405        assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))));
406        assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
407        assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))));
408        assert!(!is_safe_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)));
409        assert!(!is_safe_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)));
410        assert!(!is_safe_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED)));
411    }
412
413    #[test]
414    fn test_is_safe_ip_allows_public() {
415        use std::net::{IpAddr, Ipv4Addr};
416        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))));
417        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))));
418        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34))));
419    }
420
421    #[tokio::test]
422    async fn test_web_fetch_blocks_ssrf() {
423        let args = json!({ "url": "http://169.254.169.254/latest/meta-data/" });
424        let result = web_fetch(&args, 15_000).await;
425        assert!(result.is_err());
426        assert!(result.unwrap_err().to_string().contains("blocked"));
427    }
428
429    #[tokio::test]
430    async fn test_web_fetch_missing_url() {
431        let args = json!({});
432        let result = web_fetch(&args, 15_000).await;
433        assert!(result.is_err());
434    }
435}