Skip to main content

atomcode_core/tool/
web_fetch.rs

1use std::net::IpAddr;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::StreamExt;
6use reqwest::redirect::Policy;
7use serde::Deserialize;
8use serde_json::json;
9use url::Url;
10
11use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
12
13pub struct WebFetchTool;
14
15#[derive(Deserialize)]
16struct WebFetchArgs {
17    url: String,
18    #[serde(default = "default_max_chars")]
19    max_chars: usize,
20}
21
22fn default_max_chars() -> usize {
23    20000
24}
25
26/// Hard cap on the raw response bytes we'll buffer before bailing. Keeps a
27/// hostile server from exhausting memory by streaming indefinitely within the
28/// per-request timeout.
29const MAX_RESPONSE_BYTES: usize = 2 * 1024 * 1024; // 2 MiB
30
31/// Follow at most this many redirects. reqwest's default is 10; tighten to 5
32/// since every hop re-validates DNS + IP and legitimate sites rarely chain
33/// more than 2-3 hops (http→https, apex→www, vanity→canonical).
34const MAX_REDIRECTS: u8 = 5;
35
36const REQUEST_TIMEOUT_SECS: u64 = 20;
37const CONNECT_TIMEOUT_SECS: u64 = 5;
38
39fn validate_scheme(url: &Url) -> Result<(), String> {
40    match url.scheme() {
41        "http" | "https" => Ok(()),
42        other => Err(format!(
43            "scheme `{}` not allowed — only http(s) URLs can be fetched",
44            other
45        )),
46    }
47}
48
49/// Reject IPs that point inside the host / local network / cloud metadata.
50/// Catches the classic SSRF targets: loopback, RFC1918 privates, link-local
51/// (169.254.169.254 AWS/GCP/Azure metadata), CGNAT, reserved ranges, IPv6
52/// ULA, IPv6 link-local, and IPv4-mapped v6 whose underlying v4 is unsafe.
53fn is_safe_ip(ip: IpAddr) -> Result<(), String> {
54    let reject = |category: &str| {
55        Err(format!(
56            "refusing to connect to {ip} ({category}) — SSRF protection"
57        ))
58    };
59    match ip {
60        IpAddr::V4(v4) => {
61            if v4.is_loopback() {
62                return reject("loopback 127.0.0.0/8");
63            }
64            if v4.is_private() {
65                return reject("private network");
66            }
67            if v4.is_link_local() {
68                return reject("link-local / cloud metadata");
69            }
70            if v4.is_broadcast() {
71                return reject("broadcast");
72            }
73            if v4.is_multicast() {
74                return reject("multicast");
75            }
76            if v4.is_unspecified() {
77                return reject("unspecified 0.0.0.0");
78            }
79            let o = v4.octets();
80            if o[0] == 0 {
81                return reject("reserved 0.0.0.0/8");
82            }
83            if o[0] >= 240 {
84                return reject("reserved 240.0.0.0/4");
85            }
86            // CGNAT 100.64.0.0/10 — commonly used as carrier private space
87            if o[0] == 100 && (o[1] & 0xc0) == 64 {
88                return reject("CGNAT 100.64/10");
89            }
90            Ok(())
91        }
92        IpAddr::V6(v6) => {
93            if v6.is_loopback() {
94                return reject("loopback ::1");
95            }
96            if v6.is_unspecified() {
97                return reject("unspecified ::");
98            }
99            if v6.is_multicast() {
100                return reject("multicast");
101            }
102            let first = v6.segments()[0];
103            // Unique local addresses fc00::/7
104            if (first & 0xfe00) == 0xfc00 {
105                return reject("unique-local fc00::/7");
106            }
107            // Link-local fe80::/10 — includes IPv6 metadata endpoints
108            if (first & 0xffc0) == 0xfe80 {
109                return reject("link-local fe80::/10");
110            }
111            // IPv4-mapped ::ffff:a.b.c.d — unwrap and re-check against v4 rules
112            if let Some(mapped) = v6.to_ipv4_mapped() {
113                return is_safe_ip(IpAddr::V4(mapped));
114            }
115            Ok(())
116        }
117    }
118}
119
120/// Resolve the URL's host and check every returned IP. Every address must be
121/// safe: partial acceptance would let a host resolve to [1.2.3.4, 127.0.0.1]
122/// and gamble on which reqwest picks.
123///
124/// Caveat: DNS is looked up here and again by the kernel when reqwest connects
125/// — a TTL=0 attacker could in theory rebind between the two. Mitigation would
126/// require pinning the verified IP into the reqwest client's resolver, which
127/// we can add later if the threat model warrants it. Today's protection still
128/// eliminates the 99% of SSRF attempts that rely on literal-IP or static-DNS
129/// targets (file://, localhost, 169.254.169.254, fixed internal hostnames).
130async fn validate_host(url: &Url) -> Result<(), String> {
131    let host = url
132        .host_str()
133        .ok_or_else(|| format!("URL has no host: {}", url))?;
134    // Literal IP in URL: check directly, bypass DNS.
135    if let Ok(ip) = host.parse::<IpAddr>() {
136        return is_safe_ip(ip);
137    }
138    let port = url.port_or_known_default().unwrap_or(80);
139    let addrs = tokio::net::lookup_host((host, port))
140        .await
141        .map_err(|e| format!("DNS resolution failed for `{}`: {}", host, e))?;
142    let mut saw_any = false;
143    for addr in addrs {
144        saw_any = true;
145        is_safe_ip(addr.ip())?;
146    }
147    if !saw_any {
148        return Err(format!("DNS returned no addresses for `{}`", host));
149    }
150    Ok(())
151}
152
153fn err_result(msg: impl Into<String>) -> ToolResult {
154    ToolResult {
155        call_id: String::new(),
156        output: msg.into(),
157        success: false,
158    }
159}
160
161#[cfg(test)]
162fn host_is_auto_approved(host: &str) -> bool {
163    const ALLOWLIST: &[&str] = &[
164        "github.com",
165        "docs.rs",
166        "raw.githubusercontent.com",
167        "atomgit.com",
168        "gitcode.com",
169        "csdn.net",
170        "openatom.cn",
171    ];
172    let host = host.trim_end_matches('.').to_ascii_lowercase();
173    ALLOWLIST
174        .iter()
175        .any(|allowed| host == *allowed || host.ends_with(&format!(".{}", allowed)))
176}
177
178#[async_trait]
179impl Tool for WebFetchTool {
180    fn definition(&self) -> ToolDef {
181        ToolDef {
182            name: "web_fetch",
183            description: "Fetch a web page and return its content as clean text.\n\
184                Use after web_search to read a specific page (documentation, README, API reference).\n\
185                HTML is automatically converted to readable text.\n\
186                Only http:// and https:// URLs are allowed; requests to localhost, \
187                private networks, and cloud metadata endpoints are blocked.\n\
188                Examples:\n\
189                - {\"url\": \"https://github.com/user/repo\"}\n\
190                - {\"url\": \"https://docs.rs/reqwest/latest/reqwest/\"}".to_string(),
191            parameters: json!({
192                "type": "object",
193                "properties": {
194                    "url": { "type": "string", "description": "Absolute http(s) URL to fetch" },
195                    "max_chars": { "type": "integer", "description": "Max characters to return (default 20000)" }
196                },
197                "required": ["url"]
198            }),
199        }
200    }
201
202    fn approval(&self, args: &str) -> ApprovalRequirement {
203        // web_fetch is always auto-approved. URL validation and scheme checks
204        // are performed during execution - invalid URLs will return an error
205        // result rather than blocking for user approval.
206        let _ = args; // suppress unused variable warning
207        ApprovalRequirement::AutoApprove
208    }
209
210    async fn execute(&self, args: &str, _ctx: &ToolContext) -> Result<ToolResult> {
211        let parsed: WebFetchArgs = match serde_json::from_str(args) {
212            Ok(p) => p,
213            Err(e) => {
214                return Ok(err_result(format!(
215                    "Invalid web_fetch arguments: {}. Provide {{\"url\":\"https://...\"}}.",
216                    e
217                )))
218            }
219        };
220        let max = parsed.max_chars.min(50000);
221
222        let client = match reqwest::Client::builder()
223            // Handle redirects manually so every hop re-runs scheme + IP checks.
224            // reqwest's built-in follower would let a 302 rebind to 127.0.0.1
225            // after we've already validated the start URL's host.
226            .redirect(Policy::none())
227            .connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
228            .timeout(std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS))
229            .user_agent("Mozilla/5.0 (compatible; atomcode/web_fetch)")
230            .build()
231        {
232            Ok(c) => c,
233            Err(e) => return Ok(err_result(format!("Failed to build HTTP client: {}", e))),
234        };
235
236        let mut url = match Url::parse(&parsed.url) {
237            Ok(u) => u,
238            Err(e) => return Ok(err_result(format!("Invalid URL: {}", e))),
239        };
240
241        let mut hops = 0u8;
242        let response = loop {
243            if let Err(e) = validate_scheme(&url) {
244                return Ok(err_result(format!("Blocked: {}", e)));
245            }
246            if let Err(e) = validate_host(&url).await {
247                return Ok(err_result(format!("Blocked: {}", e)));
248            }
249
250            let resp = match client.get(url.clone()).send().await {
251                Ok(r) => r,
252                Err(e) => return Ok(err_result(format!("Failed to fetch {}: {}", url, e))),
253            };
254
255            if !resp.status().is_redirection() {
256                break resp;
257            }
258            if hops >= MAX_REDIRECTS {
259                return Ok(err_result(format!(
260                    "Too many redirects (>{}) starting from {}",
261                    MAX_REDIRECTS, parsed.url
262                )));
263            }
264            let Some(loc) = resp.headers().get(reqwest::header::LOCATION) else {
265                // Redirect status without Location — treat as terminal response
266                // so the caller sees the original status body.
267                break resp;
268            };
269            let loc_str = match loc.to_str() {
270                Ok(s) => s,
271                Err(_) => {
272                    return Ok(err_result(format!(
273                        "Redirect from {} has non-ASCII Location header",
274                        url
275                    )))
276                }
277            };
278            // Location may be relative — resolve against current URL.
279            url = match url.join(loc_str) {
280                Ok(u) => u,
281                Err(e) => {
282                    return Ok(err_result(format!(
283                        "Bad redirect target `{}` from {}: {}",
284                        loc_str, url, e
285                    )))
286                }
287            };
288            hops += 1;
289        };
290
291        let final_url = url.to_string();
292        let status = response.status();
293        if !status.is_success() {
294            return Ok(err_result(format!(
295                "HTTP {} from {}",
296                status.as_u16(),
297                final_url
298            )));
299        }
300
301        let ct_header = response
302            .headers()
303            .get(reqwest::header::CONTENT_TYPE)
304            .and_then(|v| v.to_str().ok())
305            .map(|s| s.to_ascii_lowercase());
306        let ct_is_html = ct_header
307            .as_deref()
308            .map(|s| s.contains("text/html") || s.contains("application/xhtml"))
309            .unwrap_or(false);
310
311        // Stream with a byte cap. Prevents OOM on an endless slow-serve attack
312        // that would otherwise creep under the per-request timeout.
313        let mut stream = response.bytes_stream();
314        let mut buf: Vec<u8> = Vec::with_capacity(16 * 1024);
315        let mut hit_cap = false;
316        while let Some(chunk) = stream.next().await {
317            let chunk = match chunk {
318                Ok(c) => c,
319                Err(e) => {
320                    return Ok(err_result(format!(
321                        "Failed mid-stream for {}: {}",
322                        final_url, e
323                    )))
324                }
325            };
326            if buf.len() + chunk.len() > MAX_RESPONSE_BYTES {
327                let remaining = MAX_RESPONSE_BYTES - buf.len();
328                buf.extend_from_slice(&chunk[..remaining]);
329                hit_cap = true;
330                break;
331            }
332            buf.extend_from_slice(&chunk);
333        }
334
335        if buf.is_empty() {
336            return Ok(err_result(format!("Empty response from {}", final_url)));
337        }
338        let body = String::from_utf8_lossy(&buf).to_string();
339
340        // Fall back to shape-sniffing only when the server sent no Content-Type.
341        // Prevents misclassifying JSON payloads that happen to start with '<'
342        // (rare, but the old code hit this).
343        let is_html = ct_is_html || (ct_header.is_none() && body.trim_start().starts_with('<'));
344        let text = if is_html { html_to_text(&body) } else { body };
345
346        let output = if text.len() > max {
347            let mut end = max;
348            while end > 0 && !text.is_char_boundary(end) {
349                end -= 1;
350            }
351            format!(
352                "{}\n\n[Truncated at {} chars, {} total]",
353                &text[..end],
354                max,
355                text.len()
356            )
357        } else {
358            text
359        };
360
361        if output.trim().is_empty() {
362            return Ok(err_result(format!(
363                "Page fetched but no readable text content found at {}",
364                final_url
365            )));
366        }
367
368        let cap_note = if hit_cap {
369            format!(
370                "\n\n[Response exceeded {} bytes — content was truncated before text extraction]",
371                MAX_RESPONSE_BYTES
372            )
373        } else {
374            String::new()
375        };
376
377        Ok(ToolResult {
378            call_id: String::new(),
379            output: format!("Content from {}:\n\n{}{}", final_url, output, cap_note),
380            success: true,
381        })
382    }
383}
384
385/// Convert HTML to readable plain text.
386/// Handles: block elements as newlines, links, lists, headings, script/style removal.
387fn html_to_text(html: &str) -> String {
388    // Phase 1: Remove script, style, and head content entirely
389    let cleaned = remove_tag_content(html, "script");
390    let cleaned = remove_tag_content(&cleaned, "style");
391    let cleaned = remove_tag_content(&cleaned, "head");
392    let cleaned = remove_tag_content(&cleaned, "nav");
393    let cleaned = remove_tag_content(&cleaned, "footer");
394
395    // Phase 2: Convert block elements to newlines
396    let mut result = cleaned.clone();
397    for tag in &[
398        "p",
399        "div",
400        "br",
401        "li",
402        "tr",
403        "h1",
404        "h2",
405        "h3",
406        "h4",
407        "h5",
408        "h6",
409        "article",
410        "section",
411        "blockquote",
412        "pre",
413        "dd",
414        "dt",
415    ] {
416        // Opening tags → newline
417        result = replace_tag_with(&result, tag, "\n");
418    }
419
420    // Phase 3: Strip remaining HTML tags
421    let mut text = String::with_capacity(result.len());
422    let mut in_tag = false;
423    for c in result.chars() {
424        match c {
425            '<' => in_tag = true,
426            '>' => in_tag = false,
427            _ if !in_tag => text.push(c),
428            _ => {}
429        }
430    }
431
432    // Phase 4: Decode HTML entities
433    let text = text
434        .replace("&amp;", "&")
435        .replace("&lt;", "<")
436        .replace("&gt;", ">")
437        .replace("&quot;", "\"")
438        .replace("&#x27;", "'")
439        .replace("&#39;", "'")
440        .replace("&nbsp;", " ")
441        .replace("&#x2F;", "/")
442        .replace("&apos;", "'")
443        .replace("&#160;", " ");
444
445    // Phase 5: Clean up whitespace — collapse blank lines, trim
446    let mut lines: Vec<&str> = Vec::new();
447    let mut prev_blank = false;
448    for line in text.lines() {
449        let trimmed = line.trim();
450        if trimmed.is_empty() {
451            if !prev_blank && !lines.is_empty() {
452                lines.push("");
453                prev_blank = true;
454            }
455        } else {
456            lines.push(trimmed);
457            prev_blank = false;
458        }
459    }
460
461    // Remove leading/trailing blank lines
462    while lines.first() == Some(&"") {
463        lines.remove(0);
464    }
465    while lines.last() == Some(&"") {
466        lines.pop();
467    }
468
469    lines.join("\n")
470}
471
472/// Remove a specific HTML tag and all its content (e.g., <script>...</script>).
473fn remove_tag_content(html: &str, tag: &str) -> String {
474    let open = format!("<{}", tag);
475    let close = format!("</{}>", tag);
476    let mut result = String::with_capacity(html.len());
477    let mut pos = 0;
478    let lower = html.to_lowercase();
479
480    loop {
481        let Some(rel) = lower[pos..].find(&open) else {
482            result.push_str(&html[pos..]);
483            break;
484        };
485        let abs_start = pos + rel;
486        // Boundary check: `<head` must not match `<header`. The byte right
487        // after `<{tag}` has to be a real tag-name terminator. Without this,
488        // a `<header>` later in the document hijacks the `<head>` pass,
489        // fails to find a `</head>` closer, and (with the old `break`) would
490        // silently drop the rest of the page — which is exactly what was
491        // wiping body text out of gitcode.com SSR pages.
492        let after = abs_start + open.len();
493        let next = lower.as_bytes().get(after).copied();
494        let is_tag_boundary = matches!(
495            next,
496            None | Some(b'>') | Some(b'/') | Some(b' ') | Some(b'\t') | Some(b'\n') | Some(b'\r')
497        );
498        if !is_tag_boundary {
499            // Prefix collision (e.g. `<header` while searching `<head`).
500            // Emit `<` literally, advance one byte, keep scanning.
501            result.push_str(&html[pos..=abs_start]);
502            pos = abs_start + 1;
503            continue;
504        }
505        result.push_str(&html[pos..abs_start]);
506        if let Some(end) = lower[abs_start..].find(&close) {
507            pos = abs_start + end + close.len();
508        } else {
509            // Truly unclosed tag — drop from here to EOF (matches the
510            // historical browser-tolerant behavior for `<script>` etc.).
511            break;
512        }
513    }
514    result
515}
516
517/// Replace opening tags of a given name with a replacement string.
518fn replace_tag_with(html: &str, tag: &str, replacement: &str) -> String {
519    let mut result = String::with_capacity(html.len());
520    let lower = html.to_lowercase();
521    let open = format!("<{}", tag);
522    let mut pos = 0;
523
524    loop {
525        let Some(rel) = lower[pos..].find(&open) else {
526            result.push_str(&html[pos..]);
527            break;
528        };
529        let abs_start = pos + rel;
530        // Same boundary check as remove_tag_content — `<p` must not match
531        // `<pre>`, `<h1` must not match `<h10>` (defensive), etc.
532        let after = abs_start + open.len();
533        let next = lower.as_bytes().get(after).copied();
534        let is_tag_boundary = matches!(
535            next,
536            None | Some(b'>') | Some(b'/') | Some(b' ') | Some(b'\t') | Some(b'\n') | Some(b'\r')
537        );
538        if !is_tag_boundary {
539            result.push_str(&html[pos..=abs_start]);
540            pos = abs_start + 1;
541            continue;
542        }
543        result.push_str(&html[pos..abs_start]);
544        if let Some(end) = html[abs_start..].find('>') {
545            result.push_str(replacement);
546            pos = abs_start + end + 1;
547        } else {
548            pos = abs_start + open.len();
549        }
550    }
551    result
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use std::net::{Ipv4Addr, Ipv6Addr};
558
559    // ── IP safety ──────────────────────────────────────────────────────────
560
561    #[test]
562    fn is_safe_ip_rejects_loopback_v4() {
563        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))).is_err());
564        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(127, 255, 255, 254))).is_err());
565    }
566
567    #[test]
568    fn is_safe_ip_rejects_private_v4() {
569        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))).is_err());
570        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1))).is_err());
571        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255))).is_err());
572        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))).is_err());
573    }
574
575    #[test]
576    fn is_safe_ip_rejects_cloud_metadata() {
577        // The one we really care about — AWS/GCP/Azure instance metadata.
578        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))).is_err());
579    }
580
581    #[test]
582    fn is_safe_ip_rejects_unspecified_and_broadcast() {
583        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))).is_err());
584        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255))).is_err());
585    }
586
587    #[test]
588    fn is_safe_ip_rejects_cgnat() {
589        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1))).is_err());
590        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(100, 127, 255, 255))).is_err());
591        // Boundary: 100.63.x.x is public (not CGNAT), must pass
592        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(100, 63, 0, 1))).is_ok());
593        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(100, 128, 0, 1))).is_ok());
594    }
595
596    #[test]
597    fn is_safe_ip_accepts_public_v4() {
598        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))).is_ok());
599        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))).is_ok());
600        assert!(is_safe_ip(IpAddr::V4(Ipv4Addr::new(140, 82, 112, 3))).is_ok());
601        // github.com range
602    }
603
604    #[test]
605    fn is_safe_ip_rejects_v6_loopback_and_local() {
606        assert!(is_safe_ip(IpAddr::V6(Ipv6Addr::LOCALHOST)).is_err());
607        assert!(is_safe_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED)).is_err());
608        // fc00::/7 ULA
609        assert!(is_safe_ip(IpAddr::V6("fc00::1".parse().unwrap())).is_err());
610        assert!(is_safe_ip(IpAddr::V6("fd12:3456:789a::1".parse().unwrap())).is_err());
611        // fe80::/10 link-local
612        assert!(is_safe_ip(IpAddr::V6("fe80::1".parse().unwrap())).is_err());
613    }
614
615    #[test]
616    fn is_safe_ip_ipv4_mapped_v6_rechecks_against_v4_rules() {
617        // ::ffff:127.0.0.1 must be rejected as loopback
618        let mapped = IpAddr::V6("::ffff:127.0.0.1".parse().unwrap());
619        assert!(is_safe_ip(mapped).is_err());
620        // ::ffff:8.8.8.8 is public — must pass
621        let public_mapped = IpAddr::V6("::ffff:8.8.8.8".parse().unwrap());
622        assert!(is_safe_ip(public_mapped).is_ok());
623    }
624
625    #[test]
626    fn is_safe_ip_accepts_public_v6() {
627        // Google public DNS 2001:4860:4860::8888
628        assert!(is_safe_ip(IpAddr::V6("2001:4860:4860::8888".parse().unwrap())).is_ok());
629    }
630
631    // ── Scheme whitelist ───────────────────────────────────────────────────
632
633    #[test]
634    fn scheme_allows_http_and_https() {
635        assert!(validate_scheme(&Url::parse("http://example.com").unwrap()).is_ok());
636        assert!(validate_scheme(&Url::parse("https://example.com").unwrap()).is_ok());
637    }
638
639    #[test]
640    fn scheme_blocks_file_and_other_protocols() {
641        assert!(validate_scheme(&Url::parse("file:///etc/passwd").unwrap()).is_err());
642        assert!(validate_scheme(&Url::parse("gopher://evil.com/").unwrap()).is_err());
643        assert!(validate_scheme(&Url::parse("ftp://example.com/").unwrap()).is_err());
644        assert!(validate_scheme(&Url::parse("dict://evil.com/").unwrap()).is_err());
645    }
646
647    // ── Auto-approve allowlist ─────────────────────────────────────────────
648
649    #[test]
650    fn auto_approve_known_docs() {
651        assert!(host_is_auto_approved("github.com"));
652        assert!(host_is_auto_approved("api.github.com"));
653        assert!(host_is_auto_approved("docs.rs"));
654        assert!(host_is_auto_approved("raw.githubusercontent.com"));
655    }
656
657    #[test]
658    fn auto_approve_chinese_dev_ecosystem() {
659        // Apex + www subdomain for each — matches real URLs users hand the model.
660        assert!(host_is_auto_approved("atomgit.com"));
661        assert!(host_is_auto_approved("www.atomgit.com"));
662        assert!(host_is_auto_approved("api.atomgit.com"));
663        assert!(host_is_auto_approved("gitcode.com"));
664        assert!(host_is_auto_approved("www.gitcode.com"));
665        assert!(host_is_auto_approved("csdn.net"));
666        assert!(host_is_auto_approved("www.csdn.net"));
667        assert!(host_is_auto_approved("blog.csdn.net"));
668        assert!(host_is_auto_approved("openatom.cn"));
669        assert!(host_is_auto_approved("www.openatom.cn"));
670    }
671
672    #[test]
673    fn auto_approve_is_exact_suffix_match_only() {
674        // Must not match e.g. "evilgithub.com" or "github.com.evil.com".
675        assert!(!host_is_auto_approved("evilgithub.com"));
676        assert!(!host_is_auto_approved("github.com.evil.com"));
677        assert!(!host_is_auto_approved("notdocs.rs"));
678    }
679
680    #[test]
681    fn auto_approve_trailing_dot_tolerated() {
682        // DNS-legal trailing dot shouldn't bypass the match.
683        assert!(host_is_auto_approved("github.com."));
684    }
685
686    #[test]
687    fn auto_approve_is_case_insensitive() {
688        assert!(host_is_auto_approved("GitHub.com"));
689    }
690
691    // ── approval() end-to-end ──────────────────────────────────────────────
692
693    #[test]
694    fn approval_auto_approves_localhost_literal() {
695        let tool = WebFetchTool;
696        let args = r#"{"url":"http://127.0.0.1:8080/"}"#;
697        assert!(matches!(
698            tool.approval(args),
699            ApprovalRequirement::AutoApprove
700        ));
701    }
702
703    #[test]
704    fn approval_auto_approves_file_scheme() {
705        let tool = WebFetchTool;
706        let args = r#"{"url":"file:///etc/passwd"}"#;
707        assert!(matches!(
708            tool.approval(args),
709            ApprovalRequirement::AutoApprove
710        ));
711    }
712
713    #[test]
714    fn approval_auto_approves_github() {
715        let tool = WebFetchTool;
716        let args = r#"{"url":"https://github.com/rust-lang/rust"}"#;
717        assert!(matches!(
718            tool.approval(args),
719            ApprovalRequirement::AutoApprove
720        ));
721    }
722
723    #[test]
724    fn approval_auto_approves_unknown_domain() {
725        let tool = WebFetchTool;
726        let args = r#"{"url":"https://example.com/"}"#;
727        assert!(matches!(
728            tool.approval(args),
729            ApprovalRequirement::AutoApprove
730        ));
731    }
732
733    #[test]
734    fn approval_auto_approves_malformed_args() {
735        let tool = WebFetchTool;
736        assert!(matches!(
737            tool.approval("{}"),
738            ApprovalRequirement::AutoApprove
739        ));
740        assert!(matches!(
741            tool.approval(""),
742            ApprovalRequirement::AutoApprove
743        ));
744    }
745
746    // ── execute() SSRF smoke tests ─────────────────────────────────────────
747
748    #[tokio::test]
749    async fn execute_blocks_file_scheme() {
750        let tool = WebFetchTool;
751        let ctx = ToolContext::new(std::env::temp_dir());
752        let args = r#"{"url":"file:///etc/passwd"}"#;
753        let r = tool.execute(args, &ctx).await.unwrap();
754        assert!(!r.success, "file:// must fail");
755        assert!(
756            r.output.contains("scheme") || r.output.contains("Blocked"),
757            "unexpected error: {}",
758            r.output
759        );
760    }
761
762    #[tokio::test]
763    async fn execute_blocks_localhost() {
764        let tool = WebFetchTool;
765        let ctx = ToolContext::new(std::env::temp_dir());
766        let args = r#"{"url":"http://127.0.0.1:1/"}"#;
767        let r = tool.execute(args, &ctx).await.unwrap();
768        assert!(!r.success, "127.0.0.1 must fail");
769        assert!(
770            r.output.contains("Blocked") || r.output.contains("SSRF"),
771            "unexpected error: {}",
772            r.output
773        );
774    }
775
776    #[tokio::test]
777    async fn execute_blocks_cloud_metadata() {
778        let tool = WebFetchTool;
779        let ctx = ToolContext::new(std::env::temp_dir());
780        let args = r#"{"url":"http://169.254.169.254/latest/meta-data/"}"#;
781        let r = tool.execute(args, &ctx).await.unwrap();
782        assert!(!r.success, "cloud metadata must fail");
783        assert!(
784            r.output.contains("Blocked") || r.output.contains("SSRF"),
785            "unexpected error: {}",
786            r.output
787        );
788    }
789
790    #[tokio::test]
791    async fn execute_blocks_private_network() {
792        let tool = WebFetchTool;
793        let ctx = ToolContext::new(std::env::temp_dir());
794        let args = r#"{"url":"http://10.0.0.1/"}"#;
795        let r = tool.execute(args, &ctx).await.unwrap();
796        assert!(!r.success, "10.0.0.1 must fail");
797    }
798
799    #[tokio::test]
800    async fn execute_rejects_url_that_looks_like_curl_flag() {
801        // Pre-refactor the old curl-based impl would parse `-Kfoo` as a flag.
802        // The new impl parses with url::Url which rejects anything that
803        // doesn't start with a valid scheme, so this fails at URL parse.
804        let tool = WebFetchTool;
805        let ctx = ToolContext::new(std::env::temp_dir());
806        let args = r#"{"url":"-K/etc/passwd"}"#;
807        let r = tool.execute(args, &ctx).await.unwrap();
808        assert!(!r.success);
809        assert!(
810            r.output.contains("Invalid URL") || r.output.contains("scheme"),
811            "unexpected error: {}",
812            r.output
813        );
814    }
815
816    // ── html_to_text / tag matching ────────────────────────────────────────
817
818    #[test]
819    fn remove_tag_content_keeps_prefix_collision_tags() {
820        // Repro for the gitcode.com/cann SSR page: `<head>...</head>` is
821        // followed later by `<header>...</header>`. The old naive prefix
822        // match treated `<header` as if it opened a `<head>` block, searched
823        // for a `</head>` that did not exist, and silently dropped the rest
824        // of the document.
825        let html = "<head><title>t</title></head>\
826                    <body><header>nav</header><main>BODY-CONTENT</main></body>";
827        let out = remove_tag_content(html, "head");
828        assert!(
829            out.contains("BODY-CONTENT"),
830            "body content was discarded: {}",
831            out
832        );
833        assert!(
834            out.contains("<header>nav</header>"),
835            "header element should be preserved (only <head> removed): {}",
836            out
837        );
838        assert!(
839            !out.contains("<title>"),
840            "real <head> contents must still be removed: {}",
841            out
842        );
843    }
844
845    #[test]
846    fn replace_tag_with_keeps_prefix_collision_tags() {
847        // Same boundary bug surface: replacing `<p>` opens must not also
848        // replace `<pre>` opens.
849        let out = replace_tag_with("<p>A</p><pre>B</pre>", "p", "\n");
850        // `<p>` becomes "\n", but `<pre>` must stay untouched.
851        assert!(
852            out.contains("<pre>B</pre>"),
853            "<pre> should not be matched by <p>: {}",
854            out
855        );
856    }
857
858    #[test]
859    fn html_to_text_extracts_body_when_header_follows_head() {
860        // End-to-end: structure mirrors what gitcode.com/cann ships.
861        let html = "<!doctype html><html><head><title>x</title></head>\
862                    <body><header class=\"nav\">topbar</header>\
863                    <main><h1>Title</h1><p>Real article text.</p></main>\
864                    </body></html>";
865        let text = html_to_text(html);
866        assert!(
867            text.contains("Real article text."),
868            "main body lost: {:?}",
869            text
870        );
871        assert!(text.contains("Title"), "heading lost: {:?}", text);
872    }
873
874    #[test]
875    fn remove_tag_content_handles_truly_unclosed_tag() {
876        // If a tag really has no closing element, the function should still
877        // surface earlier content rather than dropping everything from the
878        // unclosed tag onward. We accept either: the open-tag-and-after is
879        // kept verbatim, OR is stripped — but content BEFORE it must survive.
880        let html = "<p>KEEP-ME</p><script>oops no close";
881        let out = remove_tag_content(html, "script");
882        assert!(out.contains("KEEP-ME"), "leading content lost: {}", out);
883    }
884}