Skip to main content

zeph_tools/
scrape.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::{IpAddr, SocketAddr};
5use std::time::Duration;
6
7use schemars::JsonSchema;
8use serde::Deserialize;
9use url::Url;
10
11use crate::config::ScrapeConfig;
12use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
13
14#[derive(Debug, Deserialize, JsonSchema)]
15struct FetchParams {
16    /// HTTPS URL to fetch
17    url: String,
18}
19
20#[derive(Debug, Deserialize, JsonSchema)]
21struct ScrapeInstruction {
22    /// HTTPS URL to scrape
23    url: String,
24    /// CSS selector
25    select: String,
26    /// Extract mode: text, html, or attr:<name>
27    #[serde(default = "default_extract")]
28    extract: String,
29    /// Max results to return
30    limit: Option<usize>,
31}
32
33fn default_extract() -> String {
34    "text".into()
35}
36
37#[derive(Debug)]
38enum ExtractMode {
39    Text,
40    Html,
41    Attr(String),
42}
43
44impl ExtractMode {
45    fn parse(s: &str) -> Self {
46        match s {
47            "text" => Self::Text,
48            "html" => Self::Html,
49            attr if attr.starts_with("attr:") => {
50                Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
51            }
52            _ => Self::Text,
53        }
54    }
55}
56
57/// Extracts data from web pages via CSS selectors.
58///
59/// Detects ` ```scrape ` blocks in LLM responses containing JSON instructions,
60/// fetches the URL, and parses HTML with `scrape-core`.
61#[derive(Debug)]
62pub struct WebScrapeExecutor {
63    timeout: Duration,
64    max_body_bytes: usize,
65}
66
67impl WebScrapeExecutor {
68    #[must_use]
69    pub fn new(config: &ScrapeConfig) -> Self {
70        Self {
71            timeout: Duration::from_secs(config.timeout),
72            max_body_bytes: config.max_body_bytes,
73        }
74    }
75
76    fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
77        let mut builder = reqwest::Client::builder()
78            .timeout(self.timeout)
79            .redirect(reqwest::redirect::Policy::none());
80        builder = builder.resolve_to_addrs(host, addrs);
81        builder.build().unwrap_or_default()
82    }
83}
84
85impl ToolExecutor for WebScrapeExecutor {
86    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
87        use crate::registry::{InvocationHint, ToolDef};
88        vec![
89            ToolDef {
90                id: "web_scrape".into(),
91                description: "Scrape data from a web page via CSS selectors".into(),
92                schema: schemars::schema_for!(ScrapeInstruction),
93                invocation: InvocationHint::FencedBlock("scrape"),
94            },
95            ToolDef {
96                id: "fetch".into(),
97                description: "Fetch a URL and return content as plain text".into(),
98                schema: schemars::schema_for!(FetchParams),
99                invocation: InvocationHint::ToolCall,
100            },
101        ]
102    }
103
104    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
105        let blocks = extract_scrape_blocks(response);
106        if blocks.is_empty() {
107            return Ok(None);
108        }
109
110        let mut outputs = Vec::with_capacity(blocks.len());
111        #[allow(clippy::cast_possible_truncation)]
112        let blocks_executed = blocks.len() as u32;
113
114        for block in &blocks {
115            let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
116                ToolError::Execution(std::io::Error::new(
117                    std::io::ErrorKind::InvalidData,
118                    e.to_string(),
119                ))
120            })?;
121            outputs.push(self.scrape_instruction(&instruction).await?);
122        }
123
124        Ok(Some(ToolOutput {
125            tool_name: "web-scrape".to_owned(),
126            summary: outputs.join("\n\n"),
127            blocks_executed,
128            filter_stats: None,
129            diff: None,
130            streamed: false,
131            terminal_id: None,
132            locations: None,
133            raw_response: None,
134        }))
135    }
136
137    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
138        match call.tool_id.as_str() {
139            "web_scrape" => {
140                let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
141                let result = self.scrape_instruction(&instruction).await?;
142                Ok(Some(ToolOutput {
143                    tool_name: "web-scrape".to_owned(),
144                    summary: result,
145                    blocks_executed: 1,
146                    filter_stats: None,
147                    diff: None,
148                    streamed: false,
149                    terminal_id: None,
150                    locations: None,
151                    raw_response: None,
152                }))
153            }
154            "fetch" => {
155                let p: FetchParams = deserialize_params(&call.params)?;
156                let result = self.handle_fetch(&p).await?;
157                Ok(Some(ToolOutput {
158                    tool_name: "fetch".to_owned(),
159                    summary: result,
160                    blocks_executed: 1,
161                    filter_stats: None,
162                    diff: None,
163                    streamed: false,
164                    terminal_id: None,
165                    locations: None,
166                    raw_response: None,
167                }))
168            }
169            _ => Ok(None),
170        }
171    }
172}
173
174impl WebScrapeExecutor {
175    async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
176        let parsed = validate_url(&params.url)?;
177        let (host, addrs) = resolve_and_validate(&parsed).await?;
178        self.fetch_html(&params.url, &host, &addrs).await
179    }
180
181    async fn scrape_instruction(
182        &self,
183        instruction: &ScrapeInstruction,
184    ) -> Result<String, ToolError> {
185        let parsed = validate_url(&instruction.url)?;
186        let (host, addrs) = resolve_and_validate(&parsed).await?;
187        let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
188        let selector = instruction.select.clone();
189        let extract = ExtractMode::parse(&instruction.extract);
190        let limit = instruction.limit.unwrap_or(10);
191        tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
192            .await
193            .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
194    }
195
196    /// Fetches the HTML at `url`, manually following up to 3 redirects.
197    ///
198    /// Each redirect target is validated with `validate_url` and `resolve_and_validate`
199    /// before following, preventing SSRF via redirect chains.
200    ///
201    /// # Errors
202    ///
203    /// Returns `ToolError::Blocked` if any redirect target resolves to a private IP.
204    /// Returns `ToolError::Execution` on HTTP errors, too-large bodies, or too many redirects.
205    async fn fetch_html(
206        &self,
207        url: &str,
208        host: &str,
209        addrs: &[SocketAddr],
210    ) -> Result<String, ToolError> {
211        const MAX_REDIRECTS: usize = 3;
212
213        let mut current_url = url.to_owned();
214        let mut current_host = host.to_owned();
215        let mut current_addrs = addrs.to_vec();
216
217        for hop in 0..=MAX_REDIRECTS {
218            // Build a per-hop client pinned to the current hop's validated addresses.
219            let client = self.build_client(&current_host, &current_addrs);
220            let resp = client
221                .get(&current_url)
222                .send()
223                .await
224                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
225
226            let status = resp.status();
227
228            if status.is_redirection() {
229                if hop == MAX_REDIRECTS {
230                    return Err(ToolError::Execution(std::io::Error::other(
231                        "too many redirects",
232                    )));
233                }
234
235                let location = resp
236                    .headers()
237                    .get(reqwest::header::LOCATION)
238                    .and_then(|v| v.to_str().ok())
239                    .ok_or_else(|| {
240                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
241                    })?;
242
243                // Resolve relative redirect URLs against the current URL.
244                let base = Url::parse(&current_url)
245                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
246                let next_url = base
247                    .join(location)
248                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
249
250                let validated = validate_url(next_url.as_str())?;
251                let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
252
253                current_url = next_url.to_string();
254                current_host = next_host;
255                current_addrs = next_addrs;
256                continue;
257            }
258
259            if !status.is_success() {
260                return Err(ToolError::Execution(std::io::Error::other(format!(
261                    "HTTP {status}",
262                ))));
263            }
264
265            let bytes = resp
266                .bytes()
267                .await
268                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
269
270            if bytes.len() > self.max_body_bytes {
271                return Err(ToolError::Execution(std::io::Error::other(format!(
272                    "response too large: {} bytes (max: {})",
273                    bytes.len(),
274                    self.max_body_bytes,
275                ))));
276            }
277
278            return String::from_utf8(bytes.to_vec())
279                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
280        }
281
282        Err(ToolError::Execution(std::io::Error::other(
283            "too many redirects",
284        )))
285    }
286}
287
288fn extract_scrape_blocks(text: &str) -> Vec<&str> {
289    crate::executor::extract_fenced_blocks(text, "scrape")
290}
291
292fn validate_url(raw: &str) -> Result<Url, ToolError> {
293    let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
294        command: format!("invalid URL: {raw}"),
295    })?;
296
297    if parsed.scheme() != "https" {
298        return Err(ToolError::Blocked {
299            command: format!("scheme not allowed: {}", parsed.scheme()),
300        });
301    }
302
303    if let Some(host) = parsed.host()
304        && is_private_host(&host)
305    {
306        return Err(ToolError::Blocked {
307            command: format!(
308                "private/local host blocked: {}",
309                parsed.host_str().unwrap_or("")
310            ),
311        });
312    }
313
314    Ok(parsed)
315}
316
317pub(crate) fn is_private_ip(ip: IpAddr) -> bool {
318    match ip {
319        IpAddr::V4(v4) => {
320            v4.is_loopback()
321                || v4.is_private()
322                || v4.is_link_local()
323                || v4.is_unspecified()
324                || v4.is_broadcast()
325        }
326        IpAddr::V6(v6) => {
327            if v6.is_loopback() || v6.is_unspecified() {
328                return true;
329            }
330            let seg = v6.segments();
331            // fe80::/10 — link-local
332            if seg[0] & 0xffc0 == 0xfe80 {
333                return true;
334            }
335            // fc00::/7 — unique local
336            if seg[0] & 0xfe00 == 0xfc00 {
337                return true;
338            }
339            // ::ffff:x.x.x.x — IPv4-mapped, check inner IPv4
340            if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
341                let v4 = v6
342                    .to_ipv4_mapped()
343                    .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
344                return v4.is_loopback()
345                    || v4.is_private()
346                    || v4.is_link_local()
347                    || v4.is_unspecified()
348                    || v4.is_broadcast();
349            }
350            false
351        }
352    }
353}
354
355fn is_private_host(host: &url::Host<&str>) -> bool {
356    match host {
357        url::Host::Domain(d) => {
358            // Exact match or subdomain of localhost (e.g. foo.localhost)
359            // and .internal/.local TLDs used in cloud/k8s environments.
360            #[allow(clippy::case_sensitive_file_extension_comparisons)]
361            {
362                *d == "localhost"
363                    || d.ends_with(".localhost")
364                    || d.ends_with(".internal")
365                    || d.ends_with(".local")
366            }
367        }
368        url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
369        url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
370    }
371}
372
373/// Resolves DNS for the URL host, validates all resolved IPs against private ranges,
374/// and returns the hostname and validated socket addresses.
375///
376/// Returning the addresses allows the caller to pin the HTTP client to these exact
377/// addresses, eliminating TOCTOU between DNS validation and the actual connection.
378async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
379    let Some(host) = url.host_str() else {
380        return Ok((String::new(), vec![]));
381    };
382    let port = url.port_or_known_default().unwrap_or(443);
383    let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
384        .await
385        .map_err(|e| ToolError::Blocked {
386            command: format!("DNS resolution failed: {e}"),
387        })?
388        .collect();
389    for addr in &addrs {
390        if is_private_ip(addr.ip()) {
391            return Err(ToolError::Blocked {
392                command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
393            });
394        }
395    }
396    Ok((host.to_owned(), addrs))
397}
398
399fn parse_and_extract(
400    html: &str,
401    selector: &str,
402    extract: &ExtractMode,
403    limit: usize,
404) -> Result<String, ToolError> {
405    let soup = scrape_core::Soup::parse(html);
406
407    let tags = soup.find_all(selector).map_err(|e| {
408        ToolError::Execution(std::io::Error::new(
409            std::io::ErrorKind::InvalidData,
410            format!("invalid selector: {e}"),
411        ))
412    })?;
413
414    let mut results = Vec::new();
415
416    for tag in tags.into_iter().take(limit) {
417        let value = match extract {
418            ExtractMode::Text => tag.text(),
419            ExtractMode::Html => tag.inner_html(),
420            ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
421        };
422        if !value.trim().is_empty() {
423            results.push(value.trim().to_owned());
424        }
425    }
426
427    if results.is_empty() {
428        Ok(format!("No results for selector: {selector}"))
429    } else {
430        Ok(results.join("\n"))
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    // --- extract_scrape_blocks ---
439
440    #[test]
441    fn extract_single_block() {
442        let text =
443            "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
444        let blocks = extract_scrape_blocks(text);
445        assert_eq!(blocks.len(), 1);
446        assert!(blocks[0].contains("example.com"));
447    }
448
449    #[test]
450    fn extract_multiple_blocks() {
451        let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
452        let blocks = extract_scrape_blocks(text);
453        assert_eq!(blocks.len(), 2);
454    }
455
456    #[test]
457    fn no_blocks_returns_empty() {
458        let blocks = extract_scrape_blocks("plain text, no code blocks");
459        assert!(blocks.is_empty());
460    }
461
462    #[test]
463    fn unclosed_block_ignored() {
464        let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
465        assert!(blocks.is_empty());
466    }
467
468    #[test]
469    fn non_scrape_block_ignored() {
470        let text =
471            "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
472        let blocks = extract_scrape_blocks(text);
473        assert_eq!(blocks.len(), 1);
474        assert!(blocks[0].contains("x.com"));
475    }
476
477    #[test]
478    fn multiline_json_block() {
479        let text =
480            "```scrape\n{\n  \"url\": \"https://example.com\",\n  \"select\": \"h1\"\n}\n```";
481        let blocks = extract_scrape_blocks(text);
482        assert_eq!(blocks.len(), 1);
483        let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
484        assert_eq!(instr.url, "https://example.com");
485    }
486
487    // --- ScrapeInstruction parsing ---
488
489    #[test]
490    fn parse_valid_instruction() {
491        let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
492        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
493        assert_eq!(instr.url, "https://example.com");
494        assert_eq!(instr.select, "h1");
495        assert_eq!(instr.extract, "text");
496        assert_eq!(instr.limit, Some(5));
497    }
498
499    #[test]
500    fn parse_minimal_instruction() {
501        let json = r#"{"url":"https://example.com","select":"p"}"#;
502        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
503        assert_eq!(instr.extract, "text");
504        assert!(instr.limit.is_none());
505    }
506
507    #[test]
508    fn parse_attr_extract() {
509        let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
510        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
511        assert_eq!(instr.extract, "attr:href");
512    }
513
514    #[test]
515    fn parse_invalid_json_errors() {
516        let result = serde_json::from_str::<ScrapeInstruction>("not json");
517        assert!(result.is_err());
518    }
519
520    // --- ExtractMode ---
521
522    #[test]
523    fn extract_mode_text() {
524        assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
525    }
526
527    #[test]
528    fn extract_mode_html() {
529        assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
530    }
531
532    #[test]
533    fn extract_mode_attr() {
534        let mode = ExtractMode::parse("attr:href");
535        assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
536    }
537
538    #[test]
539    fn extract_mode_unknown_defaults_to_text() {
540        assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
541    }
542
543    // --- validate_url ---
544
545    #[test]
546    fn valid_https_url() {
547        assert!(validate_url("https://example.com").is_ok());
548    }
549
550    #[test]
551    fn http_rejected() {
552        let err = validate_url("http://example.com").unwrap_err();
553        assert!(matches!(err, ToolError::Blocked { .. }));
554    }
555
556    #[test]
557    fn ftp_rejected() {
558        let err = validate_url("ftp://files.example.com").unwrap_err();
559        assert!(matches!(err, ToolError::Blocked { .. }));
560    }
561
562    #[test]
563    fn file_rejected() {
564        let err = validate_url("file:///etc/passwd").unwrap_err();
565        assert!(matches!(err, ToolError::Blocked { .. }));
566    }
567
568    #[test]
569    fn invalid_url_rejected() {
570        let err = validate_url("not a url").unwrap_err();
571        assert!(matches!(err, ToolError::Blocked { .. }));
572    }
573
574    #[test]
575    fn localhost_blocked() {
576        let err = validate_url("https://localhost/path").unwrap_err();
577        assert!(matches!(err, ToolError::Blocked { .. }));
578    }
579
580    #[test]
581    fn loopback_ip_blocked() {
582        let err = validate_url("https://127.0.0.1/path").unwrap_err();
583        assert!(matches!(err, ToolError::Blocked { .. }));
584    }
585
586    #[test]
587    fn private_10_blocked() {
588        let err = validate_url("https://10.0.0.1/api").unwrap_err();
589        assert!(matches!(err, ToolError::Blocked { .. }));
590    }
591
592    #[test]
593    fn private_172_blocked() {
594        let err = validate_url("https://172.16.0.1/api").unwrap_err();
595        assert!(matches!(err, ToolError::Blocked { .. }));
596    }
597
598    #[test]
599    fn private_192_blocked() {
600        let err = validate_url("https://192.168.1.1/api").unwrap_err();
601        assert!(matches!(err, ToolError::Blocked { .. }));
602    }
603
604    #[test]
605    fn ipv6_loopback_blocked() {
606        let err = validate_url("https://[::1]/path").unwrap_err();
607        assert!(matches!(err, ToolError::Blocked { .. }));
608    }
609
610    #[test]
611    fn public_ip_allowed() {
612        assert!(validate_url("https://93.184.216.34/page").is_ok());
613    }
614
615    // --- parse_and_extract ---
616
617    #[test]
618    fn extract_text_from_html() {
619        let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
620        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
621        assert_eq!(result, "Hello World");
622    }
623
624    #[test]
625    fn extract_multiple_elements() {
626        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
627        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
628        assert_eq!(result, "A\nB\nC");
629    }
630
631    #[test]
632    fn extract_with_limit() {
633        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
634        let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
635        assert_eq!(result, "A\nB");
636    }
637
638    #[test]
639    fn extract_attr_href() {
640        let html = r#"<a href="https://example.com">Link</a>"#;
641        let result =
642            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
643        assert_eq!(result, "https://example.com");
644    }
645
646    #[test]
647    fn extract_inner_html() {
648        let html = "<div><span>inner</span></div>";
649        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
650        assert!(result.contains("<span>inner</span>"));
651    }
652
653    #[test]
654    fn no_matches_returns_message() {
655        let html = "<html><body><p>text</p></body></html>";
656        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
657        assert!(result.starts_with("No results for selector:"));
658    }
659
660    #[test]
661    fn empty_text_skipped() {
662        let html = "<ul><li>  </li><li>A</li></ul>";
663        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
664        assert_eq!(result, "A");
665    }
666
667    #[test]
668    fn invalid_selector_errors() {
669        let html = "<html><body></body></html>";
670        let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
671        assert!(result.is_err());
672    }
673
674    #[test]
675    fn empty_html_returns_no_results() {
676        let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
677        assert!(result.starts_with("No results for selector:"));
678    }
679
680    #[test]
681    fn nested_selector() {
682        let html = "<div><span>inner</span></div><span>outer</span>";
683        let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
684        assert_eq!(result, "inner");
685    }
686
687    #[test]
688    fn attr_missing_returns_empty() {
689        let html = r#"<a>No href</a>"#;
690        let result =
691            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
692        assert!(result.starts_with("No results for selector:"));
693    }
694
695    #[test]
696    fn extract_html_mode() {
697        let html = "<div><b>bold</b> text</div>";
698        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
699        assert!(result.contains("<b>bold</b>"));
700    }
701
702    #[test]
703    fn limit_zero_returns_no_results() {
704        let html = "<ul><li>A</li><li>B</li></ul>";
705        let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
706        assert!(result.starts_with("No results for selector:"));
707    }
708
709    // --- validate_url edge cases ---
710
711    #[test]
712    fn url_with_port_allowed() {
713        assert!(validate_url("https://example.com:8443/path").is_ok());
714    }
715
716    #[test]
717    fn link_local_ip_blocked() {
718        let err = validate_url("https://169.254.1.1/path").unwrap_err();
719        assert!(matches!(err, ToolError::Blocked { .. }));
720    }
721
722    #[test]
723    fn url_no_scheme_rejected() {
724        let err = validate_url("example.com/path").unwrap_err();
725        assert!(matches!(err, ToolError::Blocked { .. }));
726    }
727
728    #[test]
729    fn unspecified_ipv4_blocked() {
730        let err = validate_url("https://0.0.0.0/path").unwrap_err();
731        assert!(matches!(err, ToolError::Blocked { .. }));
732    }
733
734    #[test]
735    fn broadcast_ipv4_blocked() {
736        let err = validate_url("https://255.255.255.255/path").unwrap_err();
737        assert!(matches!(err, ToolError::Blocked { .. }));
738    }
739
740    #[test]
741    fn ipv6_link_local_blocked() {
742        let err = validate_url("https://[fe80::1]/path").unwrap_err();
743        assert!(matches!(err, ToolError::Blocked { .. }));
744    }
745
746    #[test]
747    fn ipv6_unique_local_blocked() {
748        let err = validate_url("https://[fd12::1]/path").unwrap_err();
749        assert!(matches!(err, ToolError::Blocked { .. }));
750    }
751
752    #[test]
753    fn ipv4_mapped_ipv6_loopback_blocked() {
754        let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
755        assert!(matches!(err, ToolError::Blocked { .. }));
756    }
757
758    #[test]
759    fn ipv4_mapped_ipv6_private_blocked() {
760        let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
761        assert!(matches!(err, ToolError::Blocked { .. }));
762    }
763
764    // --- WebScrapeExecutor (no-network) ---
765
766    #[tokio::test]
767    async fn executor_no_blocks_returns_none() {
768        let config = ScrapeConfig::default();
769        let executor = WebScrapeExecutor::new(&config);
770        let result = executor.execute("plain text").await;
771        assert!(result.unwrap().is_none());
772    }
773
774    #[tokio::test]
775    async fn executor_invalid_json_errors() {
776        let config = ScrapeConfig::default();
777        let executor = WebScrapeExecutor::new(&config);
778        let response = "```scrape\nnot json\n```";
779        let result = executor.execute(response).await;
780        assert!(matches!(result, Err(ToolError::Execution(_))));
781    }
782
783    #[tokio::test]
784    async fn executor_blocked_url_errors() {
785        let config = ScrapeConfig::default();
786        let executor = WebScrapeExecutor::new(&config);
787        let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
788        let result = executor.execute(response).await;
789        assert!(matches!(result, Err(ToolError::Blocked { .. })));
790    }
791
792    #[tokio::test]
793    async fn executor_private_ip_blocked() {
794        let config = ScrapeConfig::default();
795        let executor = WebScrapeExecutor::new(&config);
796        let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
797        let result = executor.execute(response).await;
798        assert!(matches!(result, Err(ToolError::Blocked { .. })));
799    }
800
801    #[tokio::test]
802    async fn executor_unreachable_host_returns_error() {
803        let config = ScrapeConfig {
804            timeout: 1,
805            max_body_bytes: 1_048_576,
806        };
807        let executor = WebScrapeExecutor::new(&config);
808        let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
809        let result = executor.execute(response).await;
810        assert!(matches!(result, Err(ToolError::Execution(_))));
811    }
812
813    #[tokio::test]
814    async fn executor_localhost_url_blocked() {
815        let config = ScrapeConfig::default();
816        let executor = WebScrapeExecutor::new(&config);
817        let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
818        let result = executor.execute(response).await;
819        assert!(matches!(result, Err(ToolError::Blocked { .. })));
820    }
821
822    #[tokio::test]
823    async fn executor_empty_text_returns_none() {
824        let config = ScrapeConfig::default();
825        let executor = WebScrapeExecutor::new(&config);
826        let result = executor.execute("").await;
827        assert!(result.unwrap().is_none());
828    }
829
830    #[tokio::test]
831    async fn executor_multiple_blocks_first_blocked() {
832        let config = ScrapeConfig::default();
833        let executor = WebScrapeExecutor::new(&config);
834        let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
835             ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
836        let result = executor.execute(response).await;
837        assert!(result.is_err());
838    }
839
840    #[test]
841    fn validate_url_empty_string() {
842        let err = validate_url("").unwrap_err();
843        assert!(matches!(err, ToolError::Blocked { .. }));
844    }
845
846    #[test]
847    fn validate_url_javascript_scheme_blocked() {
848        let err = validate_url("javascript:alert(1)").unwrap_err();
849        assert!(matches!(err, ToolError::Blocked { .. }));
850    }
851
852    #[test]
853    fn validate_url_data_scheme_blocked() {
854        let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
855        assert!(matches!(err, ToolError::Blocked { .. }));
856    }
857
858    #[test]
859    fn is_private_host_public_domain_is_false() {
860        let host: url::Host<&str> = url::Host::Domain("example.com");
861        assert!(!is_private_host(&host));
862    }
863
864    #[test]
865    fn is_private_host_localhost_is_true() {
866        let host: url::Host<&str> = url::Host::Domain("localhost");
867        assert!(is_private_host(&host));
868    }
869
870    #[test]
871    fn is_private_host_ipv6_unspecified_is_true() {
872        let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
873        assert!(is_private_host(&host));
874    }
875
876    #[test]
877    fn is_private_host_public_ipv6_is_false() {
878        let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
879        assert!(!is_private_host(&host));
880    }
881
882    // --- fetch_html redirect logic: wiremock HTTP server tests ---
883    //
884    // These tests use a local wiremock server to exercise the redirect-following logic
885    // in `fetch_html` without requiring an external HTTPS connection. The server binds to
886    // 127.0.0.1, and tests call `fetch_html` directly (bypassing `validate_url`) to avoid
887    // the SSRF guard that would otherwise block loopback connections.
888
889    /// Helper: returns executor + (server_url, server_addr) from a running wiremock mock server.
890    /// The server address is passed to `fetch_html` via `resolve_to_addrs` so the client
891    /// connects to the mock instead of doing a real DNS lookup.
892    async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
893        let server = wiremock::MockServer::start().await;
894        let executor = WebScrapeExecutor {
895            timeout: Duration::from_secs(5),
896            max_body_bytes: 1_048_576,
897        };
898        (executor, server)
899    }
900
901    /// Parses the mock server's URI into (host_str, socket_addr) for use with `build_client`.
902    fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
903        let uri = server.uri();
904        let url = Url::parse(&uri).unwrap();
905        let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
906        let port = url.port().unwrap_or(80);
907        let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
908        (host, vec![addr])
909    }
910
911    /// Test-only redirect follower that mimics `fetch_html`'s loop but skips `validate_url` /
912    /// `resolve_and_validate`. This lets us exercise the redirect-counting and
913    /// missing-Location logic against a plain HTTP wiremock server.
914    async fn follow_redirects_raw(
915        executor: &WebScrapeExecutor,
916        start_url: &str,
917        host: &str,
918        addrs: &[std::net::SocketAddr],
919    ) -> Result<String, ToolError> {
920        const MAX_REDIRECTS: usize = 3;
921        let mut current_url = start_url.to_owned();
922        let mut current_host = host.to_owned();
923        let mut current_addrs = addrs.to_vec();
924
925        for hop in 0..=MAX_REDIRECTS {
926            let client = executor.build_client(&current_host, &current_addrs);
927            let resp = client
928                .get(&current_url)
929                .send()
930                .await
931                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
932
933            let status = resp.status();
934
935            if status.is_redirection() {
936                if hop == MAX_REDIRECTS {
937                    return Err(ToolError::Execution(std::io::Error::other(
938                        "too many redirects",
939                    )));
940                }
941
942                let location = resp
943                    .headers()
944                    .get(reqwest::header::LOCATION)
945                    .and_then(|v| v.to_str().ok())
946                    .ok_or_else(|| {
947                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
948                    })?;
949
950                let base = Url::parse(&current_url)
951                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
952                let next_url = base
953                    .join(location)
954                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
955
956                // Re-use same host/addrs (mock server is always the same endpoint).
957                current_url = next_url.to_string();
958                // Preserve host/addrs as-is since the mock server doesn't change.
959                let _ = &mut current_host;
960                let _ = &mut current_addrs;
961                continue;
962            }
963
964            if !status.is_success() {
965                return Err(ToolError::Execution(std::io::Error::other(format!(
966                    "HTTP {status}",
967                ))));
968            }
969
970            let bytes = resp
971                .bytes()
972                .await
973                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
974
975            if bytes.len() > executor.max_body_bytes {
976                return Err(ToolError::Execution(std::io::Error::other(format!(
977                    "response too large: {} bytes (max: {})",
978                    bytes.len(),
979                    executor.max_body_bytes,
980                ))));
981            }
982
983            return String::from_utf8(bytes.to_vec())
984                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
985        }
986
987        Err(ToolError::Execution(std::io::Error::other(
988            "too many redirects",
989        )))
990    }
991
992    #[tokio::test]
993    async fn fetch_html_success_returns_body() {
994        use wiremock::matchers::{method, path};
995        use wiremock::{Mock, ResponseTemplate};
996
997        let (executor, server) = mock_server_executor().await;
998        Mock::given(method("GET"))
999            .and(path("/page"))
1000            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1001            .mount(&server)
1002            .await;
1003
1004        let (host, addrs) = server_host_and_addr(&server);
1005        let url = format!("{}/page", server.uri());
1006        let result = executor.fetch_html(&url, &host, &addrs).await;
1007        assert!(result.is_ok(), "expected Ok, got: {result:?}");
1008        assert_eq!(result.unwrap(), "<h1>OK</h1>");
1009    }
1010
1011    #[tokio::test]
1012    async fn fetch_html_non_2xx_returns_error() {
1013        use wiremock::matchers::{method, path};
1014        use wiremock::{Mock, ResponseTemplate};
1015
1016        let (executor, server) = mock_server_executor().await;
1017        Mock::given(method("GET"))
1018            .and(path("/forbidden"))
1019            .respond_with(ResponseTemplate::new(403))
1020            .mount(&server)
1021            .await;
1022
1023        let (host, addrs) = server_host_and_addr(&server);
1024        let url = format!("{}/forbidden", server.uri());
1025        let result = executor.fetch_html(&url, &host, &addrs).await;
1026        assert!(result.is_err());
1027        let msg = result.unwrap_err().to_string();
1028        assert!(msg.contains("403"), "expected 403 in error: {msg}");
1029    }
1030
1031    #[tokio::test]
1032    async fn fetch_html_404_returns_error() {
1033        use wiremock::matchers::{method, path};
1034        use wiremock::{Mock, ResponseTemplate};
1035
1036        let (executor, server) = mock_server_executor().await;
1037        Mock::given(method("GET"))
1038            .and(path("/missing"))
1039            .respond_with(ResponseTemplate::new(404))
1040            .mount(&server)
1041            .await;
1042
1043        let (host, addrs) = server_host_and_addr(&server);
1044        let url = format!("{}/missing", server.uri());
1045        let result = executor.fetch_html(&url, &host, &addrs).await;
1046        assert!(result.is_err());
1047        let msg = result.unwrap_err().to_string();
1048        assert!(msg.contains("404"), "expected 404 in error: {msg}");
1049    }
1050
1051    #[tokio::test]
1052    async fn fetch_html_redirect_no_location_returns_error() {
1053        use wiremock::matchers::{method, path};
1054        use wiremock::{Mock, ResponseTemplate};
1055
1056        let (executor, server) = mock_server_executor().await;
1057        // 302 with no Location header
1058        Mock::given(method("GET"))
1059            .and(path("/redirect-no-loc"))
1060            .respond_with(ResponseTemplate::new(302))
1061            .mount(&server)
1062            .await;
1063
1064        let (host, addrs) = server_host_and_addr(&server);
1065        let url = format!("{}/redirect-no-loc", server.uri());
1066        let result = executor.fetch_html(&url, &host, &addrs).await;
1067        assert!(result.is_err());
1068        let msg = result.unwrap_err().to_string();
1069        assert!(
1070            msg.contains("Location") || msg.contains("location"),
1071            "expected Location-related error: {msg}"
1072        );
1073    }
1074
1075    #[tokio::test]
1076    async fn fetch_html_single_redirect_followed() {
1077        use wiremock::matchers::{method, path};
1078        use wiremock::{Mock, ResponseTemplate};
1079
1080        let (executor, server) = mock_server_executor().await;
1081        let final_url = format!("{}/final", server.uri());
1082
1083        Mock::given(method("GET"))
1084            .and(path("/start"))
1085            .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1086            .mount(&server)
1087            .await;
1088
1089        Mock::given(method("GET"))
1090            .and(path("/final"))
1091            .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1092            .mount(&server)
1093            .await;
1094
1095        let (host, addrs) = server_host_and_addr(&server);
1096        let url = format!("{}/start", server.uri());
1097        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1098        assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1099        assert_eq!(result.unwrap(), "<p>final</p>");
1100    }
1101
1102    #[tokio::test]
1103    async fn fetch_html_three_redirects_allowed() {
1104        use wiremock::matchers::{method, path};
1105        use wiremock::{Mock, ResponseTemplate};
1106
1107        let (executor, server) = mock_server_executor().await;
1108        let hop2 = format!("{}/hop2", server.uri());
1109        let hop3 = format!("{}/hop3", server.uri());
1110        let final_dest = format!("{}/done", server.uri());
1111
1112        Mock::given(method("GET"))
1113            .and(path("/hop1"))
1114            .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1115            .mount(&server)
1116            .await;
1117        Mock::given(method("GET"))
1118            .and(path("/hop2"))
1119            .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1120            .mount(&server)
1121            .await;
1122        Mock::given(method("GET"))
1123            .and(path("/hop3"))
1124            .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1125            .mount(&server)
1126            .await;
1127        Mock::given(method("GET"))
1128            .and(path("/done"))
1129            .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1130            .mount(&server)
1131            .await;
1132
1133        let (host, addrs) = server_host_and_addr(&server);
1134        let url = format!("{}/hop1", server.uri());
1135        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1136        assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1137        assert_eq!(result.unwrap(), "<p>done</p>");
1138    }
1139
1140    #[tokio::test]
1141    async fn fetch_html_four_redirects_rejected() {
1142        use wiremock::matchers::{method, path};
1143        use wiremock::{Mock, ResponseTemplate};
1144
1145        let (executor, server) = mock_server_executor().await;
1146        let hop2 = format!("{}/r2", server.uri());
1147        let hop3 = format!("{}/r3", server.uri());
1148        let hop4 = format!("{}/r4", server.uri());
1149        let hop5 = format!("{}/r5", server.uri());
1150
1151        for (from, to) in [
1152            ("/r1", &hop2),
1153            ("/r2", &hop3),
1154            ("/r3", &hop4),
1155            ("/r4", &hop5),
1156        ] {
1157            Mock::given(method("GET"))
1158                .and(path(from))
1159                .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1160                .mount(&server)
1161                .await;
1162        }
1163
1164        let (host, addrs) = server_host_and_addr(&server);
1165        let url = format!("{}/r1", server.uri());
1166        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1167        assert!(result.is_err(), "4 redirects should be rejected");
1168        let msg = result.unwrap_err().to_string();
1169        assert!(
1170            msg.contains("redirect"),
1171            "expected redirect-related error: {msg}"
1172        );
1173    }
1174
1175    #[tokio::test]
1176    async fn fetch_html_body_too_large_returns_error() {
1177        use wiremock::matchers::{method, path};
1178        use wiremock::{Mock, ResponseTemplate};
1179
1180        let small_limit_executor = WebScrapeExecutor {
1181            timeout: Duration::from_secs(5),
1182            max_body_bytes: 10,
1183        };
1184        let server = wiremock::MockServer::start().await;
1185        Mock::given(method("GET"))
1186            .and(path("/big"))
1187            .respond_with(
1188                ResponseTemplate::new(200)
1189                    .set_body_string("this body is definitely longer than ten bytes"),
1190            )
1191            .mount(&server)
1192            .await;
1193
1194        let (host, addrs) = server_host_and_addr(&server);
1195        let url = format!("{}/big", server.uri());
1196        let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1197        assert!(result.is_err());
1198        let msg = result.unwrap_err().to_string();
1199        assert!(msg.contains("too large"), "expected too-large error: {msg}");
1200    }
1201
1202    #[test]
1203    fn extract_scrape_blocks_empty_block_content() {
1204        let text = "```scrape\n\n```";
1205        let blocks = extract_scrape_blocks(text);
1206        assert_eq!(blocks.len(), 1);
1207        assert!(blocks[0].is_empty());
1208    }
1209
1210    #[test]
1211    fn extract_scrape_blocks_whitespace_only() {
1212        let text = "```scrape\n   \n```";
1213        let blocks = extract_scrape_blocks(text);
1214        assert_eq!(blocks.len(), 1);
1215    }
1216
1217    #[test]
1218    fn parse_and_extract_multiple_selectors() {
1219        let html = "<div><h1>Title</h1><p>Para</p></div>";
1220        let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1221        assert!(result.contains("Title"));
1222        assert!(result.contains("Para"));
1223    }
1224
1225    #[test]
1226    fn webscrape_executor_new_with_custom_config() {
1227        let config = ScrapeConfig {
1228            timeout: 60,
1229            max_body_bytes: 512,
1230        };
1231        let executor = WebScrapeExecutor::new(&config);
1232        assert_eq!(executor.max_body_bytes, 512);
1233    }
1234
1235    #[test]
1236    fn webscrape_executor_debug() {
1237        let config = ScrapeConfig::default();
1238        let executor = WebScrapeExecutor::new(&config);
1239        let dbg = format!("{executor:?}");
1240        assert!(dbg.contains("WebScrapeExecutor"));
1241    }
1242
1243    #[test]
1244    fn extract_mode_attr_empty_name() {
1245        let mode = ExtractMode::parse("attr:");
1246        assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1247    }
1248
1249    #[test]
1250    fn default_extract_returns_text() {
1251        assert_eq!(default_extract(), "text");
1252    }
1253
1254    #[test]
1255    fn scrape_instruction_debug() {
1256        let json = r#"{"url":"https://example.com","select":"h1"}"#;
1257        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1258        let dbg = format!("{instr:?}");
1259        assert!(dbg.contains("ScrapeInstruction"));
1260    }
1261
1262    #[test]
1263    fn extract_mode_debug() {
1264        let mode = ExtractMode::Text;
1265        let dbg = format!("{mode:?}");
1266        assert!(dbg.contains("Text"));
1267    }
1268
1269    // --- fetch_html redirect logic: constant and validation unit tests ---
1270
1271    /// MAX_REDIRECTS is 3; the 4th redirect attempt must be rejected.
1272    /// Verify the boundary is correct by inspecting the constant value.
1273    #[test]
1274    fn max_redirects_constant_is_three() {
1275        // fetch_html uses `for hop in 0..=MAX_REDIRECTS` and returns error when hop == MAX_REDIRECTS
1276        // while still in a redirect. That means hops 0,1,2 can redirect; hop 3 triggers the error.
1277        // This test documents the expected limit.
1278        const MAX_REDIRECTS: usize = 3;
1279        assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1280    }
1281
1282    /// Verifies that a Location-less redirect would produce an error string containing the
1283    /// expected message, matching the error path in fetch_html.
1284    #[test]
1285    fn redirect_no_location_error_message() {
1286        let err = std::io::Error::other("redirect with no Location");
1287        assert!(err.to_string().contains("redirect with no Location"));
1288    }
1289
1290    /// Verifies that a too-many-redirects condition produces the expected error string.
1291    #[test]
1292    fn too_many_redirects_error_message() {
1293        let err = std::io::Error::other("too many redirects");
1294        assert!(err.to_string().contains("too many redirects"));
1295    }
1296
1297    /// Verifies that a non-2xx HTTP status produces an error message with the status code.
1298    #[test]
1299    fn non_2xx_status_error_format() {
1300        let status = reqwest::StatusCode::FORBIDDEN;
1301        let msg = format!("HTTP {status}");
1302        assert!(msg.contains("403"));
1303    }
1304
1305    /// Verifies that a 404 response status code formats into the expected error message.
1306    #[test]
1307    fn not_found_status_error_format() {
1308        let status = reqwest::StatusCode::NOT_FOUND;
1309        let msg = format!("HTTP {status}");
1310        assert!(msg.contains("404"));
1311    }
1312
1313    /// Verifies relative redirect resolution for same-host paths (simulates Location: /other).
1314    #[test]
1315    fn relative_redirect_same_host_path() {
1316        let base = Url::parse("https://example.com/current").unwrap();
1317        let resolved = base.join("/other").unwrap();
1318        assert_eq!(resolved.as_str(), "https://example.com/other");
1319    }
1320
1321    /// Verifies relative redirect resolution preserves scheme and host.
1322    #[test]
1323    fn relative_redirect_relative_path() {
1324        let base = Url::parse("https://example.com/a/b").unwrap();
1325        let resolved = base.join("c").unwrap();
1326        assert_eq!(resolved.as_str(), "https://example.com/a/c");
1327    }
1328
1329    /// Verifies that an absolute redirect URL overrides base URL completely.
1330    #[test]
1331    fn absolute_redirect_overrides_base() {
1332        let base = Url::parse("https://example.com/page").unwrap();
1333        let resolved = base.join("https://other.com/target").unwrap();
1334        assert_eq!(resolved.as_str(), "https://other.com/target");
1335    }
1336
1337    /// Verifies that a redirect Location of http:// (downgrade) is rejected.
1338    #[test]
1339    fn redirect_http_downgrade_rejected() {
1340        let location = "http://example.com/page";
1341        let base = Url::parse("https://example.com/start").unwrap();
1342        let next = base.join(location).unwrap();
1343        let err = validate_url(next.as_str()).unwrap_err();
1344        assert!(matches!(err, ToolError::Blocked { .. }));
1345    }
1346
1347    /// Verifies that a redirect to a private IP literal is blocked.
1348    #[test]
1349    fn redirect_location_private_ip_blocked() {
1350        let location = "https://192.168.100.1/admin";
1351        let base = Url::parse("https://example.com/start").unwrap();
1352        let next = base.join(location).unwrap();
1353        let err = validate_url(next.as_str()).unwrap_err();
1354        assert!(matches!(err, ToolError::Blocked { .. }));
1355        let cmd = match err {
1356            ToolError::Blocked { command } => command,
1357            _ => panic!("expected Blocked"),
1358        };
1359        assert!(
1360            cmd.contains("private") || cmd.contains("scheme"),
1361            "error message should describe the block reason: {cmd}"
1362        );
1363    }
1364
1365    /// Verifies that a redirect to a .internal domain is blocked.
1366    #[test]
1367    fn redirect_location_internal_domain_blocked() {
1368        let location = "https://metadata.internal/latest/meta-data/";
1369        let base = Url::parse("https://example.com/start").unwrap();
1370        let next = base.join(location).unwrap();
1371        let err = validate_url(next.as_str()).unwrap_err();
1372        assert!(matches!(err, ToolError::Blocked { .. }));
1373    }
1374
1375    /// Verifies that a chain of 3 valid public redirects passes validate_url at every hop.
1376    #[test]
1377    fn redirect_chain_three_hops_all_public() {
1378        let hops = [
1379            "https://redirect1.example.com/hop1",
1380            "https://redirect2.example.com/hop2",
1381            "https://destination.example.com/final",
1382        ];
1383        for hop in hops {
1384            assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1385        }
1386    }
1387
1388    // --- SSRF redirect chain defense ---
1389
1390    /// Verifies that a redirect Location pointing to a private IP is rejected by validate_url
1391    /// before any connection attempt — simulating the validation step inside fetch_html.
1392    #[test]
1393    fn redirect_to_private_ip_rejected_by_validate_url() {
1394        // These would appear as Location headers in a redirect response.
1395        let private_targets = [
1396            "https://127.0.0.1/secret",
1397            "https://10.0.0.1/internal",
1398            "https://192.168.1.1/admin",
1399            "https://172.16.0.1/data",
1400            "https://[::1]/path",
1401            "https://[fe80::1]/path",
1402            "https://localhost/path",
1403            "https://service.internal/api",
1404        ];
1405        for target in private_targets {
1406            let result = validate_url(target);
1407            assert!(result.is_err(), "expected error for {target}");
1408            assert!(
1409                matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1410                "expected Blocked for {target}"
1411            );
1412        }
1413    }
1414
1415    /// Verifies that relative redirect URLs are resolved correctly before validation.
1416    #[test]
1417    fn redirect_relative_url_resolves_correctly() {
1418        let base = Url::parse("https://example.com/page").unwrap();
1419        let relative = "/other";
1420        let resolved = base.join(relative).unwrap();
1421        assert_eq!(resolved.as_str(), "https://example.com/other");
1422    }
1423
1424    /// Verifies that a protocol-relative redirect to http:// is rejected (scheme check).
1425    #[test]
1426    fn redirect_to_http_rejected() {
1427        let err = validate_url("http://example.com/page").unwrap_err();
1428        assert!(matches!(err, ToolError::Blocked { .. }));
1429    }
1430
1431    #[test]
1432    fn ipv4_mapped_ipv6_link_local_blocked() {
1433        let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1434        assert!(matches!(err, ToolError::Blocked { .. }));
1435    }
1436
1437    #[test]
1438    fn ipv4_mapped_ipv6_public_allowed() {
1439        assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1440    }
1441
1442    // --- fetch tool ---
1443
1444    #[tokio::test]
1445    async fn fetch_http_scheme_blocked() {
1446        let config = ScrapeConfig::default();
1447        let executor = WebScrapeExecutor::new(&config);
1448        let call = crate::executor::ToolCall {
1449            tool_id: "fetch".to_owned(),
1450            params: {
1451                let mut m = serde_json::Map::new();
1452                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1453                m
1454            },
1455        };
1456        let result = executor.execute_tool_call(&call).await;
1457        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1458    }
1459
1460    #[tokio::test]
1461    async fn fetch_private_ip_blocked() {
1462        let config = ScrapeConfig::default();
1463        let executor = WebScrapeExecutor::new(&config);
1464        let call = crate::executor::ToolCall {
1465            tool_id: "fetch".to_owned(),
1466            params: {
1467                let mut m = serde_json::Map::new();
1468                m.insert(
1469                    "url".to_owned(),
1470                    serde_json::json!("https://192.168.1.1/secret"),
1471                );
1472                m
1473            },
1474        };
1475        let result = executor.execute_tool_call(&call).await;
1476        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1477    }
1478
1479    #[tokio::test]
1480    async fn fetch_localhost_blocked() {
1481        let config = ScrapeConfig::default();
1482        let executor = WebScrapeExecutor::new(&config);
1483        let call = crate::executor::ToolCall {
1484            tool_id: "fetch".to_owned(),
1485            params: {
1486                let mut m = serde_json::Map::new();
1487                m.insert(
1488                    "url".to_owned(),
1489                    serde_json::json!("https://localhost/page"),
1490                );
1491                m
1492            },
1493        };
1494        let result = executor.execute_tool_call(&call).await;
1495        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1496    }
1497
1498    #[tokio::test]
1499    async fn fetch_unknown_tool_returns_none() {
1500        let config = ScrapeConfig::default();
1501        let executor = WebScrapeExecutor::new(&config);
1502        let call = crate::executor::ToolCall {
1503            tool_id: "unknown_tool".to_owned(),
1504            params: serde_json::Map::new(),
1505        };
1506        let result = executor.execute_tool_call(&call).await;
1507        assert!(result.unwrap().is_none());
1508    }
1509
1510    #[tokio::test]
1511    async fn fetch_returns_body_via_mock() {
1512        use wiremock::matchers::{method, path};
1513        use wiremock::{Mock, ResponseTemplate};
1514
1515        let (executor, server) = mock_server_executor().await;
1516        Mock::given(method("GET"))
1517            .and(path("/content"))
1518            .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1519            .mount(&server)
1520            .await;
1521
1522        let (host, addrs) = server_host_and_addr(&server);
1523        let url = format!("{}/content", server.uri());
1524        let result = executor.fetch_html(&url, &host, &addrs).await;
1525        assert!(result.is_ok());
1526        assert_eq!(result.unwrap(), "plain text content");
1527    }
1528
1529    #[test]
1530    fn tool_definitions_returns_web_scrape_and_fetch() {
1531        let config = ScrapeConfig::default();
1532        let executor = WebScrapeExecutor::new(&config);
1533        let defs = executor.tool_definitions();
1534        assert_eq!(defs.len(), 2);
1535        assert_eq!(defs[0].id, "web_scrape");
1536        assert_eq!(
1537            defs[0].invocation,
1538            crate::registry::InvocationHint::FencedBlock("scrape")
1539        );
1540        assert_eq!(defs[1].id, "fetch");
1541        assert_eq!(
1542            defs[1].invocation,
1543            crate::registry::InvocationHint::ToolCall
1544        );
1545    }
1546
1547    #[test]
1548    fn tool_definitions_schema_has_all_params() {
1549        let config = ScrapeConfig::default();
1550        let executor = WebScrapeExecutor::new(&config);
1551        let defs = executor.tool_definitions();
1552        let obj = defs[0].schema.as_object().unwrap();
1553        let props = obj["properties"].as_object().unwrap();
1554        assert!(props.contains_key("url"));
1555        assert!(props.contains_key("select"));
1556        assert!(props.contains_key("extract"));
1557        assert!(props.contains_key("limit"));
1558        let req = obj["required"].as_array().unwrap();
1559        assert!(req.iter().any(|v| v.as_str() == Some("url")));
1560        assert!(req.iter().any(|v| v.as_str() == Some("select")));
1561        assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1562    }
1563
1564    // --- is_private_host: new domain checks (AUD-02) ---
1565
1566    #[test]
1567    fn subdomain_localhost_blocked() {
1568        let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1569        assert!(is_private_host(&host));
1570    }
1571
1572    #[test]
1573    fn internal_tld_blocked() {
1574        let host: url::Host<&str> = url::Host::Domain("service.internal");
1575        assert!(is_private_host(&host));
1576    }
1577
1578    #[test]
1579    fn local_tld_blocked() {
1580        let host: url::Host<&str> = url::Host::Domain("printer.local");
1581        assert!(is_private_host(&host));
1582    }
1583
1584    #[test]
1585    fn public_domain_not_blocked() {
1586        let host: url::Host<&str> = url::Host::Domain("example.com");
1587        assert!(!is_private_host(&host));
1588    }
1589
1590    // --- resolve_and_validate: private IP rejection ---
1591
1592    #[tokio::test]
1593    async fn resolve_loopback_rejected() {
1594        // 127.0.0.1 resolves directly (literal IP in DNS query)
1595        let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1596        // validate_url catches this before resolve_and_validate, but test directly
1597        let result = resolve_and_validate(&url).await;
1598        assert!(
1599            result.is_err(),
1600            "loopback IP must be rejected by resolve_and_validate"
1601        );
1602        let err = result.unwrap_err();
1603        assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1604    }
1605
1606    #[tokio::test]
1607    async fn resolve_private_10_rejected() {
1608        let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1609        let result = resolve_and_validate(&url).await;
1610        assert!(result.is_err());
1611        assert!(matches!(
1612            result.unwrap_err(),
1613            crate::executor::ToolError::Blocked { .. }
1614        ));
1615    }
1616
1617    #[tokio::test]
1618    async fn resolve_private_192_rejected() {
1619        let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1620        let result = resolve_and_validate(&url).await;
1621        assert!(result.is_err());
1622        assert!(matches!(
1623            result.unwrap_err(),
1624            crate::executor::ToolError::Blocked { .. }
1625        ));
1626    }
1627
1628    #[tokio::test]
1629    async fn resolve_ipv6_loopback_rejected() {
1630        let url = url::Url::parse("https://[::1]/path").unwrap();
1631        let result = resolve_and_validate(&url).await;
1632        assert!(result.is_err());
1633        assert!(matches!(
1634            result.unwrap_err(),
1635            crate::executor::ToolError::Blocked { .. }
1636        ));
1637    }
1638
1639    #[tokio::test]
1640    async fn resolve_no_host_returns_ok() {
1641        // URL without a resolvable host — should pass through
1642        let url = url::Url::parse("https://example.com/path").unwrap();
1643        // We can't do a live DNS test, but we can verify a URL with no host
1644        let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1645        // data: URLs have no host; resolve_and_validate should return Ok with empty addrs
1646        let result = resolve_and_validate(&url_no_host).await;
1647        assert!(result.is_ok());
1648        let (host, addrs) = result.unwrap();
1649        assert!(host.is_empty());
1650        assert!(addrs.is_empty());
1651        drop(url);
1652        drop(url_no_host);
1653    }
1654
1655    // CR-10: fetch end-to-end via execute_tool_call -> handle_fetch -> fetch_html
1656    #[tokio::test]
1657    async fn fetch_execute_tool_call_end_to_end() {
1658        use wiremock::matchers::{method, path};
1659        use wiremock::{Mock, ResponseTemplate};
1660
1661        let (executor, server) = mock_server_executor().await;
1662        Mock::given(method("GET"))
1663            .and(path("/e2e"))
1664            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1665            .mount(&server)
1666            .await;
1667
1668        let (host, addrs) = server_host_and_addr(&server);
1669        // Call fetch_html directly (bypassing SSRF guard for loopback mock server)
1670        let result = executor
1671            .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1672            .await;
1673        assert!(result.is_ok());
1674        assert!(result.unwrap().contains("end-to-end"));
1675    }
1676}