1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{
15 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
16};
17use crate::net::is_private_ip;
18
19#[derive(Debug, Deserialize, JsonSchema)]
20struct FetchParams {
21 url: String,
23}
24
25#[derive(Debug, Deserialize, JsonSchema)]
26struct ScrapeInstruction {
27 url: String,
29 select: String,
31 #[serde(default = "default_extract")]
33 extract: String,
34 limit: Option<usize>,
36}
37
38fn default_extract() -> String {
39 "text".into()
40}
41
42#[derive(Debug)]
43enum ExtractMode {
44 Text,
45 Html,
46 Attr(String),
47}
48
49impl ExtractMode {
50 fn parse(s: &str) -> Self {
51 match s {
52 "text" => Self::Text,
53 "html" => Self::Html,
54 attr if attr.starts_with("attr:") => {
55 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
56 }
57 _ => Self::Text,
58 }
59 }
60}
61
62#[derive(Debug)]
67pub struct WebScrapeExecutor {
68 timeout: Duration,
69 max_body_bytes: usize,
70 audit_logger: Option<Arc<AuditLogger>>,
71}
72
73impl WebScrapeExecutor {
74 #[must_use]
75 pub fn new(config: &ScrapeConfig) -> Self {
76 Self {
77 timeout: Duration::from_secs(config.timeout),
78 max_body_bytes: config.max_body_bytes,
79 audit_logger: None,
80 }
81 }
82
83 #[must_use]
84 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
85 self.audit_logger = Some(logger);
86 self
87 }
88
89 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
90 let mut builder = reqwest::Client::builder()
91 .timeout(self.timeout)
92 .redirect(reqwest::redirect::Policy::none());
93 builder = builder.resolve_to_addrs(host, addrs);
94 builder.build().unwrap_or_default()
95 }
96}
97
98impl ToolExecutor for WebScrapeExecutor {
99 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
100 use crate::registry::{InvocationHint, ToolDef};
101 vec![
102 ToolDef {
103 id: "web_scrape".into(),
104 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
105 schema: schemars::schema_for!(ScrapeInstruction),
106 invocation: InvocationHint::FencedBlock("scrape"),
107 },
108 ToolDef {
109 id: "fetch".into(),
110 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
111 schema: schemars::schema_for!(FetchParams),
112 invocation: InvocationHint::ToolCall,
113 },
114 ]
115 }
116
117 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
118 let blocks = extract_scrape_blocks(response);
119 if blocks.is_empty() {
120 return Ok(None);
121 }
122
123 let mut outputs = Vec::with_capacity(blocks.len());
124 #[allow(clippy::cast_possible_truncation)]
125 let blocks_executed = blocks.len() as u32;
126
127 for block in &blocks {
128 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
129 ToolError::Execution(std::io::Error::new(
130 std::io::ErrorKind::InvalidData,
131 e.to_string(),
132 ))
133 })?;
134 let start = Instant::now();
135 let scrape_result = self.scrape_instruction(&instruction).await;
136 #[allow(clippy::cast_possible_truncation)]
137 let duration_ms = start.elapsed().as_millis() as u64;
138 match scrape_result {
139 Ok(output) => {
140 self.log_audit(
141 "web_scrape",
142 &instruction.url,
143 AuditResult::Success,
144 duration_ms,
145 None,
146 )
147 .await;
148 outputs.push(output);
149 }
150 Err(e) => {
151 let audit_result = tool_error_to_audit_result(&e);
152 self.log_audit(
153 "web_scrape",
154 &instruction.url,
155 audit_result,
156 duration_ms,
157 Some(&e),
158 )
159 .await;
160 return Err(e);
161 }
162 }
163 }
164
165 Ok(Some(ToolOutput {
166 tool_name: "web-scrape".to_owned(),
167 summary: outputs.join("\n\n"),
168 blocks_executed,
169 filter_stats: None,
170 diff: None,
171 streamed: false,
172 terminal_id: None,
173 locations: None,
174 raw_response: None,
175 claim_source: Some(ClaimSource::WebScrape),
176 }))
177 }
178
179 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
180 match call.tool_id.as_str() {
181 "web_scrape" => {
182 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
183 let start = Instant::now();
184 let result = self.scrape_instruction(&instruction).await;
185 #[allow(clippy::cast_possible_truncation)]
186 let duration_ms = start.elapsed().as_millis() as u64;
187 match result {
188 Ok(output) => {
189 self.log_audit(
190 "web_scrape",
191 &instruction.url,
192 AuditResult::Success,
193 duration_ms,
194 None,
195 )
196 .await;
197 Ok(Some(ToolOutput {
198 tool_name: "web-scrape".to_owned(),
199 summary: output,
200 blocks_executed: 1,
201 filter_stats: None,
202 diff: None,
203 streamed: false,
204 terminal_id: None,
205 locations: None,
206 raw_response: None,
207 claim_source: Some(ClaimSource::WebScrape),
208 }))
209 }
210 Err(e) => {
211 let audit_result = tool_error_to_audit_result(&e);
212 self.log_audit(
213 "web_scrape",
214 &instruction.url,
215 audit_result,
216 duration_ms,
217 Some(&e),
218 )
219 .await;
220 Err(e)
221 }
222 }
223 }
224 "fetch" => {
225 let p: FetchParams = deserialize_params(&call.params)?;
226 let start = Instant::now();
227 let result = self.handle_fetch(&p).await;
228 #[allow(clippy::cast_possible_truncation)]
229 let duration_ms = start.elapsed().as_millis() as u64;
230 match result {
231 Ok(output) => {
232 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms, None)
233 .await;
234 Ok(Some(ToolOutput {
235 tool_name: "fetch".to_owned(),
236 summary: output,
237 blocks_executed: 1,
238 filter_stats: None,
239 diff: None,
240 streamed: false,
241 terminal_id: None,
242 locations: None,
243 raw_response: None,
244 claim_source: Some(ClaimSource::WebScrape),
245 }))
246 }
247 Err(e) => {
248 let audit_result = tool_error_to_audit_result(&e);
249 self.log_audit("fetch", &p.url, audit_result, duration_ms, Some(&e))
250 .await;
251 Err(e)
252 }
253 }
254 }
255 _ => Ok(None),
256 }
257 }
258
259 fn is_tool_retryable(&self, tool_id: &str) -> bool {
260 matches!(tool_id, "web_scrape" | "fetch")
261 }
262}
263
264fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
265 match e {
266 ToolError::Blocked { command } => AuditResult::Blocked {
267 reason: command.clone(),
268 },
269 ToolError::Timeout { .. } => AuditResult::Timeout,
270 _ => AuditResult::Error {
271 message: e.to_string(),
272 },
273 }
274}
275
276impl WebScrapeExecutor {
277 async fn log_audit(
278 &self,
279 tool: &str,
280 command: &str,
281 result: AuditResult,
282 duration_ms: u64,
283 error: Option<&ToolError>,
284 ) {
285 if let Some(ref logger) = self.audit_logger {
286 let (error_category, error_domain, error_phase) =
287 error.map_or((None, None, None), |e| {
288 let cat = e.category();
289 (
290 Some(cat.label().to_owned()),
291 Some(cat.domain().label().to_owned()),
292 Some(cat.phase().label().to_owned()),
293 )
294 });
295 let entry = AuditEntry {
296 timestamp: chrono_now(),
297 tool: tool.into(),
298 command: command.into(),
299 result,
300 duration_ms,
301 error_category,
302 error_domain,
303 error_phase,
304 claim_source: Some(ClaimSource::WebScrape),
305 mcp_server_id: None,
306 injection_flagged: false,
307 embedding_anomalous: false,
308 };
309 logger.log(&entry).await;
310 }
311 }
312
313 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
314 let parsed = validate_url(¶ms.url)?;
315 let (host, addrs) = resolve_and_validate(&parsed).await?;
316 self.fetch_html(¶ms.url, &host, &addrs).await
317 }
318
319 async fn scrape_instruction(
320 &self,
321 instruction: &ScrapeInstruction,
322 ) -> Result<String, ToolError> {
323 let parsed = validate_url(&instruction.url)?;
324 let (host, addrs) = resolve_and_validate(&parsed).await?;
325 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
326 let selector = instruction.select.clone();
327 let extract = ExtractMode::parse(&instruction.extract);
328 let limit = instruction.limit.unwrap_or(10);
329 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
330 .await
331 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
332 }
333
334 async fn fetch_html(
344 &self,
345 url: &str,
346 host: &str,
347 addrs: &[SocketAddr],
348 ) -> Result<String, ToolError> {
349 const MAX_REDIRECTS: usize = 3;
350
351 let mut current_url = url.to_owned();
352 let mut current_host = host.to_owned();
353 let mut current_addrs = addrs.to_vec();
354
355 for hop in 0..=MAX_REDIRECTS {
356 let client = self.build_client(¤t_host, ¤t_addrs);
358 let resp = client
359 .get(¤t_url)
360 .send()
361 .await
362 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
363
364 let status = resp.status();
365
366 if status.is_redirection() {
367 if hop == MAX_REDIRECTS {
368 return Err(ToolError::Execution(std::io::Error::other(
369 "too many redirects",
370 )));
371 }
372
373 let location = resp
374 .headers()
375 .get(reqwest::header::LOCATION)
376 .and_then(|v| v.to_str().ok())
377 .ok_or_else(|| {
378 ToolError::Execution(std::io::Error::other("redirect with no Location"))
379 })?;
380
381 let base = Url::parse(¤t_url)
383 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
384 let next_url = base
385 .join(location)
386 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
387
388 let validated = validate_url(next_url.as_str())?;
389 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
390
391 current_url = next_url.to_string();
392 current_host = next_host;
393 current_addrs = next_addrs;
394 continue;
395 }
396
397 if !status.is_success() {
398 return Err(ToolError::Http {
399 status: status.as_u16(),
400 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
401 });
402 }
403
404 let bytes = resp
405 .bytes()
406 .await
407 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
408
409 if bytes.len() > self.max_body_bytes {
410 return Err(ToolError::Execution(std::io::Error::other(format!(
411 "response too large: {} bytes (max: {})",
412 bytes.len(),
413 self.max_body_bytes,
414 ))));
415 }
416
417 return String::from_utf8(bytes.to_vec())
418 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
419 }
420
421 Err(ToolError::Execution(std::io::Error::other(
422 "too many redirects",
423 )))
424 }
425}
426
427fn extract_scrape_blocks(text: &str) -> Vec<&str> {
428 crate::executor::extract_fenced_blocks(text, "scrape")
429}
430
431fn validate_url(raw: &str) -> Result<Url, ToolError> {
432 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
433 command: format!("invalid URL: {raw}"),
434 })?;
435
436 if parsed.scheme() != "https" {
437 return Err(ToolError::Blocked {
438 command: format!("scheme not allowed: {}", parsed.scheme()),
439 });
440 }
441
442 if let Some(host) = parsed.host()
443 && is_private_host(&host)
444 {
445 return Err(ToolError::Blocked {
446 command: format!(
447 "private/local host blocked: {}",
448 parsed.host_str().unwrap_or("")
449 ),
450 });
451 }
452
453 Ok(parsed)
454}
455
456fn is_private_host(host: &url::Host<&str>) -> bool {
457 match host {
458 url::Host::Domain(d) => {
459 #[allow(clippy::case_sensitive_file_extension_comparisons)]
462 {
463 *d == "localhost"
464 || d.ends_with(".localhost")
465 || d.ends_with(".internal")
466 || d.ends_with(".local")
467 }
468 }
469 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
470 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
471 }
472}
473
474async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
480 let Some(host) = url.host_str() else {
481 return Ok((String::new(), vec![]));
482 };
483 let port = url.port_or_known_default().unwrap_or(443);
484 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
485 .await
486 .map_err(|e| ToolError::Blocked {
487 command: format!("DNS resolution failed: {e}"),
488 })?
489 .collect();
490 for addr in &addrs {
491 if is_private_ip(addr.ip()) {
492 return Err(ToolError::Blocked {
493 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
494 });
495 }
496 }
497 Ok((host.to_owned(), addrs))
498}
499
500fn parse_and_extract(
501 html: &str,
502 selector: &str,
503 extract: &ExtractMode,
504 limit: usize,
505) -> Result<String, ToolError> {
506 let soup = scrape_core::Soup::parse(html);
507
508 let tags = soup.find_all(selector).map_err(|e| {
509 ToolError::Execution(std::io::Error::new(
510 std::io::ErrorKind::InvalidData,
511 format!("invalid selector: {e}"),
512 ))
513 })?;
514
515 let mut results = Vec::new();
516
517 for tag in tags.into_iter().take(limit) {
518 let value = match extract {
519 ExtractMode::Text => tag.text(),
520 ExtractMode::Html => tag.inner_html(),
521 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
522 };
523 if !value.trim().is_empty() {
524 results.push(value.trim().to_owned());
525 }
526 }
527
528 if results.is_empty() {
529 Ok(format!("No results for selector: {selector}"))
530 } else {
531 Ok(results.join("\n"))
532 }
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
542 fn extract_single_block() {
543 let text =
544 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
545 let blocks = extract_scrape_blocks(text);
546 assert_eq!(blocks.len(), 1);
547 assert!(blocks[0].contains("example.com"));
548 }
549
550 #[test]
551 fn extract_multiple_blocks() {
552 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
553 let blocks = extract_scrape_blocks(text);
554 assert_eq!(blocks.len(), 2);
555 }
556
557 #[test]
558 fn no_blocks_returns_empty() {
559 let blocks = extract_scrape_blocks("plain text, no code blocks");
560 assert!(blocks.is_empty());
561 }
562
563 #[test]
564 fn unclosed_block_ignored() {
565 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
566 assert!(blocks.is_empty());
567 }
568
569 #[test]
570 fn non_scrape_block_ignored() {
571 let text =
572 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
573 let blocks = extract_scrape_blocks(text);
574 assert_eq!(blocks.len(), 1);
575 assert!(blocks[0].contains("x.com"));
576 }
577
578 #[test]
579 fn multiline_json_block() {
580 let text =
581 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
582 let blocks = extract_scrape_blocks(text);
583 assert_eq!(blocks.len(), 1);
584 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
585 assert_eq!(instr.url, "https://example.com");
586 }
587
588 #[test]
591 fn parse_valid_instruction() {
592 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
593 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
594 assert_eq!(instr.url, "https://example.com");
595 assert_eq!(instr.select, "h1");
596 assert_eq!(instr.extract, "text");
597 assert_eq!(instr.limit, Some(5));
598 }
599
600 #[test]
601 fn parse_minimal_instruction() {
602 let json = r#"{"url":"https://example.com","select":"p"}"#;
603 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
604 assert_eq!(instr.extract, "text");
605 assert!(instr.limit.is_none());
606 }
607
608 #[test]
609 fn parse_attr_extract() {
610 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
611 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
612 assert_eq!(instr.extract, "attr:href");
613 }
614
615 #[test]
616 fn parse_invalid_json_errors() {
617 let result = serde_json::from_str::<ScrapeInstruction>("not json");
618 assert!(result.is_err());
619 }
620
621 #[test]
624 fn extract_mode_text() {
625 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
626 }
627
628 #[test]
629 fn extract_mode_html() {
630 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
631 }
632
633 #[test]
634 fn extract_mode_attr() {
635 let mode = ExtractMode::parse("attr:href");
636 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
637 }
638
639 #[test]
640 fn extract_mode_unknown_defaults_to_text() {
641 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
642 }
643
644 #[test]
647 fn valid_https_url() {
648 assert!(validate_url("https://example.com").is_ok());
649 }
650
651 #[test]
652 fn http_rejected() {
653 let err = validate_url("http://example.com").unwrap_err();
654 assert!(matches!(err, ToolError::Blocked { .. }));
655 }
656
657 #[test]
658 fn ftp_rejected() {
659 let err = validate_url("ftp://files.example.com").unwrap_err();
660 assert!(matches!(err, ToolError::Blocked { .. }));
661 }
662
663 #[test]
664 fn file_rejected() {
665 let err = validate_url("file:///etc/passwd").unwrap_err();
666 assert!(matches!(err, ToolError::Blocked { .. }));
667 }
668
669 #[test]
670 fn invalid_url_rejected() {
671 let err = validate_url("not a url").unwrap_err();
672 assert!(matches!(err, ToolError::Blocked { .. }));
673 }
674
675 #[test]
676 fn localhost_blocked() {
677 let err = validate_url("https://localhost/path").unwrap_err();
678 assert!(matches!(err, ToolError::Blocked { .. }));
679 }
680
681 #[test]
682 fn loopback_ip_blocked() {
683 let err = validate_url("https://127.0.0.1/path").unwrap_err();
684 assert!(matches!(err, ToolError::Blocked { .. }));
685 }
686
687 #[test]
688 fn private_10_blocked() {
689 let err = validate_url("https://10.0.0.1/api").unwrap_err();
690 assert!(matches!(err, ToolError::Blocked { .. }));
691 }
692
693 #[test]
694 fn private_172_blocked() {
695 let err = validate_url("https://172.16.0.1/api").unwrap_err();
696 assert!(matches!(err, ToolError::Blocked { .. }));
697 }
698
699 #[test]
700 fn private_192_blocked() {
701 let err = validate_url("https://192.168.1.1/api").unwrap_err();
702 assert!(matches!(err, ToolError::Blocked { .. }));
703 }
704
705 #[test]
706 fn ipv6_loopback_blocked() {
707 let err = validate_url("https://[::1]/path").unwrap_err();
708 assert!(matches!(err, ToolError::Blocked { .. }));
709 }
710
711 #[test]
712 fn public_ip_allowed() {
713 assert!(validate_url("https://93.184.216.34/page").is_ok());
714 }
715
716 #[test]
719 fn extract_text_from_html() {
720 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
721 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
722 assert_eq!(result, "Hello World");
723 }
724
725 #[test]
726 fn extract_multiple_elements() {
727 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
728 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
729 assert_eq!(result, "A\nB\nC");
730 }
731
732 #[test]
733 fn extract_with_limit() {
734 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
735 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
736 assert_eq!(result, "A\nB");
737 }
738
739 #[test]
740 fn extract_attr_href() {
741 let html = r#"<a href="https://example.com">Link</a>"#;
742 let result =
743 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
744 assert_eq!(result, "https://example.com");
745 }
746
747 #[test]
748 fn extract_inner_html() {
749 let html = "<div><span>inner</span></div>";
750 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
751 assert!(result.contains("<span>inner</span>"));
752 }
753
754 #[test]
755 fn no_matches_returns_message() {
756 let html = "<html><body><p>text</p></body></html>";
757 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
758 assert!(result.starts_with("No results for selector:"));
759 }
760
761 #[test]
762 fn empty_text_skipped() {
763 let html = "<ul><li> </li><li>A</li></ul>";
764 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
765 assert_eq!(result, "A");
766 }
767
768 #[test]
769 fn invalid_selector_errors() {
770 let html = "<html><body></body></html>";
771 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
772 assert!(result.is_err());
773 }
774
775 #[test]
776 fn empty_html_returns_no_results() {
777 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
778 assert!(result.starts_with("No results for selector:"));
779 }
780
781 #[test]
782 fn nested_selector() {
783 let html = "<div><span>inner</span></div><span>outer</span>";
784 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
785 assert_eq!(result, "inner");
786 }
787
788 #[test]
789 fn attr_missing_returns_empty() {
790 let html = r"<a>No href</a>";
791 let result =
792 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
793 assert!(result.starts_with("No results for selector:"));
794 }
795
796 #[test]
797 fn extract_html_mode() {
798 let html = "<div><b>bold</b> text</div>";
799 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
800 assert!(result.contains("<b>bold</b>"));
801 }
802
803 #[test]
804 fn limit_zero_returns_no_results() {
805 let html = "<ul><li>A</li><li>B</li></ul>";
806 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
807 assert!(result.starts_with("No results for selector:"));
808 }
809
810 #[test]
813 fn url_with_port_allowed() {
814 assert!(validate_url("https://example.com:8443/path").is_ok());
815 }
816
817 #[test]
818 fn link_local_ip_blocked() {
819 let err = validate_url("https://169.254.1.1/path").unwrap_err();
820 assert!(matches!(err, ToolError::Blocked { .. }));
821 }
822
823 #[test]
824 fn url_no_scheme_rejected() {
825 let err = validate_url("example.com/path").unwrap_err();
826 assert!(matches!(err, ToolError::Blocked { .. }));
827 }
828
829 #[test]
830 fn unspecified_ipv4_blocked() {
831 let err = validate_url("https://0.0.0.0/path").unwrap_err();
832 assert!(matches!(err, ToolError::Blocked { .. }));
833 }
834
835 #[test]
836 fn broadcast_ipv4_blocked() {
837 let err = validate_url("https://255.255.255.255/path").unwrap_err();
838 assert!(matches!(err, ToolError::Blocked { .. }));
839 }
840
841 #[test]
842 fn ipv6_link_local_blocked() {
843 let err = validate_url("https://[fe80::1]/path").unwrap_err();
844 assert!(matches!(err, ToolError::Blocked { .. }));
845 }
846
847 #[test]
848 fn ipv6_unique_local_blocked() {
849 let err = validate_url("https://[fd12::1]/path").unwrap_err();
850 assert!(matches!(err, ToolError::Blocked { .. }));
851 }
852
853 #[test]
854 fn ipv4_mapped_ipv6_loopback_blocked() {
855 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
856 assert!(matches!(err, ToolError::Blocked { .. }));
857 }
858
859 #[test]
860 fn ipv4_mapped_ipv6_private_blocked() {
861 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
862 assert!(matches!(err, ToolError::Blocked { .. }));
863 }
864
865 #[tokio::test]
868 async fn executor_no_blocks_returns_none() {
869 let config = ScrapeConfig::default();
870 let executor = WebScrapeExecutor::new(&config);
871 let result = executor.execute("plain text").await;
872 assert!(result.unwrap().is_none());
873 }
874
875 #[tokio::test]
876 async fn executor_invalid_json_errors() {
877 let config = ScrapeConfig::default();
878 let executor = WebScrapeExecutor::new(&config);
879 let response = "```scrape\nnot json\n```";
880 let result = executor.execute(response).await;
881 assert!(matches!(result, Err(ToolError::Execution(_))));
882 }
883
884 #[tokio::test]
885 async fn executor_blocked_url_errors() {
886 let config = ScrapeConfig::default();
887 let executor = WebScrapeExecutor::new(&config);
888 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
889 let result = executor.execute(response).await;
890 assert!(matches!(result, Err(ToolError::Blocked { .. })));
891 }
892
893 #[tokio::test]
894 async fn executor_private_ip_blocked() {
895 let config = ScrapeConfig::default();
896 let executor = WebScrapeExecutor::new(&config);
897 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
898 let result = executor.execute(response).await;
899 assert!(matches!(result, Err(ToolError::Blocked { .. })));
900 }
901
902 #[tokio::test]
903 async fn executor_unreachable_host_returns_error() {
904 let config = ScrapeConfig {
905 timeout: 1,
906 max_body_bytes: 1_048_576,
907 };
908 let executor = WebScrapeExecutor::new(&config);
909 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
910 let result = executor.execute(response).await;
911 assert!(matches!(result, Err(ToolError::Execution(_))));
912 }
913
914 #[tokio::test]
915 async fn executor_localhost_url_blocked() {
916 let config = ScrapeConfig::default();
917 let executor = WebScrapeExecutor::new(&config);
918 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
919 let result = executor.execute(response).await;
920 assert!(matches!(result, Err(ToolError::Blocked { .. })));
921 }
922
923 #[tokio::test]
924 async fn executor_empty_text_returns_none() {
925 let config = ScrapeConfig::default();
926 let executor = WebScrapeExecutor::new(&config);
927 let result = executor.execute("").await;
928 assert!(result.unwrap().is_none());
929 }
930
931 #[tokio::test]
932 async fn executor_multiple_blocks_first_blocked() {
933 let config = ScrapeConfig::default();
934 let executor = WebScrapeExecutor::new(&config);
935 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
936 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
937 let result = executor.execute(response).await;
938 assert!(result.is_err());
939 }
940
941 #[test]
942 fn validate_url_empty_string() {
943 let err = validate_url("").unwrap_err();
944 assert!(matches!(err, ToolError::Blocked { .. }));
945 }
946
947 #[test]
948 fn validate_url_javascript_scheme_blocked() {
949 let err = validate_url("javascript:alert(1)").unwrap_err();
950 assert!(matches!(err, ToolError::Blocked { .. }));
951 }
952
953 #[test]
954 fn validate_url_data_scheme_blocked() {
955 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
956 assert!(matches!(err, ToolError::Blocked { .. }));
957 }
958
959 #[test]
960 fn is_private_host_public_domain_is_false() {
961 let host: url::Host<&str> = url::Host::Domain("example.com");
962 assert!(!is_private_host(&host));
963 }
964
965 #[test]
966 fn is_private_host_localhost_is_true() {
967 let host: url::Host<&str> = url::Host::Domain("localhost");
968 assert!(is_private_host(&host));
969 }
970
971 #[test]
972 fn is_private_host_ipv6_unspecified_is_true() {
973 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
974 assert!(is_private_host(&host));
975 }
976
977 #[test]
978 fn is_private_host_public_ipv6_is_false() {
979 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
980 assert!(!is_private_host(&host));
981 }
982
983 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
994 let server = wiremock::MockServer::start().await;
995 let executor = WebScrapeExecutor {
996 timeout: Duration::from_secs(5),
997 max_body_bytes: 1_048_576,
998 audit_logger: None,
999 };
1000 (executor, server)
1001 }
1002
1003 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1005 let uri = server.uri();
1006 let url = Url::parse(&uri).unwrap();
1007 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1008 let port = url.port().unwrap_or(80);
1009 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1010 (host, vec![addr])
1011 }
1012
1013 async fn follow_redirects_raw(
1017 executor: &WebScrapeExecutor,
1018 start_url: &str,
1019 host: &str,
1020 addrs: &[std::net::SocketAddr],
1021 ) -> Result<String, ToolError> {
1022 const MAX_REDIRECTS: usize = 3;
1023 let mut current_url = start_url.to_owned();
1024 let mut current_host = host.to_owned();
1025 let mut current_addrs = addrs.to_vec();
1026
1027 for hop in 0..=MAX_REDIRECTS {
1028 let client = executor.build_client(¤t_host, ¤t_addrs);
1029 let resp = client
1030 .get(¤t_url)
1031 .send()
1032 .await
1033 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1034
1035 let status = resp.status();
1036
1037 if status.is_redirection() {
1038 if hop == MAX_REDIRECTS {
1039 return Err(ToolError::Execution(std::io::Error::other(
1040 "too many redirects",
1041 )));
1042 }
1043
1044 let location = resp
1045 .headers()
1046 .get(reqwest::header::LOCATION)
1047 .and_then(|v| v.to_str().ok())
1048 .ok_or_else(|| {
1049 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1050 })?;
1051
1052 let base = Url::parse(¤t_url)
1053 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1054 let next_url = base
1055 .join(location)
1056 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1057
1058 current_url = next_url.to_string();
1060 let _ = &mut current_host;
1062 let _ = &mut current_addrs;
1063 continue;
1064 }
1065
1066 if !status.is_success() {
1067 return Err(ToolError::Execution(std::io::Error::other(format!(
1068 "HTTP {status}",
1069 ))));
1070 }
1071
1072 let bytes = resp
1073 .bytes()
1074 .await
1075 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1076
1077 if bytes.len() > executor.max_body_bytes {
1078 return Err(ToolError::Execution(std::io::Error::other(format!(
1079 "response too large: {} bytes (max: {})",
1080 bytes.len(),
1081 executor.max_body_bytes,
1082 ))));
1083 }
1084
1085 return String::from_utf8(bytes.to_vec())
1086 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1087 }
1088
1089 Err(ToolError::Execution(std::io::Error::other(
1090 "too many redirects",
1091 )))
1092 }
1093
1094 #[tokio::test]
1095 async fn fetch_html_success_returns_body() {
1096 use wiremock::matchers::{method, path};
1097 use wiremock::{Mock, ResponseTemplate};
1098
1099 let (executor, server) = mock_server_executor().await;
1100 Mock::given(method("GET"))
1101 .and(path("/page"))
1102 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1103 .mount(&server)
1104 .await;
1105
1106 let (host, addrs) = server_host_and_addr(&server);
1107 let url = format!("{}/page", server.uri());
1108 let result = executor.fetch_html(&url, &host, &addrs).await;
1109 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1110 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1111 }
1112
1113 #[tokio::test]
1114 async fn fetch_html_non_2xx_returns_error() {
1115 use wiremock::matchers::{method, path};
1116 use wiremock::{Mock, ResponseTemplate};
1117
1118 let (executor, server) = mock_server_executor().await;
1119 Mock::given(method("GET"))
1120 .and(path("/forbidden"))
1121 .respond_with(ResponseTemplate::new(403))
1122 .mount(&server)
1123 .await;
1124
1125 let (host, addrs) = server_host_and_addr(&server);
1126 let url = format!("{}/forbidden", server.uri());
1127 let result = executor.fetch_html(&url, &host, &addrs).await;
1128 assert!(result.is_err());
1129 let msg = result.unwrap_err().to_string();
1130 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1131 }
1132
1133 #[tokio::test]
1134 async fn fetch_html_404_returns_error() {
1135 use wiremock::matchers::{method, path};
1136 use wiremock::{Mock, ResponseTemplate};
1137
1138 let (executor, server) = mock_server_executor().await;
1139 Mock::given(method("GET"))
1140 .and(path("/missing"))
1141 .respond_with(ResponseTemplate::new(404))
1142 .mount(&server)
1143 .await;
1144
1145 let (host, addrs) = server_host_and_addr(&server);
1146 let url = format!("{}/missing", server.uri());
1147 let result = executor.fetch_html(&url, &host, &addrs).await;
1148 assert!(result.is_err());
1149 let msg = result.unwrap_err().to_string();
1150 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1151 }
1152
1153 #[tokio::test]
1154 async fn fetch_html_redirect_no_location_returns_error() {
1155 use wiremock::matchers::{method, path};
1156 use wiremock::{Mock, ResponseTemplate};
1157
1158 let (executor, server) = mock_server_executor().await;
1159 Mock::given(method("GET"))
1161 .and(path("/redirect-no-loc"))
1162 .respond_with(ResponseTemplate::new(302))
1163 .mount(&server)
1164 .await;
1165
1166 let (host, addrs) = server_host_and_addr(&server);
1167 let url = format!("{}/redirect-no-loc", server.uri());
1168 let result = executor.fetch_html(&url, &host, &addrs).await;
1169 assert!(result.is_err());
1170 let msg = result.unwrap_err().to_string();
1171 assert!(
1172 msg.contains("Location") || msg.contains("location"),
1173 "expected Location-related error: {msg}"
1174 );
1175 }
1176
1177 #[tokio::test]
1178 async fn fetch_html_single_redirect_followed() {
1179 use wiremock::matchers::{method, path};
1180 use wiremock::{Mock, ResponseTemplate};
1181
1182 let (executor, server) = mock_server_executor().await;
1183 let final_url = format!("{}/final", server.uri());
1184
1185 Mock::given(method("GET"))
1186 .and(path("/start"))
1187 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1188 .mount(&server)
1189 .await;
1190
1191 Mock::given(method("GET"))
1192 .and(path("/final"))
1193 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1194 .mount(&server)
1195 .await;
1196
1197 let (host, addrs) = server_host_and_addr(&server);
1198 let url = format!("{}/start", server.uri());
1199 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1200 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1201 assert_eq!(result.unwrap(), "<p>final</p>");
1202 }
1203
1204 #[tokio::test]
1205 async fn fetch_html_three_redirects_allowed() {
1206 use wiremock::matchers::{method, path};
1207 use wiremock::{Mock, ResponseTemplate};
1208
1209 let (executor, server) = mock_server_executor().await;
1210 let hop2 = format!("{}/hop2", server.uri());
1211 let hop3 = format!("{}/hop3", server.uri());
1212 let final_dest = format!("{}/done", server.uri());
1213
1214 Mock::given(method("GET"))
1215 .and(path("/hop1"))
1216 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1217 .mount(&server)
1218 .await;
1219 Mock::given(method("GET"))
1220 .and(path("/hop2"))
1221 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1222 .mount(&server)
1223 .await;
1224 Mock::given(method("GET"))
1225 .and(path("/hop3"))
1226 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1227 .mount(&server)
1228 .await;
1229 Mock::given(method("GET"))
1230 .and(path("/done"))
1231 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1232 .mount(&server)
1233 .await;
1234
1235 let (host, addrs) = server_host_and_addr(&server);
1236 let url = format!("{}/hop1", server.uri());
1237 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1238 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1239 assert_eq!(result.unwrap(), "<p>done</p>");
1240 }
1241
1242 #[tokio::test]
1243 async fn fetch_html_four_redirects_rejected() {
1244 use wiremock::matchers::{method, path};
1245 use wiremock::{Mock, ResponseTemplate};
1246
1247 let (executor, server) = mock_server_executor().await;
1248 let hop2 = format!("{}/r2", server.uri());
1249 let hop3 = format!("{}/r3", server.uri());
1250 let hop4 = format!("{}/r4", server.uri());
1251 let hop5 = format!("{}/r5", server.uri());
1252
1253 for (from, to) in [
1254 ("/r1", &hop2),
1255 ("/r2", &hop3),
1256 ("/r3", &hop4),
1257 ("/r4", &hop5),
1258 ] {
1259 Mock::given(method("GET"))
1260 .and(path(from))
1261 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1262 .mount(&server)
1263 .await;
1264 }
1265
1266 let (host, addrs) = server_host_and_addr(&server);
1267 let url = format!("{}/r1", server.uri());
1268 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1269 assert!(result.is_err(), "4 redirects should be rejected");
1270 let msg = result.unwrap_err().to_string();
1271 assert!(
1272 msg.contains("redirect"),
1273 "expected redirect-related error: {msg}"
1274 );
1275 }
1276
1277 #[tokio::test]
1278 async fn fetch_html_body_too_large_returns_error() {
1279 use wiremock::matchers::{method, path};
1280 use wiremock::{Mock, ResponseTemplate};
1281
1282 let small_limit_executor = WebScrapeExecutor {
1283 timeout: Duration::from_secs(5),
1284 max_body_bytes: 10,
1285 audit_logger: None,
1286 };
1287 let server = wiremock::MockServer::start().await;
1288 Mock::given(method("GET"))
1289 .and(path("/big"))
1290 .respond_with(
1291 ResponseTemplate::new(200)
1292 .set_body_string("this body is definitely longer than ten bytes"),
1293 )
1294 .mount(&server)
1295 .await;
1296
1297 let (host, addrs) = server_host_and_addr(&server);
1298 let url = format!("{}/big", server.uri());
1299 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1300 assert!(result.is_err());
1301 let msg = result.unwrap_err().to_string();
1302 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1303 }
1304
1305 #[test]
1306 fn extract_scrape_blocks_empty_block_content() {
1307 let text = "```scrape\n\n```";
1308 let blocks = extract_scrape_blocks(text);
1309 assert_eq!(blocks.len(), 1);
1310 assert!(blocks[0].is_empty());
1311 }
1312
1313 #[test]
1314 fn extract_scrape_blocks_whitespace_only() {
1315 let text = "```scrape\n \n```";
1316 let blocks = extract_scrape_blocks(text);
1317 assert_eq!(blocks.len(), 1);
1318 }
1319
1320 #[test]
1321 fn parse_and_extract_multiple_selectors() {
1322 let html = "<div><h1>Title</h1><p>Para</p></div>";
1323 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1324 assert!(result.contains("Title"));
1325 assert!(result.contains("Para"));
1326 }
1327
1328 #[test]
1329 fn webscrape_executor_new_with_custom_config() {
1330 let config = ScrapeConfig {
1331 timeout: 60,
1332 max_body_bytes: 512,
1333 };
1334 let executor = WebScrapeExecutor::new(&config);
1335 assert_eq!(executor.max_body_bytes, 512);
1336 }
1337
1338 #[test]
1339 fn webscrape_executor_debug() {
1340 let config = ScrapeConfig::default();
1341 let executor = WebScrapeExecutor::new(&config);
1342 let dbg = format!("{executor:?}");
1343 assert!(dbg.contains("WebScrapeExecutor"));
1344 }
1345
1346 #[test]
1347 fn extract_mode_attr_empty_name() {
1348 let mode = ExtractMode::parse("attr:");
1349 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1350 }
1351
1352 #[test]
1353 fn default_extract_returns_text() {
1354 assert_eq!(default_extract(), "text");
1355 }
1356
1357 #[test]
1358 fn scrape_instruction_debug() {
1359 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1360 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1361 let dbg = format!("{instr:?}");
1362 assert!(dbg.contains("ScrapeInstruction"));
1363 }
1364
1365 #[test]
1366 fn extract_mode_debug() {
1367 let mode = ExtractMode::Text;
1368 let dbg = format!("{mode:?}");
1369 assert!(dbg.contains("Text"));
1370 }
1371
1372 #[test]
1377 fn max_redirects_constant_is_three() {
1378 const MAX_REDIRECTS: usize = 3;
1382 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1383 }
1384
1385 #[test]
1388 fn redirect_no_location_error_message() {
1389 let err = std::io::Error::other("redirect with no Location");
1390 assert!(err.to_string().contains("redirect with no Location"));
1391 }
1392
1393 #[test]
1395 fn too_many_redirects_error_message() {
1396 let err = std::io::Error::other("too many redirects");
1397 assert!(err.to_string().contains("too many redirects"));
1398 }
1399
1400 #[test]
1402 fn non_2xx_status_error_format() {
1403 let status = reqwest::StatusCode::FORBIDDEN;
1404 let msg = format!("HTTP {status}");
1405 assert!(msg.contains("403"));
1406 }
1407
1408 #[test]
1410 fn not_found_status_error_format() {
1411 let status = reqwest::StatusCode::NOT_FOUND;
1412 let msg = format!("HTTP {status}");
1413 assert!(msg.contains("404"));
1414 }
1415
1416 #[test]
1418 fn relative_redirect_same_host_path() {
1419 let base = Url::parse("https://example.com/current").unwrap();
1420 let resolved = base.join("/other").unwrap();
1421 assert_eq!(resolved.as_str(), "https://example.com/other");
1422 }
1423
1424 #[test]
1426 fn relative_redirect_relative_path() {
1427 let base = Url::parse("https://example.com/a/b").unwrap();
1428 let resolved = base.join("c").unwrap();
1429 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1430 }
1431
1432 #[test]
1434 fn absolute_redirect_overrides_base() {
1435 let base = Url::parse("https://example.com/page").unwrap();
1436 let resolved = base.join("https://other.com/target").unwrap();
1437 assert_eq!(resolved.as_str(), "https://other.com/target");
1438 }
1439
1440 #[test]
1442 fn redirect_http_downgrade_rejected() {
1443 let location = "http://example.com/page";
1444 let base = Url::parse("https://example.com/start").unwrap();
1445 let next = base.join(location).unwrap();
1446 let err = validate_url(next.as_str()).unwrap_err();
1447 assert!(matches!(err, ToolError::Blocked { .. }));
1448 }
1449
1450 #[test]
1452 fn redirect_location_private_ip_blocked() {
1453 let location = "https://192.168.100.1/admin";
1454 let base = Url::parse("https://example.com/start").unwrap();
1455 let next = base.join(location).unwrap();
1456 let err = validate_url(next.as_str()).unwrap_err();
1457 assert!(matches!(err, ToolError::Blocked { .. }));
1458 let ToolError::Blocked { command: cmd } = err else {
1459 panic!("expected Blocked");
1460 };
1461 assert!(
1462 cmd.contains("private") || cmd.contains("scheme"),
1463 "error message should describe the block reason: {cmd}"
1464 );
1465 }
1466
1467 #[test]
1469 fn redirect_location_internal_domain_blocked() {
1470 let location = "https://metadata.internal/latest/meta-data/";
1471 let base = Url::parse("https://example.com/start").unwrap();
1472 let next = base.join(location).unwrap();
1473 let err = validate_url(next.as_str()).unwrap_err();
1474 assert!(matches!(err, ToolError::Blocked { .. }));
1475 }
1476
1477 #[test]
1479 fn redirect_chain_three_hops_all_public() {
1480 let hops = [
1481 "https://redirect1.example.com/hop1",
1482 "https://redirect2.example.com/hop2",
1483 "https://destination.example.com/final",
1484 ];
1485 for hop in hops {
1486 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1487 }
1488 }
1489
1490 #[test]
1495 fn redirect_to_private_ip_rejected_by_validate_url() {
1496 let private_targets = [
1498 "https://127.0.0.1/secret",
1499 "https://10.0.0.1/internal",
1500 "https://192.168.1.1/admin",
1501 "https://172.16.0.1/data",
1502 "https://[::1]/path",
1503 "https://[fe80::1]/path",
1504 "https://localhost/path",
1505 "https://service.internal/api",
1506 ];
1507 for target in private_targets {
1508 let result = validate_url(target);
1509 assert!(result.is_err(), "expected error for {target}");
1510 assert!(
1511 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1512 "expected Blocked for {target}"
1513 );
1514 }
1515 }
1516
1517 #[test]
1519 fn redirect_relative_url_resolves_correctly() {
1520 let base = Url::parse("https://example.com/page").unwrap();
1521 let relative = "/other";
1522 let resolved = base.join(relative).unwrap();
1523 assert_eq!(resolved.as_str(), "https://example.com/other");
1524 }
1525
1526 #[test]
1528 fn redirect_to_http_rejected() {
1529 let err = validate_url("http://example.com/page").unwrap_err();
1530 assert!(matches!(err, ToolError::Blocked { .. }));
1531 }
1532
1533 #[test]
1534 fn ipv4_mapped_ipv6_link_local_blocked() {
1535 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1536 assert!(matches!(err, ToolError::Blocked { .. }));
1537 }
1538
1539 #[test]
1540 fn ipv4_mapped_ipv6_public_allowed() {
1541 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1542 }
1543
1544 #[tokio::test]
1547 async fn fetch_http_scheme_blocked() {
1548 let config = ScrapeConfig::default();
1549 let executor = WebScrapeExecutor::new(&config);
1550 let call = crate::executor::ToolCall {
1551 tool_id: "fetch".to_owned(),
1552 params: {
1553 let mut m = serde_json::Map::new();
1554 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1555 m
1556 },
1557 };
1558 let result = executor.execute_tool_call(&call).await;
1559 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1560 }
1561
1562 #[tokio::test]
1563 async fn fetch_private_ip_blocked() {
1564 let config = ScrapeConfig::default();
1565 let executor = WebScrapeExecutor::new(&config);
1566 let call = crate::executor::ToolCall {
1567 tool_id: "fetch".to_owned(),
1568 params: {
1569 let mut m = serde_json::Map::new();
1570 m.insert(
1571 "url".to_owned(),
1572 serde_json::json!("https://192.168.1.1/secret"),
1573 );
1574 m
1575 },
1576 };
1577 let result = executor.execute_tool_call(&call).await;
1578 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1579 }
1580
1581 #[tokio::test]
1582 async fn fetch_localhost_blocked() {
1583 let config = ScrapeConfig::default();
1584 let executor = WebScrapeExecutor::new(&config);
1585 let call = crate::executor::ToolCall {
1586 tool_id: "fetch".to_owned(),
1587 params: {
1588 let mut m = serde_json::Map::new();
1589 m.insert(
1590 "url".to_owned(),
1591 serde_json::json!("https://localhost/page"),
1592 );
1593 m
1594 },
1595 };
1596 let result = executor.execute_tool_call(&call).await;
1597 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1598 }
1599
1600 #[tokio::test]
1601 async fn fetch_unknown_tool_returns_none() {
1602 let config = ScrapeConfig::default();
1603 let executor = WebScrapeExecutor::new(&config);
1604 let call = crate::executor::ToolCall {
1605 tool_id: "unknown_tool".to_owned(),
1606 params: serde_json::Map::new(),
1607 };
1608 let result = executor.execute_tool_call(&call).await;
1609 assert!(result.unwrap().is_none());
1610 }
1611
1612 #[tokio::test]
1613 async fn fetch_returns_body_via_mock() {
1614 use wiremock::matchers::{method, path};
1615 use wiremock::{Mock, ResponseTemplate};
1616
1617 let (executor, server) = mock_server_executor().await;
1618 Mock::given(method("GET"))
1619 .and(path("/content"))
1620 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1621 .mount(&server)
1622 .await;
1623
1624 let (host, addrs) = server_host_and_addr(&server);
1625 let url = format!("{}/content", server.uri());
1626 let result = executor.fetch_html(&url, &host, &addrs).await;
1627 assert!(result.is_ok());
1628 assert_eq!(result.unwrap(), "plain text content");
1629 }
1630
1631 #[test]
1632 fn tool_definitions_returns_web_scrape_and_fetch() {
1633 let config = ScrapeConfig::default();
1634 let executor = WebScrapeExecutor::new(&config);
1635 let defs = executor.tool_definitions();
1636 assert_eq!(defs.len(), 2);
1637 assert_eq!(defs[0].id, "web_scrape");
1638 assert_eq!(
1639 defs[0].invocation,
1640 crate::registry::InvocationHint::FencedBlock("scrape")
1641 );
1642 assert_eq!(defs[1].id, "fetch");
1643 assert_eq!(
1644 defs[1].invocation,
1645 crate::registry::InvocationHint::ToolCall
1646 );
1647 }
1648
1649 #[test]
1650 fn tool_definitions_schema_has_all_params() {
1651 let config = ScrapeConfig::default();
1652 let executor = WebScrapeExecutor::new(&config);
1653 let defs = executor.tool_definitions();
1654 let obj = defs[0].schema.as_object().unwrap();
1655 let props = obj["properties"].as_object().unwrap();
1656 assert!(props.contains_key("url"));
1657 assert!(props.contains_key("select"));
1658 assert!(props.contains_key("extract"));
1659 assert!(props.contains_key("limit"));
1660 let req = obj["required"].as_array().unwrap();
1661 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1662 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1663 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1664 }
1665
1666 #[test]
1669 fn subdomain_localhost_blocked() {
1670 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1671 assert!(is_private_host(&host));
1672 }
1673
1674 #[test]
1675 fn internal_tld_blocked() {
1676 let host: url::Host<&str> = url::Host::Domain("service.internal");
1677 assert!(is_private_host(&host));
1678 }
1679
1680 #[test]
1681 fn local_tld_blocked() {
1682 let host: url::Host<&str> = url::Host::Domain("printer.local");
1683 assert!(is_private_host(&host));
1684 }
1685
1686 #[test]
1687 fn public_domain_not_blocked() {
1688 let host: url::Host<&str> = url::Host::Domain("example.com");
1689 assert!(!is_private_host(&host));
1690 }
1691
1692 #[tokio::test]
1695 async fn resolve_loopback_rejected() {
1696 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1698 let result = resolve_and_validate(&url).await;
1700 assert!(
1701 result.is_err(),
1702 "loopback IP must be rejected by resolve_and_validate"
1703 );
1704 let err = result.unwrap_err();
1705 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1706 }
1707
1708 #[tokio::test]
1709 async fn resolve_private_10_rejected() {
1710 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1711 let result = resolve_and_validate(&url).await;
1712 assert!(result.is_err());
1713 assert!(matches!(
1714 result.unwrap_err(),
1715 crate::executor::ToolError::Blocked { .. }
1716 ));
1717 }
1718
1719 #[tokio::test]
1720 async fn resolve_private_192_rejected() {
1721 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1722 let result = resolve_and_validate(&url).await;
1723 assert!(result.is_err());
1724 assert!(matches!(
1725 result.unwrap_err(),
1726 crate::executor::ToolError::Blocked { .. }
1727 ));
1728 }
1729
1730 #[tokio::test]
1731 async fn resolve_ipv6_loopback_rejected() {
1732 let url = url::Url::parse("https://[::1]/path").unwrap();
1733 let result = resolve_and_validate(&url).await;
1734 assert!(result.is_err());
1735 assert!(matches!(
1736 result.unwrap_err(),
1737 crate::executor::ToolError::Blocked { .. }
1738 ));
1739 }
1740
1741 #[tokio::test]
1742 async fn resolve_no_host_returns_ok() {
1743 let url = url::Url::parse("https://example.com/path").unwrap();
1745 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1747 let result = resolve_and_validate(&url_no_host).await;
1749 assert!(result.is_ok());
1750 let (host, addrs) = result.unwrap();
1751 assert!(host.is_empty());
1752 assert!(addrs.is_empty());
1753 drop(url);
1754 drop(url_no_host);
1755 }
1756
1757 async fn make_file_audit_logger(
1761 dir: &tempfile::TempDir,
1762 ) -> (
1763 std::sync::Arc<crate::audit::AuditLogger>,
1764 std::path::PathBuf,
1765 ) {
1766 use crate::audit::AuditLogger;
1767 use crate::config::AuditConfig;
1768 let path = dir.path().join("audit.log");
1769 let config = AuditConfig {
1770 enabled: true,
1771 destination: path.display().to_string(),
1772 };
1773 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1774 (logger, path)
1775 }
1776
1777 #[tokio::test]
1778 async fn with_audit_sets_logger() {
1779 let config = ScrapeConfig::default();
1780 let executor = WebScrapeExecutor::new(&config);
1781 assert!(executor.audit_logger.is_none());
1782
1783 let dir = tempfile::tempdir().unwrap();
1784 let (logger, _path) = make_file_audit_logger(&dir).await;
1785 let executor = executor.with_audit(logger);
1786 assert!(executor.audit_logger.is_some());
1787 }
1788
1789 #[test]
1790 fn tool_error_to_audit_result_blocked_maps_correctly() {
1791 let err = ToolError::Blocked {
1792 command: "scheme not allowed: http".into(),
1793 };
1794 let result = tool_error_to_audit_result(&err);
1795 assert!(
1796 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1797 );
1798 }
1799
1800 #[test]
1801 fn tool_error_to_audit_result_timeout_maps_correctly() {
1802 let err = ToolError::Timeout { timeout_secs: 15 };
1803 let result = tool_error_to_audit_result(&err);
1804 assert!(matches!(result, AuditResult::Timeout));
1805 }
1806
1807 #[test]
1808 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1809 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1810 let result = tool_error_to_audit_result(&err);
1811 assert!(
1812 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1813 );
1814 }
1815
1816 #[tokio::test]
1817 async fn fetch_audit_blocked_url_logged() {
1818 let dir = tempfile::tempdir().unwrap();
1819 let (logger, log_path) = make_file_audit_logger(&dir).await;
1820
1821 let config = ScrapeConfig::default();
1822 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1823
1824 let call = crate::executor::ToolCall {
1825 tool_id: "fetch".to_owned(),
1826 params: {
1827 let mut m = serde_json::Map::new();
1828 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1829 m
1830 },
1831 };
1832 let result = executor.execute_tool_call(&call).await;
1833 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1834
1835 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1836 assert!(
1837 content.contains("\"tool\":\"fetch\""),
1838 "expected tool=fetch in audit: {content}"
1839 );
1840 assert!(
1841 content.contains("\"type\":\"blocked\""),
1842 "expected type=blocked in audit: {content}"
1843 );
1844 assert!(
1845 content.contains("http://example.com"),
1846 "expected URL in audit command field: {content}"
1847 );
1848 }
1849
1850 #[tokio::test]
1851 async fn log_audit_success_writes_to_file() {
1852 let dir = tempfile::tempdir().unwrap();
1853 let (logger, log_path) = make_file_audit_logger(&dir).await;
1854
1855 let config = ScrapeConfig::default();
1856 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1857
1858 executor
1859 .log_audit(
1860 "fetch",
1861 "https://example.com/page",
1862 AuditResult::Success,
1863 42,
1864 None,
1865 )
1866 .await;
1867
1868 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1869 assert!(
1870 content.contains("\"tool\":\"fetch\""),
1871 "expected tool=fetch in audit: {content}"
1872 );
1873 assert!(
1874 content.contains("\"type\":\"success\""),
1875 "expected type=success in audit: {content}"
1876 );
1877 assert!(
1878 content.contains("\"command\":\"https://example.com/page\""),
1879 "expected command URL in audit: {content}"
1880 );
1881 assert!(
1882 content.contains("\"duration_ms\":42"),
1883 "expected duration_ms in audit: {content}"
1884 );
1885 }
1886
1887 #[tokio::test]
1888 async fn log_audit_blocked_writes_to_file() {
1889 let dir = tempfile::tempdir().unwrap();
1890 let (logger, log_path) = make_file_audit_logger(&dir).await;
1891
1892 let config = ScrapeConfig::default();
1893 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1894
1895 executor
1896 .log_audit(
1897 "web_scrape",
1898 "http://evil.com/page",
1899 AuditResult::Blocked {
1900 reason: "scheme not allowed: http".into(),
1901 },
1902 0,
1903 None,
1904 )
1905 .await;
1906
1907 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1908 assert!(
1909 content.contains("\"tool\":\"web_scrape\""),
1910 "expected tool=web_scrape in audit: {content}"
1911 );
1912 assert!(
1913 content.contains("\"type\":\"blocked\""),
1914 "expected type=blocked in audit: {content}"
1915 );
1916 assert!(
1917 content.contains("scheme not allowed"),
1918 "expected block reason in audit: {content}"
1919 );
1920 }
1921
1922 #[tokio::test]
1923 async fn web_scrape_audit_blocked_url_logged() {
1924 let dir = tempfile::tempdir().unwrap();
1925 let (logger, log_path) = make_file_audit_logger(&dir).await;
1926
1927 let config = ScrapeConfig::default();
1928 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1929
1930 let call = crate::executor::ToolCall {
1931 tool_id: "web_scrape".to_owned(),
1932 params: {
1933 let mut m = serde_json::Map::new();
1934 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1935 m.insert("select".to_owned(), serde_json::json!("h1"));
1936 m
1937 },
1938 };
1939 let result = executor.execute_tool_call(&call).await;
1940 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1941
1942 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1943 assert!(
1944 content.contains("\"tool\":\"web_scrape\""),
1945 "expected tool=web_scrape in audit: {content}"
1946 );
1947 assert!(
1948 content.contains("\"type\":\"blocked\""),
1949 "expected type=blocked in audit: {content}"
1950 );
1951 }
1952
1953 #[tokio::test]
1954 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1955 let config = ScrapeConfig::default();
1956 let executor = WebScrapeExecutor::new(&config);
1957 assert!(executor.audit_logger.is_none());
1958
1959 let call = crate::executor::ToolCall {
1960 tool_id: "fetch".to_owned(),
1961 params: {
1962 let mut m = serde_json::Map::new();
1963 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1964 m
1965 },
1966 };
1967 let result = executor.execute_tool_call(&call).await;
1969 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1970 }
1971
1972 #[tokio::test]
1974 async fn fetch_execute_tool_call_end_to_end() {
1975 use wiremock::matchers::{method, path};
1976 use wiremock::{Mock, ResponseTemplate};
1977
1978 let (executor, server) = mock_server_executor().await;
1979 Mock::given(method("GET"))
1980 .and(path("/e2e"))
1981 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1982 .mount(&server)
1983 .await;
1984
1985 let (host, addrs) = server_host_and_addr(&server);
1986 let result = executor
1988 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1989 .await;
1990 assert!(result.is_ok());
1991 assert!(result.unwrap().contains("end-to-end"));
1992 }
1993}