Skip to main content

zeph_tools/
scrape.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
15
16#[derive(Debug, Deserialize, JsonSchema)]
17struct FetchParams {
18    /// HTTPS URL to fetch
19    url: String,
20}
21
22#[derive(Debug, Deserialize, JsonSchema)]
23struct ScrapeInstruction {
24    /// HTTPS URL to scrape
25    url: String,
26    /// CSS selector
27    select: String,
28    /// Extract mode: text, html, or attr:<name>
29    #[serde(default = "default_extract")]
30    extract: String,
31    /// Max results to return
32    limit: Option<usize>,
33}
34
35fn default_extract() -> String {
36    "text".into()
37}
38
39#[derive(Debug)]
40enum ExtractMode {
41    Text,
42    Html,
43    Attr(String),
44}
45
46impl ExtractMode {
47    fn parse(s: &str) -> Self {
48        match s {
49            "text" => Self::Text,
50            "html" => Self::Html,
51            attr if attr.starts_with("attr:") => {
52                Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
53            }
54            _ => Self::Text,
55        }
56    }
57}
58
59/// Extracts data from web pages via CSS selectors.
60///
61/// Detects ` ```scrape ` blocks in LLM responses containing JSON instructions,
62/// fetches the URL, and parses HTML with `scrape-core`.
63#[derive(Debug)]
64pub struct WebScrapeExecutor {
65    timeout: Duration,
66    max_body_bytes: usize,
67    audit_logger: Option<Arc<AuditLogger>>,
68}
69
70impl WebScrapeExecutor {
71    #[must_use]
72    pub fn new(config: &ScrapeConfig) -> Self {
73        Self {
74            timeout: Duration::from_secs(config.timeout),
75            max_body_bytes: config.max_body_bytes,
76            audit_logger: None,
77        }
78    }
79
80    #[must_use]
81    pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
82        self.audit_logger = Some(logger);
83        self
84    }
85
86    fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
87        let mut builder = reqwest::Client::builder()
88            .timeout(self.timeout)
89            .redirect(reqwest::redirect::Policy::none());
90        builder = builder.resolve_to_addrs(host, addrs);
91        builder.build().unwrap_or_default()
92    }
93}
94
95impl ToolExecutor for WebScrapeExecutor {
96    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
97        use crate::registry::{InvocationHint, ToolDef};
98        vec![
99            ToolDef {
100                id: "web_scrape".into(),
101                description: "Extract structured data from a web page using CSS selectors.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures\nExample: {\"url\": \"https://example.com\", \"select\": \"h1\", \"extract\": \"text\"}".into(),
102                schema: schemars::schema_for!(ScrapeInstruction),
103                invocation: InvocationHint::FencedBlock("scrape"),
104            },
105            ToolDef {
106                id: "fetch".into(),
107                description: "Fetch a URL and return the response body as plain text.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures\nExample: {\"url\": \"https://api.example.com/data.json\"}".into(),
108                schema: schemars::schema_for!(FetchParams),
109                invocation: InvocationHint::ToolCall,
110            },
111        ]
112    }
113
114    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
115        let blocks = extract_scrape_blocks(response);
116        if blocks.is_empty() {
117            return Ok(None);
118        }
119
120        let mut outputs = Vec::with_capacity(blocks.len());
121        #[allow(clippy::cast_possible_truncation)]
122        let blocks_executed = blocks.len() as u32;
123
124        for block in &blocks {
125            let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
126                ToolError::Execution(std::io::Error::new(
127                    std::io::ErrorKind::InvalidData,
128                    e.to_string(),
129                ))
130            })?;
131            let start = Instant::now();
132            let scrape_result = self.scrape_instruction(&instruction).await;
133            #[allow(clippy::cast_possible_truncation)]
134            let duration_ms = start.elapsed().as_millis() as u64;
135            match scrape_result {
136                Ok(output) => {
137                    self.log_audit(
138                        "web_scrape",
139                        &instruction.url,
140                        AuditResult::Success,
141                        duration_ms,
142                    )
143                    .await;
144                    outputs.push(output);
145                }
146                Err(e) => {
147                    let audit_result = tool_error_to_audit_result(&e);
148                    self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
149                        .await;
150                    return Err(e);
151                }
152            }
153        }
154
155        Ok(Some(ToolOutput {
156            tool_name: "web-scrape".to_owned(),
157            summary: outputs.join("\n\n"),
158            blocks_executed,
159            filter_stats: None,
160            diff: None,
161            streamed: false,
162            terminal_id: None,
163            locations: None,
164            raw_response: None,
165        }))
166    }
167
168    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
169        match call.tool_id.as_str() {
170            "web_scrape" => {
171                let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
172                let start = Instant::now();
173                let result = self.scrape_instruction(&instruction).await;
174                #[allow(clippy::cast_possible_truncation)]
175                let duration_ms = start.elapsed().as_millis() as u64;
176                match result {
177                    Ok(output) => {
178                        self.log_audit(
179                            "web_scrape",
180                            &instruction.url,
181                            AuditResult::Success,
182                            duration_ms,
183                        )
184                        .await;
185                        Ok(Some(ToolOutput {
186                            tool_name: "web-scrape".to_owned(),
187                            summary: output,
188                            blocks_executed: 1,
189                            filter_stats: None,
190                            diff: None,
191                            streamed: false,
192                            terminal_id: None,
193                            locations: None,
194                            raw_response: None,
195                        }))
196                    }
197                    Err(e) => {
198                        let audit_result = tool_error_to_audit_result(&e);
199                        self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
200                            .await;
201                        Err(e)
202                    }
203                }
204            }
205            "fetch" => {
206                let p: FetchParams = deserialize_params(&call.params)?;
207                let start = Instant::now();
208                let result = self.handle_fetch(&p).await;
209                #[allow(clippy::cast_possible_truncation)]
210                let duration_ms = start.elapsed().as_millis() as u64;
211                match result {
212                    Ok(output) => {
213                        self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms)
214                            .await;
215                        Ok(Some(ToolOutput {
216                            tool_name: "fetch".to_owned(),
217                            summary: output,
218                            blocks_executed: 1,
219                            filter_stats: None,
220                            diff: None,
221                            streamed: false,
222                            terminal_id: None,
223                            locations: None,
224                            raw_response: None,
225                        }))
226                    }
227                    Err(e) => {
228                        let audit_result = tool_error_to_audit_result(&e);
229                        self.log_audit("fetch", &p.url, audit_result, duration_ms)
230                            .await;
231                        Err(e)
232                    }
233                }
234            }
235            _ => Ok(None),
236        }
237    }
238}
239
240fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
241    match e {
242        ToolError::Blocked { command } => AuditResult::Blocked {
243            reason: command.clone(),
244        },
245        ToolError::Timeout { .. } => AuditResult::Timeout,
246        _ => AuditResult::Error {
247            message: e.to_string(),
248        },
249    }
250}
251
252impl WebScrapeExecutor {
253    async fn log_audit(&self, tool: &str, command: &str, result: AuditResult, duration_ms: u64) {
254        if let Some(ref logger) = self.audit_logger {
255            let entry = AuditEntry {
256                timestamp: chrono_now(),
257                tool: tool.into(),
258                command: command.into(),
259                result,
260                duration_ms,
261            };
262            logger.log(&entry).await;
263        }
264    }
265
266    async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
267        let parsed = validate_url(&params.url)?;
268        let (host, addrs) = resolve_and_validate(&parsed).await?;
269        self.fetch_html(&params.url, &host, &addrs).await
270    }
271
272    async fn scrape_instruction(
273        &self,
274        instruction: &ScrapeInstruction,
275    ) -> Result<String, ToolError> {
276        let parsed = validate_url(&instruction.url)?;
277        let (host, addrs) = resolve_and_validate(&parsed).await?;
278        let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
279        let selector = instruction.select.clone();
280        let extract = ExtractMode::parse(&instruction.extract);
281        let limit = instruction.limit.unwrap_or(10);
282        tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
283            .await
284            .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
285    }
286
287    /// Fetches the HTML at `url`, manually following up to 3 redirects.
288    ///
289    /// Each redirect target is validated with `validate_url` and `resolve_and_validate`
290    /// before following, preventing SSRF via redirect chains.
291    ///
292    /// # Errors
293    ///
294    /// Returns `ToolError::Blocked` if any redirect target resolves to a private IP.
295    /// Returns `ToolError::Execution` on HTTP errors, too-large bodies, or too many redirects.
296    async fn fetch_html(
297        &self,
298        url: &str,
299        host: &str,
300        addrs: &[SocketAddr],
301    ) -> Result<String, ToolError> {
302        const MAX_REDIRECTS: usize = 3;
303
304        let mut current_url = url.to_owned();
305        let mut current_host = host.to_owned();
306        let mut current_addrs = addrs.to_vec();
307
308        for hop in 0..=MAX_REDIRECTS {
309            // Build a per-hop client pinned to the current hop's validated addresses.
310            let client = self.build_client(&current_host, &current_addrs);
311            let resp = client
312                .get(&current_url)
313                .send()
314                .await
315                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
316
317            let status = resp.status();
318
319            if status.is_redirection() {
320                if hop == MAX_REDIRECTS {
321                    return Err(ToolError::Execution(std::io::Error::other(
322                        "too many redirects",
323                    )));
324                }
325
326                let location = resp
327                    .headers()
328                    .get(reqwest::header::LOCATION)
329                    .and_then(|v| v.to_str().ok())
330                    .ok_or_else(|| {
331                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
332                    })?;
333
334                // Resolve relative redirect URLs against the current URL.
335                let base = Url::parse(&current_url)
336                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
337                let next_url = base
338                    .join(location)
339                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
340
341                let validated = validate_url(next_url.as_str())?;
342                let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
343
344                current_url = next_url.to_string();
345                current_host = next_host;
346                current_addrs = next_addrs;
347                continue;
348            }
349
350            if !status.is_success() {
351                return Err(ToolError::Execution(std::io::Error::other(format!(
352                    "HTTP {status}",
353                ))));
354            }
355
356            let bytes = resp
357                .bytes()
358                .await
359                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
360
361            if bytes.len() > self.max_body_bytes {
362                return Err(ToolError::Execution(std::io::Error::other(format!(
363                    "response too large: {} bytes (max: {})",
364                    bytes.len(),
365                    self.max_body_bytes,
366                ))));
367            }
368
369            return String::from_utf8(bytes.to_vec())
370                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
371        }
372
373        Err(ToolError::Execution(std::io::Error::other(
374            "too many redirects",
375        )))
376    }
377}
378
379fn extract_scrape_blocks(text: &str) -> Vec<&str> {
380    crate::executor::extract_fenced_blocks(text, "scrape")
381}
382
383fn validate_url(raw: &str) -> Result<Url, ToolError> {
384    let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
385        command: format!("invalid URL: {raw}"),
386    })?;
387
388    if parsed.scheme() != "https" {
389        return Err(ToolError::Blocked {
390            command: format!("scheme not allowed: {}", parsed.scheme()),
391        });
392    }
393
394    if let Some(host) = parsed.host()
395        && is_private_host(&host)
396    {
397        return Err(ToolError::Blocked {
398            command: format!(
399                "private/local host blocked: {}",
400                parsed.host_str().unwrap_or("")
401            ),
402        });
403    }
404
405    Ok(parsed)
406}
407
408pub(crate) fn is_private_ip(ip: IpAddr) -> bool {
409    match ip {
410        IpAddr::V4(v4) => {
411            v4.is_loopback()
412                || v4.is_private()
413                || v4.is_link_local()
414                || v4.is_unspecified()
415                || v4.is_broadcast()
416        }
417        IpAddr::V6(v6) => {
418            if v6.is_loopback() || v6.is_unspecified() {
419                return true;
420            }
421            let seg = v6.segments();
422            // fe80::/10 — link-local
423            if seg[0] & 0xffc0 == 0xfe80 {
424                return true;
425            }
426            // fc00::/7 — unique local
427            if seg[0] & 0xfe00 == 0xfc00 {
428                return true;
429            }
430            // ::ffff:x.x.x.x — IPv4-mapped, check inner IPv4
431            if seg[0..6] == [0, 0, 0, 0, 0, 0xffff] {
432                let v4 = v6
433                    .to_ipv4_mapped()
434                    .unwrap_or(std::net::Ipv4Addr::UNSPECIFIED);
435                return v4.is_loopback()
436                    || v4.is_private()
437                    || v4.is_link_local()
438                    || v4.is_unspecified()
439                    || v4.is_broadcast();
440            }
441            false
442        }
443    }
444}
445
446fn is_private_host(host: &url::Host<&str>) -> bool {
447    match host {
448        url::Host::Domain(d) => {
449            // Exact match or subdomain of localhost (e.g. foo.localhost)
450            // and .internal/.local TLDs used in cloud/k8s environments.
451            #[allow(clippy::case_sensitive_file_extension_comparisons)]
452            {
453                *d == "localhost"
454                    || d.ends_with(".localhost")
455                    || d.ends_with(".internal")
456                    || d.ends_with(".local")
457            }
458        }
459        url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
460        url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
461    }
462}
463
464/// Resolves DNS for the URL host, validates all resolved IPs against private ranges,
465/// and returns the hostname and validated socket addresses.
466///
467/// Returning the addresses allows the caller to pin the HTTP client to these exact
468/// addresses, eliminating TOCTOU between DNS validation and the actual connection.
469async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
470    let Some(host) = url.host_str() else {
471        return Ok((String::new(), vec![]));
472    };
473    let port = url.port_or_known_default().unwrap_or(443);
474    let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
475        .await
476        .map_err(|e| ToolError::Blocked {
477            command: format!("DNS resolution failed: {e}"),
478        })?
479        .collect();
480    for addr in &addrs {
481        if is_private_ip(addr.ip()) {
482            return Err(ToolError::Blocked {
483                command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
484            });
485        }
486    }
487    Ok((host.to_owned(), addrs))
488}
489
490fn parse_and_extract(
491    html: &str,
492    selector: &str,
493    extract: &ExtractMode,
494    limit: usize,
495) -> Result<String, ToolError> {
496    let soup = scrape_core::Soup::parse(html);
497
498    let tags = soup.find_all(selector).map_err(|e| {
499        ToolError::Execution(std::io::Error::new(
500            std::io::ErrorKind::InvalidData,
501            format!("invalid selector: {e}"),
502        ))
503    })?;
504
505    let mut results = Vec::new();
506
507    for tag in tags.into_iter().take(limit) {
508        let value = match extract {
509            ExtractMode::Text => tag.text(),
510            ExtractMode::Html => tag.inner_html(),
511            ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
512        };
513        if !value.trim().is_empty() {
514            results.push(value.trim().to_owned());
515        }
516    }
517
518    if results.is_empty() {
519        Ok(format!("No results for selector: {selector}"))
520    } else {
521        Ok(results.join("\n"))
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    // --- extract_scrape_blocks ---
530
531    #[test]
532    fn extract_single_block() {
533        let text =
534            "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
535        let blocks = extract_scrape_blocks(text);
536        assert_eq!(blocks.len(), 1);
537        assert!(blocks[0].contains("example.com"));
538    }
539
540    #[test]
541    fn extract_multiple_blocks() {
542        let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
543        let blocks = extract_scrape_blocks(text);
544        assert_eq!(blocks.len(), 2);
545    }
546
547    #[test]
548    fn no_blocks_returns_empty() {
549        let blocks = extract_scrape_blocks("plain text, no code blocks");
550        assert!(blocks.is_empty());
551    }
552
553    #[test]
554    fn unclosed_block_ignored() {
555        let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
556        assert!(blocks.is_empty());
557    }
558
559    #[test]
560    fn non_scrape_block_ignored() {
561        let text =
562            "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
563        let blocks = extract_scrape_blocks(text);
564        assert_eq!(blocks.len(), 1);
565        assert!(blocks[0].contains("x.com"));
566    }
567
568    #[test]
569    fn multiline_json_block() {
570        let text =
571            "```scrape\n{\n  \"url\": \"https://example.com\",\n  \"select\": \"h1\"\n}\n```";
572        let blocks = extract_scrape_blocks(text);
573        assert_eq!(blocks.len(), 1);
574        let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
575        assert_eq!(instr.url, "https://example.com");
576    }
577
578    // --- ScrapeInstruction parsing ---
579
580    #[test]
581    fn parse_valid_instruction() {
582        let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
583        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
584        assert_eq!(instr.url, "https://example.com");
585        assert_eq!(instr.select, "h1");
586        assert_eq!(instr.extract, "text");
587        assert_eq!(instr.limit, Some(5));
588    }
589
590    #[test]
591    fn parse_minimal_instruction() {
592        let json = r#"{"url":"https://example.com","select":"p"}"#;
593        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
594        assert_eq!(instr.extract, "text");
595        assert!(instr.limit.is_none());
596    }
597
598    #[test]
599    fn parse_attr_extract() {
600        let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
601        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
602        assert_eq!(instr.extract, "attr:href");
603    }
604
605    #[test]
606    fn parse_invalid_json_errors() {
607        let result = serde_json::from_str::<ScrapeInstruction>("not json");
608        assert!(result.is_err());
609    }
610
611    // --- ExtractMode ---
612
613    #[test]
614    fn extract_mode_text() {
615        assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
616    }
617
618    #[test]
619    fn extract_mode_html() {
620        assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
621    }
622
623    #[test]
624    fn extract_mode_attr() {
625        let mode = ExtractMode::parse("attr:href");
626        assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
627    }
628
629    #[test]
630    fn extract_mode_unknown_defaults_to_text() {
631        assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
632    }
633
634    // --- validate_url ---
635
636    #[test]
637    fn valid_https_url() {
638        assert!(validate_url("https://example.com").is_ok());
639    }
640
641    #[test]
642    fn http_rejected() {
643        let err = validate_url("http://example.com").unwrap_err();
644        assert!(matches!(err, ToolError::Blocked { .. }));
645    }
646
647    #[test]
648    fn ftp_rejected() {
649        let err = validate_url("ftp://files.example.com").unwrap_err();
650        assert!(matches!(err, ToolError::Blocked { .. }));
651    }
652
653    #[test]
654    fn file_rejected() {
655        let err = validate_url("file:///etc/passwd").unwrap_err();
656        assert!(matches!(err, ToolError::Blocked { .. }));
657    }
658
659    #[test]
660    fn invalid_url_rejected() {
661        let err = validate_url("not a url").unwrap_err();
662        assert!(matches!(err, ToolError::Blocked { .. }));
663    }
664
665    #[test]
666    fn localhost_blocked() {
667        let err = validate_url("https://localhost/path").unwrap_err();
668        assert!(matches!(err, ToolError::Blocked { .. }));
669    }
670
671    #[test]
672    fn loopback_ip_blocked() {
673        let err = validate_url("https://127.0.0.1/path").unwrap_err();
674        assert!(matches!(err, ToolError::Blocked { .. }));
675    }
676
677    #[test]
678    fn private_10_blocked() {
679        let err = validate_url("https://10.0.0.1/api").unwrap_err();
680        assert!(matches!(err, ToolError::Blocked { .. }));
681    }
682
683    #[test]
684    fn private_172_blocked() {
685        let err = validate_url("https://172.16.0.1/api").unwrap_err();
686        assert!(matches!(err, ToolError::Blocked { .. }));
687    }
688
689    #[test]
690    fn private_192_blocked() {
691        let err = validate_url("https://192.168.1.1/api").unwrap_err();
692        assert!(matches!(err, ToolError::Blocked { .. }));
693    }
694
695    #[test]
696    fn ipv6_loopback_blocked() {
697        let err = validate_url("https://[::1]/path").unwrap_err();
698        assert!(matches!(err, ToolError::Blocked { .. }));
699    }
700
701    #[test]
702    fn public_ip_allowed() {
703        assert!(validate_url("https://93.184.216.34/page").is_ok());
704    }
705
706    // --- parse_and_extract ---
707
708    #[test]
709    fn extract_text_from_html() {
710        let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
711        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
712        assert_eq!(result, "Hello World");
713    }
714
715    #[test]
716    fn extract_multiple_elements() {
717        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
718        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
719        assert_eq!(result, "A\nB\nC");
720    }
721
722    #[test]
723    fn extract_with_limit() {
724        let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
725        let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
726        assert_eq!(result, "A\nB");
727    }
728
729    #[test]
730    fn extract_attr_href() {
731        let html = r#"<a href="https://example.com">Link</a>"#;
732        let result =
733            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
734        assert_eq!(result, "https://example.com");
735    }
736
737    #[test]
738    fn extract_inner_html() {
739        let html = "<div><span>inner</span></div>";
740        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
741        assert!(result.contains("<span>inner</span>"));
742    }
743
744    #[test]
745    fn no_matches_returns_message() {
746        let html = "<html><body><p>text</p></body></html>";
747        let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
748        assert!(result.starts_with("No results for selector:"));
749    }
750
751    #[test]
752    fn empty_text_skipped() {
753        let html = "<ul><li>  </li><li>A</li></ul>";
754        let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
755        assert_eq!(result, "A");
756    }
757
758    #[test]
759    fn invalid_selector_errors() {
760        let html = "<html><body></body></html>";
761        let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
762        assert!(result.is_err());
763    }
764
765    #[test]
766    fn empty_html_returns_no_results() {
767        let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
768        assert!(result.starts_with("No results for selector:"));
769    }
770
771    #[test]
772    fn nested_selector() {
773        let html = "<div><span>inner</span></div><span>outer</span>";
774        let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
775        assert_eq!(result, "inner");
776    }
777
778    #[test]
779    fn attr_missing_returns_empty() {
780        let html = r#"<a>No href</a>"#;
781        let result =
782            parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
783        assert!(result.starts_with("No results for selector:"));
784    }
785
786    #[test]
787    fn extract_html_mode() {
788        let html = "<div><b>bold</b> text</div>";
789        let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
790        assert!(result.contains("<b>bold</b>"));
791    }
792
793    #[test]
794    fn limit_zero_returns_no_results() {
795        let html = "<ul><li>A</li><li>B</li></ul>";
796        let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
797        assert!(result.starts_with("No results for selector:"));
798    }
799
800    // --- validate_url edge cases ---
801
802    #[test]
803    fn url_with_port_allowed() {
804        assert!(validate_url("https://example.com:8443/path").is_ok());
805    }
806
807    #[test]
808    fn link_local_ip_blocked() {
809        let err = validate_url("https://169.254.1.1/path").unwrap_err();
810        assert!(matches!(err, ToolError::Blocked { .. }));
811    }
812
813    #[test]
814    fn url_no_scheme_rejected() {
815        let err = validate_url("example.com/path").unwrap_err();
816        assert!(matches!(err, ToolError::Blocked { .. }));
817    }
818
819    #[test]
820    fn unspecified_ipv4_blocked() {
821        let err = validate_url("https://0.0.0.0/path").unwrap_err();
822        assert!(matches!(err, ToolError::Blocked { .. }));
823    }
824
825    #[test]
826    fn broadcast_ipv4_blocked() {
827        let err = validate_url("https://255.255.255.255/path").unwrap_err();
828        assert!(matches!(err, ToolError::Blocked { .. }));
829    }
830
831    #[test]
832    fn ipv6_link_local_blocked() {
833        let err = validate_url("https://[fe80::1]/path").unwrap_err();
834        assert!(matches!(err, ToolError::Blocked { .. }));
835    }
836
837    #[test]
838    fn ipv6_unique_local_blocked() {
839        let err = validate_url("https://[fd12::1]/path").unwrap_err();
840        assert!(matches!(err, ToolError::Blocked { .. }));
841    }
842
843    #[test]
844    fn ipv4_mapped_ipv6_loopback_blocked() {
845        let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
846        assert!(matches!(err, ToolError::Blocked { .. }));
847    }
848
849    #[test]
850    fn ipv4_mapped_ipv6_private_blocked() {
851        let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
852        assert!(matches!(err, ToolError::Blocked { .. }));
853    }
854
855    // --- WebScrapeExecutor (no-network) ---
856
857    #[tokio::test]
858    async fn executor_no_blocks_returns_none() {
859        let config = ScrapeConfig::default();
860        let executor = WebScrapeExecutor::new(&config);
861        let result = executor.execute("plain text").await;
862        assert!(result.unwrap().is_none());
863    }
864
865    #[tokio::test]
866    async fn executor_invalid_json_errors() {
867        let config = ScrapeConfig::default();
868        let executor = WebScrapeExecutor::new(&config);
869        let response = "```scrape\nnot json\n```";
870        let result = executor.execute(response).await;
871        assert!(matches!(result, Err(ToolError::Execution(_))));
872    }
873
874    #[tokio::test]
875    async fn executor_blocked_url_errors() {
876        let config = ScrapeConfig::default();
877        let executor = WebScrapeExecutor::new(&config);
878        let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
879        let result = executor.execute(response).await;
880        assert!(matches!(result, Err(ToolError::Blocked { .. })));
881    }
882
883    #[tokio::test]
884    async fn executor_private_ip_blocked() {
885        let config = ScrapeConfig::default();
886        let executor = WebScrapeExecutor::new(&config);
887        let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
888        let result = executor.execute(response).await;
889        assert!(matches!(result, Err(ToolError::Blocked { .. })));
890    }
891
892    #[tokio::test]
893    async fn executor_unreachable_host_returns_error() {
894        let config = ScrapeConfig {
895            timeout: 1,
896            max_body_bytes: 1_048_576,
897        };
898        let executor = WebScrapeExecutor::new(&config);
899        let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
900        let result = executor.execute(response).await;
901        assert!(matches!(result, Err(ToolError::Execution(_))));
902    }
903
904    #[tokio::test]
905    async fn executor_localhost_url_blocked() {
906        let config = ScrapeConfig::default();
907        let executor = WebScrapeExecutor::new(&config);
908        let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
909        let result = executor.execute(response).await;
910        assert!(matches!(result, Err(ToolError::Blocked { .. })));
911    }
912
913    #[tokio::test]
914    async fn executor_empty_text_returns_none() {
915        let config = ScrapeConfig::default();
916        let executor = WebScrapeExecutor::new(&config);
917        let result = executor.execute("").await;
918        assert!(result.unwrap().is_none());
919    }
920
921    #[tokio::test]
922    async fn executor_multiple_blocks_first_blocked() {
923        let config = ScrapeConfig::default();
924        let executor = WebScrapeExecutor::new(&config);
925        let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
926             ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
927        let result = executor.execute(response).await;
928        assert!(result.is_err());
929    }
930
931    #[test]
932    fn validate_url_empty_string() {
933        let err = validate_url("").unwrap_err();
934        assert!(matches!(err, ToolError::Blocked { .. }));
935    }
936
937    #[test]
938    fn validate_url_javascript_scheme_blocked() {
939        let err = validate_url("javascript:alert(1)").unwrap_err();
940        assert!(matches!(err, ToolError::Blocked { .. }));
941    }
942
943    #[test]
944    fn validate_url_data_scheme_blocked() {
945        let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
946        assert!(matches!(err, ToolError::Blocked { .. }));
947    }
948
949    #[test]
950    fn is_private_host_public_domain_is_false() {
951        let host: url::Host<&str> = url::Host::Domain("example.com");
952        assert!(!is_private_host(&host));
953    }
954
955    #[test]
956    fn is_private_host_localhost_is_true() {
957        let host: url::Host<&str> = url::Host::Domain("localhost");
958        assert!(is_private_host(&host));
959    }
960
961    #[test]
962    fn is_private_host_ipv6_unspecified_is_true() {
963        let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
964        assert!(is_private_host(&host));
965    }
966
967    #[test]
968    fn is_private_host_public_ipv6_is_false() {
969        let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
970        assert!(!is_private_host(&host));
971    }
972
973    // --- fetch_html redirect logic: wiremock HTTP server tests ---
974    //
975    // These tests use a local wiremock server to exercise the redirect-following logic
976    // in `fetch_html` without requiring an external HTTPS connection. The server binds to
977    // 127.0.0.1, and tests call `fetch_html` directly (bypassing `validate_url`) to avoid
978    // the SSRF guard that would otherwise block loopback connections.
979
980    /// Helper: returns executor + (server_url, server_addr) from a running wiremock mock server.
981    /// The server address is passed to `fetch_html` via `resolve_to_addrs` so the client
982    /// connects to the mock instead of doing a real DNS lookup.
983    async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
984        let server = wiremock::MockServer::start().await;
985        let executor = WebScrapeExecutor {
986            timeout: Duration::from_secs(5),
987            max_body_bytes: 1_048_576,
988            audit_logger: None,
989        };
990        (executor, server)
991    }
992
993    /// Parses the mock server's URI into (host_str, socket_addr) for use with `build_client`.
994    fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
995        let uri = server.uri();
996        let url = Url::parse(&uri).unwrap();
997        let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
998        let port = url.port().unwrap_or(80);
999        let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1000        (host, vec![addr])
1001    }
1002
1003    /// Test-only redirect follower that mimics `fetch_html`'s loop but skips `validate_url` /
1004    /// `resolve_and_validate`. This lets us exercise the redirect-counting and
1005    /// missing-Location logic against a plain HTTP wiremock server.
1006    async fn follow_redirects_raw(
1007        executor: &WebScrapeExecutor,
1008        start_url: &str,
1009        host: &str,
1010        addrs: &[std::net::SocketAddr],
1011    ) -> Result<String, ToolError> {
1012        const MAX_REDIRECTS: usize = 3;
1013        let mut current_url = start_url.to_owned();
1014        let mut current_host = host.to_owned();
1015        let mut current_addrs = addrs.to_vec();
1016
1017        for hop in 0..=MAX_REDIRECTS {
1018            let client = executor.build_client(&current_host, &current_addrs);
1019            let resp = client
1020                .get(&current_url)
1021                .send()
1022                .await
1023                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1024
1025            let status = resp.status();
1026
1027            if status.is_redirection() {
1028                if hop == MAX_REDIRECTS {
1029                    return Err(ToolError::Execution(std::io::Error::other(
1030                        "too many redirects",
1031                    )));
1032                }
1033
1034                let location = resp
1035                    .headers()
1036                    .get(reqwest::header::LOCATION)
1037                    .and_then(|v| v.to_str().ok())
1038                    .ok_or_else(|| {
1039                        ToolError::Execution(std::io::Error::other("redirect with no Location"))
1040                    })?;
1041
1042                let base = Url::parse(&current_url)
1043                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1044                let next_url = base
1045                    .join(location)
1046                    .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1047
1048                // Re-use same host/addrs (mock server is always the same endpoint).
1049                current_url = next_url.to_string();
1050                // Preserve host/addrs as-is since the mock server doesn't change.
1051                let _ = &mut current_host;
1052                let _ = &mut current_addrs;
1053                continue;
1054            }
1055
1056            if !status.is_success() {
1057                return Err(ToolError::Execution(std::io::Error::other(format!(
1058                    "HTTP {status}",
1059                ))));
1060            }
1061
1062            let bytes = resp
1063                .bytes()
1064                .await
1065                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1066
1067            if bytes.len() > executor.max_body_bytes {
1068                return Err(ToolError::Execution(std::io::Error::other(format!(
1069                    "response too large: {} bytes (max: {})",
1070                    bytes.len(),
1071                    executor.max_body_bytes,
1072                ))));
1073            }
1074
1075            return String::from_utf8(bytes.to_vec())
1076                .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1077        }
1078
1079        Err(ToolError::Execution(std::io::Error::other(
1080            "too many redirects",
1081        )))
1082    }
1083
1084    #[tokio::test]
1085    async fn fetch_html_success_returns_body() {
1086        use wiremock::matchers::{method, path};
1087        use wiremock::{Mock, ResponseTemplate};
1088
1089        let (executor, server) = mock_server_executor().await;
1090        Mock::given(method("GET"))
1091            .and(path("/page"))
1092            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1093            .mount(&server)
1094            .await;
1095
1096        let (host, addrs) = server_host_and_addr(&server);
1097        let url = format!("{}/page", server.uri());
1098        let result = executor.fetch_html(&url, &host, &addrs).await;
1099        assert!(result.is_ok(), "expected Ok, got: {result:?}");
1100        assert_eq!(result.unwrap(), "<h1>OK</h1>");
1101    }
1102
1103    #[tokio::test]
1104    async fn fetch_html_non_2xx_returns_error() {
1105        use wiremock::matchers::{method, path};
1106        use wiremock::{Mock, ResponseTemplate};
1107
1108        let (executor, server) = mock_server_executor().await;
1109        Mock::given(method("GET"))
1110            .and(path("/forbidden"))
1111            .respond_with(ResponseTemplate::new(403))
1112            .mount(&server)
1113            .await;
1114
1115        let (host, addrs) = server_host_and_addr(&server);
1116        let url = format!("{}/forbidden", server.uri());
1117        let result = executor.fetch_html(&url, &host, &addrs).await;
1118        assert!(result.is_err());
1119        let msg = result.unwrap_err().to_string();
1120        assert!(msg.contains("403"), "expected 403 in error: {msg}");
1121    }
1122
1123    #[tokio::test]
1124    async fn fetch_html_404_returns_error() {
1125        use wiremock::matchers::{method, path};
1126        use wiremock::{Mock, ResponseTemplate};
1127
1128        let (executor, server) = mock_server_executor().await;
1129        Mock::given(method("GET"))
1130            .and(path("/missing"))
1131            .respond_with(ResponseTemplate::new(404))
1132            .mount(&server)
1133            .await;
1134
1135        let (host, addrs) = server_host_and_addr(&server);
1136        let url = format!("{}/missing", server.uri());
1137        let result = executor.fetch_html(&url, &host, &addrs).await;
1138        assert!(result.is_err());
1139        let msg = result.unwrap_err().to_string();
1140        assert!(msg.contains("404"), "expected 404 in error: {msg}");
1141    }
1142
1143    #[tokio::test]
1144    async fn fetch_html_redirect_no_location_returns_error() {
1145        use wiremock::matchers::{method, path};
1146        use wiremock::{Mock, ResponseTemplate};
1147
1148        let (executor, server) = mock_server_executor().await;
1149        // 302 with no Location header
1150        Mock::given(method("GET"))
1151            .and(path("/redirect-no-loc"))
1152            .respond_with(ResponseTemplate::new(302))
1153            .mount(&server)
1154            .await;
1155
1156        let (host, addrs) = server_host_and_addr(&server);
1157        let url = format!("{}/redirect-no-loc", server.uri());
1158        let result = executor.fetch_html(&url, &host, &addrs).await;
1159        assert!(result.is_err());
1160        let msg = result.unwrap_err().to_string();
1161        assert!(
1162            msg.contains("Location") || msg.contains("location"),
1163            "expected Location-related error: {msg}"
1164        );
1165    }
1166
1167    #[tokio::test]
1168    async fn fetch_html_single_redirect_followed() {
1169        use wiremock::matchers::{method, path};
1170        use wiremock::{Mock, ResponseTemplate};
1171
1172        let (executor, server) = mock_server_executor().await;
1173        let final_url = format!("{}/final", server.uri());
1174
1175        Mock::given(method("GET"))
1176            .and(path("/start"))
1177            .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1178            .mount(&server)
1179            .await;
1180
1181        Mock::given(method("GET"))
1182            .and(path("/final"))
1183            .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1184            .mount(&server)
1185            .await;
1186
1187        let (host, addrs) = server_host_and_addr(&server);
1188        let url = format!("{}/start", server.uri());
1189        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1190        assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1191        assert_eq!(result.unwrap(), "<p>final</p>");
1192    }
1193
1194    #[tokio::test]
1195    async fn fetch_html_three_redirects_allowed() {
1196        use wiremock::matchers::{method, path};
1197        use wiremock::{Mock, ResponseTemplate};
1198
1199        let (executor, server) = mock_server_executor().await;
1200        let hop2 = format!("{}/hop2", server.uri());
1201        let hop3 = format!("{}/hop3", server.uri());
1202        let final_dest = format!("{}/done", server.uri());
1203
1204        Mock::given(method("GET"))
1205            .and(path("/hop1"))
1206            .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1207            .mount(&server)
1208            .await;
1209        Mock::given(method("GET"))
1210            .and(path("/hop2"))
1211            .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1212            .mount(&server)
1213            .await;
1214        Mock::given(method("GET"))
1215            .and(path("/hop3"))
1216            .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1217            .mount(&server)
1218            .await;
1219        Mock::given(method("GET"))
1220            .and(path("/done"))
1221            .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1222            .mount(&server)
1223            .await;
1224
1225        let (host, addrs) = server_host_and_addr(&server);
1226        let url = format!("{}/hop1", server.uri());
1227        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1228        assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1229        assert_eq!(result.unwrap(), "<p>done</p>");
1230    }
1231
1232    #[tokio::test]
1233    async fn fetch_html_four_redirects_rejected() {
1234        use wiremock::matchers::{method, path};
1235        use wiremock::{Mock, ResponseTemplate};
1236
1237        let (executor, server) = mock_server_executor().await;
1238        let hop2 = format!("{}/r2", server.uri());
1239        let hop3 = format!("{}/r3", server.uri());
1240        let hop4 = format!("{}/r4", server.uri());
1241        let hop5 = format!("{}/r5", server.uri());
1242
1243        for (from, to) in [
1244            ("/r1", &hop2),
1245            ("/r2", &hop3),
1246            ("/r3", &hop4),
1247            ("/r4", &hop5),
1248        ] {
1249            Mock::given(method("GET"))
1250                .and(path(from))
1251                .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1252                .mount(&server)
1253                .await;
1254        }
1255
1256        let (host, addrs) = server_host_and_addr(&server);
1257        let url = format!("{}/r1", server.uri());
1258        let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1259        assert!(result.is_err(), "4 redirects should be rejected");
1260        let msg = result.unwrap_err().to_string();
1261        assert!(
1262            msg.contains("redirect"),
1263            "expected redirect-related error: {msg}"
1264        );
1265    }
1266
1267    #[tokio::test]
1268    async fn fetch_html_body_too_large_returns_error() {
1269        use wiremock::matchers::{method, path};
1270        use wiremock::{Mock, ResponseTemplate};
1271
1272        let small_limit_executor = WebScrapeExecutor {
1273            timeout: Duration::from_secs(5),
1274            max_body_bytes: 10,
1275            audit_logger: None,
1276        };
1277        let server = wiremock::MockServer::start().await;
1278        Mock::given(method("GET"))
1279            .and(path("/big"))
1280            .respond_with(
1281                ResponseTemplate::new(200)
1282                    .set_body_string("this body is definitely longer than ten bytes"),
1283            )
1284            .mount(&server)
1285            .await;
1286
1287        let (host, addrs) = server_host_and_addr(&server);
1288        let url = format!("{}/big", server.uri());
1289        let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1290        assert!(result.is_err());
1291        let msg = result.unwrap_err().to_string();
1292        assert!(msg.contains("too large"), "expected too-large error: {msg}");
1293    }
1294
1295    #[test]
1296    fn extract_scrape_blocks_empty_block_content() {
1297        let text = "```scrape\n\n```";
1298        let blocks = extract_scrape_blocks(text);
1299        assert_eq!(blocks.len(), 1);
1300        assert!(blocks[0].is_empty());
1301    }
1302
1303    #[test]
1304    fn extract_scrape_blocks_whitespace_only() {
1305        let text = "```scrape\n   \n```";
1306        let blocks = extract_scrape_blocks(text);
1307        assert_eq!(blocks.len(), 1);
1308    }
1309
1310    #[test]
1311    fn parse_and_extract_multiple_selectors() {
1312        let html = "<div><h1>Title</h1><p>Para</p></div>";
1313        let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1314        assert!(result.contains("Title"));
1315        assert!(result.contains("Para"));
1316    }
1317
1318    #[test]
1319    fn webscrape_executor_new_with_custom_config() {
1320        let config = ScrapeConfig {
1321            timeout: 60,
1322            max_body_bytes: 512,
1323        };
1324        let executor = WebScrapeExecutor::new(&config);
1325        assert_eq!(executor.max_body_bytes, 512);
1326    }
1327
1328    #[test]
1329    fn webscrape_executor_debug() {
1330        let config = ScrapeConfig::default();
1331        let executor = WebScrapeExecutor::new(&config);
1332        let dbg = format!("{executor:?}");
1333        assert!(dbg.contains("WebScrapeExecutor"));
1334    }
1335
1336    #[test]
1337    fn extract_mode_attr_empty_name() {
1338        let mode = ExtractMode::parse("attr:");
1339        assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1340    }
1341
1342    #[test]
1343    fn default_extract_returns_text() {
1344        assert_eq!(default_extract(), "text");
1345    }
1346
1347    #[test]
1348    fn scrape_instruction_debug() {
1349        let json = r#"{"url":"https://example.com","select":"h1"}"#;
1350        let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1351        let dbg = format!("{instr:?}");
1352        assert!(dbg.contains("ScrapeInstruction"));
1353    }
1354
1355    #[test]
1356    fn extract_mode_debug() {
1357        let mode = ExtractMode::Text;
1358        let dbg = format!("{mode:?}");
1359        assert!(dbg.contains("Text"));
1360    }
1361
1362    // --- fetch_html redirect logic: constant and validation unit tests ---
1363
1364    /// MAX_REDIRECTS is 3; the 4th redirect attempt must be rejected.
1365    /// Verify the boundary is correct by inspecting the constant value.
1366    #[test]
1367    fn max_redirects_constant_is_three() {
1368        // fetch_html uses `for hop in 0..=MAX_REDIRECTS` and returns error when hop == MAX_REDIRECTS
1369        // while still in a redirect. That means hops 0,1,2 can redirect; hop 3 triggers the error.
1370        // This test documents the expected limit.
1371        const MAX_REDIRECTS: usize = 3;
1372        assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1373    }
1374
1375    /// Verifies that a Location-less redirect would produce an error string containing the
1376    /// expected message, matching the error path in fetch_html.
1377    #[test]
1378    fn redirect_no_location_error_message() {
1379        let err = std::io::Error::other("redirect with no Location");
1380        assert!(err.to_string().contains("redirect with no Location"));
1381    }
1382
1383    /// Verifies that a too-many-redirects condition produces the expected error string.
1384    #[test]
1385    fn too_many_redirects_error_message() {
1386        let err = std::io::Error::other("too many redirects");
1387        assert!(err.to_string().contains("too many redirects"));
1388    }
1389
1390    /// Verifies that a non-2xx HTTP status produces an error message with the status code.
1391    #[test]
1392    fn non_2xx_status_error_format() {
1393        let status = reqwest::StatusCode::FORBIDDEN;
1394        let msg = format!("HTTP {status}");
1395        assert!(msg.contains("403"));
1396    }
1397
1398    /// Verifies that a 404 response status code formats into the expected error message.
1399    #[test]
1400    fn not_found_status_error_format() {
1401        let status = reqwest::StatusCode::NOT_FOUND;
1402        let msg = format!("HTTP {status}");
1403        assert!(msg.contains("404"));
1404    }
1405
1406    /// Verifies relative redirect resolution for same-host paths (simulates Location: /other).
1407    #[test]
1408    fn relative_redirect_same_host_path() {
1409        let base = Url::parse("https://example.com/current").unwrap();
1410        let resolved = base.join("/other").unwrap();
1411        assert_eq!(resolved.as_str(), "https://example.com/other");
1412    }
1413
1414    /// Verifies relative redirect resolution preserves scheme and host.
1415    #[test]
1416    fn relative_redirect_relative_path() {
1417        let base = Url::parse("https://example.com/a/b").unwrap();
1418        let resolved = base.join("c").unwrap();
1419        assert_eq!(resolved.as_str(), "https://example.com/a/c");
1420    }
1421
1422    /// Verifies that an absolute redirect URL overrides base URL completely.
1423    #[test]
1424    fn absolute_redirect_overrides_base() {
1425        let base = Url::parse("https://example.com/page").unwrap();
1426        let resolved = base.join("https://other.com/target").unwrap();
1427        assert_eq!(resolved.as_str(), "https://other.com/target");
1428    }
1429
1430    /// Verifies that a redirect Location of http:// (downgrade) is rejected.
1431    #[test]
1432    fn redirect_http_downgrade_rejected() {
1433        let location = "http://example.com/page";
1434        let base = Url::parse("https://example.com/start").unwrap();
1435        let next = base.join(location).unwrap();
1436        let err = validate_url(next.as_str()).unwrap_err();
1437        assert!(matches!(err, ToolError::Blocked { .. }));
1438    }
1439
1440    /// Verifies that a redirect to a private IP literal is blocked.
1441    #[test]
1442    fn redirect_location_private_ip_blocked() {
1443        let location = "https://192.168.100.1/admin";
1444        let base = Url::parse("https://example.com/start").unwrap();
1445        let next = base.join(location).unwrap();
1446        let err = validate_url(next.as_str()).unwrap_err();
1447        assert!(matches!(err, ToolError::Blocked { .. }));
1448        let cmd = match err {
1449            ToolError::Blocked { command } => command,
1450            _ => panic!("expected Blocked"),
1451        };
1452        assert!(
1453            cmd.contains("private") || cmd.contains("scheme"),
1454            "error message should describe the block reason: {cmd}"
1455        );
1456    }
1457
1458    /// Verifies that a redirect to a .internal domain is blocked.
1459    #[test]
1460    fn redirect_location_internal_domain_blocked() {
1461        let location = "https://metadata.internal/latest/meta-data/";
1462        let base = Url::parse("https://example.com/start").unwrap();
1463        let next = base.join(location).unwrap();
1464        let err = validate_url(next.as_str()).unwrap_err();
1465        assert!(matches!(err, ToolError::Blocked { .. }));
1466    }
1467
1468    /// Verifies that a chain of 3 valid public redirects passes validate_url at every hop.
1469    #[test]
1470    fn redirect_chain_three_hops_all_public() {
1471        let hops = [
1472            "https://redirect1.example.com/hop1",
1473            "https://redirect2.example.com/hop2",
1474            "https://destination.example.com/final",
1475        ];
1476        for hop in hops {
1477            assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1478        }
1479    }
1480
1481    // --- SSRF redirect chain defense ---
1482
1483    /// Verifies that a redirect Location pointing to a private IP is rejected by validate_url
1484    /// before any connection attempt — simulating the validation step inside fetch_html.
1485    #[test]
1486    fn redirect_to_private_ip_rejected_by_validate_url() {
1487        // These would appear as Location headers in a redirect response.
1488        let private_targets = [
1489            "https://127.0.0.1/secret",
1490            "https://10.0.0.1/internal",
1491            "https://192.168.1.1/admin",
1492            "https://172.16.0.1/data",
1493            "https://[::1]/path",
1494            "https://[fe80::1]/path",
1495            "https://localhost/path",
1496            "https://service.internal/api",
1497        ];
1498        for target in private_targets {
1499            let result = validate_url(target);
1500            assert!(result.is_err(), "expected error for {target}");
1501            assert!(
1502                matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1503                "expected Blocked for {target}"
1504            );
1505        }
1506    }
1507
1508    /// Verifies that relative redirect URLs are resolved correctly before validation.
1509    #[test]
1510    fn redirect_relative_url_resolves_correctly() {
1511        let base = Url::parse("https://example.com/page").unwrap();
1512        let relative = "/other";
1513        let resolved = base.join(relative).unwrap();
1514        assert_eq!(resolved.as_str(), "https://example.com/other");
1515    }
1516
1517    /// Verifies that a protocol-relative redirect to http:// is rejected (scheme check).
1518    #[test]
1519    fn redirect_to_http_rejected() {
1520        let err = validate_url("http://example.com/page").unwrap_err();
1521        assert!(matches!(err, ToolError::Blocked { .. }));
1522    }
1523
1524    #[test]
1525    fn ipv4_mapped_ipv6_link_local_blocked() {
1526        let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1527        assert!(matches!(err, ToolError::Blocked { .. }));
1528    }
1529
1530    #[test]
1531    fn ipv4_mapped_ipv6_public_allowed() {
1532        assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1533    }
1534
1535    // --- fetch tool ---
1536
1537    #[tokio::test]
1538    async fn fetch_http_scheme_blocked() {
1539        let config = ScrapeConfig::default();
1540        let executor = WebScrapeExecutor::new(&config);
1541        let call = crate::executor::ToolCall {
1542            tool_id: "fetch".to_owned(),
1543            params: {
1544                let mut m = serde_json::Map::new();
1545                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1546                m
1547            },
1548        };
1549        let result = executor.execute_tool_call(&call).await;
1550        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1551    }
1552
1553    #[tokio::test]
1554    async fn fetch_private_ip_blocked() {
1555        let config = ScrapeConfig::default();
1556        let executor = WebScrapeExecutor::new(&config);
1557        let call = crate::executor::ToolCall {
1558            tool_id: "fetch".to_owned(),
1559            params: {
1560                let mut m = serde_json::Map::new();
1561                m.insert(
1562                    "url".to_owned(),
1563                    serde_json::json!("https://192.168.1.1/secret"),
1564                );
1565                m
1566            },
1567        };
1568        let result = executor.execute_tool_call(&call).await;
1569        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1570    }
1571
1572    #[tokio::test]
1573    async fn fetch_localhost_blocked() {
1574        let config = ScrapeConfig::default();
1575        let executor = WebScrapeExecutor::new(&config);
1576        let call = crate::executor::ToolCall {
1577            tool_id: "fetch".to_owned(),
1578            params: {
1579                let mut m = serde_json::Map::new();
1580                m.insert(
1581                    "url".to_owned(),
1582                    serde_json::json!("https://localhost/page"),
1583                );
1584                m
1585            },
1586        };
1587        let result = executor.execute_tool_call(&call).await;
1588        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1589    }
1590
1591    #[tokio::test]
1592    async fn fetch_unknown_tool_returns_none() {
1593        let config = ScrapeConfig::default();
1594        let executor = WebScrapeExecutor::new(&config);
1595        let call = crate::executor::ToolCall {
1596            tool_id: "unknown_tool".to_owned(),
1597            params: serde_json::Map::new(),
1598        };
1599        let result = executor.execute_tool_call(&call).await;
1600        assert!(result.unwrap().is_none());
1601    }
1602
1603    #[tokio::test]
1604    async fn fetch_returns_body_via_mock() {
1605        use wiremock::matchers::{method, path};
1606        use wiremock::{Mock, ResponseTemplate};
1607
1608        let (executor, server) = mock_server_executor().await;
1609        Mock::given(method("GET"))
1610            .and(path("/content"))
1611            .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1612            .mount(&server)
1613            .await;
1614
1615        let (host, addrs) = server_host_and_addr(&server);
1616        let url = format!("{}/content", server.uri());
1617        let result = executor.fetch_html(&url, &host, &addrs).await;
1618        assert!(result.is_ok());
1619        assert_eq!(result.unwrap(), "plain text content");
1620    }
1621
1622    #[test]
1623    fn tool_definitions_returns_web_scrape_and_fetch() {
1624        let config = ScrapeConfig::default();
1625        let executor = WebScrapeExecutor::new(&config);
1626        let defs = executor.tool_definitions();
1627        assert_eq!(defs.len(), 2);
1628        assert_eq!(defs[0].id, "web_scrape");
1629        assert_eq!(
1630            defs[0].invocation,
1631            crate::registry::InvocationHint::FencedBlock("scrape")
1632        );
1633        assert_eq!(defs[1].id, "fetch");
1634        assert_eq!(
1635            defs[1].invocation,
1636            crate::registry::InvocationHint::ToolCall
1637        );
1638    }
1639
1640    #[test]
1641    fn tool_definitions_schema_has_all_params() {
1642        let config = ScrapeConfig::default();
1643        let executor = WebScrapeExecutor::new(&config);
1644        let defs = executor.tool_definitions();
1645        let obj = defs[0].schema.as_object().unwrap();
1646        let props = obj["properties"].as_object().unwrap();
1647        assert!(props.contains_key("url"));
1648        assert!(props.contains_key("select"));
1649        assert!(props.contains_key("extract"));
1650        assert!(props.contains_key("limit"));
1651        let req = obj["required"].as_array().unwrap();
1652        assert!(req.iter().any(|v| v.as_str() == Some("url")));
1653        assert!(req.iter().any(|v| v.as_str() == Some("select")));
1654        assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1655    }
1656
1657    // --- is_private_host: new domain checks (AUD-02) ---
1658
1659    #[test]
1660    fn subdomain_localhost_blocked() {
1661        let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1662        assert!(is_private_host(&host));
1663    }
1664
1665    #[test]
1666    fn internal_tld_blocked() {
1667        let host: url::Host<&str> = url::Host::Domain("service.internal");
1668        assert!(is_private_host(&host));
1669    }
1670
1671    #[test]
1672    fn local_tld_blocked() {
1673        let host: url::Host<&str> = url::Host::Domain("printer.local");
1674        assert!(is_private_host(&host));
1675    }
1676
1677    #[test]
1678    fn public_domain_not_blocked() {
1679        let host: url::Host<&str> = url::Host::Domain("example.com");
1680        assert!(!is_private_host(&host));
1681    }
1682
1683    // --- resolve_and_validate: private IP rejection ---
1684
1685    #[tokio::test]
1686    async fn resolve_loopback_rejected() {
1687        // 127.0.0.1 resolves directly (literal IP in DNS query)
1688        let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1689        // validate_url catches this before resolve_and_validate, but test directly
1690        let result = resolve_and_validate(&url).await;
1691        assert!(
1692            result.is_err(),
1693            "loopback IP must be rejected by resolve_and_validate"
1694        );
1695        let err = result.unwrap_err();
1696        assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1697    }
1698
1699    #[tokio::test]
1700    async fn resolve_private_10_rejected() {
1701        let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1702        let result = resolve_and_validate(&url).await;
1703        assert!(result.is_err());
1704        assert!(matches!(
1705            result.unwrap_err(),
1706            crate::executor::ToolError::Blocked { .. }
1707        ));
1708    }
1709
1710    #[tokio::test]
1711    async fn resolve_private_192_rejected() {
1712        let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1713        let result = resolve_and_validate(&url).await;
1714        assert!(result.is_err());
1715        assert!(matches!(
1716            result.unwrap_err(),
1717            crate::executor::ToolError::Blocked { .. }
1718        ));
1719    }
1720
1721    #[tokio::test]
1722    async fn resolve_ipv6_loopback_rejected() {
1723        let url = url::Url::parse("https://[::1]/path").unwrap();
1724        let result = resolve_and_validate(&url).await;
1725        assert!(result.is_err());
1726        assert!(matches!(
1727            result.unwrap_err(),
1728            crate::executor::ToolError::Blocked { .. }
1729        ));
1730    }
1731
1732    #[tokio::test]
1733    async fn resolve_no_host_returns_ok() {
1734        // URL without a resolvable host — should pass through
1735        let url = url::Url::parse("https://example.com/path").unwrap();
1736        // We can't do a live DNS test, but we can verify a URL with no host
1737        let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1738        // data: URLs have no host; resolve_and_validate should return Ok with empty addrs
1739        let result = resolve_and_validate(&url_no_host).await;
1740        assert!(result.is_ok());
1741        let (host, addrs) = result.unwrap();
1742        assert!(host.is_empty());
1743        assert!(addrs.is_empty());
1744        drop(url);
1745        drop(url_no_host);
1746    }
1747
1748    // --- audit logging ---
1749
1750    /// Helper: build an AuditLogger writing to a temp file, and return the logger + path.
1751    async fn make_file_audit_logger(
1752        dir: &tempfile::TempDir,
1753    ) -> (
1754        std::sync::Arc<crate::audit::AuditLogger>,
1755        std::path::PathBuf,
1756    ) {
1757        use crate::audit::AuditLogger;
1758        use crate::config::AuditConfig;
1759        let path = dir.path().join("audit.log");
1760        let config = AuditConfig {
1761            enabled: true,
1762            destination: path.display().to_string(),
1763        };
1764        let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1765        (logger, path)
1766    }
1767
1768    #[tokio::test]
1769    async fn with_audit_sets_logger() {
1770        let config = ScrapeConfig::default();
1771        let executor = WebScrapeExecutor::new(&config);
1772        assert!(executor.audit_logger.is_none());
1773
1774        let dir = tempfile::tempdir().unwrap();
1775        let (logger, _path) = make_file_audit_logger(&dir).await;
1776        let executor = executor.with_audit(logger);
1777        assert!(executor.audit_logger.is_some());
1778    }
1779
1780    #[test]
1781    fn tool_error_to_audit_result_blocked_maps_correctly() {
1782        let err = ToolError::Blocked {
1783            command: "scheme not allowed: http".into(),
1784        };
1785        let result = tool_error_to_audit_result(&err);
1786        assert!(
1787            matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1788        );
1789    }
1790
1791    #[test]
1792    fn tool_error_to_audit_result_timeout_maps_correctly() {
1793        let err = ToolError::Timeout { timeout_secs: 15 };
1794        let result = tool_error_to_audit_result(&err);
1795        assert!(matches!(result, AuditResult::Timeout));
1796    }
1797
1798    #[test]
1799    fn tool_error_to_audit_result_execution_error_maps_correctly() {
1800        let err = ToolError::Execution(std::io::Error::other("connection refused"));
1801        let result = tool_error_to_audit_result(&err);
1802        assert!(
1803            matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1804        );
1805    }
1806
1807    #[tokio::test]
1808    async fn fetch_audit_blocked_url_logged() {
1809        let dir = tempfile::tempdir().unwrap();
1810        let (logger, log_path) = make_file_audit_logger(&dir).await;
1811
1812        let config = ScrapeConfig::default();
1813        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1814
1815        let call = crate::executor::ToolCall {
1816            tool_id: "fetch".to_owned(),
1817            params: {
1818                let mut m = serde_json::Map::new();
1819                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1820                m
1821            },
1822        };
1823        let result = executor.execute_tool_call(&call).await;
1824        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1825
1826        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1827        assert!(
1828            content.contains("\"tool\":\"fetch\""),
1829            "expected tool=fetch in audit: {content}"
1830        );
1831        assert!(
1832            content.contains("\"type\":\"blocked\""),
1833            "expected type=blocked in audit: {content}"
1834        );
1835        assert!(
1836            content.contains("http://example.com"),
1837            "expected URL in audit command field: {content}"
1838        );
1839    }
1840
1841    #[tokio::test]
1842    async fn log_audit_success_writes_to_file() {
1843        let dir = tempfile::tempdir().unwrap();
1844        let (logger, log_path) = make_file_audit_logger(&dir).await;
1845
1846        let config = ScrapeConfig::default();
1847        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1848
1849        executor
1850            .log_audit(
1851                "fetch",
1852                "https://example.com/page",
1853                AuditResult::Success,
1854                42,
1855            )
1856            .await;
1857
1858        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1859        assert!(
1860            content.contains("\"tool\":\"fetch\""),
1861            "expected tool=fetch in audit: {content}"
1862        );
1863        assert!(
1864            content.contains("\"type\":\"success\""),
1865            "expected type=success in audit: {content}"
1866        );
1867        assert!(
1868            content.contains("\"command\":\"https://example.com/page\""),
1869            "expected command URL in audit: {content}"
1870        );
1871        assert!(
1872            content.contains("\"duration_ms\":42"),
1873            "expected duration_ms in audit: {content}"
1874        );
1875    }
1876
1877    #[tokio::test]
1878    async fn log_audit_blocked_writes_to_file() {
1879        let dir = tempfile::tempdir().unwrap();
1880        let (logger, log_path) = make_file_audit_logger(&dir).await;
1881
1882        let config = ScrapeConfig::default();
1883        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1884
1885        executor
1886            .log_audit(
1887                "web_scrape",
1888                "http://evil.com/page",
1889                AuditResult::Blocked {
1890                    reason: "scheme not allowed: http".into(),
1891                },
1892                0,
1893            )
1894            .await;
1895
1896        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1897        assert!(
1898            content.contains("\"tool\":\"web_scrape\""),
1899            "expected tool=web_scrape in audit: {content}"
1900        );
1901        assert!(
1902            content.contains("\"type\":\"blocked\""),
1903            "expected type=blocked in audit: {content}"
1904        );
1905        assert!(
1906            content.contains("scheme not allowed"),
1907            "expected block reason in audit: {content}"
1908        );
1909    }
1910
1911    #[tokio::test]
1912    async fn web_scrape_audit_blocked_url_logged() {
1913        let dir = tempfile::tempdir().unwrap();
1914        let (logger, log_path) = make_file_audit_logger(&dir).await;
1915
1916        let config = ScrapeConfig::default();
1917        let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1918
1919        let call = crate::executor::ToolCall {
1920            tool_id: "web_scrape".to_owned(),
1921            params: {
1922                let mut m = serde_json::Map::new();
1923                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1924                m.insert("select".to_owned(), serde_json::json!("h1"));
1925                m
1926            },
1927        };
1928        let result = executor.execute_tool_call(&call).await;
1929        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1930
1931        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1932        assert!(
1933            content.contains("\"tool\":\"web_scrape\""),
1934            "expected tool=web_scrape in audit: {content}"
1935        );
1936        assert!(
1937            content.contains("\"type\":\"blocked\""),
1938            "expected type=blocked in audit: {content}"
1939        );
1940    }
1941
1942    #[tokio::test]
1943    async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1944        let config = ScrapeConfig::default();
1945        let executor = WebScrapeExecutor::new(&config);
1946        assert!(executor.audit_logger.is_none());
1947
1948        let call = crate::executor::ToolCall {
1949            tool_id: "fetch".to_owned(),
1950            params: {
1951                let mut m = serde_json::Map::new();
1952                m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1953                m
1954            },
1955        };
1956        // Must not panic even without an audit logger
1957        let result = executor.execute_tool_call(&call).await;
1958        assert!(matches!(result, Err(ToolError::Blocked { .. })));
1959    }
1960
1961    // CR-10: fetch end-to-end via execute_tool_call -> handle_fetch -> fetch_html
1962    #[tokio::test]
1963    async fn fetch_execute_tool_call_end_to_end() {
1964        use wiremock::matchers::{method, path};
1965        use wiremock::{Mock, ResponseTemplate};
1966
1967        let (executor, server) = mock_server_executor().await;
1968        Mock::given(method("GET"))
1969            .and(path("/e2e"))
1970            .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1971            .mount(&server)
1972            .await;
1973
1974        let (host, addrs) = server_host_and_addr(&server);
1975        // Call fetch_html directly (bypassing SSRF guard for loopback mock server)
1976        let result = executor
1977            .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1978            .await;
1979        assert!(result.is_ok());
1980        assert!(result.unwrap().contains("end-to-end"));
1981    }
1982}