1use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24use schemars::JsonSchema;
25use serde::Deserialize;
26use url::Url;
27
28use zeph_common::ToolName;
29
30use zeph_sanitizer::IpiFilter;
31
32use crate::audit::{AuditEntry, AuditLogger, AuditResult, EgressEvent, chrono_now};
33use crate::config::{EgressConfig, ScrapeConfig};
34use crate::executor::{
35 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
36};
37use crate::net::is_private_ip;
38
39fn redact_url_for_log(url: &str) -> String {
43 let Ok(mut parsed) = Url::parse(url) else {
44 return url.to_owned();
45 };
46 let _ = parsed.set_username("");
48 let _ = parsed.set_password(None);
49 let sensitive = [
51 "token", "key", "secret", "password", "auth", "sig", "api_key", "apikey",
52 ];
53 let filtered: Vec<(String, String)> = parsed
54 .query_pairs()
55 .filter(|(k, _)| {
56 let lower = k.to_lowercase();
57 !sensitive.iter().any(|s| lower.contains(s))
58 })
59 .map(|(k, v)| (k.into_owned(), v.into_owned()))
60 .collect();
61 if filtered.is_empty() {
62 parsed.set_query(None);
63 } else {
64 let q: String = filtered
65 .iter()
66 .map(|(k, v)| format!("{k}={v}"))
67 .collect::<Vec<_>>()
68 .join("&");
69 parsed.set_query(Some(&q));
70 }
71 parsed.to_string()
72}
73
74#[derive(Debug, Deserialize, JsonSchema)]
75struct FetchParams {
76 url: String,
78}
79
80#[derive(Debug, Deserialize, JsonSchema)]
81struct ScrapeInstruction {
82 url: String,
84 select: String,
86 #[serde(default = "default_extract")]
88 extract: String,
89 limit: Option<usize>,
91}
92
93fn default_extract() -> String {
94 "text".into()
95}
96
97#[derive(Debug)]
98enum ExtractMode {
99 Text,
100 Html,
101 Attr(String),
102}
103
104impl ExtractMode {
105 fn parse(s: &str) -> Self {
106 match s {
107 "text" => Self::Text,
108 "html" => Self::Html,
109 attr if attr.starts_with("attr:") => {
110 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
111 }
112 _ => Self::Text,
113 }
114 }
115}
116
117#[derive(Debug)]
158pub struct WebScrapeExecutor {
159 timeout: Duration,
160 max_body_bytes: usize,
161 allowed_domains: Vec<String>,
162 denied_domains: Vec<String>,
163 audit_logger: Option<Arc<AuditLogger>>,
164 egress_config: EgressConfig,
165 egress_tx: Option<tokio::sync::mpsc::Sender<EgressEvent>>,
166 egress_dropped: Arc<AtomicU64>,
167 ipi_filter: IpiFilter,
169}
170
171impl WebScrapeExecutor {
172 #[must_use]
176 pub fn new(config: &ScrapeConfig) -> Self {
177 Self {
178 timeout: Duration::from_secs(config.timeout),
179 max_body_bytes: config.max_body_bytes,
180 allowed_domains: config.allowed_domains.clone(),
181 denied_domains: config.denied_domains.clone(),
182 audit_logger: None,
183 egress_config: EgressConfig::default(),
184 egress_tx: None,
185 egress_dropped: Arc::new(AtomicU64::new(0)),
186 ipi_filter: IpiFilter::new(config.ipi_filter_threshold),
187 }
188 }
189
190 #[must_use]
192 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
193 self.audit_logger = Some(logger);
194 self
195 }
196
197 #[must_use]
199 pub fn with_egress_config(mut self, config: EgressConfig) -> Self {
200 self.egress_config = config;
201 self
202 }
203
204 #[must_use]
209 pub fn with_egress_tx(
210 mut self,
211 tx: tokio::sync::mpsc::Sender<EgressEvent>,
212 dropped: Arc<AtomicU64>,
213 ) -> Self {
214 self.egress_tx = Some(tx);
215 self.egress_dropped = dropped;
216 self
217 }
218
219 #[must_use]
221 pub fn egress_dropped(&self) -> Arc<AtomicU64> {
222 Arc::clone(&self.egress_dropped)
223 }
224
225 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
226 let mut builder = reqwest::Client::builder()
227 .timeout(self.timeout)
228 .redirect(reqwest::redirect::Policy::none());
229 builder = builder.resolve_to_addrs(host, addrs);
230 builder.build().unwrap_or_default()
231 }
232}
233
234impl ToolExecutor for WebScrapeExecutor {
235 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
236 use crate::registry::{InvocationHint, ToolDef};
237 vec![
238 ToolDef {
239 id: "web_scrape".into(),
240 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
241 schema: schemars::schema_for!(ScrapeInstruction),
242 invocation: InvocationHint::FencedBlock("scrape"),
243 output_schema: None,
244 },
245 ToolDef {
246 id: "fetch".into(),
247 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
248 schema: schemars::schema_for!(FetchParams),
249 invocation: InvocationHint::ToolCall,
250 output_schema: None,
251 },
252 ]
253 }
254
255 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
256 let blocks = extract_scrape_blocks(response);
257 if blocks.is_empty() {
258 return Ok(None);
259 }
260
261 let mut outputs = Vec::with_capacity(blocks.len());
262 #[allow(clippy::cast_possible_truncation)]
263 let blocks_executed = blocks.len() as u32;
264
265 for block in &blocks {
266 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
267 ToolError::Execution(std::io::Error::new(
268 std::io::ErrorKind::InvalidData,
269 e.to_string(),
270 ))
271 })?;
272 let correlation_id = EgressEvent::new_correlation_id();
273 let start = Instant::now();
274 let scrape_result = self
275 .scrape_instruction(&instruction, &correlation_id, None)
276 .await;
277 #[allow(clippy::cast_possible_truncation)]
278 let duration_ms = start.elapsed().as_millis() as u64;
279 match scrape_result {
280 Ok(output) => {
281 self.log_audit(
282 "web_scrape",
283 &instruction.url,
284 AuditResult::Success,
285 duration_ms,
286 None,
287 None,
288 Some(correlation_id),
289 )
290 .await;
291 outputs.push(output);
292 }
293 Err(e) => {
294 let audit_result = tool_error_to_audit_result(&e);
295 self.log_audit(
296 "web_scrape",
297 &instruction.url,
298 audit_result,
299 duration_ms,
300 Some(&e),
301 None,
302 Some(correlation_id),
303 )
304 .await;
305 return Err(e);
306 }
307 }
308 }
309
310 Ok(Some(ToolOutput {
311 tool_name: ToolName::new("web-scrape"),
312 summary: outputs.join("\n\n"),
313 blocks_executed,
314 filter_stats: None,
315 diff: None,
316 streamed: false,
317 terminal_id: None,
318 locations: None,
319 raw_response: None,
320 claim_source: Some(ClaimSource::WebScrape),
321 }))
322 }
323
324 #[cfg_attr(
325 feature = "profiling",
326 tracing::instrument(name = "tool.web_scrape", skip_all)
327 )]
328 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
329 match call.tool_id.as_str() {
330 "web_scrape" => {
331 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
332 let correlation_id = EgressEvent::new_correlation_id();
333 let start = Instant::now();
334 let result = self
335 .scrape_instruction(&instruction, &correlation_id, call.caller_id.clone())
336 .await;
337 #[allow(clippy::cast_possible_truncation)]
338 let duration_ms = start.elapsed().as_millis() as u64;
339 self.run_with_audit(
340 "web_scrape",
341 "web-scrape",
342 &instruction.url,
343 call.caller_id.clone(),
344 correlation_id,
345 duration_ms,
346 result,
347 )
348 .await
349 }
350 "fetch" => {
351 let p: FetchParams = deserialize_params(&call.params)?;
352 let correlation_id = EgressEvent::new_correlation_id();
353 let start = Instant::now();
354 let result = self
355 .handle_fetch(&p, &correlation_id, call.caller_id.clone())
356 .await;
357 #[allow(clippy::cast_possible_truncation)]
358 let duration_ms = start.elapsed().as_millis() as u64;
359 self.run_with_audit(
360 "fetch",
361 "fetch",
362 &p.url,
363 call.caller_id.clone(),
364 correlation_id,
365 duration_ms,
366 result,
367 )
368 .await
369 }
370 _ => Ok(None),
371 }
372 }
373
374 fn is_tool_retryable(&self, tool_id: &str) -> bool {
375 matches!(tool_id, "web_scrape" | "fetch")
376 }
377}
378
379fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
380 match e {
381 ToolError::Blocked { command } => AuditResult::Blocked {
382 reason: command.clone(),
383 },
384 ToolError::Timeout { .. } => AuditResult::Timeout,
385 _ => AuditResult::Error {
386 message: e.to_string(),
387 },
388 }
389}
390
391impl WebScrapeExecutor {
392 #[allow(clippy::too_many_arguments)]
393 async fn run_with_audit(
394 &self,
395 audit_tool_name: &str,
396 public_tool_name: &str,
397 audit_command: &str,
398 caller_id: Option<String>,
399 correlation_id: String,
400 duration_ms: u64,
401 result: Result<String, ToolError>,
402 ) -> Result<Option<ToolOutput>, ToolError> {
403 match result {
404 Ok(output) => {
405 self.log_audit(
406 audit_tool_name,
407 audit_command,
408 AuditResult::Success,
409 duration_ms,
410 None,
411 caller_id,
412 Some(correlation_id),
413 )
414 .await;
415 Ok(Some(ToolOutput {
416 tool_name: ToolName::new(public_tool_name),
417 summary: output,
418 blocks_executed: 1,
419 filter_stats: None,
420 diff: None,
421 streamed: false,
422 terminal_id: None,
423 locations: None,
424 raw_response: None,
425 claim_source: Some(ClaimSource::WebScrape),
426 }))
427 }
428 Err(e) => {
429 let audit_result = tool_error_to_audit_result(&e);
430 self.log_audit(
431 audit_tool_name,
432 audit_command,
433 audit_result,
434 duration_ms,
435 Some(&e),
436 caller_id,
437 Some(correlation_id),
438 )
439 .await;
440 Err(e)
441 }
442 }
443 }
444
445 #[allow(clippy::too_many_arguments)] async fn log_audit(
447 &self,
448 tool: &str,
449 command: &str,
450 result: AuditResult,
451 duration_ms: u64,
452 error: Option<&ToolError>,
453 caller_id: Option<String>,
454 correlation_id: Option<String>,
455 ) {
456 if let Some(ref logger) = self.audit_logger {
457 let (error_category, error_domain, error_phase) =
458 error.map_or((None, None, None), |e| {
459 let cat = e.category();
460 (
461 Some(cat.label().to_owned()),
462 Some(cat.domain().label().to_owned()),
463 Some(cat.phase().label().to_owned()),
464 )
465 });
466 let entry = AuditEntry {
467 timestamp: chrono_now(),
468 tool: tool.into(),
469 command: command.into(),
470 result,
471 duration_ms,
472 error_category,
473 error_domain,
474 error_phase,
475 claim_source: Some(ClaimSource::WebScrape),
476 mcp_server_id: None,
477 injection_flagged: false,
478 embedding_anomalous: false,
479 cross_boundary_mcp_to_acp: false,
480 adversarial_policy_decision: None,
481 exit_code: None,
482 truncated: false,
483 caller_id,
484 policy_match: None,
485 correlation_id,
486 vigil_risk: None,
487 execution_env: None,
488 resolved_cwd: None,
489 scope_at_definition: None,
490 scope_at_dispatch: None,
491 };
492 logger.log(&entry).await;
493 }
494 }
495
496 fn send_egress_event(&self, event: EgressEvent) {
497 if let Some(ref tx) = self.egress_tx {
498 match tx.try_send(event) {
499 Ok(()) => {}
500 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
501 self.egress_dropped.fetch_add(1, Ordering::Relaxed);
502 }
503 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
504 tracing::debug!("egress channel closed; executor continuing without telemetry");
505 }
506 }
507 }
508 }
509
510 async fn log_egress_event(&self, event: &EgressEvent) {
511 if let Some(ref logger) = self.audit_logger {
512 logger.log_egress(event).await;
513 }
514 self.send_egress_event(event.clone());
515 }
516
517 async fn handle_fetch(
518 &self,
519 params: &FetchParams,
520 correlation_id: &str,
521 caller_id: Option<String>,
522 ) -> Result<String, ToolError> {
523 let parsed = validate_url(¶ms.url);
524 let host_str = parsed
525 .as_ref()
526 .map(|u| u.host_str().unwrap_or("").to_owned())
527 .unwrap_or_default();
528
529 if let Err(ref _e) = parsed {
530 if self.egress_config.enabled && self.egress_config.log_blocked {
531 let event = Self::make_blocked_event(
532 "fetch",
533 ¶ms.url,
534 &host_str,
535 correlation_id,
536 caller_id.clone(),
537 "scheme",
538 );
539 self.log_egress_event(&event).await;
540 }
541 return Err(parsed.unwrap_err());
542 }
543 let parsed = parsed.unwrap();
544
545 if let Err(e) = check_domain_policy(
546 parsed.host_str().unwrap_or(""),
547 &self.allowed_domains,
548 &self.denied_domains,
549 ) {
550 if self.egress_config.enabled && self.egress_config.log_blocked {
551 let event = Self::make_blocked_event(
552 "fetch",
553 ¶ms.url,
554 parsed.host_str().unwrap_or(""),
555 correlation_id,
556 caller_id.clone(),
557 "blocklist",
558 );
559 self.log_egress_event(&event).await;
560 }
561 return Err(e);
562 }
563
564 let (host, addrs) = match resolve_and_validate(&parsed).await {
565 Ok(v) => v,
566 Err(e) => {
567 if self.egress_config.enabled && self.egress_config.log_blocked {
568 let event = Self::make_blocked_event(
569 "fetch",
570 ¶ms.url,
571 parsed.host_str().unwrap_or(""),
572 correlation_id,
573 caller_id.clone(),
574 "ssrf",
575 );
576 self.log_egress_event(&event).await;
577 }
578 return Err(e);
579 }
580 };
581
582 let body = self
583 .fetch_html(
584 ¶ms.url,
585 &host,
586 &addrs,
587 "fetch",
588 correlation_id,
589 caller_id,
590 )
591 .await?;
592 self.apply_ipi_filter(&body, ¶ms.url).await
593 }
594
595 async fn scrape_instruction(
596 &self,
597 instruction: &ScrapeInstruction,
598 correlation_id: &str,
599 caller_id: Option<String>,
600 ) -> Result<String, ToolError> {
601 let parsed = validate_url(&instruction.url);
602 let host_str = parsed
603 .as_ref()
604 .map(|u| u.host_str().unwrap_or("").to_owned())
605 .unwrap_or_default();
606
607 if let Err(ref _e) = parsed {
608 if self.egress_config.enabled && self.egress_config.log_blocked {
609 let event = Self::make_blocked_event(
610 "web_scrape",
611 &instruction.url,
612 &host_str,
613 correlation_id,
614 caller_id.clone(),
615 "scheme",
616 );
617 self.log_egress_event(&event).await;
618 }
619 return Err(parsed.unwrap_err());
620 }
621 let parsed = parsed.unwrap();
622
623 if let Err(e) = check_domain_policy(
624 parsed.host_str().unwrap_or(""),
625 &self.allowed_domains,
626 &self.denied_domains,
627 ) {
628 if self.egress_config.enabled && self.egress_config.log_blocked {
629 let event = Self::make_blocked_event(
630 "web_scrape",
631 &instruction.url,
632 parsed.host_str().unwrap_or(""),
633 correlation_id,
634 caller_id.clone(),
635 "blocklist",
636 );
637 self.log_egress_event(&event).await;
638 }
639 return Err(e);
640 }
641
642 let (host, addrs) = match resolve_and_validate(&parsed).await {
643 Ok(v) => v,
644 Err(e) => {
645 if self.egress_config.enabled && self.egress_config.log_blocked {
646 let event = Self::make_blocked_event(
647 "web_scrape",
648 &instruction.url,
649 parsed.host_str().unwrap_or(""),
650 correlation_id,
651 caller_id.clone(),
652 "ssrf",
653 );
654 self.log_egress_event(&event).await;
655 }
656 return Err(e);
657 }
658 };
659
660 let html = self
661 .fetch_html(
662 &instruction.url,
663 &host,
664 &addrs,
665 "web_scrape",
666 correlation_id,
667 caller_id,
668 )
669 .await?;
670 let selector = instruction.select.clone();
671 let extract = ExtractMode::parse(&instruction.extract);
672 let limit = instruction.limit.unwrap_or(10);
673 let extracted = tokio::task::spawn_blocking(move || {
674 parse_and_extract(&html, &selector, &extract, limit)
675 })
676 .await
677 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))??;
678 self.apply_ipi_filter(&extracted, &instruction.url).await
680 }
681
682 fn make_blocked_event(
683 tool: &str,
684 url: &str,
685 host: &str,
686 correlation_id: &str,
687 caller_id: Option<String>,
688 block_reason: &'static str,
689 ) -> EgressEvent {
690 EgressEvent {
691 timestamp: chrono_now(),
692 kind: "egress",
693 correlation_id: correlation_id.to_owned(),
694 tool: tool.into(),
695 url: redact_url_for_log(url),
696 host: host.to_owned(),
697 method: "GET".to_owned(),
698 status: None,
699 duration_ms: 0,
700 response_bytes: 0,
701 blocked: true,
702 block_reason: Some(block_reason),
703 caller_id,
704 hop: 0,
705 }
706 }
707
708 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
719 async fn fetch_html(
720 &self,
721 url: &str,
722 host: &str,
723 addrs: &[SocketAddr],
724 tool: &str,
725 correlation_id: &str,
726 caller_id: Option<String>,
727 ) -> Result<String, ToolError> {
728 const MAX_REDIRECTS: usize = 3;
729
730 let mut current_url = url.to_owned();
731 let mut current_host = host.to_owned();
732 let mut current_addrs = addrs.to_vec();
733
734 for hop in 0..=MAX_REDIRECTS {
735 let hop_start = Instant::now();
736 let client = self.build_client(¤t_host, ¤t_addrs);
738 let resp = client
739 .get(¤t_url)
740 .send()
741 .await
742 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
743
744 let resp = match resp {
745 Ok(r) => r,
746 Err(e) => {
747 if self.egress_config.enabled {
748 #[allow(clippy::cast_possible_truncation)]
749 let duration_ms = hop_start.elapsed().as_millis() as u64;
750 let event = EgressEvent {
751 timestamp: chrono_now(),
752 kind: "egress",
753 correlation_id: correlation_id.to_owned(),
754 tool: tool.into(),
755 url: redact_url_for_log(¤t_url),
756 host: current_host.clone(),
757 method: "GET".to_owned(),
758 status: None,
759 duration_ms,
760 response_bytes: 0,
761 blocked: false,
762 block_reason: None,
763 caller_id: caller_id.clone(),
764 #[allow(clippy::cast_possible_truncation)]
765 hop: hop as u8,
766 };
767 self.log_egress_event(&event).await;
768 }
769 return Err(e);
770 }
771 };
772
773 let status = resp.status();
774
775 if status.is_redirection() {
776 if hop == MAX_REDIRECTS {
777 return Err(ToolError::Execution(std::io::Error::other(
778 "too many redirects",
779 )));
780 }
781
782 let location = resp
783 .headers()
784 .get(reqwest::header::LOCATION)
785 .and_then(|v| v.to_str().ok())
786 .ok_or_else(|| {
787 ToolError::Execution(std::io::Error::other("redirect with no Location"))
788 })?;
789
790 let base = Url::parse(¤t_url)
792 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
793 let next_url = base
794 .join(location)
795 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
796
797 let validated = validate_url(next_url.as_str());
798 if let Err(ref _e) = validated {
799 if self.egress_config.enabled && self.egress_config.log_blocked {
800 #[allow(clippy::cast_possible_truncation)]
801 let duration_ms = hop_start.elapsed().as_millis() as u64;
802 let next_host = next_url.host_str().unwrap_or("").to_owned();
803 let event = EgressEvent {
804 timestamp: chrono_now(),
805 kind: "egress",
806 correlation_id: correlation_id.to_owned(),
807 tool: tool.into(),
808 url: redact_url_for_log(next_url.as_str()),
809 host: next_host,
810 method: "GET".to_owned(),
811 status: None,
812 duration_ms,
813 response_bytes: 0,
814 blocked: true,
815 block_reason: Some("ssrf"),
816 caller_id: caller_id.clone(),
817 #[allow(clippy::cast_possible_truncation)]
818 hop: (hop + 1) as u8,
819 };
820 self.log_egress_event(&event).await;
821 }
822 return Err(validated.unwrap_err());
823 }
824 let validated = validated.unwrap();
825 let resolve_result = resolve_and_validate(&validated).await;
826 if let Err(ref _e) = resolve_result {
827 if self.egress_config.enabled && self.egress_config.log_blocked {
828 #[allow(clippy::cast_possible_truncation)]
829 let duration_ms = hop_start.elapsed().as_millis() as u64;
830 let next_host = next_url.host_str().unwrap_or("").to_owned();
831 let event = EgressEvent {
832 timestamp: chrono_now(),
833 kind: "egress",
834 correlation_id: correlation_id.to_owned(),
835 tool: tool.into(),
836 url: redact_url_for_log(next_url.as_str()),
837 host: next_host,
838 method: "GET".to_owned(),
839 status: None,
840 duration_ms,
841 response_bytes: 0,
842 blocked: true,
843 block_reason: Some("ssrf"),
844 caller_id: caller_id.clone(),
845 #[allow(clippy::cast_possible_truncation)]
846 hop: (hop + 1) as u8,
847 };
848 self.log_egress_event(&event).await;
849 }
850 return Err(resolve_result.unwrap_err());
851 }
852 let (next_host, next_addrs) = resolve_result.unwrap();
853
854 current_url = next_url.to_string();
855 current_host = next_host;
856 current_addrs = next_addrs;
857 continue;
858 }
859
860 if !status.is_success() {
861 if self.egress_config.enabled {
862 #[allow(clippy::cast_possible_truncation)]
863 let duration_ms = hop_start.elapsed().as_millis() as u64;
864 let event = EgressEvent {
865 timestamp: chrono_now(),
866 kind: "egress",
867 correlation_id: correlation_id.to_owned(),
868 tool: tool.into(),
869 url: current_url.clone(),
870 host: current_host.clone(),
871 method: "GET".to_owned(),
872 status: Some(status.as_u16()),
873 duration_ms,
874 response_bytes: 0,
875 blocked: false,
876 block_reason: None,
877 caller_id: caller_id.clone(),
878 #[allow(clippy::cast_possible_truncation)]
879 hop: hop as u8,
880 };
881 self.log_egress_event(&event).await;
882 }
883 return Err(ToolError::Http {
884 status: status.as_u16(),
885 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
886 });
887 }
888
889 let bytes = resp
890 .bytes()
891 .await
892 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
893
894 if bytes.len() > self.max_body_bytes {
895 if self.egress_config.enabled {
896 #[allow(clippy::cast_possible_truncation)]
897 let duration_ms = hop_start.elapsed().as_millis() as u64;
898 let event = EgressEvent {
899 timestamp: chrono_now(),
900 kind: "egress",
901 correlation_id: correlation_id.to_owned(),
902 tool: tool.into(),
903 url: current_url.clone(),
904 host: current_host.clone(),
905 method: "GET".to_owned(),
906 status: Some(status.as_u16()),
907 duration_ms,
908 response_bytes: bytes.len(),
909 blocked: false,
910 block_reason: None,
911 caller_id: caller_id.clone(),
912 #[allow(clippy::cast_possible_truncation)]
913 hop: hop as u8,
914 };
915 self.log_egress_event(&event).await;
916 }
917 return Err(ToolError::Execution(std::io::Error::other(format!(
918 "response too large: {} bytes (max: {})",
919 bytes.len(),
920 self.max_body_bytes,
921 ))));
922 }
923
924 if self.egress_config.enabled {
926 #[allow(clippy::cast_possible_truncation)]
927 let duration_ms = hop_start.elapsed().as_millis() as u64;
928 let response_bytes = if self.egress_config.log_response_bytes {
929 bytes.len()
930 } else {
931 0
932 };
933 let event = EgressEvent {
934 timestamp: chrono_now(),
935 kind: "egress",
936 correlation_id: correlation_id.to_owned(),
937 tool: tool.into(),
938 url: current_url.clone(),
939 host: current_host.clone(),
940 method: "GET".to_owned(),
941 status: Some(status.as_u16()),
942 duration_ms,
943 response_bytes,
944 blocked: false,
945 block_reason: None,
946 caller_id: caller_id.clone(),
947 #[allow(clippy::cast_possible_truncation)]
948 hop: hop as u8,
949 };
950 self.log_egress_event(&event).await;
951 }
952
953 return String::from_utf8(bytes.to_vec())
954 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
955 }
956
957 Err(ToolError::Execution(std::io::Error::other(
958 "too many redirects",
959 )))
960 }
961
962 async fn apply_ipi_filter(&self, body: &str, url: &str) -> Result<String, ToolError> {
972 let span = tracing::info_span!("tools.scrape.apply_ipi_filter", url, body_len = body.len());
973 let verdict = self
974 .ipi_filter
975 .filter_async(body.to_owned())
976 .await
977 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
978 let _enter = span.enter();
979 if !verdict.patterns_found.is_empty() {
980 tracing::warn!(
981 url = url,
982 score = verdict.score,
983 patterns = ?verdict.patterns_found,
984 "IPI patterns detected in fetched content"
985 );
986 }
987 if verdict.sanitized == body {
989 Ok(verdict.sanitized)
990 } else {
991 Ok(format!(
992 "[IPI WARNING: score={:.2}, patterns={}] {}",
993 verdict.score,
994 verdict.patterns_found.join(", "),
995 verdict.sanitized,
996 ))
997 }
998 }
999}
1000
1001fn extract_scrape_blocks(text: &str) -> Vec<&str> {
1002 crate::executor::extract_fenced_blocks(text, "scrape")
1003}
1004
1005fn check_domain_policy(
1017 host: &str,
1018 allowed_domains: &[String],
1019 denied_domains: &[String],
1020) -> Result<(), ToolError> {
1021 if denied_domains.iter().any(|p| domain_matches(p, host)) {
1022 return Err(ToolError::Blocked {
1023 command: format!("domain blocked by denylist: {host}"),
1024 });
1025 }
1026 if !allowed_domains.is_empty() {
1027 let is_ip = host.parse::<std::net::IpAddr>().is_ok()
1029 || (host.starts_with('[') && host.ends_with(']'));
1030 if is_ip {
1031 return Err(ToolError::Blocked {
1032 command: format!(
1033 "bare IP address not allowed when domain allowlist is active: {host}"
1034 ),
1035 });
1036 }
1037 if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
1038 return Err(ToolError::Blocked {
1039 command: format!("domain not in allowlist: {host}"),
1040 });
1041 }
1042 }
1043 Ok(())
1044}
1045
1046use crate::domain_match::domain_matches;
1048
1049fn validate_url(raw: &str) -> Result<Url, ToolError> {
1050 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
1051 command: format!("invalid URL: {raw}"),
1052 })?;
1053
1054 if parsed.scheme() != "https" {
1055 return Err(ToolError::Blocked {
1056 command: format!("scheme not allowed: {}", parsed.scheme()),
1057 });
1058 }
1059
1060 if let Some(host) = parsed.host()
1061 && is_private_host(&host)
1062 {
1063 return Err(ToolError::Blocked {
1064 command: format!(
1065 "private/local host blocked: {}",
1066 parsed.host_str().unwrap_or("")
1067 ),
1068 });
1069 }
1070
1071 Ok(parsed)
1072}
1073
1074fn is_private_host(host: &url::Host<&str>) -> bool {
1075 match host {
1076 url::Host::Domain(d) => {
1077 #[allow(clippy::case_sensitive_file_extension_comparisons)]
1080 {
1081 *d == "localhost"
1082 || d.ends_with(".localhost")
1083 || d.ends_with(".internal")
1084 || d.ends_with(".local")
1085 }
1086 }
1087 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
1088 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
1089 }
1090}
1091
1092async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
1098 let Some(host) = url.host_str() else {
1099 return Ok((String::new(), vec![]));
1100 };
1101 let port = url.port_or_known_default().unwrap_or(443);
1102 let span = tracing::info_span!("scrape.dns.resolve", host = host);
1103 let dns_result = tokio::time::timeout(
1104 Duration::from_secs(10),
1105 tokio::net::lookup_host(format!("{host}:{port}")),
1106 )
1107 .await;
1108 let addrs: Vec<SocketAddr> = {
1109 let _enter = span.enter();
1110 dns_result
1111 .map_err(|_| ToolError::Timeout { timeout_secs: 10 })?
1112 .map_err(|e| ToolError::Blocked {
1113 command: format!("DNS resolution failed: {e}"),
1114 })?
1115 .collect()
1116 };
1117 for addr in &addrs {
1118 if is_private_ip(addr.ip()) {
1119 return Err(ToolError::Blocked {
1120 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
1121 });
1122 }
1123 }
1124 Ok((host.to_owned(), addrs))
1125}
1126
1127fn parse_and_extract(
1128 html: &str,
1129 selector: &str,
1130 extract: &ExtractMode,
1131 limit: usize,
1132) -> Result<String, ToolError> {
1133 let soup = scrape_core::Soup::parse(html);
1134
1135 let tags = soup.find_all(selector).map_err(|e| {
1136 ToolError::Execution(std::io::Error::new(
1137 std::io::ErrorKind::InvalidData,
1138 format!("invalid selector: {e}"),
1139 ))
1140 })?;
1141
1142 let mut results = Vec::new();
1143
1144 for tag in tags.into_iter().take(limit) {
1145 let value = match extract {
1146 ExtractMode::Text => tag.text(),
1147 ExtractMode::Html => tag.inner_html(),
1148 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
1149 };
1150 if !value.trim().is_empty() {
1151 results.push(value.trim().to_owned());
1152 }
1153 }
1154
1155 if results.is_empty() {
1156 Ok(format!("No results for selector: {selector}"))
1157 } else {
1158 Ok(results.join("\n"))
1159 }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165
1166 #[test]
1169 fn extract_single_block() {
1170 let text =
1171 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
1172 let blocks = extract_scrape_blocks(text);
1173 assert_eq!(blocks.len(), 1);
1174 assert!(blocks[0].contains("example.com"));
1175 }
1176
1177 #[test]
1178 fn extract_multiple_blocks() {
1179 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
1180 let blocks = extract_scrape_blocks(text);
1181 assert_eq!(blocks.len(), 2);
1182 }
1183
1184 #[test]
1185 fn no_blocks_returns_empty() {
1186 let blocks = extract_scrape_blocks("plain text, no code blocks");
1187 assert!(blocks.is_empty());
1188 }
1189
1190 #[test]
1191 fn unclosed_block_ignored() {
1192 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
1193 assert!(blocks.is_empty());
1194 }
1195
1196 #[test]
1197 fn non_scrape_block_ignored() {
1198 let text =
1199 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
1200 let blocks = extract_scrape_blocks(text);
1201 assert_eq!(blocks.len(), 1);
1202 assert!(blocks[0].contains("x.com"));
1203 }
1204
1205 #[test]
1206 fn multiline_json_block() {
1207 let text =
1208 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
1209 let blocks = extract_scrape_blocks(text);
1210 assert_eq!(blocks.len(), 1);
1211 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
1212 assert_eq!(instr.url, "https://example.com");
1213 }
1214
1215 #[test]
1218 fn parse_valid_instruction() {
1219 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
1220 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1221 assert_eq!(instr.url, "https://example.com");
1222 assert_eq!(instr.select, "h1");
1223 assert_eq!(instr.extract, "text");
1224 assert_eq!(instr.limit, Some(5));
1225 }
1226
1227 #[test]
1228 fn parse_minimal_instruction() {
1229 let json = r#"{"url":"https://example.com","select":"p"}"#;
1230 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1231 assert_eq!(instr.extract, "text");
1232 assert!(instr.limit.is_none());
1233 }
1234
1235 #[test]
1236 fn parse_attr_extract() {
1237 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
1238 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1239 assert_eq!(instr.extract, "attr:href");
1240 }
1241
1242 #[test]
1243 fn parse_invalid_json_errors() {
1244 let result = serde_json::from_str::<ScrapeInstruction>("not json");
1245 assert!(result.is_err());
1246 }
1247
1248 #[test]
1251 fn extract_mode_text() {
1252 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
1253 }
1254
1255 #[test]
1256 fn extract_mode_html() {
1257 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
1258 }
1259
1260 #[test]
1261 fn extract_mode_attr() {
1262 let mode = ExtractMode::parse("attr:href");
1263 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
1264 }
1265
1266 #[test]
1267 fn extract_mode_unknown_defaults_to_text() {
1268 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
1269 }
1270
1271 #[test]
1274 fn valid_https_url() {
1275 assert!(validate_url("https://example.com").is_ok());
1276 }
1277
1278 #[test]
1279 fn http_rejected() {
1280 let err = validate_url("http://example.com").unwrap_err();
1281 assert!(matches!(err, ToolError::Blocked { .. }));
1282 }
1283
1284 #[test]
1285 fn ftp_rejected() {
1286 let err = validate_url("ftp://files.example.com").unwrap_err();
1287 assert!(matches!(err, ToolError::Blocked { .. }));
1288 }
1289
1290 #[test]
1291 fn file_rejected() {
1292 let err = validate_url("file:///etc/passwd").unwrap_err();
1293 assert!(matches!(err, ToolError::Blocked { .. }));
1294 }
1295
1296 #[test]
1297 fn invalid_url_rejected() {
1298 let err = validate_url("not a url").unwrap_err();
1299 assert!(matches!(err, ToolError::Blocked { .. }));
1300 }
1301
1302 #[test]
1303 fn localhost_blocked() {
1304 let err = validate_url("https://localhost/path").unwrap_err();
1305 assert!(matches!(err, ToolError::Blocked { .. }));
1306 }
1307
1308 #[test]
1309 fn loopback_ip_blocked() {
1310 let err = validate_url("https://127.0.0.1/path").unwrap_err();
1311 assert!(matches!(err, ToolError::Blocked { .. }));
1312 }
1313
1314 #[test]
1315 fn private_10_blocked() {
1316 let err = validate_url("https://10.0.0.1/api").unwrap_err();
1317 assert!(matches!(err, ToolError::Blocked { .. }));
1318 }
1319
1320 #[test]
1321 fn private_172_blocked() {
1322 let err = validate_url("https://172.16.0.1/api").unwrap_err();
1323 assert!(matches!(err, ToolError::Blocked { .. }));
1324 }
1325
1326 #[test]
1327 fn private_192_blocked() {
1328 let err = validate_url("https://192.168.1.1/api").unwrap_err();
1329 assert!(matches!(err, ToolError::Blocked { .. }));
1330 }
1331
1332 #[test]
1333 fn ipv6_loopback_blocked() {
1334 let err = validate_url("https://[::1]/path").unwrap_err();
1335 assert!(matches!(err, ToolError::Blocked { .. }));
1336 }
1337
1338 #[test]
1339 fn public_ip_allowed() {
1340 assert!(validate_url("https://93.184.216.34/page").is_ok());
1341 }
1342
1343 #[test]
1346 fn extract_text_from_html() {
1347 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
1348 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1349 assert_eq!(result, "Hello World");
1350 }
1351
1352 #[test]
1353 fn extract_multiple_elements() {
1354 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1355 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1356 assert_eq!(result, "A\nB\nC");
1357 }
1358
1359 #[test]
1360 fn extract_with_limit() {
1361 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1362 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
1363 assert_eq!(result, "A\nB");
1364 }
1365
1366 #[test]
1367 fn extract_attr_href() {
1368 let html = r#"<a href="https://example.com">Link</a>"#;
1369 let result =
1370 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1371 assert_eq!(result, "https://example.com");
1372 }
1373
1374 #[test]
1375 fn extract_inner_html() {
1376 let html = "<div><span>inner</span></div>";
1377 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1378 assert!(result.contains("<span>inner</span>"));
1379 }
1380
1381 #[test]
1382 fn no_matches_returns_message() {
1383 let html = "<html><body><p>text</p></body></html>";
1384 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1385 assert!(result.starts_with("No results for selector:"));
1386 }
1387
1388 #[test]
1389 fn empty_text_skipped() {
1390 let html = "<ul><li> </li><li>A</li></ul>";
1391 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1392 assert_eq!(result, "A");
1393 }
1394
1395 #[test]
1396 fn invalid_selector_errors() {
1397 let html = "<html><body></body></html>";
1398 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
1399 assert!(result.is_err());
1400 }
1401
1402 #[test]
1403 fn empty_html_returns_no_results() {
1404 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
1405 assert!(result.starts_with("No results for selector:"));
1406 }
1407
1408 #[test]
1409 fn nested_selector() {
1410 let html = "<div><span>inner</span></div><span>outer</span>";
1411 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
1412 assert_eq!(result, "inner");
1413 }
1414
1415 #[test]
1416 fn attr_missing_returns_empty() {
1417 let html = r"<a>No href</a>";
1418 let result =
1419 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1420 assert!(result.starts_with("No results for selector:"));
1421 }
1422
1423 #[test]
1424 fn extract_html_mode() {
1425 let html = "<div><b>bold</b> text</div>";
1426 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1427 assert!(result.contains("<b>bold</b>"));
1428 }
1429
1430 #[test]
1431 fn limit_zero_returns_no_results() {
1432 let html = "<ul><li>A</li><li>B</li></ul>";
1433 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
1434 assert!(result.starts_with("No results for selector:"));
1435 }
1436
1437 #[test]
1440 fn url_with_port_allowed() {
1441 assert!(validate_url("https://example.com:8443/path").is_ok());
1442 }
1443
1444 #[test]
1445 fn link_local_ip_blocked() {
1446 let err = validate_url("https://169.254.1.1/path").unwrap_err();
1447 assert!(matches!(err, ToolError::Blocked { .. }));
1448 }
1449
1450 #[test]
1451 fn url_no_scheme_rejected() {
1452 let err = validate_url("example.com/path").unwrap_err();
1453 assert!(matches!(err, ToolError::Blocked { .. }));
1454 }
1455
1456 #[test]
1457 fn unspecified_ipv4_blocked() {
1458 let err = validate_url("https://0.0.0.0/path").unwrap_err();
1459 assert!(matches!(err, ToolError::Blocked { .. }));
1460 }
1461
1462 #[test]
1463 fn broadcast_ipv4_blocked() {
1464 let err = validate_url("https://255.255.255.255/path").unwrap_err();
1465 assert!(matches!(err, ToolError::Blocked { .. }));
1466 }
1467
1468 #[test]
1469 fn ipv6_link_local_blocked() {
1470 let err = validate_url("https://[fe80::1]/path").unwrap_err();
1471 assert!(matches!(err, ToolError::Blocked { .. }));
1472 }
1473
1474 #[test]
1475 fn ipv6_unique_local_blocked() {
1476 let err = validate_url("https://[fd12::1]/path").unwrap_err();
1477 assert!(matches!(err, ToolError::Blocked { .. }));
1478 }
1479
1480 #[test]
1481 fn ipv4_mapped_ipv6_loopback_blocked() {
1482 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
1483 assert!(matches!(err, ToolError::Blocked { .. }));
1484 }
1485
1486 #[test]
1487 fn ipv4_mapped_ipv6_private_blocked() {
1488 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
1489 assert!(matches!(err, ToolError::Blocked { .. }));
1490 }
1491
1492 #[tokio::test]
1495 async fn executor_no_blocks_returns_none() {
1496 let config = ScrapeConfig::default();
1497 let executor = WebScrapeExecutor::new(&config);
1498 let result = executor.execute("plain text").await;
1499 assert!(result.unwrap().is_none());
1500 }
1501
1502 #[tokio::test]
1503 async fn executor_invalid_json_errors() {
1504 let config = ScrapeConfig::default();
1505 let executor = WebScrapeExecutor::new(&config);
1506 let response = "```scrape\nnot json\n```";
1507 let result = executor.execute(response).await;
1508 assert!(matches!(result, Err(ToolError::Execution(_))));
1509 }
1510
1511 #[tokio::test]
1512 async fn executor_blocked_url_errors() {
1513 let config = ScrapeConfig::default();
1514 let executor = WebScrapeExecutor::new(&config);
1515 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
1516 let result = executor.execute(response).await;
1517 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1518 }
1519
1520 #[tokio::test]
1521 async fn executor_private_ip_blocked() {
1522 let config = ScrapeConfig::default();
1523 let executor = WebScrapeExecutor::new(&config);
1524 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
1525 let result = executor.execute(response).await;
1526 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1527 }
1528
1529 #[tokio::test]
1530 async fn executor_unreachable_host_returns_error() {
1531 let config = ScrapeConfig {
1532 timeout: 1,
1533 max_body_bytes: 1_048_576,
1534 ..Default::default()
1535 };
1536 let executor = WebScrapeExecutor::new(&config);
1537 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
1538 let result = executor.execute(response).await;
1539 assert!(matches!(result, Err(ToolError::Execution(_))));
1540 }
1541
1542 #[tokio::test]
1543 async fn executor_localhost_url_blocked() {
1544 let config = ScrapeConfig::default();
1545 let executor = WebScrapeExecutor::new(&config);
1546 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
1547 let result = executor.execute(response).await;
1548 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1549 }
1550
1551 #[tokio::test]
1552 async fn executor_empty_text_returns_none() {
1553 let config = ScrapeConfig::default();
1554 let executor = WebScrapeExecutor::new(&config);
1555 let result = executor.execute("").await;
1556 assert!(result.unwrap().is_none());
1557 }
1558
1559 #[tokio::test]
1560 async fn executor_multiple_blocks_first_blocked() {
1561 let config = ScrapeConfig::default();
1562 let executor = WebScrapeExecutor::new(&config);
1563 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1564 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1565 let result = executor.execute(response).await;
1566 assert!(result.is_err());
1567 }
1568
1569 #[test]
1570 fn validate_url_empty_string() {
1571 let err = validate_url("").unwrap_err();
1572 assert!(matches!(err, ToolError::Blocked { .. }));
1573 }
1574
1575 #[test]
1576 fn validate_url_javascript_scheme_blocked() {
1577 let err = validate_url("javascript:alert(1)").unwrap_err();
1578 assert!(matches!(err, ToolError::Blocked { .. }));
1579 }
1580
1581 #[test]
1582 fn validate_url_data_scheme_blocked() {
1583 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1584 assert!(matches!(err, ToolError::Blocked { .. }));
1585 }
1586
1587 #[test]
1588 fn is_private_host_public_domain_is_false() {
1589 let host: url::Host<&str> = url::Host::Domain("example.com");
1590 assert!(!is_private_host(&host));
1591 }
1592
1593 #[test]
1594 fn is_private_host_localhost_is_true() {
1595 let host: url::Host<&str> = url::Host::Domain("localhost");
1596 assert!(is_private_host(&host));
1597 }
1598
1599 #[test]
1600 fn is_private_host_ipv6_unspecified_is_true() {
1601 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1602 assert!(is_private_host(&host));
1603 }
1604
1605 #[test]
1606 fn is_private_host_public_ipv6_is_false() {
1607 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1608 assert!(!is_private_host(&host));
1609 }
1610
1611 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1622 let server = wiremock::MockServer::start().await;
1623 let executor = WebScrapeExecutor {
1624 timeout: Duration::from_secs(5),
1625 max_body_bytes: 1_048_576,
1626 allowed_domains: vec![],
1627 denied_domains: vec![],
1628 audit_logger: None,
1629 egress_config: EgressConfig::default(),
1630 egress_tx: None,
1631 egress_dropped: Arc::new(AtomicU64::new(0)),
1632 ipi_filter: IpiFilter::new(0.6),
1633 };
1634 (executor, server)
1635 }
1636
1637 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1639 let uri = server.uri();
1640 let url = Url::parse(&uri).unwrap();
1641 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1642 let port = url.port().unwrap_or(80);
1643 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1644 (host, vec![addr])
1645 }
1646
1647 async fn follow_redirects_raw(
1651 executor: &WebScrapeExecutor,
1652 start_url: &str,
1653 host: &str,
1654 addrs: &[std::net::SocketAddr],
1655 ) -> Result<String, ToolError> {
1656 const MAX_REDIRECTS: usize = 3;
1657 let mut current_url = start_url.to_owned();
1658 let mut current_host = host.to_owned();
1659 let mut current_addrs = addrs.to_vec();
1660
1661 for hop in 0..=MAX_REDIRECTS {
1662 let client = executor.build_client(¤t_host, ¤t_addrs);
1663 let resp = client
1664 .get(¤t_url)
1665 .send()
1666 .await
1667 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1668
1669 let status = resp.status();
1670
1671 if status.is_redirection() {
1672 if hop == MAX_REDIRECTS {
1673 return Err(ToolError::Execution(std::io::Error::other(
1674 "too many redirects",
1675 )));
1676 }
1677
1678 let location = resp
1679 .headers()
1680 .get(reqwest::header::LOCATION)
1681 .and_then(|v| v.to_str().ok())
1682 .ok_or_else(|| {
1683 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1684 })?;
1685
1686 let base = Url::parse(¤t_url)
1687 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1688 let next_url = base
1689 .join(location)
1690 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1691
1692 current_url = next_url.to_string();
1694 let _ = &mut current_host;
1696 let _ = &mut current_addrs;
1697 continue;
1698 }
1699
1700 if !status.is_success() {
1701 return Err(ToolError::Execution(std::io::Error::other(format!(
1702 "HTTP {status}",
1703 ))));
1704 }
1705
1706 let bytes = resp
1707 .bytes()
1708 .await
1709 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1710
1711 if bytes.len() > executor.max_body_bytes {
1712 return Err(ToolError::Execution(std::io::Error::other(format!(
1713 "response too large: {} bytes (max: {})",
1714 bytes.len(),
1715 executor.max_body_bytes,
1716 ))));
1717 }
1718
1719 return String::from_utf8(bytes.to_vec())
1720 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1721 }
1722
1723 Err(ToolError::Execution(std::io::Error::other(
1724 "too many redirects",
1725 )))
1726 }
1727
1728 #[tokio::test]
1729 async fn fetch_html_success_returns_body() {
1730 use wiremock::matchers::{method, path};
1731 use wiremock::{Mock, ResponseTemplate};
1732
1733 let (executor, server) = mock_server_executor().await;
1734 Mock::given(method("GET"))
1735 .and(path("/page"))
1736 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1737 .mount(&server)
1738 .await;
1739
1740 let (host, addrs) = server_host_and_addr(&server);
1741 let url = format!("{}/page", server.uri());
1742 let result = executor
1743 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1744 .await;
1745 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1746 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1747 }
1748
1749 #[tokio::test]
1750 async fn fetch_html_non_2xx_returns_error() {
1751 use wiremock::matchers::{method, path};
1752 use wiremock::{Mock, ResponseTemplate};
1753
1754 let (executor, server) = mock_server_executor().await;
1755 Mock::given(method("GET"))
1756 .and(path("/forbidden"))
1757 .respond_with(ResponseTemplate::new(403))
1758 .mount(&server)
1759 .await;
1760
1761 let (host, addrs) = server_host_and_addr(&server);
1762 let url = format!("{}/forbidden", server.uri());
1763 let result = executor
1764 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1765 .await;
1766 assert!(result.is_err());
1767 let msg = result.unwrap_err().to_string();
1768 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1769 }
1770
1771 #[tokio::test]
1772 async fn fetch_html_404_returns_error() {
1773 use wiremock::matchers::{method, path};
1774 use wiremock::{Mock, ResponseTemplate};
1775
1776 let (executor, server) = mock_server_executor().await;
1777 Mock::given(method("GET"))
1778 .and(path("/missing"))
1779 .respond_with(ResponseTemplate::new(404))
1780 .mount(&server)
1781 .await;
1782
1783 let (host, addrs) = server_host_and_addr(&server);
1784 let url = format!("{}/missing", server.uri());
1785 let result = executor
1786 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1787 .await;
1788 assert!(result.is_err());
1789 let msg = result.unwrap_err().to_string();
1790 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1791 }
1792
1793 #[tokio::test]
1794 async fn fetch_html_redirect_no_location_returns_error() {
1795 use wiremock::matchers::{method, path};
1796 use wiremock::{Mock, ResponseTemplate};
1797
1798 let (executor, server) = mock_server_executor().await;
1799 Mock::given(method("GET"))
1801 .and(path("/redirect-no-loc"))
1802 .respond_with(ResponseTemplate::new(302))
1803 .mount(&server)
1804 .await;
1805
1806 let (host, addrs) = server_host_and_addr(&server);
1807 let url = format!("{}/redirect-no-loc", server.uri());
1808 let result = executor
1809 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1810 .await;
1811 assert!(result.is_err());
1812 let msg = result.unwrap_err().to_string();
1813 assert!(
1814 msg.contains("Location") || msg.contains("location"),
1815 "expected Location-related error: {msg}"
1816 );
1817 }
1818
1819 #[tokio::test]
1820 async fn fetch_html_single_redirect_followed() {
1821 use wiremock::matchers::{method, path};
1822 use wiremock::{Mock, ResponseTemplate};
1823
1824 let (executor, server) = mock_server_executor().await;
1825 let final_url = format!("{}/final", server.uri());
1826
1827 Mock::given(method("GET"))
1828 .and(path("/start"))
1829 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1830 .mount(&server)
1831 .await;
1832
1833 Mock::given(method("GET"))
1834 .and(path("/final"))
1835 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1836 .mount(&server)
1837 .await;
1838
1839 let (host, addrs) = server_host_and_addr(&server);
1840 let url = format!("{}/start", server.uri());
1841 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1842 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1843 assert_eq!(result.unwrap(), "<p>final</p>");
1844 }
1845
1846 #[tokio::test]
1847 async fn fetch_html_three_redirects_allowed() {
1848 use wiremock::matchers::{method, path};
1849 use wiremock::{Mock, ResponseTemplate};
1850
1851 let (executor, server) = mock_server_executor().await;
1852 let hop2 = format!("{}/hop2", server.uri());
1853 let hop3 = format!("{}/hop3", server.uri());
1854 let final_dest = format!("{}/done", server.uri());
1855
1856 Mock::given(method("GET"))
1857 .and(path("/hop1"))
1858 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1859 .mount(&server)
1860 .await;
1861 Mock::given(method("GET"))
1862 .and(path("/hop2"))
1863 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1864 .mount(&server)
1865 .await;
1866 Mock::given(method("GET"))
1867 .and(path("/hop3"))
1868 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1869 .mount(&server)
1870 .await;
1871 Mock::given(method("GET"))
1872 .and(path("/done"))
1873 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1874 .mount(&server)
1875 .await;
1876
1877 let (host, addrs) = server_host_and_addr(&server);
1878 let url = format!("{}/hop1", server.uri());
1879 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1880 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1881 assert_eq!(result.unwrap(), "<p>done</p>");
1882 }
1883
1884 #[tokio::test]
1885 async fn fetch_html_four_redirects_rejected() {
1886 use wiremock::matchers::{method, path};
1887 use wiremock::{Mock, ResponseTemplate};
1888
1889 let (executor, server) = mock_server_executor().await;
1890 let hop2 = format!("{}/r2", server.uri());
1891 let hop3 = format!("{}/r3", server.uri());
1892 let hop4 = format!("{}/r4", server.uri());
1893 let hop5 = format!("{}/r5", server.uri());
1894
1895 for (from, to) in [
1896 ("/r1", &hop2),
1897 ("/r2", &hop3),
1898 ("/r3", &hop4),
1899 ("/r4", &hop5),
1900 ] {
1901 Mock::given(method("GET"))
1902 .and(path(from))
1903 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1904 .mount(&server)
1905 .await;
1906 }
1907
1908 let (host, addrs) = server_host_and_addr(&server);
1909 let url = format!("{}/r1", server.uri());
1910 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1911 assert!(result.is_err(), "4 redirects should be rejected");
1912 let msg = result.unwrap_err().to_string();
1913 assert!(
1914 msg.contains("redirect"),
1915 "expected redirect-related error: {msg}"
1916 );
1917 }
1918
1919 #[tokio::test]
1920 async fn fetch_html_body_too_large_returns_error() {
1921 use wiremock::matchers::{method, path};
1922 use wiremock::{Mock, ResponseTemplate};
1923
1924 let small_limit_executor = WebScrapeExecutor {
1925 timeout: Duration::from_secs(5),
1926 max_body_bytes: 10,
1927 allowed_domains: vec![],
1928 denied_domains: vec![],
1929 audit_logger: None,
1930 egress_config: EgressConfig::default(),
1931 egress_tx: None,
1932 egress_dropped: Arc::new(AtomicU64::new(0)),
1933 ipi_filter: IpiFilter::new(0.6),
1934 };
1935 let server = wiremock::MockServer::start().await;
1936 Mock::given(method("GET"))
1937 .and(path("/big"))
1938 .respond_with(
1939 ResponseTemplate::new(200)
1940 .set_body_string("this body is definitely longer than ten bytes"),
1941 )
1942 .mount(&server)
1943 .await;
1944
1945 let (host, addrs) = server_host_and_addr(&server);
1946 let url = format!("{}/big", server.uri());
1947 let result = small_limit_executor
1948 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1949 .await;
1950 assert!(result.is_err());
1951 let msg = result.unwrap_err().to_string();
1952 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1953 }
1954
1955 #[test]
1956 fn extract_scrape_blocks_empty_block_content() {
1957 let text = "```scrape\n\n```";
1958 let blocks = extract_scrape_blocks(text);
1959 assert_eq!(blocks.len(), 1);
1960 assert!(blocks[0].is_empty());
1961 }
1962
1963 #[test]
1964 fn extract_scrape_blocks_whitespace_only() {
1965 let text = "```scrape\n \n```";
1966 let blocks = extract_scrape_blocks(text);
1967 assert_eq!(blocks.len(), 1);
1968 }
1969
1970 #[test]
1971 fn parse_and_extract_multiple_selectors() {
1972 let html = "<div><h1>Title</h1><p>Para</p></div>";
1973 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1974 assert!(result.contains("Title"));
1975 assert!(result.contains("Para"));
1976 }
1977
1978 #[test]
1979 fn webscrape_executor_new_with_custom_config() {
1980 let config = ScrapeConfig {
1981 timeout: 60,
1982 max_body_bytes: 512,
1983 ..Default::default()
1984 };
1985 let executor = WebScrapeExecutor::new(&config);
1986 assert_eq!(executor.max_body_bytes, 512);
1987 }
1988
1989 #[test]
1990 fn webscrape_executor_debug() {
1991 let config = ScrapeConfig::default();
1992 let executor = WebScrapeExecutor::new(&config);
1993 let dbg = format!("{executor:?}");
1994 assert!(dbg.contains("WebScrapeExecutor"));
1995 }
1996
1997 #[test]
1998 fn extract_mode_attr_empty_name() {
1999 let mode = ExtractMode::parse("attr:");
2000 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
2001 }
2002
2003 #[test]
2004 fn default_extract_returns_text() {
2005 assert_eq!(default_extract(), "text");
2006 }
2007
2008 #[test]
2009 fn scrape_instruction_debug() {
2010 let json = r#"{"url":"https://example.com","select":"h1"}"#;
2011 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
2012 let dbg = format!("{instr:?}");
2013 assert!(dbg.contains("ScrapeInstruction"));
2014 }
2015
2016 #[test]
2017 fn extract_mode_debug() {
2018 let mode = ExtractMode::Text;
2019 let dbg = format!("{mode:?}");
2020 assert!(dbg.contains("Text"));
2021 }
2022
2023 #[test]
2028 fn max_redirects_constant_is_three() {
2029 const MAX_REDIRECTS: usize = 3;
2033 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
2034 }
2035
2036 #[test]
2039 fn redirect_no_location_error_message() {
2040 let err = std::io::Error::other("redirect with no Location");
2041 assert!(err.to_string().contains("redirect with no Location"));
2042 }
2043
2044 #[test]
2046 fn too_many_redirects_error_message() {
2047 let err = std::io::Error::other("too many redirects");
2048 assert!(err.to_string().contains("too many redirects"));
2049 }
2050
2051 #[test]
2053 fn non_2xx_status_error_format() {
2054 let status = reqwest::StatusCode::FORBIDDEN;
2055 let msg = format!("HTTP {status}");
2056 assert!(msg.contains("403"));
2057 }
2058
2059 #[test]
2061 fn not_found_status_error_format() {
2062 let status = reqwest::StatusCode::NOT_FOUND;
2063 let msg = format!("HTTP {status}");
2064 assert!(msg.contains("404"));
2065 }
2066
2067 #[test]
2069 fn relative_redirect_same_host_path() {
2070 let base = Url::parse("https://example.com/current").unwrap();
2071 let resolved = base.join("/other").unwrap();
2072 assert_eq!(resolved.as_str(), "https://example.com/other");
2073 }
2074
2075 #[test]
2077 fn relative_redirect_relative_path() {
2078 let base = Url::parse("https://example.com/a/b").unwrap();
2079 let resolved = base.join("c").unwrap();
2080 assert_eq!(resolved.as_str(), "https://example.com/a/c");
2081 }
2082
2083 #[test]
2085 fn absolute_redirect_overrides_base() {
2086 let base = Url::parse("https://example.com/page").unwrap();
2087 let resolved = base.join("https://other.com/target").unwrap();
2088 assert_eq!(resolved.as_str(), "https://other.com/target");
2089 }
2090
2091 #[test]
2093 fn redirect_http_downgrade_rejected() {
2094 let location = "http://example.com/page";
2095 let base = Url::parse("https://example.com/start").unwrap();
2096 let next = base.join(location).unwrap();
2097 let err = validate_url(next.as_str()).unwrap_err();
2098 assert!(matches!(err, ToolError::Blocked { .. }));
2099 }
2100
2101 #[test]
2103 fn redirect_location_private_ip_blocked() {
2104 let location = "https://192.168.100.1/admin";
2105 let base = Url::parse("https://example.com/start").unwrap();
2106 let next = base.join(location).unwrap();
2107 let err = validate_url(next.as_str()).unwrap_err();
2108 assert!(matches!(err, ToolError::Blocked { .. }));
2109 let ToolError::Blocked { command: cmd } = err else {
2110 panic!("expected Blocked");
2111 };
2112 assert!(
2113 cmd.contains("private") || cmd.contains("scheme"),
2114 "error message should describe the block reason: {cmd}"
2115 );
2116 }
2117
2118 #[test]
2120 fn redirect_location_internal_domain_blocked() {
2121 let location = "https://metadata.internal/latest/meta-data/";
2122 let base = Url::parse("https://example.com/start").unwrap();
2123 let next = base.join(location).unwrap();
2124 let err = validate_url(next.as_str()).unwrap_err();
2125 assert!(matches!(err, ToolError::Blocked { .. }));
2126 }
2127
2128 #[test]
2130 fn redirect_chain_three_hops_all_public() {
2131 let hops = [
2132 "https://redirect1.example.com/hop1",
2133 "https://redirect2.example.com/hop2",
2134 "https://destination.example.com/final",
2135 ];
2136 for hop in hops {
2137 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
2138 }
2139 }
2140
2141 #[test]
2146 fn redirect_to_private_ip_rejected_by_validate_url() {
2147 let private_targets = [
2149 "https://127.0.0.1/secret",
2150 "https://10.0.0.1/internal",
2151 "https://192.168.1.1/admin",
2152 "https://172.16.0.1/data",
2153 "https://[::1]/path",
2154 "https://[fe80::1]/path",
2155 "https://localhost/path",
2156 "https://service.internal/api",
2157 ];
2158 for target in private_targets {
2159 let result = validate_url(target);
2160 assert!(result.is_err(), "expected error for {target}");
2161 assert!(
2162 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
2163 "expected Blocked for {target}"
2164 );
2165 }
2166 }
2167
2168 #[test]
2170 fn redirect_relative_url_resolves_correctly() {
2171 let base = Url::parse("https://example.com/page").unwrap();
2172 let relative = "/other";
2173 let resolved = base.join(relative).unwrap();
2174 assert_eq!(resolved.as_str(), "https://example.com/other");
2175 }
2176
2177 #[test]
2179 fn redirect_to_http_rejected() {
2180 let err = validate_url("http://example.com/page").unwrap_err();
2181 assert!(matches!(err, ToolError::Blocked { .. }));
2182 }
2183
2184 #[test]
2185 fn ipv4_mapped_ipv6_link_local_blocked() {
2186 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
2187 assert!(matches!(err, ToolError::Blocked { .. }));
2188 }
2189
2190 #[test]
2191 fn ipv4_mapped_ipv6_public_allowed() {
2192 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
2193 }
2194
2195 #[tokio::test]
2198 async fn fetch_http_scheme_blocked() {
2199 let config = ScrapeConfig::default();
2200 let executor = WebScrapeExecutor::new(&config);
2201 let call = crate::executor::ToolCall {
2202 tool_id: ToolName::new("fetch"),
2203 params: {
2204 let mut m = serde_json::Map::new();
2205 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2206 m
2207 },
2208 caller_id: None,
2209 context: None,
2210
2211 tool_call_id: String::new(),
2212 };
2213 let result = executor.execute_tool_call(&call).await;
2214 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2215 }
2216
2217 #[tokio::test]
2218 async fn fetch_private_ip_blocked() {
2219 let config = ScrapeConfig::default();
2220 let executor = WebScrapeExecutor::new(&config);
2221 let call = crate::executor::ToolCall {
2222 tool_id: ToolName::new("fetch"),
2223 params: {
2224 let mut m = serde_json::Map::new();
2225 m.insert(
2226 "url".to_owned(),
2227 serde_json::json!("https://192.168.1.1/secret"),
2228 );
2229 m
2230 },
2231 caller_id: None,
2232 context: None,
2233
2234 tool_call_id: String::new(),
2235 };
2236 let result = executor.execute_tool_call(&call).await;
2237 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2238 }
2239
2240 #[tokio::test]
2241 async fn fetch_localhost_blocked() {
2242 let config = ScrapeConfig::default();
2243 let executor = WebScrapeExecutor::new(&config);
2244 let call = crate::executor::ToolCall {
2245 tool_id: ToolName::new("fetch"),
2246 params: {
2247 let mut m = serde_json::Map::new();
2248 m.insert(
2249 "url".to_owned(),
2250 serde_json::json!("https://localhost/page"),
2251 );
2252 m
2253 },
2254 caller_id: None,
2255 context: None,
2256
2257 tool_call_id: String::new(),
2258 };
2259 let result = executor.execute_tool_call(&call).await;
2260 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2261 }
2262
2263 #[tokio::test]
2264 async fn fetch_unknown_tool_returns_none() {
2265 let config = ScrapeConfig::default();
2266 let executor = WebScrapeExecutor::new(&config);
2267 let call = crate::executor::ToolCall {
2268 tool_id: ToolName::new("unknown_tool"),
2269 params: serde_json::Map::new(),
2270 caller_id: None,
2271 context: None,
2272
2273 tool_call_id: String::new(),
2274 };
2275 let result = executor.execute_tool_call(&call).await;
2276 assert!(result.unwrap().is_none());
2277 }
2278
2279 #[tokio::test]
2280 async fn fetch_returns_body_via_mock() {
2281 use wiremock::matchers::{method, path};
2282 use wiremock::{Mock, ResponseTemplate};
2283
2284 let (executor, server) = mock_server_executor().await;
2285 Mock::given(method("GET"))
2286 .and(path("/content"))
2287 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
2288 .mount(&server)
2289 .await;
2290
2291 let (host, addrs) = server_host_and_addr(&server);
2292 let url = format!("{}/content", server.uri());
2293 let result = executor
2294 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
2295 .await;
2296 assert!(result.is_ok());
2297 assert_eq!(result.unwrap(), "plain text content");
2298 }
2299
2300 #[test]
2301 fn tool_definitions_returns_web_scrape_and_fetch() {
2302 let config = ScrapeConfig::default();
2303 let executor = WebScrapeExecutor::new(&config);
2304 let defs = executor.tool_definitions();
2305 assert_eq!(defs.len(), 2);
2306 assert_eq!(defs[0].id, "web_scrape");
2307 assert_eq!(
2308 defs[0].invocation,
2309 crate::registry::InvocationHint::FencedBlock("scrape")
2310 );
2311 assert_eq!(defs[1].id, "fetch");
2312 assert_eq!(
2313 defs[1].invocation,
2314 crate::registry::InvocationHint::ToolCall
2315 );
2316 }
2317
2318 #[test]
2319 fn tool_definitions_schema_has_all_params() {
2320 let config = ScrapeConfig::default();
2321 let executor = WebScrapeExecutor::new(&config);
2322 let defs = executor.tool_definitions();
2323 let obj = defs[0].schema.as_object().unwrap();
2324 let props = obj["properties"].as_object().unwrap();
2325 assert!(props.contains_key("url"));
2326 assert!(props.contains_key("select"));
2327 assert!(props.contains_key("extract"));
2328 assert!(props.contains_key("limit"));
2329 let req = obj["required"].as_array().unwrap();
2330 assert!(req.iter().any(|v| v.as_str() == Some("url")));
2331 assert!(req.iter().any(|v| v.as_str() == Some("select")));
2332 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
2333 }
2334
2335 #[test]
2338 fn subdomain_localhost_blocked() {
2339 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
2340 assert!(is_private_host(&host));
2341 }
2342
2343 #[test]
2344 fn internal_tld_blocked() {
2345 let host: url::Host<&str> = url::Host::Domain("service.internal");
2346 assert!(is_private_host(&host));
2347 }
2348
2349 #[test]
2350 fn local_tld_blocked() {
2351 let host: url::Host<&str> = url::Host::Domain("printer.local");
2352 assert!(is_private_host(&host));
2353 }
2354
2355 #[test]
2356 fn public_domain_not_blocked() {
2357 let host: url::Host<&str> = url::Host::Domain("example.com");
2358 assert!(!is_private_host(&host));
2359 }
2360
2361 #[tokio::test]
2364 async fn resolve_loopback_rejected() {
2365 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
2367 let result = resolve_and_validate(&url).await;
2369 assert!(
2370 result.is_err(),
2371 "loopback IP must be rejected by resolve_and_validate"
2372 );
2373 let err = result.unwrap_err();
2374 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
2375 }
2376
2377 #[tokio::test]
2378 async fn resolve_private_10_rejected() {
2379 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
2380 let result = resolve_and_validate(&url).await;
2381 assert!(result.is_err());
2382 assert!(matches!(
2383 result.unwrap_err(),
2384 crate::executor::ToolError::Blocked { .. }
2385 ));
2386 }
2387
2388 #[tokio::test]
2389 async fn resolve_private_192_rejected() {
2390 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
2391 let result = resolve_and_validate(&url).await;
2392 assert!(result.is_err());
2393 assert!(matches!(
2394 result.unwrap_err(),
2395 crate::executor::ToolError::Blocked { .. }
2396 ));
2397 }
2398
2399 #[tokio::test]
2400 async fn resolve_ipv6_loopback_rejected() {
2401 let url = url::Url::parse("https://[::1]/path").unwrap();
2402 let result = resolve_and_validate(&url).await;
2403 assert!(result.is_err());
2404 assert!(matches!(
2405 result.unwrap_err(),
2406 crate::executor::ToolError::Blocked { .. }
2407 ));
2408 }
2409
2410 #[tokio::test]
2411 async fn resolve_no_host_returns_ok() {
2412 let url = url::Url::parse("https://example.com/path").unwrap();
2414 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
2416 let result = resolve_and_validate(&url_no_host).await;
2418 assert!(result.is_ok());
2419 let (host, addrs) = result.unwrap();
2420 assert!(host.is_empty());
2421 assert!(addrs.is_empty());
2422 drop(url);
2423 drop(url_no_host);
2424 }
2425
2426 async fn make_file_audit_logger(
2430 dir: &tempfile::TempDir,
2431 ) -> (
2432 std::sync::Arc<crate::audit::AuditLogger>,
2433 std::path::PathBuf,
2434 ) {
2435 use crate::audit::AuditLogger;
2436 use crate::config::AuditConfig;
2437 let path = dir.path().join("audit.log");
2438 let config = AuditConfig {
2439 enabled: true,
2440 destination: crate::config::AuditDestination::File(path.clone()),
2441 ..Default::default()
2442 };
2443 let logger = std::sync::Arc::new(AuditLogger::from_config(&config, false).await.unwrap());
2444 (logger, path)
2445 }
2446
2447 #[tokio::test]
2448 async fn with_audit_sets_logger() {
2449 let config = ScrapeConfig::default();
2450 let executor = WebScrapeExecutor::new(&config);
2451 assert!(executor.audit_logger.is_none());
2452
2453 let dir = tempfile::tempdir().unwrap();
2454 let (logger, _path) = make_file_audit_logger(&dir).await;
2455 let executor = executor.with_audit(logger);
2456 assert!(executor.audit_logger.is_some());
2457 }
2458
2459 #[test]
2460 fn tool_error_to_audit_result_blocked_maps_correctly() {
2461 let err = ToolError::Blocked {
2462 command: "scheme not allowed: http".into(),
2463 };
2464 let result = tool_error_to_audit_result(&err);
2465 assert!(
2466 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
2467 );
2468 }
2469
2470 #[test]
2471 fn tool_error_to_audit_result_timeout_maps_correctly() {
2472 let err = ToolError::Timeout { timeout_secs: 15 };
2473 let result = tool_error_to_audit_result(&err);
2474 assert!(matches!(result, AuditResult::Timeout));
2475 }
2476
2477 #[test]
2478 fn tool_error_to_audit_result_execution_error_maps_correctly() {
2479 let err = ToolError::Execution(std::io::Error::other("connection refused"));
2480 let result = tool_error_to_audit_result(&err);
2481 assert!(
2482 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
2483 );
2484 }
2485
2486 #[tokio::test]
2487 async fn fetch_audit_blocked_url_logged() {
2488 let dir = tempfile::tempdir().unwrap();
2489 let (logger, log_path) = make_file_audit_logger(&dir).await;
2490
2491 let config = ScrapeConfig::default();
2492 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2493
2494 let call = crate::executor::ToolCall {
2495 tool_id: ToolName::new("fetch"),
2496 params: {
2497 let mut m = serde_json::Map::new();
2498 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2499 m
2500 },
2501 caller_id: None,
2502 context: None,
2503
2504 tool_call_id: String::new(),
2505 };
2506 let result = executor.execute_tool_call(&call).await;
2507 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2508
2509 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2510 assert!(
2511 content.contains("\"tool\":\"fetch\""),
2512 "expected tool=fetch in audit: {content}"
2513 );
2514 assert!(
2515 content.contains("\"type\":\"blocked\""),
2516 "expected type=blocked in audit: {content}"
2517 );
2518 assert!(
2519 content.contains("http://example.com"),
2520 "expected URL in audit command field: {content}"
2521 );
2522 }
2523
2524 #[tokio::test]
2525 async fn log_audit_success_writes_to_file() {
2526 let dir = tempfile::tempdir().unwrap();
2527 let (logger, log_path) = make_file_audit_logger(&dir).await;
2528
2529 let config = ScrapeConfig::default();
2530 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2531
2532 executor
2533 .log_audit(
2534 "fetch",
2535 "https://example.com/page",
2536 AuditResult::Success,
2537 42,
2538 None,
2539 None,
2540 None,
2541 )
2542 .await;
2543
2544 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2545 assert!(
2546 content.contains("\"tool\":\"fetch\""),
2547 "expected tool=fetch in audit: {content}"
2548 );
2549 assert!(
2550 content.contains("\"type\":\"success\""),
2551 "expected type=success in audit: {content}"
2552 );
2553 assert!(
2554 content.contains("\"command\":\"https://example.com/page\""),
2555 "expected command URL in audit: {content}"
2556 );
2557 assert!(
2558 content.contains("\"duration_ms\":42"),
2559 "expected duration_ms in audit: {content}"
2560 );
2561 }
2562
2563 #[tokio::test]
2564 async fn log_audit_blocked_writes_to_file() {
2565 let dir = tempfile::tempdir().unwrap();
2566 let (logger, log_path) = make_file_audit_logger(&dir).await;
2567
2568 let config = ScrapeConfig::default();
2569 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2570
2571 executor
2572 .log_audit(
2573 "web_scrape",
2574 "http://evil.com/page",
2575 AuditResult::Blocked {
2576 reason: "scheme not allowed: http".into(),
2577 },
2578 0,
2579 None,
2580 None,
2581 None,
2582 )
2583 .await;
2584
2585 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2586 assert!(
2587 content.contains("\"tool\":\"web_scrape\""),
2588 "expected tool=web_scrape in audit: {content}"
2589 );
2590 assert!(
2591 content.contains("\"type\":\"blocked\""),
2592 "expected type=blocked in audit: {content}"
2593 );
2594 assert!(
2595 content.contains("scheme not allowed"),
2596 "expected block reason in audit: {content}"
2597 );
2598 }
2599
2600 #[tokio::test]
2601 async fn web_scrape_audit_blocked_url_logged() {
2602 let dir = tempfile::tempdir().unwrap();
2603 let (logger, log_path) = make_file_audit_logger(&dir).await;
2604
2605 let config = ScrapeConfig::default();
2606 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2607
2608 let call = crate::executor::ToolCall {
2609 tool_id: ToolName::new("web_scrape"),
2610 params: {
2611 let mut m = serde_json::Map::new();
2612 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2613 m.insert("select".to_owned(), serde_json::json!("h1"));
2614 m
2615 },
2616 caller_id: None,
2617 context: None,
2618
2619 tool_call_id: String::new(),
2620 };
2621 let result = executor.execute_tool_call(&call).await;
2622 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2623
2624 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2625 assert!(
2626 content.contains("\"tool\":\"web_scrape\""),
2627 "expected tool=web_scrape in audit: {content}"
2628 );
2629 assert!(
2630 content.contains("\"type\":\"blocked\""),
2631 "expected type=blocked in audit: {content}"
2632 );
2633 }
2634
2635 #[tokio::test]
2636 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2637 let config = ScrapeConfig::default();
2638 let executor = WebScrapeExecutor::new(&config);
2639 assert!(executor.audit_logger.is_none());
2640
2641 let call = crate::executor::ToolCall {
2642 tool_id: ToolName::new("fetch"),
2643 params: {
2644 let mut m = serde_json::Map::new();
2645 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2646 m
2647 },
2648 caller_id: None,
2649 context: None,
2650
2651 tool_call_id: String::new(),
2652 };
2653 let result = executor.execute_tool_call(&call).await;
2655 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2656 }
2657
2658 #[tokio::test]
2660 async fn fetch_execute_tool_call_end_to_end() {
2661 use wiremock::matchers::{method, path};
2662 use wiremock::{Mock, ResponseTemplate};
2663
2664 let (executor, server) = mock_server_executor().await;
2665 Mock::given(method("GET"))
2666 .and(path("/e2e"))
2667 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2668 .mount(&server)
2669 .await;
2670
2671 let (host, addrs) = server_host_and_addr(&server);
2672 let result = executor
2674 .fetch_html(
2675 &format!("{}/e2e", server.uri()),
2676 &host,
2677 &addrs,
2678 "fetch",
2679 "test-cid",
2680 None,
2681 )
2682 .await;
2683 assert!(result.is_ok());
2684 assert!(result.unwrap().contains("end-to-end"));
2685 }
2686
2687 #[test]
2690 fn domain_matches_exact() {
2691 assert!(domain_matches("example.com", "example.com"));
2692 assert!(!domain_matches("example.com", "other.com"));
2693 assert!(!domain_matches("example.com", "sub.example.com"));
2694 }
2695
2696 #[test]
2697 fn domain_matches_wildcard_single_subdomain() {
2698 assert!(domain_matches("*.example.com", "sub.example.com"));
2699 assert!(!domain_matches("*.example.com", "example.com"));
2700 assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2701 }
2702
2703 #[test]
2704 fn domain_matches_wildcard_does_not_match_empty_label() {
2705 assert!(!domain_matches("*.example.com", ".example.com"));
2707 }
2708
2709 #[test]
2710 fn domain_matches_multi_wildcard_treated_as_exact() {
2711 assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2713 }
2714
2715 #[test]
2718 fn check_domain_policy_empty_lists_allow_all() {
2719 assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2720 assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2721 }
2722
2723 #[test]
2724 fn check_domain_policy_denylist_blocks() {
2725 let denied = vec!["evil.com".to_string()];
2726 let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2727 assert!(matches!(err, ToolError::Blocked { .. }));
2728 }
2729
2730 #[test]
2731 fn check_domain_policy_denylist_does_not_block_other_domains() {
2732 let denied = vec!["evil.com".to_string()];
2733 assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2734 }
2735
2736 #[test]
2737 fn check_domain_policy_allowlist_permits_matching() {
2738 let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2739 assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2740 assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2741 }
2742
2743 #[test]
2744 fn check_domain_policy_allowlist_blocks_unknown() {
2745 let allowed = vec!["docs.rs".to_string()];
2746 let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2747 assert!(matches!(err, ToolError::Blocked { .. }));
2748 }
2749
2750 #[test]
2751 fn check_domain_policy_deny_overrides_allow() {
2752 let allowed = vec!["example.com".to_string()];
2753 let denied = vec!["example.com".to_string()];
2754 let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2755 assert!(matches!(err, ToolError::Blocked { .. }));
2756 }
2757
2758 #[test]
2759 fn check_domain_policy_wildcard_in_denylist() {
2760 let denied = vec!["*.evil.com".to_string()];
2761 let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2762 assert!(matches!(err, ToolError::Blocked { .. }));
2763 assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2765 }
2766}