1use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use url::Url;
11
12use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
13use crate::config::ScrapeConfig;
14use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
15use crate::net::is_private_ip;
16
17#[derive(Debug, Deserialize, JsonSchema)]
18struct FetchParams {
19 url: String,
21}
22
23#[derive(Debug, Deserialize, JsonSchema)]
24struct ScrapeInstruction {
25 url: String,
27 select: String,
29 #[serde(default = "default_extract")]
31 extract: String,
32 limit: Option<usize>,
34}
35
36fn default_extract() -> String {
37 "text".into()
38}
39
40#[derive(Debug)]
41enum ExtractMode {
42 Text,
43 Html,
44 Attr(String),
45}
46
47impl ExtractMode {
48 fn parse(s: &str) -> Self {
49 match s {
50 "text" => Self::Text,
51 "html" => Self::Html,
52 attr if attr.starts_with("attr:") => {
53 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
54 }
55 _ => Self::Text,
56 }
57 }
58}
59
60#[derive(Debug)]
65pub struct WebScrapeExecutor {
66 timeout: Duration,
67 max_body_bytes: usize,
68 audit_logger: Option<Arc<AuditLogger>>,
69}
70
71impl WebScrapeExecutor {
72 #[must_use]
73 pub fn new(config: &ScrapeConfig) -> Self {
74 Self {
75 timeout: Duration::from_secs(config.timeout),
76 max_body_bytes: config.max_body_bytes,
77 audit_logger: None,
78 }
79 }
80
81 #[must_use]
82 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
83 self.audit_logger = Some(logger);
84 self
85 }
86
87 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
88 let mut builder = reqwest::Client::builder()
89 .timeout(self.timeout)
90 .redirect(reqwest::redirect::Policy::none());
91 builder = builder.resolve_to_addrs(host, addrs);
92 builder.build().unwrap_or_default()
93 }
94}
95
96impl ToolExecutor for WebScrapeExecutor {
97 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
98 use crate::registry::{InvocationHint, ToolDef};
99 vec![
100 ToolDef {
101 id: "web_scrape".into(),
102 description: "Extract structured data from a web page using CSS selectors.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures\nExample: {\"url\": \"https://example.com\", \"select\": \"h1\", \"extract\": \"text\"}".into(),
103 schema: schemars::schema_for!(ScrapeInstruction),
104 invocation: InvocationHint::FencedBlock("scrape"),
105 },
106 ToolDef {
107 id: "fetch".into(),
108 description: "Fetch a URL and return the response body as plain text.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures\nExample: {\"url\": \"https://api.example.com/data.json\"}".into(),
109 schema: schemars::schema_for!(FetchParams),
110 invocation: InvocationHint::ToolCall,
111 },
112 ]
113 }
114
115 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
116 let blocks = extract_scrape_blocks(response);
117 if blocks.is_empty() {
118 return Ok(None);
119 }
120
121 let mut outputs = Vec::with_capacity(blocks.len());
122 #[allow(clippy::cast_possible_truncation)]
123 let blocks_executed = blocks.len() as u32;
124
125 for block in &blocks {
126 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
127 ToolError::Execution(std::io::Error::new(
128 std::io::ErrorKind::InvalidData,
129 e.to_string(),
130 ))
131 })?;
132 let start = Instant::now();
133 let scrape_result = self.scrape_instruction(&instruction).await;
134 #[allow(clippy::cast_possible_truncation)]
135 let duration_ms = start.elapsed().as_millis() as u64;
136 match scrape_result {
137 Ok(output) => {
138 self.log_audit(
139 "web_scrape",
140 &instruction.url,
141 AuditResult::Success,
142 duration_ms,
143 )
144 .await;
145 outputs.push(output);
146 }
147 Err(e) => {
148 let audit_result = tool_error_to_audit_result(&e);
149 self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
150 .await;
151 return Err(e);
152 }
153 }
154 }
155
156 Ok(Some(ToolOutput {
157 tool_name: "web-scrape".to_owned(),
158 summary: outputs.join("\n\n"),
159 blocks_executed,
160 filter_stats: None,
161 diff: None,
162 streamed: false,
163 terminal_id: None,
164 locations: None,
165 raw_response: None,
166 }))
167 }
168
169 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
170 match call.tool_id.as_str() {
171 "web_scrape" => {
172 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
173 let start = Instant::now();
174 let result = self.scrape_instruction(&instruction).await;
175 #[allow(clippy::cast_possible_truncation)]
176 let duration_ms = start.elapsed().as_millis() as u64;
177 match result {
178 Ok(output) => {
179 self.log_audit(
180 "web_scrape",
181 &instruction.url,
182 AuditResult::Success,
183 duration_ms,
184 )
185 .await;
186 Ok(Some(ToolOutput {
187 tool_name: "web-scrape".to_owned(),
188 summary: output,
189 blocks_executed: 1,
190 filter_stats: None,
191 diff: None,
192 streamed: false,
193 terminal_id: None,
194 locations: None,
195 raw_response: None,
196 }))
197 }
198 Err(e) => {
199 let audit_result = tool_error_to_audit_result(&e);
200 self.log_audit("web_scrape", &instruction.url, audit_result, duration_ms)
201 .await;
202 Err(e)
203 }
204 }
205 }
206 "fetch" => {
207 let p: FetchParams = deserialize_params(&call.params)?;
208 let start = Instant::now();
209 let result = self.handle_fetch(&p).await;
210 #[allow(clippy::cast_possible_truncation)]
211 let duration_ms = start.elapsed().as_millis() as u64;
212 match result {
213 Ok(output) => {
214 self.log_audit("fetch", &p.url, AuditResult::Success, duration_ms)
215 .await;
216 Ok(Some(ToolOutput {
217 tool_name: "fetch".to_owned(),
218 summary: output,
219 blocks_executed: 1,
220 filter_stats: None,
221 diff: None,
222 streamed: false,
223 terminal_id: None,
224 locations: None,
225 raw_response: None,
226 }))
227 }
228 Err(e) => {
229 let audit_result = tool_error_to_audit_result(&e);
230 self.log_audit("fetch", &p.url, audit_result, duration_ms)
231 .await;
232 Err(e)
233 }
234 }
235 }
236 _ => Ok(None),
237 }
238 }
239
240 fn is_tool_retryable(&self, tool_id: &str) -> bool {
241 matches!(tool_id, "web_scrape" | "fetch")
242 }
243}
244
245fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
246 match e {
247 ToolError::Blocked { command } => AuditResult::Blocked {
248 reason: command.clone(),
249 },
250 ToolError::Timeout { .. } => AuditResult::Timeout,
251 _ => AuditResult::Error {
252 message: e.to_string(),
253 },
254 }
255}
256
257impl WebScrapeExecutor {
258 async fn log_audit(&self, tool: &str, command: &str, result: AuditResult, duration_ms: u64) {
259 if let Some(ref logger) = self.audit_logger {
260 let entry = AuditEntry {
261 timestamp: chrono_now(),
262 tool: tool.into(),
263 command: command.into(),
264 result,
265 duration_ms,
266 };
267 logger.log(&entry).await;
268 }
269 }
270
271 async fn handle_fetch(&self, params: &FetchParams) -> Result<String, ToolError> {
272 let parsed = validate_url(¶ms.url)?;
273 let (host, addrs) = resolve_and_validate(&parsed).await?;
274 self.fetch_html(¶ms.url, &host, &addrs).await
275 }
276
277 async fn scrape_instruction(
278 &self,
279 instruction: &ScrapeInstruction,
280 ) -> Result<String, ToolError> {
281 let parsed = validate_url(&instruction.url)?;
282 let (host, addrs) = resolve_and_validate(&parsed).await?;
283 let html = self.fetch_html(&instruction.url, &host, &addrs).await?;
284 let selector = instruction.select.clone();
285 let extract = ExtractMode::parse(&instruction.extract);
286 let limit = instruction.limit.unwrap_or(10);
287 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
288 .await
289 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
290 }
291
292 async fn fetch_html(
302 &self,
303 url: &str,
304 host: &str,
305 addrs: &[SocketAddr],
306 ) -> Result<String, ToolError> {
307 const MAX_REDIRECTS: usize = 3;
308
309 let mut current_url = url.to_owned();
310 let mut current_host = host.to_owned();
311 let mut current_addrs = addrs.to_vec();
312
313 for hop in 0..=MAX_REDIRECTS {
314 let client = self.build_client(¤t_host, ¤t_addrs);
316 let resp = client
317 .get(¤t_url)
318 .send()
319 .await
320 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
321
322 let status = resp.status();
323
324 if status.is_redirection() {
325 if hop == MAX_REDIRECTS {
326 return Err(ToolError::Execution(std::io::Error::other(
327 "too many redirects",
328 )));
329 }
330
331 let location = resp
332 .headers()
333 .get(reqwest::header::LOCATION)
334 .and_then(|v| v.to_str().ok())
335 .ok_or_else(|| {
336 ToolError::Execution(std::io::Error::other("redirect with no Location"))
337 })?;
338
339 let base = Url::parse(¤t_url)
341 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
342 let next_url = base
343 .join(location)
344 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
345
346 let validated = validate_url(next_url.as_str())?;
347 let (next_host, next_addrs) = resolve_and_validate(&validated).await?;
348
349 current_url = next_url.to_string();
350 current_host = next_host;
351 current_addrs = next_addrs;
352 continue;
353 }
354
355 if !status.is_success() {
356 return Err(ToolError::Execution(std::io::Error::other(format!(
357 "HTTP {status}",
358 ))));
359 }
360
361 let bytes = resp
362 .bytes()
363 .await
364 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
365
366 if bytes.len() > self.max_body_bytes {
367 return Err(ToolError::Execution(std::io::Error::other(format!(
368 "response too large: {} bytes (max: {})",
369 bytes.len(),
370 self.max_body_bytes,
371 ))));
372 }
373
374 return String::from_utf8(bytes.to_vec())
375 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
376 }
377
378 Err(ToolError::Execution(std::io::Error::other(
379 "too many redirects",
380 )))
381 }
382}
383
384fn extract_scrape_blocks(text: &str) -> Vec<&str> {
385 crate::executor::extract_fenced_blocks(text, "scrape")
386}
387
388fn validate_url(raw: &str) -> Result<Url, ToolError> {
389 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
390 command: format!("invalid URL: {raw}"),
391 })?;
392
393 if parsed.scheme() != "https" {
394 return Err(ToolError::Blocked {
395 command: format!("scheme not allowed: {}", parsed.scheme()),
396 });
397 }
398
399 if let Some(host) = parsed.host()
400 && is_private_host(&host)
401 {
402 return Err(ToolError::Blocked {
403 command: format!(
404 "private/local host blocked: {}",
405 parsed.host_str().unwrap_or("")
406 ),
407 });
408 }
409
410 Ok(parsed)
411}
412
413fn is_private_host(host: &url::Host<&str>) -> bool {
414 match host {
415 url::Host::Domain(d) => {
416 #[allow(clippy::case_sensitive_file_extension_comparisons)]
419 {
420 *d == "localhost"
421 || d.ends_with(".localhost")
422 || d.ends_with(".internal")
423 || d.ends_with(".local")
424 }
425 }
426 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
427 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
428 }
429}
430
431async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
437 let Some(host) = url.host_str() else {
438 return Ok((String::new(), vec![]));
439 };
440 let port = url.port_or_known_default().unwrap_or(443);
441 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
442 .await
443 .map_err(|e| ToolError::Blocked {
444 command: format!("DNS resolution failed: {e}"),
445 })?
446 .collect();
447 for addr in &addrs {
448 if is_private_ip(addr.ip()) {
449 return Err(ToolError::Blocked {
450 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
451 });
452 }
453 }
454 Ok((host.to_owned(), addrs))
455}
456
457fn parse_and_extract(
458 html: &str,
459 selector: &str,
460 extract: &ExtractMode,
461 limit: usize,
462) -> Result<String, ToolError> {
463 let soup = scrape_core::Soup::parse(html);
464
465 let tags = soup.find_all(selector).map_err(|e| {
466 ToolError::Execution(std::io::Error::new(
467 std::io::ErrorKind::InvalidData,
468 format!("invalid selector: {e}"),
469 ))
470 })?;
471
472 let mut results = Vec::new();
473
474 for tag in tags.into_iter().take(limit) {
475 let value = match extract {
476 ExtractMode::Text => tag.text(),
477 ExtractMode::Html => tag.inner_html(),
478 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
479 };
480 if !value.trim().is_empty() {
481 results.push(value.trim().to_owned());
482 }
483 }
484
485 if results.is_empty() {
486 Ok(format!("No results for selector: {selector}"))
487 } else {
488 Ok(results.join("\n"))
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
499 fn extract_single_block() {
500 let text =
501 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
502 let blocks = extract_scrape_blocks(text);
503 assert_eq!(blocks.len(), 1);
504 assert!(blocks[0].contains("example.com"));
505 }
506
507 #[test]
508 fn extract_multiple_blocks() {
509 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
510 let blocks = extract_scrape_blocks(text);
511 assert_eq!(blocks.len(), 2);
512 }
513
514 #[test]
515 fn no_blocks_returns_empty() {
516 let blocks = extract_scrape_blocks("plain text, no code blocks");
517 assert!(blocks.is_empty());
518 }
519
520 #[test]
521 fn unclosed_block_ignored() {
522 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
523 assert!(blocks.is_empty());
524 }
525
526 #[test]
527 fn non_scrape_block_ignored() {
528 let text =
529 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
530 let blocks = extract_scrape_blocks(text);
531 assert_eq!(blocks.len(), 1);
532 assert!(blocks[0].contains("x.com"));
533 }
534
535 #[test]
536 fn multiline_json_block() {
537 let text =
538 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
539 let blocks = extract_scrape_blocks(text);
540 assert_eq!(blocks.len(), 1);
541 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
542 assert_eq!(instr.url, "https://example.com");
543 }
544
545 #[test]
548 fn parse_valid_instruction() {
549 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
550 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
551 assert_eq!(instr.url, "https://example.com");
552 assert_eq!(instr.select, "h1");
553 assert_eq!(instr.extract, "text");
554 assert_eq!(instr.limit, Some(5));
555 }
556
557 #[test]
558 fn parse_minimal_instruction() {
559 let json = r#"{"url":"https://example.com","select":"p"}"#;
560 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
561 assert_eq!(instr.extract, "text");
562 assert!(instr.limit.is_none());
563 }
564
565 #[test]
566 fn parse_attr_extract() {
567 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
568 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
569 assert_eq!(instr.extract, "attr:href");
570 }
571
572 #[test]
573 fn parse_invalid_json_errors() {
574 let result = serde_json::from_str::<ScrapeInstruction>("not json");
575 assert!(result.is_err());
576 }
577
578 #[test]
581 fn extract_mode_text() {
582 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
583 }
584
585 #[test]
586 fn extract_mode_html() {
587 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
588 }
589
590 #[test]
591 fn extract_mode_attr() {
592 let mode = ExtractMode::parse("attr:href");
593 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
594 }
595
596 #[test]
597 fn extract_mode_unknown_defaults_to_text() {
598 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
599 }
600
601 #[test]
604 fn valid_https_url() {
605 assert!(validate_url("https://example.com").is_ok());
606 }
607
608 #[test]
609 fn http_rejected() {
610 let err = validate_url("http://example.com").unwrap_err();
611 assert!(matches!(err, ToolError::Blocked { .. }));
612 }
613
614 #[test]
615 fn ftp_rejected() {
616 let err = validate_url("ftp://files.example.com").unwrap_err();
617 assert!(matches!(err, ToolError::Blocked { .. }));
618 }
619
620 #[test]
621 fn file_rejected() {
622 let err = validate_url("file:///etc/passwd").unwrap_err();
623 assert!(matches!(err, ToolError::Blocked { .. }));
624 }
625
626 #[test]
627 fn invalid_url_rejected() {
628 let err = validate_url("not a url").unwrap_err();
629 assert!(matches!(err, ToolError::Blocked { .. }));
630 }
631
632 #[test]
633 fn localhost_blocked() {
634 let err = validate_url("https://localhost/path").unwrap_err();
635 assert!(matches!(err, ToolError::Blocked { .. }));
636 }
637
638 #[test]
639 fn loopback_ip_blocked() {
640 let err = validate_url("https://127.0.0.1/path").unwrap_err();
641 assert!(matches!(err, ToolError::Blocked { .. }));
642 }
643
644 #[test]
645 fn private_10_blocked() {
646 let err = validate_url("https://10.0.0.1/api").unwrap_err();
647 assert!(matches!(err, ToolError::Blocked { .. }));
648 }
649
650 #[test]
651 fn private_172_blocked() {
652 let err = validate_url("https://172.16.0.1/api").unwrap_err();
653 assert!(matches!(err, ToolError::Blocked { .. }));
654 }
655
656 #[test]
657 fn private_192_blocked() {
658 let err = validate_url("https://192.168.1.1/api").unwrap_err();
659 assert!(matches!(err, ToolError::Blocked { .. }));
660 }
661
662 #[test]
663 fn ipv6_loopback_blocked() {
664 let err = validate_url("https://[::1]/path").unwrap_err();
665 assert!(matches!(err, ToolError::Blocked { .. }));
666 }
667
668 #[test]
669 fn public_ip_allowed() {
670 assert!(validate_url("https://93.184.216.34/page").is_ok());
671 }
672
673 #[test]
676 fn extract_text_from_html() {
677 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
678 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
679 assert_eq!(result, "Hello World");
680 }
681
682 #[test]
683 fn extract_multiple_elements() {
684 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
685 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
686 assert_eq!(result, "A\nB\nC");
687 }
688
689 #[test]
690 fn extract_with_limit() {
691 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
692 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
693 assert_eq!(result, "A\nB");
694 }
695
696 #[test]
697 fn extract_attr_href() {
698 let html = r#"<a href="https://example.com">Link</a>"#;
699 let result =
700 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
701 assert_eq!(result, "https://example.com");
702 }
703
704 #[test]
705 fn extract_inner_html() {
706 let html = "<div><span>inner</span></div>";
707 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
708 assert!(result.contains("<span>inner</span>"));
709 }
710
711 #[test]
712 fn no_matches_returns_message() {
713 let html = "<html><body><p>text</p></body></html>";
714 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
715 assert!(result.starts_with("No results for selector:"));
716 }
717
718 #[test]
719 fn empty_text_skipped() {
720 let html = "<ul><li> </li><li>A</li></ul>";
721 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
722 assert_eq!(result, "A");
723 }
724
725 #[test]
726 fn invalid_selector_errors() {
727 let html = "<html><body></body></html>";
728 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
729 assert!(result.is_err());
730 }
731
732 #[test]
733 fn empty_html_returns_no_results() {
734 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
735 assert!(result.starts_with("No results for selector:"));
736 }
737
738 #[test]
739 fn nested_selector() {
740 let html = "<div><span>inner</span></div><span>outer</span>";
741 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
742 assert_eq!(result, "inner");
743 }
744
745 #[test]
746 fn attr_missing_returns_empty() {
747 let html = r"<a>No href</a>";
748 let result =
749 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
750 assert!(result.starts_with("No results for selector:"));
751 }
752
753 #[test]
754 fn extract_html_mode() {
755 let html = "<div><b>bold</b> text</div>";
756 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
757 assert!(result.contains("<b>bold</b>"));
758 }
759
760 #[test]
761 fn limit_zero_returns_no_results() {
762 let html = "<ul><li>A</li><li>B</li></ul>";
763 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
764 assert!(result.starts_with("No results for selector:"));
765 }
766
767 #[test]
770 fn url_with_port_allowed() {
771 assert!(validate_url("https://example.com:8443/path").is_ok());
772 }
773
774 #[test]
775 fn link_local_ip_blocked() {
776 let err = validate_url("https://169.254.1.1/path").unwrap_err();
777 assert!(matches!(err, ToolError::Blocked { .. }));
778 }
779
780 #[test]
781 fn url_no_scheme_rejected() {
782 let err = validate_url("example.com/path").unwrap_err();
783 assert!(matches!(err, ToolError::Blocked { .. }));
784 }
785
786 #[test]
787 fn unspecified_ipv4_blocked() {
788 let err = validate_url("https://0.0.0.0/path").unwrap_err();
789 assert!(matches!(err, ToolError::Blocked { .. }));
790 }
791
792 #[test]
793 fn broadcast_ipv4_blocked() {
794 let err = validate_url("https://255.255.255.255/path").unwrap_err();
795 assert!(matches!(err, ToolError::Blocked { .. }));
796 }
797
798 #[test]
799 fn ipv6_link_local_blocked() {
800 let err = validate_url("https://[fe80::1]/path").unwrap_err();
801 assert!(matches!(err, ToolError::Blocked { .. }));
802 }
803
804 #[test]
805 fn ipv6_unique_local_blocked() {
806 let err = validate_url("https://[fd12::1]/path").unwrap_err();
807 assert!(matches!(err, ToolError::Blocked { .. }));
808 }
809
810 #[test]
811 fn ipv4_mapped_ipv6_loopback_blocked() {
812 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
813 assert!(matches!(err, ToolError::Blocked { .. }));
814 }
815
816 #[test]
817 fn ipv4_mapped_ipv6_private_blocked() {
818 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
819 assert!(matches!(err, ToolError::Blocked { .. }));
820 }
821
822 #[tokio::test]
825 async fn executor_no_blocks_returns_none() {
826 let config = ScrapeConfig::default();
827 let executor = WebScrapeExecutor::new(&config);
828 let result = executor.execute("plain text").await;
829 assert!(result.unwrap().is_none());
830 }
831
832 #[tokio::test]
833 async fn executor_invalid_json_errors() {
834 let config = ScrapeConfig::default();
835 let executor = WebScrapeExecutor::new(&config);
836 let response = "```scrape\nnot json\n```";
837 let result = executor.execute(response).await;
838 assert!(matches!(result, Err(ToolError::Execution(_))));
839 }
840
841 #[tokio::test]
842 async fn executor_blocked_url_errors() {
843 let config = ScrapeConfig::default();
844 let executor = WebScrapeExecutor::new(&config);
845 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
846 let result = executor.execute(response).await;
847 assert!(matches!(result, Err(ToolError::Blocked { .. })));
848 }
849
850 #[tokio::test]
851 async fn executor_private_ip_blocked() {
852 let config = ScrapeConfig::default();
853 let executor = WebScrapeExecutor::new(&config);
854 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
855 let result = executor.execute(response).await;
856 assert!(matches!(result, Err(ToolError::Blocked { .. })));
857 }
858
859 #[tokio::test]
860 async fn executor_unreachable_host_returns_error() {
861 let config = ScrapeConfig {
862 timeout: 1,
863 max_body_bytes: 1_048_576,
864 };
865 let executor = WebScrapeExecutor::new(&config);
866 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
867 let result = executor.execute(response).await;
868 assert!(matches!(result, Err(ToolError::Execution(_))));
869 }
870
871 #[tokio::test]
872 async fn executor_localhost_url_blocked() {
873 let config = ScrapeConfig::default();
874 let executor = WebScrapeExecutor::new(&config);
875 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
876 let result = executor.execute(response).await;
877 assert!(matches!(result, Err(ToolError::Blocked { .. })));
878 }
879
880 #[tokio::test]
881 async fn executor_empty_text_returns_none() {
882 let config = ScrapeConfig::default();
883 let executor = WebScrapeExecutor::new(&config);
884 let result = executor.execute("").await;
885 assert!(result.unwrap().is_none());
886 }
887
888 #[tokio::test]
889 async fn executor_multiple_blocks_first_blocked() {
890 let config = ScrapeConfig::default();
891 let executor = WebScrapeExecutor::new(&config);
892 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
893 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
894 let result = executor.execute(response).await;
895 assert!(result.is_err());
896 }
897
898 #[test]
899 fn validate_url_empty_string() {
900 let err = validate_url("").unwrap_err();
901 assert!(matches!(err, ToolError::Blocked { .. }));
902 }
903
904 #[test]
905 fn validate_url_javascript_scheme_blocked() {
906 let err = validate_url("javascript:alert(1)").unwrap_err();
907 assert!(matches!(err, ToolError::Blocked { .. }));
908 }
909
910 #[test]
911 fn validate_url_data_scheme_blocked() {
912 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
913 assert!(matches!(err, ToolError::Blocked { .. }));
914 }
915
916 #[test]
917 fn is_private_host_public_domain_is_false() {
918 let host: url::Host<&str> = url::Host::Domain("example.com");
919 assert!(!is_private_host(&host));
920 }
921
922 #[test]
923 fn is_private_host_localhost_is_true() {
924 let host: url::Host<&str> = url::Host::Domain("localhost");
925 assert!(is_private_host(&host));
926 }
927
928 #[test]
929 fn is_private_host_ipv6_unspecified_is_true() {
930 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
931 assert!(is_private_host(&host));
932 }
933
934 #[test]
935 fn is_private_host_public_ipv6_is_false() {
936 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
937 assert!(!is_private_host(&host));
938 }
939
940 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
951 let server = wiremock::MockServer::start().await;
952 let executor = WebScrapeExecutor {
953 timeout: Duration::from_secs(5),
954 max_body_bytes: 1_048_576,
955 audit_logger: None,
956 };
957 (executor, server)
958 }
959
960 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
962 let uri = server.uri();
963 let url = Url::parse(&uri).unwrap();
964 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
965 let port = url.port().unwrap_or(80);
966 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
967 (host, vec![addr])
968 }
969
970 async fn follow_redirects_raw(
974 executor: &WebScrapeExecutor,
975 start_url: &str,
976 host: &str,
977 addrs: &[std::net::SocketAddr],
978 ) -> Result<String, ToolError> {
979 const MAX_REDIRECTS: usize = 3;
980 let mut current_url = start_url.to_owned();
981 let mut current_host = host.to_owned();
982 let mut current_addrs = addrs.to_vec();
983
984 for hop in 0..=MAX_REDIRECTS {
985 let client = executor.build_client(¤t_host, ¤t_addrs);
986 let resp = client
987 .get(¤t_url)
988 .send()
989 .await
990 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
991
992 let status = resp.status();
993
994 if status.is_redirection() {
995 if hop == MAX_REDIRECTS {
996 return Err(ToolError::Execution(std::io::Error::other(
997 "too many redirects",
998 )));
999 }
1000
1001 let location = resp
1002 .headers()
1003 .get(reqwest::header::LOCATION)
1004 .and_then(|v| v.to_str().ok())
1005 .ok_or_else(|| {
1006 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1007 })?;
1008
1009 let base = Url::parse(¤t_url)
1010 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1011 let next_url = base
1012 .join(location)
1013 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1014
1015 current_url = next_url.to_string();
1017 let _ = &mut current_host;
1019 let _ = &mut current_addrs;
1020 continue;
1021 }
1022
1023 if !status.is_success() {
1024 return Err(ToolError::Execution(std::io::Error::other(format!(
1025 "HTTP {status}",
1026 ))));
1027 }
1028
1029 let bytes = resp
1030 .bytes()
1031 .await
1032 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1033
1034 if bytes.len() > executor.max_body_bytes {
1035 return Err(ToolError::Execution(std::io::Error::other(format!(
1036 "response too large: {} bytes (max: {})",
1037 bytes.len(),
1038 executor.max_body_bytes,
1039 ))));
1040 }
1041
1042 return String::from_utf8(bytes.to_vec())
1043 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1044 }
1045
1046 Err(ToolError::Execution(std::io::Error::other(
1047 "too many redirects",
1048 )))
1049 }
1050
1051 #[tokio::test]
1052 async fn fetch_html_success_returns_body() {
1053 use wiremock::matchers::{method, path};
1054 use wiremock::{Mock, ResponseTemplate};
1055
1056 let (executor, server) = mock_server_executor().await;
1057 Mock::given(method("GET"))
1058 .and(path("/page"))
1059 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1060 .mount(&server)
1061 .await;
1062
1063 let (host, addrs) = server_host_and_addr(&server);
1064 let url = format!("{}/page", server.uri());
1065 let result = executor.fetch_html(&url, &host, &addrs).await;
1066 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1067 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1068 }
1069
1070 #[tokio::test]
1071 async fn fetch_html_non_2xx_returns_error() {
1072 use wiremock::matchers::{method, path};
1073 use wiremock::{Mock, ResponseTemplate};
1074
1075 let (executor, server) = mock_server_executor().await;
1076 Mock::given(method("GET"))
1077 .and(path("/forbidden"))
1078 .respond_with(ResponseTemplate::new(403))
1079 .mount(&server)
1080 .await;
1081
1082 let (host, addrs) = server_host_and_addr(&server);
1083 let url = format!("{}/forbidden", server.uri());
1084 let result = executor.fetch_html(&url, &host, &addrs).await;
1085 assert!(result.is_err());
1086 let msg = result.unwrap_err().to_string();
1087 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1088 }
1089
1090 #[tokio::test]
1091 async fn fetch_html_404_returns_error() {
1092 use wiremock::matchers::{method, path};
1093 use wiremock::{Mock, ResponseTemplate};
1094
1095 let (executor, server) = mock_server_executor().await;
1096 Mock::given(method("GET"))
1097 .and(path("/missing"))
1098 .respond_with(ResponseTemplate::new(404))
1099 .mount(&server)
1100 .await;
1101
1102 let (host, addrs) = server_host_and_addr(&server);
1103 let url = format!("{}/missing", server.uri());
1104 let result = executor.fetch_html(&url, &host, &addrs).await;
1105 assert!(result.is_err());
1106 let msg = result.unwrap_err().to_string();
1107 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1108 }
1109
1110 #[tokio::test]
1111 async fn fetch_html_redirect_no_location_returns_error() {
1112 use wiremock::matchers::{method, path};
1113 use wiremock::{Mock, ResponseTemplate};
1114
1115 let (executor, server) = mock_server_executor().await;
1116 Mock::given(method("GET"))
1118 .and(path("/redirect-no-loc"))
1119 .respond_with(ResponseTemplate::new(302))
1120 .mount(&server)
1121 .await;
1122
1123 let (host, addrs) = server_host_and_addr(&server);
1124 let url = format!("{}/redirect-no-loc", server.uri());
1125 let result = executor.fetch_html(&url, &host, &addrs).await;
1126 assert!(result.is_err());
1127 let msg = result.unwrap_err().to_string();
1128 assert!(
1129 msg.contains("Location") || msg.contains("location"),
1130 "expected Location-related error: {msg}"
1131 );
1132 }
1133
1134 #[tokio::test]
1135 async fn fetch_html_single_redirect_followed() {
1136 use wiremock::matchers::{method, path};
1137 use wiremock::{Mock, ResponseTemplate};
1138
1139 let (executor, server) = mock_server_executor().await;
1140 let final_url = format!("{}/final", server.uri());
1141
1142 Mock::given(method("GET"))
1143 .and(path("/start"))
1144 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1145 .mount(&server)
1146 .await;
1147
1148 Mock::given(method("GET"))
1149 .and(path("/final"))
1150 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1151 .mount(&server)
1152 .await;
1153
1154 let (host, addrs) = server_host_and_addr(&server);
1155 let url = format!("{}/start", server.uri());
1156 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1157 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1158 assert_eq!(result.unwrap(), "<p>final</p>");
1159 }
1160
1161 #[tokio::test]
1162 async fn fetch_html_three_redirects_allowed() {
1163 use wiremock::matchers::{method, path};
1164 use wiremock::{Mock, ResponseTemplate};
1165
1166 let (executor, server) = mock_server_executor().await;
1167 let hop2 = format!("{}/hop2", server.uri());
1168 let hop3 = format!("{}/hop3", server.uri());
1169 let final_dest = format!("{}/done", server.uri());
1170
1171 Mock::given(method("GET"))
1172 .and(path("/hop1"))
1173 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1174 .mount(&server)
1175 .await;
1176 Mock::given(method("GET"))
1177 .and(path("/hop2"))
1178 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1179 .mount(&server)
1180 .await;
1181 Mock::given(method("GET"))
1182 .and(path("/hop3"))
1183 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1184 .mount(&server)
1185 .await;
1186 Mock::given(method("GET"))
1187 .and(path("/done"))
1188 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1189 .mount(&server)
1190 .await;
1191
1192 let (host, addrs) = server_host_and_addr(&server);
1193 let url = format!("{}/hop1", server.uri());
1194 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1195 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1196 assert_eq!(result.unwrap(), "<p>done</p>");
1197 }
1198
1199 #[tokio::test]
1200 async fn fetch_html_four_redirects_rejected() {
1201 use wiremock::matchers::{method, path};
1202 use wiremock::{Mock, ResponseTemplate};
1203
1204 let (executor, server) = mock_server_executor().await;
1205 let hop2 = format!("{}/r2", server.uri());
1206 let hop3 = format!("{}/r3", server.uri());
1207 let hop4 = format!("{}/r4", server.uri());
1208 let hop5 = format!("{}/r5", server.uri());
1209
1210 for (from, to) in [
1211 ("/r1", &hop2),
1212 ("/r2", &hop3),
1213 ("/r3", &hop4),
1214 ("/r4", &hop5),
1215 ] {
1216 Mock::given(method("GET"))
1217 .and(path(from))
1218 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1219 .mount(&server)
1220 .await;
1221 }
1222
1223 let (host, addrs) = server_host_and_addr(&server);
1224 let url = format!("{}/r1", server.uri());
1225 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1226 assert!(result.is_err(), "4 redirects should be rejected");
1227 let msg = result.unwrap_err().to_string();
1228 assert!(
1229 msg.contains("redirect"),
1230 "expected redirect-related error: {msg}"
1231 );
1232 }
1233
1234 #[tokio::test]
1235 async fn fetch_html_body_too_large_returns_error() {
1236 use wiremock::matchers::{method, path};
1237 use wiremock::{Mock, ResponseTemplate};
1238
1239 let small_limit_executor = WebScrapeExecutor {
1240 timeout: Duration::from_secs(5),
1241 max_body_bytes: 10,
1242 audit_logger: None,
1243 };
1244 let server = wiremock::MockServer::start().await;
1245 Mock::given(method("GET"))
1246 .and(path("/big"))
1247 .respond_with(
1248 ResponseTemplate::new(200)
1249 .set_body_string("this body is definitely longer than ten bytes"),
1250 )
1251 .mount(&server)
1252 .await;
1253
1254 let (host, addrs) = server_host_and_addr(&server);
1255 let url = format!("{}/big", server.uri());
1256 let result = small_limit_executor.fetch_html(&url, &host, &addrs).await;
1257 assert!(result.is_err());
1258 let msg = result.unwrap_err().to_string();
1259 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1260 }
1261
1262 #[test]
1263 fn extract_scrape_blocks_empty_block_content() {
1264 let text = "```scrape\n\n```";
1265 let blocks = extract_scrape_blocks(text);
1266 assert_eq!(blocks.len(), 1);
1267 assert!(blocks[0].is_empty());
1268 }
1269
1270 #[test]
1271 fn extract_scrape_blocks_whitespace_only() {
1272 let text = "```scrape\n \n```";
1273 let blocks = extract_scrape_blocks(text);
1274 assert_eq!(blocks.len(), 1);
1275 }
1276
1277 #[test]
1278 fn parse_and_extract_multiple_selectors() {
1279 let html = "<div><h1>Title</h1><p>Para</p></div>";
1280 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1281 assert!(result.contains("Title"));
1282 assert!(result.contains("Para"));
1283 }
1284
1285 #[test]
1286 fn webscrape_executor_new_with_custom_config() {
1287 let config = ScrapeConfig {
1288 timeout: 60,
1289 max_body_bytes: 512,
1290 };
1291 let executor = WebScrapeExecutor::new(&config);
1292 assert_eq!(executor.max_body_bytes, 512);
1293 }
1294
1295 #[test]
1296 fn webscrape_executor_debug() {
1297 let config = ScrapeConfig::default();
1298 let executor = WebScrapeExecutor::new(&config);
1299 let dbg = format!("{executor:?}");
1300 assert!(dbg.contains("WebScrapeExecutor"));
1301 }
1302
1303 #[test]
1304 fn extract_mode_attr_empty_name() {
1305 let mode = ExtractMode::parse("attr:");
1306 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1307 }
1308
1309 #[test]
1310 fn default_extract_returns_text() {
1311 assert_eq!(default_extract(), "text");
1312 }
1313
1314 #[test]
1315 fn scrape_instruction_debug() {
1316 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1317 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1318 let dbg = format!("{instr:?}");
1319 assert!(dbg.contains("ScrapeInstruction"));
1320 }
1321
1322 #[test]
1323 fn extract_mode_debug() {
1324 let mode = ExtractMode::Text;
1325 let dbg = format!("{mode:?}");
1326 assert!(dbg.contains("Text"));
1327 }
1328
1329 #[test]
1334 fn max_redirects_constant_is_three() {
1335 const MAX_REDIRECTS: usize = 3;
1339 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1340 }
1341
1342 #[test]
1345 fn redirect_no_location_error_message() {
1346 let err = std::io::Error::other("redirect with no Location");
1347 assert!(err.to_string().contains("redirect with no Location"));
1348 }
1349
1350 #[test]
1352 fn too_many_redirects_error_message() {
1353 let err = std::io::Error::other("too many redirects");
1354 assert!(err.to_string().contains("too many redirects"));
1355 }
1356
1357 #[test]
1359 fn non_2xx_status_error_format() {
1360 let status = reqwest::StatusCode::FORBIDDEN;
1361 let msg = format!("HTTP {status}");
1362 assert!(msg.contains("403"));
1363 }
1364
1365 #[test]
1367 fn not_found_status_error_format() {
1368 let status = reqwest::StatusCode::NOT_FOUND;
1369 let msg = format!("HTTP {status}");
1370 assert!(msg.contains("404"));
1371 }
1372
1373 #[test]
1375 fn relative_redirect_same_host_path() {
1376 let base = Url::parse("https://example.com/current").unwrap();
1377 let resolved = base.join("/other").unwrap();
1378 assert_eq!(resolved.as_str(), "https://example.com/other");
1379 }
1380
1381 #[test]
1383 fn relative_redirect_relative_path() {
1384 let base = Url::parse("https://example.com/a/b").unwrap();
1385 let resolved = base.join("c").unwrap();
1386 assert_eq!(resolved.as_str(), "https://example.com/a/c");
1387 }
1388
1389 #[test]
1391 fn absolute_redirect_overrides_base() {
1392 let base = Url::parse("https://example.com/page").unwrap();
1393 let resolved = base.join("https://other.com/target").unwrap();
1394 assert_eq!(resolved.as_str(), "https://other.com/target");
1395 }
1396
1397 #[test]
1399 fn redirect_http_downgrade_rejected() {
1400 let location = "http://example.com/page";
1401 let base = Url::parse("https://example.com/start").unwrap();
1402 let next = base.join(location).unwrap();
1403 let err = validate_url(next.as_str()).unwrap_err();
1404 assert!(matches!(err, ToolError::Blocked { .. }));
1405 }
1406
1407 #[test]
1409 fn redirect_location_private_ip_blocked() {
1410 let location = "https://192.168.100.1/admin";
1411 let base = Url::parse("https://example.com/start").unwrap();
1412 let next = base.join(location).unwrap();
1413 let err = validate_url(next.as_str()).unwrap_err();
1414 assert!(matches!(err, ToolError::Blocked { .. }));
1415 let ToolError::Blocked { command: cmd } = err else {
1416 panic!("expected Blocked");
1417 };
1418 assert!(
1419 cmd.contains("private") || cmd.contains("scheme"),
1420 "error message should describe the block reason: {cmd}"
1421 );
1422 }
1423
1424 #[test]
1426 fn redirect_location_internal_domain_blocked() {
1427 let location = "https://metadata.internal/latest/meta-data/";
1428 let base = Url::parse("https://example.com/start").unwrap();
1429 let next = base.join(location).unwrap();
1430 let err = validate_url(next.as_str()).unwrap_err();
1431 assert!(matches!(err, ToolError::Blocked { .. }));
1432 }
1433
1434 #[test]
1436 fn redirect_chain_three_hops_all_public() {
1437 let hops = [
1438 "https://redirect1.example.com/hop1",
1439 "https://redirect2.example.com/hop2",
1440 "https://destination.example.com/final",
1441 ];
1442 for hop in hops {
1443 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
1444 }
1445 }
1446
1447 #[test]
1452 fn redirect_to_private_ip_rejected_by_validate_url() {
1453 let private_targets = [
1455 "https://127.0.0.1/secret",
1456 "https://10.0.0.1/internal",
1457 "https://192.168.1.1/admin",
1458 "https://172.16.0.1/data",
1459 "https://[::1]/path",
1460 "https://[fe80::1]/path",
1461 "https://localhost/path",
1462 "https://service.internal/api",
1463 ];
1464 for target in private_targets {
1465 let result = validate_url(target);
1466 assert!(result.is_err(), "expected error for {target}");
1467 assert!(
1468 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
1469 "expected Blocked for {target}"
1470 );
1471 }
1472 }
1473
1474 #[test]
1476 fn redirect_relative_url_resolves_correctly() {
1477 let base = Url::parse("https://example.com/page").unwrap();
1478 let relative = "/other";
1479 let resolved = base.join(relative).unwrap();
1480 assert_eq!(resolved.as_str(), "https://example.com/other");
1481 }
1482
1483 #[test]
1485 fn redirect_to_http_rejected() {
1486 let err = validate_url("http://example.com/page").unwrap_err();
1487 assert!(matches!(err, ToolError::Blocked { .. }));
1488 }
1489
1490 #[test]
1491 fn ipv4_mapped_ipv6_link_local_blocked() {
1492 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
1493 assert!(matches!(err, ToolError::Blocked { .. }));
1494 }
1495
1496 #[test]
1497 fn ipv4_mapped_ipv6_public_allowed() {
1498 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
1499 }
1500
1501 #[tokio::test]
1504 async fn fetch_http_scheme_blocked() {
1505 let config = ScrapeConfig::default();
1506 let executor = WebScrapeExecutor::new(&config);
1507 let call = crate::executor::ToolCall {
1508 tool_id: "fetch".to_owned(),
1509 params: {
1510 let mut m = serde_json::Map::new();
1511 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1512 m
1513 },
1514 };
1515 let result = executor.execute_tool_call(&call).await;
1516 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1517 }
1518
1519 #[tokio::test]
1520 async fn fetch_private_ip_blocked() {
1521 let config = ScrapeConfig::default();
1522 let executor = WebScrapeExecutor::new(&config);
1523 let call = crate::executor::ToolCall {
1524 tool_id: "fetch".to_owned(),
1525 params: {
1526 let mut m = serde_json::Map::new();
1527 m.insert(
1528 "url".to_owned(),
1529 serde_json::json!("https://192.168.1.1/secret"),
1530 );
1531 m
1532 },
1533 };
1534 let result = executor.execute_tool_call(&call).await;
1535 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1536 }
1537
1538 #[tokio::test]
1539 async fn fetch_localhost_blocked() {
1540 let config = ScrapeConfig::default();
1541 let executor = WebScrapeExecutor::new(&config);
1542 let call = crate::executor::ToolCall {
1543 tool_id: "fetch".to_owned(),
1544 params: {
1545 let mut m = serde_json::Map::new();
1546 m.insert(
1547 "url".to_owned(),
1548 serde_json::json!("https://localhost/page"),
1549 );
1550 m
1551 },
1552 };
1553 let result = executor.execute_tool_call(&call).await;
1554 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1555 }
1556
1557 #[tokio::test]
1558 async fn fetch_unknown_tool_returns_none() {
1559 let config = ScrapeConfig::default();
1560 let executor = WebScrapeExecutor::new(&config);
1561 let call = crate::executor::ToolCall {
1562 tool_id: "unknown_tool".to_owned(),
1563 params: serde_json::Map::new(),
1564 };
1565 let result = executor.execute_tool_call(&call).await;
1566 assert!(result.unwrap().is_none());
1567 }
1568
1569 #[tokio::test]
1570 async fn fetch_returns_body_via_mock() {
1571 use wiremock::matchers::{method, path};
1572 use wiremock::{Mock, ResponseTemplate};
1573
1574 let (executor, server) = mock_server_executor().await;
1575 Mock::given(method("GET"))
1576 .and(path("/content"))
1577 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
1578 .mount(&server)
1579 .await;
1580
1581 let (host, addrs) = server_host_and_addr(&server);
1582 let url = format!("{}/content", server.uri());
1583 let result = executor.fetch_html(&url, &host, &addrs).await;
1584 assert!(result.is_ok());
1585 assert_eq!(result.unwrap(), "plain text content");
1586 }
1587
1588 #[test]
1589 fn tool_definitions_returns_web_scrape_and_fetch() {
1590 let config = ScrapeConfig::default();
1591 let executor = WebScrapeExecutor::new(&config);
1592 let defs = executor.tool_definitions();
1593 assert_eq!(defs.len(), 2);
1594 assert_eq!(defs[0].id, "web_scrape");
1595 assert_eq!(
1596 defs[0].invocation,
1597 crate::registry::InvocationHint::FencedBlock("scrape")
1598 );
1599 assert_eq!(defs[1].id, "fetch");
1600 assert_eq!(
1601 defs[1].invocation,
1602 crate::registry::InvocationHint::ToolCall
1603 );
1604 }
1605
1606 #[test]
1607 fn tool_definitions_schema_has_all_params() {
1608 let config = ScrapeConfig::default();
1609 let executor = WebScrapeExecutor::new(&config);
1610 let defs = executor.tool_definitions();
1611 let obj = defs[0].schema.as_object().unwrap();
1612 let props = obj["properties"].as_object().unwrap();
1613 assert!(props.contains_key("url"));
1614 assert!(props.contains_key("select"));
1615 assert!(props.contains_key("extract"));
1616 assert!(props.contains_key("limit"));
1617 let req = obj["required"].as_array().unwrap();
1618 assert!(req.iter().any(|v| v.as_str() == Some("url")));
1619 assert!(req.iter().any(|v| v.as_str() == Some("select")));
1620 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
1621 }
1622
1623 #[test]
1626 fn subdomain_localhost_blocked() {
1627 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
1628 assert!(is_private_host(&host));
1629 }
1630
1631 #[test]
1632 fn internal_tld_blocked() {
1633 let host: url::Host<&str> = url::Host::Domain("service.internal");
1634 assert!(is_private_host(&host));
1635 }
1636
1637 #[test]
1638 fn local_tld_blocked() {
1639 let host: url::Host<&str> = url::Host::Domain("printer.local");
1640 assert!(is_private_host(&host));
1641 }
1642
1643 #[test]
1644 fn public_domain_not_blocked() {
1645 let host: url::Host<&str> = url::Host::Domain("example.com");
1646 assert!(!is_private_host(&host));
1647 }
1648
1649 #[tokio::test]
1652 async fn resolve_loopback_rejected() {
1653 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
1655 let result = resolve_and_validate(&url).await;
1657 assert!(
1658 result.is_err(),
1659 "loopback IP must be rejected by resolve_and_validate"
1660 );
1661 let err = result.unwrap_err();
1662 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
1663 }
1664
1665 #[tokio::test]
1666 async fn resolve_private_10_rejected() {
1667 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
1668 let result = resolve_and_validate(&url).await;
1669 assert!(result.is_err());
1670 assert!(matches!(
1671 result.unwrap_err(),
1672 crate::executor::ToolError::Blocked { .. }
1673 ));
1674 }
1675
1676 #[tokio::test]
1677 async fn resolve_private_192_rejected() {
1678 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
1679 let result = resolve_and_validate(&url).await;
1680 assert!(result.is_err());
1681 assert!(matches!(
1682 result.unwrap_err(),
1683 crate::executor::ToolError::Blocked { .. }
1684 ));
1685 }
1686
1687 #[tokio::test]
1688 async fn resolve_ipv6_loopback_rejected() {
1689 let url = url::Url::parse("https://[::1]/path").unwrap();
1690 let result = resolve_and_validate(&url).await;
1691 assert!(result.is_err());
1692 assert!(matches!(
1693 result.unwrap_err(),
1694 crate::executor::ToolError::Blocked { .. }
1695 ));
1696 }
1697
1698 #[tokio::test]
1699 async fn resolve_no_host_returns_ok() {
1700 let url = url::Url::parse("https://example.com/path").unwrap();
1702 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
1704 let result = resolve_and_validate(&url_no_host).await;
1706 assert!(result.is_ok());
1707 let (host, addrs) = result.unwrap();
1708 assert!(host.is_empty());
1709 assert!(addrs.is_empty());
1710 drop(url);
1711 drop(url_no_host);
1712 }
1713
1714 async fn make_file_audit_logger(
1718 dir: &tempfile::TempDir,
1719 ) -> (
1720 std::sync::Arc<crate::audit::AuditLogger>,
1721 std::path::PathBuf,
1722 ) {
1723 use crate::audit::AuditLogger;
1724 use crate::config::AuditConfig;
1725 let path = dir.path().join("audit.log");
1726 let config = AuditConfig {
1727 enabled: true,
1728 destination: path.display().to_string(),
1729 };
1730 let logger = std::sync::Arc::new(AuditLogger::from_config(&config).await.unwrap());
1731 (logger, path)
1732 }
1733
1734 #[tokio::test]
1735 async fn with_audit_sets_logger() {
1736 let config = ScrapeConfig::default();
1737 let executor = WebScrapeExecutor::new(&config);
1738 assert!(executor.audit_logger.is_none());
1739
1740 let dir = tempfile::tempdir().unwrap();
1741 let (logger, _path) = make_file_audit_logger(&dir).await;
1742 let executor = executor.with_audit(logger);
1743 assert!(executor.audit_logger.is_some());
1744 }
1745
1746 #[test]
1747 fn tool_error_to_audit_result_blocked_maps_correctly() {
1748 let err = ToolError::Blocked {
1749 command: "scheme not allowed: http".into(),
1750 };
1751 let result = tool_error_to_audit_result(&err);
1752 assert!(
1753 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
1754 );
1755 }
1756
1757 #[test]
1758 fn tool_error_to_audit_result_timeout_maps_correctly() {
1759 let err = ToolError::Timeout { timeout_secs: 15 };
1760 let result = tool_error_to_audit_result(&err);
1761 assert!(matches!(result, AuditResult::Timeout));
1762 }
1763
1764 #[test]
1765 fn tool_error_to_audit_result_execution_error_maps_correctly() {
1766 let err = ToolError::Execution(std::io::Error::other("connection refused"));
1767 let result = tool_error_to_audit_result(&err);
1768 assert!(
1769 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
1770 );
1771 }
1772
1773 #[tokio::test]
1774 async fn fetch_audit_blocked_url_logged() {
1775 let dir = tempfile::tempdir().unwrap();
1776 let (logger, log_path) = make_file_audit_logger(&dir).await;
1777
1778 let config = ScrapeConfig::default();
1779 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1780
1781 let call = crate::executor::ToolCall {
1782 tool_id: "fetch".to_owned(),
1783 params: {
1784 let mut m = serde_json::Map::new();
1785 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1786 m
1787 },
1788 };
1789 let result = executor.execute_tool_call(&call).await;
1790 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1791
1792 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1793 assert!(
1794 content.contains("\"tool\":\"fetch\""),
1795 "expected tool=fetch in audit: {content}"
1796 );
1797 assert!(
1798 content.contains("\"type\":\"blocked\""),
1799 "expected type=blocked in audit: {content}"
1800 );
1801 assert!(
1802 content.contains("http://example.com"),
1803 "expected URL in audit command field: {content}"
1804 );
1805 }
1806
1807 #[tokio::test]
1808 async fn log_audit_success_writes_to_file() {
1809 let dir = tempfile::tempdir().unwrap();
1810 let (logger, log_path) = make_file_audit_logger(&dir).await;
1811
1812 let config = ScrapeConfig::default();
1813 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1814
1815 executor
1816 .log_audit(
1817 "fetch",
1818 "https://example.com/page",
1819 AuditResult::Success,
1820 42,
1821 )
1822 .await;
1823
1824 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1825 assert!(
1826 content.contains("\"tool\":\"fetch\""),
1827 "expected tool=fetch in audit: {content}"
1828 );
1829 assert!(
1830 content.contains("\"type\":\"success\""),
1831 "expected type=success in audit: {content}"
1832 );
1833 assert!(
1834 content.contains("\"command\":\"https://example.com/page\""),
1835 "expected command URL in audit: {content}"
1836 );
1837 assert!(
1838 content.contains("\"duration_ms\":42"),
1839 "expected duration_ms in audit: {content}"
1840 );
1841 }
1842
1843 #[tokio::test]
1844 async fn log_audit_blocked_writes_to_file() {
1845 let dir = tempfile::tempdir().unwrap();
1846 let (logger, log_path) = make_file_audit_logger(&dir).await;
1847
1848 let config = ScrapeConfig::default();
1849 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1850
1851 executor
1852 .log_audit(
1853 "web_scrape",
1854 "http://evil.com/page",
1855 AuditResult::Blocked {
1856 reason: "scheme not allowed: http".into(),
1857 },
1858 0,
1859 )
1860 .await;
1861
1862 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1863 assert!(
1864 content.contains("\"tool\":\"web_scrape\""),
1865 "expected tool=web_scrape in audit: {content}"
1866 );
1867 assert!(
1868 content.contains("\"type\":\"blocked\""),
1869 "expected type=blocked in audit: {content}"
1870 );
1871 assert!(
1872 content.contains("scheme not allowed"),
1873 "expected block reason in audit: {content}"
1874 );
1875 }
1876
1877 #[tokio::test]
1878 async fn web_scrape_audit_blocked_url_logged() {
1879 let dir = tempfile::tempdir().unwrap();
1880 let (logger, log_path) = make_file_audit_logger(&dir).await;
1881
1882 let config = ScrapeConfig::default();
1883 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
1884
1885 let call = crate::executor::ToolCall {
1886 tool_id: "web_scrape".to_owned(),
1887 params: {
1888 let mut m = serde_json::Map::new();
1889 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1890 m.insert("select".to_owned(), serde_json::json!("h1"));
1891 m
1892 },
1893 };
1894 let result = executor.execute_tool_call(&call).await;
1895 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1896
1897 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
1898 assert!(
1899 content.contains("\"tool\":\"web_scrape\""),
1900 "expected tool=web_scrape in audit: {content}"
1901 );
1902 assert!(
1903 content.contains("\"type\":\"blocked\""),
1904 "expected type=blocked in audit: {content}"
1905 );
1906 }
1907
1908 #[tokio::test]
1909 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
1910 let config = ScrapeConfig::default();
1911 let executor = WebScrapeExecutor::new(&config);
1912 assert!(executor.audit_logger.is_none());
1913
1914 let call = crate::executor::ToolCall {
1915 tool_id: "fetch".to_owned(),
1916 params: {
1917 let mut m = serde_json::Map::new();
1918 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
1919 m
1920 },
1921 };
1922 let result = executor.execute_tool_call(&call).await;
1924 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1925 }
1926
1927 #[tokio::test]
1929 async fn fetch_execute_tool_call_end_to_end() {
1930 use wiremock::matchers::{method, path};
1931 use wiremock::{Mock, ResponseTemplate};
1932
1933 let (executor, server) = mock_server_executor().await;
1934 Mock::given(method("GET"))
1935 .and(path("/e2e"))
1936 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
1937 .mount(&server)
1938 .await;
1939
1940 let (host, addrs) = server_host_and_addr(&server);
1941 let result = executor
1943 .fetch_html(&format!("{}/e2e", server.uri()), &host, &addrs)
1944 .await;
1945 assert!(result.is_ok());
1946 assert!(result.unwrap().contains("end-to-end"));
1947 }
1948}