1use std::net::{IpAddr, SocketAddr};
20use std::sync::Arc;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::time::{Duration, Instant};
23
24use schemars::JsonSchema;
25use serde::Deserialize;
26use url::Url;
27
28use zeph_common::ToolName;
29
30use crate::audit::{AuditEntry, AuditLogger, AuditResult, EgressEvent, chrono_now};
31use crate::config::{EgressConfig, ScrapeConfig};
32use crate::executor::{
33 ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
34};
35use crate::net::is_private_ip;
36
37fn redact_url_for_log(url: &str) -> String {
41 let Ok(mut parsed) = Url::parse(url) else {
42 return url.to_owned();
43 };
44 let _ = parsed.set_username("");
46 let _ = parsed.set_password(None);
47 let sensitive = [
49 "token", "key", "secret", "password", "auth", "sig", "api_key", "apikey",
50 ];
51 let filtered: Vec<(String, String)> = parsed
52 .query_pairs()
53 .filter(|(k, _)| {
54 let lower = k.to_lowercase();
55 !sensitive.iter().any(|s| lower.contains(s))
56 })
57 .map(|(k, v)| (k.into_owned(), v.into_owned()))
58 .collect();
59 if filtered.is_empty() {
60 parsed.set_query(None);
61 } else {
62 let q: String = filtered
63 .iter()
64 .map(|(k, v)| format!("{k}={v}"))
65 .collect::<Vec<_>>()
66 .join("&");
67 parsed.set_query(Some(&q));
68 }
69 parsed.to_string()
70}
71
72#[derive(Debug, Deserialize, JsonSchema)]
73struct FetchParams {
74 url: String,
76}
77
78#[derive(Debug, Deserialize, JsonSchema)]
79struct ScrapeInstruction {
80 url: String,
82 select: String,
84 #[serde(default = "default_extract")]
86 extract: String,
87 limit: Option<usize>,
89}
90
91fn default_extract() -> String {
92 "text".into()
93}
94
95#[derive(Debug)]
96enum ExtractMode {
97 Text,
98 Html,
99 Attr(String),
100}
101
102impl ExtractMode {
103 fn parse(s: &str) -> Self {
104 match s {
105 "text" => Self::Text,
106 "html" => Self::Html,
107 attr if attr.starts_with("attr:") => {
108 Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
109 }
110 _ => Self::Text,
111 }
112 }
113}
114
115#[derive(Debug)]
154pub struct WebScrapeExecutor {
155 timeout: Duration,
156 max_body_bytes: usize,
157 allowed_domains: Vec<String>,
158 denied_domains: Vec<String>,
159 audit_logger: Option<Arc<AuditLogger>>,
160 egress_config: EgressConfig,
161 egress_tx: Option<tokio::sync::mpsc::Sender<EgressEvent>>,
162 egress_dropped: Arc<AtomicU64>,
163}
164
165impl WebScrapeExecutor {
166 #[must_use]
170 pub fn new(config: &ScrapeConfig) -> Self {
171 Self {
172 timeout: Duration::from_secs(config.timeout),
173 max_body_bytes: config.max_body_bytes,
174 allowed_domains: config.allowed_domains.clone(),
175 denied_domains: config.denied_domains.clone(),
176 audit_logger: None,
177 egress_config: EgressConfig::default(),
178 egress_tx: None,
179 egress_dropped: Arc::new(AtomicU64::new(0)),
180 }
181 }
182
183 #[must_use]
185 pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
186 self.audit_logger = Some(logger);
187 self
188 }
189
190 #[must_use]
192 pub fn with_egress_config(mut self, config: EgressConfig) -> Self {
193 self.egress_config = config;
194 self
195 }
196
197 #[must_use]
202 pub fn with_egress_tx(
203 mut self,
204 tx: tokio::sync::mpsc::Sender<EgressEvent>,
205 dropped: Arc<AtomicU64>,
206 ) -> Self {
207 self.egress_tx = Some(tx);
208 self.egress_dropped = dropped;
209 self
210 }
211
212 #[must_use]
214 pub fn egress_dropped(&self) -> Arc<AtomicU64> {
215 Arc::clone(&self.egress_dropped)
216 }
217
218 fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
219 let mut builder = reqwest::Client::builder()
220 .timeout(self.timeout)
221 .redirect(reqwest::redirect::Policy::none());
222 builder = builder.resolve_to_addrs(host, addrs);
223 builder.build().unwrap_or_default()
224 }
225}
226
227impl ToolExecutor for WebScrapeExecutor {
228 fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
229 use crate::registry::{InvocationHint, ToolDef};
230 vec![
231 ToolDef {
232 id: "web_scrape".into(),
233 description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
234 schema: schemars::schema_for!(ScrapeInstruction),
235 invocation: InvocationHint::FencedBlock("scrape"),
236 output_schema: None,
237 },
238 ToolDef {
239 id: "fetch".into(),
240 description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
241 schema: schemars::schema_for!(FetchParams),
242 invocation: InvocationHint::ToolCall,
243 output_schema: None,
244 },
245 ]
246 }
247
248 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
249 let blocks = extract_scrape_blocks(response);
250 if blocks.is_empty() {
251 return Ok(None);
252 }
253
254 let mut outputs = Vec::with_capacity(blocks.len());
255 #[allow(clippy::cast_possible_truncation)]
256 let blocks_executed = blocks.len() as u32;
257
258 for block in &blocks {
259 let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
260 ToolError::Execution(std::io::Error::new(
261 std::io::ErrorKind::InvalidData,
262 e.to_string(),
263 ))
264 })?;
265 let correlation_id = EgressEvent::new_correlation_id();
266 let start = Instant::now();
267 let scrape_result = self
268 .scrape_instruction(&instruction, &correlation_id, None)
269 .await;
270 #[allow(clippy::cast_possible_truncation)]
271 let duration_ms = start.elapsed().as_millis() as u64;
272 match scrape_result {
273 Ok(output) => {
274 self.log_audit(
275 "web_scrape",
276 &instruction.url,
277 AuditResult::Success,
278 duration_ms,
279 None,
280 None,
281 Some(correlation_id),
282 )
283 .await;
284 outputs.push(output);
285 }
286 Err(e) => {
287 let audit_result = tool_error_to_audit_result(&e);
288 self.log_audit(
289 "web_scrape",
290 &instruction.url,
291 audit_result,
292 duration_ms,
293 Some(&e),
294 None,
295 Some(correlation_id),
296 )
297 .await;
298 return Err(e);
299 }
300 }
301 }
302
303 Ok(Some(ToolOutput {
304 tool_name: ToolName::new("web-scrape"),
305 summary: outputs.join("\n\n"),
306 blocks_executed,
307 filter_stats: None,
308 diff: None,
309 streamed: false,
310 terminal_id: None,
311 locations: None,
312 raw_response: None,
313 claim_source: Some(ClaimSource::WebScrape),
314 }))
315 }
316
317 #[cfg_attr(
318 feature = "profiling",
319 tracing::instrument(name = "tool.web_scrape", skip_all)
320 )]
321 #[allow(clippy::too_many_lines)]
322 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
323 match call.tool_id.as_str() {
324 "web_scrape" => {
325 let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
326 let correlation_id = EgressEvent::new_correlation_id();
327 let start = Instant::now();
328 let result = self
329 .scrape_instruction(&instruction, &correlation_id, call.caller_id.clone())
330 .await;
331 #[allow(clippy::cast_possible_truncation)]
332 let duration_ms = start.elapsed().as_millis() as u64;
333 match result {
334 Ok(output) => {
335 self.log_audit(
336 "web_scrape",
337 &instruction.url,
338 AuditResult::Success,
339 duration_ms,
340 None,
341 call.caller_id.clone(),
342 Some(correlation_id),
343 )
344 .await;
345 Ok(Some(ToolOutput {
346 tool_name: ToolName::new("web-scrape"),
347 summary: output,
348 blocks_executed: 1,
349 filter_stats: None,
350 diff: None,
351 streamed: false,
352 terminal_id: None,
353 locations: None,
354 raw_response: None,
355 claim_source: Some(ClaimSource::WebScrape),
356 }))
357 }
358 Err(e) => {
359 let audit_result = tool_error_to_audit_result(&e);
360 self.log_audit(
361 "web_scrape",
362 &instruction.url,
363 audit_result,
364 duration_ms,
365 Some(&e),
366 call.caller_id.clone(),
367 Some(correlation_id),
368 )
369 .await;
370 Err(e)
371 }
372 }
373 }
374 "fetch" => {
375 let p: FetchParams = deserialize_params(&call.params)?;
376 let correlation_id = EgressEvent::new_correlation_id();
377 let start = Instant::now();
378 let result = self
379 .handle_fetch(&p, &correlation_id, call.caller_id.clone())
380 .await;
381 #[allow(clippy::cast_possible_truncation)]
382 let duration_ms = start.elapsed().as_millis() as u64;
383 match result {
384 Ok(output) => {
385 self.log_audit(
386 "fetch",
387 &p.url,
388 AuditResult::Success,
389 duration_ms,
390 None,
391 call.caller_id.clone(),
392 Some(correlation_id),
393 )
394 .await;
395 Ok(Some(ToolOutput {
396 tool_name: ToolName::new("fetch"),
397 summary: output,
398 blocks_executed: 1,
399 filter_stats: None,
400 diff: None,
401 streamed: false,
402 terminal_id: None,
403 locations: None,
404 raw_response: None,
405 claim_source: Some(ClaimSource::WebScrape),
406 }))
407 }
408 Err(e) => {
409 let audit_result = tool_error_to_audit_result(&e);
410 self.log_audit(
411 "fetch",
412 &p.url,
413 audit_result,
414 duration_ms,
415 Some(&e),
416 call.caller_id.clone(),
417 Some(correlation_id),
418 )
419 .await;
420 Err(e)
421 }
422 }
423 }
424 _ => Ok(None),
425 }
426 }
427
428 fn is_tool_retryable(&self, tool_id: &str) -> bool {
429 matches!(tool_id, "web_scrape" | "fetch")
430 }
431}
432
433fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
434 match e {
435 ToolError::Blocked { command } => AuditResult::Blocked {
436 reason: command.clone(),
437 },
438 ToolError::Timeout { .. } => AuditResult::Timeout,
439 _ => AuditResult::Error {
440 message: e.to_string(),
441 },
442 }
443}
444
445impl WebScrapeExecutor {
446 #[allow(clippy::too_many_arguments)]
447 async fn log_audit(
448 &self,
449 tool: &str,
450 command: &str,
451 result: AuditResult,
452 duration_ms: u64,
453 error: Option<&ToolError>,
454 caller_id: Option<String>,
455 correlation_id: Option<String>,
456 ) {
457 if let Some(ref logger) = self.audit_logger {
458 let (error_category, error_domain, error_phase) =
459 error.map_or((None, None, None), |e| {
460 let cat = e.category();
461 (
462 Some(cat.label().to_owned()),
463 Some(cat.domain().label().to_owned()),
464 Some(cat.phase().label().to_owned()),
465 )
466 });
467 let entry = AuditEntry {
468 timestamp: chrono_now(),
469 tool: tool.into(),
470 command: command.into(),
471 result,
472 duration_ms,
473 error_category,
474 error_domain,
475 error_phase,
476 claim_source: Some(ClaimSource::WebScrape),
477 mcp_server_id: None,
478 injection_flagged: false,
479 embedding_anomalous: false,
480 cross_boundary_mcp_to_acp: false,
481 adversarial_policy_decision: None,
482 exit_code: None,
483 truncated: false,
484 caller_id,
485 policy_match: None,
486 correlation_id,
487 vigil_risk: None,
488 };
489 logger.log(&entry).await;
490 }
491 }
492
493 fn send_egress_event(&self, event: EgressEvent) {
494 if let Some(ref tx) = self.egress_tx {
495 match tx.try_send(event) {
496 Ok(()) => {}
497 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
498 self.egress_dropped.fetch_add(1, Ordering::Relaxed);
499 }
500 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
501 tracing::debug!("egress channel closed; executor continuing without telemetry");
502 }
503 }
504 }
505 }
506
507 async fn log_egress_event(&self, event: &EgressEvent) {
508 if let Some(ref logger) = self.audit_logger {
509 logger.log_egress(event).await;
510 }
511 self.send_egress_event(event.clone());
512 }
513
514 async fn handle_fetch(
515 &self,
516 params: &FetchParams,
517 correlation_id: &str,
518 caller_id: Option<String>,
519 ) -> Result<String, ToolError> {
520 let parsed = validate_url(¶ms.url);
521 let host_str = parsed
522 .as_ref()
523 .map(|u| u.host_str().unwrap_or("").to_owned())
524 .unwrap_or_default();
525
526 if let Err(ref _e) = parsed {
527 if self.egress_config.enabled && self.egress_config.log_blocked {
528 let event = Self::make_blocked_event(
529 "fetch",
530 ¶ms.url,
531 &host_str,
532 correlation_id,
533 caller_id.clone(),
534 "scheme",
535 );
536 self.log_egress_event(&event).await;
537 }
538 return Err(parsed.unwrap_err());
539 }
540 let parsed = parsed.unwrap();
541
542 if let Err(e) = check_domain_policy(
543 parsed.host_str().unwrap_or(""),
544 &self.allowed_domains,
545 &self.denied_domains,
546 ) {
547 if self.egress_config.enabled && self.egress_config.log_blocked {
548 let event = Self::make_blocked_event(
549 "fetch",
550 ¶ms.url,
551 parsed.host_str().unwrap_or(""),
552 correlation_id,
553 caller_id.clone(),
554 "blocklist",
555 );
556 self.log_egress_event(&event).await;
557 }
558 return Err(e);
559 }
560
561 let (host, addrs) = match resolve_and_validate(&parsed).await {
562 Ok(v) => v,
563 Err(e) => {
564 if self.egress_config.enabled && self.egress_config.log_blocked {
565 let event = Self::make_blocked_event(
566 "fetch",
567 ¶ms.url,
568 parsed.host_str().unwrap_or(""),
569 correlation_id,
570 caller_id.clone(),
571 "ssrf",
572 );
573 self.log_egress_event(&event).await;
574 }
575 return Err(e);
576 }
577 };
578
579 self.fetch_html(
580 ¶ms.url,
581 &host,
582 &addrs,
583 "fetch",
584 correlation_id,
585 caller_id,
586 )
587 .await
588 }
589
590 async fn scrape_instruction(
591 &self,
592 instruction: &ScrapeInstruction,
593 correlation_id: &str,
594 caller_id: Option<String>,
595 ) -> Result<String, ToolError> {
596 let parsed = validate_url(&instruction.url);
597 let host_str = parsed
598 .as_ref()
599 .map(|u| u.host_str().unwrap_or("").to_owned())
600 .unwrap_or_default();
601
602 if let Err(ref _e) = parsed {
603 if self.egress_config.enabled && self.egress_config.log_blocked {
604 let event = Self::make_blocked_event(
605 "web_scrape",
606 &instruction.url,
607 &host_str,
608 correlation_id,
609 caller_id.clone(),
610 "scheme",
611 );
612 self.log_egress_event(&event).await;
613 }
614 return Err(parsed.unwrap_err());
615 }
616 let parsed = parsed.unwrap();
617
618 if let Err(e) = check_domain_policy(
619 parsed.host_str().unwrap_or(""),
620 &self.allowed_domains,
621 &self.denied_domains,
622 ) {
623 if self.egress_config.enabled && self.egress_config.log_blocked {
624 let event = Self::make_blocked_event(
625 "web_scrape",
626 &instruction.url,
627 parsed.host_str().unwrap_or(""),
628 correlation_id,
629 caller_id.clone(),
630 "blocklist",
631 );
632 self.log_egress_event(&event).await;
633 }
634 return Err(e);
635 }
636
637 let (host, addrs) = match resolve_and_validate(&parsed).await {
638 Ok(v) => v,
639 Err(e) => {
640 if self.egress_config.enabled && self.egress_config.log_blocked {
641 let event = Self::make_blocked_event(
642 "web_scrape",
643 &instruction.url,
644 parsed.host_str().unwrap_or(""),
645 correlation_id,
646 caller_id.clone(),
647 "ssrf",
648 );
649 self.log_egress_event(&event).await;
650 }
651 return Err(e);
652 }
653 };
654
655 let html = self
656 .fetch_html(
657 &instruction.url,
658 &host,
659 &addrs,
660 "web_scrape",
661 correlation_id,
662 caller_id,
663 )
664 .await?;
665 let selector = instruction.select.clone();
666 let extract = ExtractMode::parse(&instruction.extract);
667 let limit = instruction.limit.unwrap_or(10);
668 tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
669 .await
670 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
671 }
672
673 fn make_blocked_event(
674 tool: &str,
675 url: &str,
676 host: &str,
677 correlation_id: &str,
678 caller_id: Option<String>,
679 block_reason: &'static str,
680 ) -> EgressEvent {
681 EgressEvent {
682 timestamp: chrono_now(),
683 kind: "egress",
684 correlation_id: correlation_id.to_owned(),
685 tool: tool.into(),
686 url: redact_url_for_log(url),
687 host: host.to_owned(),
688 method: "GET".to_owned(),
689 status: None,
690 duration_ms: 0,
691 response_bytes: 0,
692 blocked: true,
693 block_reason: Some(block_reason),
694 caller_id,
695 hop: 0,
696 }
697 }
698
699 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
710 async fn fetch_html(
711 &self,
712 url: &str,
713 host: &str,
714 addrs: &[SocketAddr],
715 tool: &str,
716 correlation_id: &str,
717 caller_id: Option<String>,
718 ) -> Result<String, ToolError> {
719 const MAX_REDIRECTS: usize = 3;
720
721 let mut current_url = url.to_owned();
722 let mut current_host = host.to_owned();
723 let mut current_addrs = addrs.to_vec();
724
725 for hop in 0..=MAX_REDIRECTS {
726 let hop_start = Instant::now();
727 let client = self.build_client(¤t_host, ¤t_addrs);
729 let resp = client
730 .get(¤t_url)
731 .send()
732 .await
733 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
734
735 let resp = match resp {
736 Ok(r) => r,
737 Err(e) => {
738 if self.egress_config.enabled {
739 #[allow(clippy::cast_possible_truncation)]
740 let duration_ms = hop_start.elapsed().as_millis() as u64;
741 let event = EgressEvent {
742 timestamp: chrono_now(),
743 kind: "egress",
744 correlation_id: correlation_id.to_owned(),
745 tool: tool.into(),
746 url: redact_url_for_log(¤t_url),
747 host: current_host.clone(),
748 method: "GET".to_owned(),
749 status: None,
750 duration_ms,
751 response_bytes: 0,
752 blocked: false,
753 block_reason: None,
754 caller_id: caller_id.clone(),
755 #[allow(clippy::cast_possible_truncation)]
756 hop: hop as u8,
757 };
758 self.log_egress_event(&event).await;
759 }
760 return Err(e);
761 }
762 };
763
764 let status = resp.status();
765
766 if status.is_redirection() {
767 if hop == MAX_REDIRECTS {
768 return Err(ToolError::Execution(std::io::Error::other(
769 "too many redirects",
770 )));
771 }
772
773 let location = resp
774 .headers()
775 .get(reqwest::header::LOCATION)
776 .and_then(|v| v.to_str().ok())
777 .ok_or_else(|| {
778 ToolError::Execution(std::io::Error::other("redirect with no Location"))
779 })?;
780
781 let base = Url::parse(¤t_url)
783 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
784 let next_url = base
785 .join(location)
786 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
787
788 let validated = validate_url(next_url.as_str());
789 if let Err(ref _e) = validated {
790 if self.egress_config.enabled && self.egress_config.log_blocked {
791 #[allow(clippy::cast_possible_truncation)]
792 let duration_ms = hop_start.elapsed().as_millis() as u64;
793 let next_host = next_url.host_str().unwrap_or("").to_owned();
794 let event = EgressEvent {
795 timestamp: chrono_now(),
796 kind: "egress",
797 correlation_id: correlation_id.to_owned(),
798 tool: tool.into(),
799 url: redact_url_for_log(next_url.as_str()),
800 host: next_host,
801 method: "GET".to_owned(),
802 status: None,
803 duration_ms,
804 response_bytes: 0,
805 blocked: true,
806 block_reason: Some("ssrf"),
807 caller_id: caller_id.clone(),
808 #[allow(clippy::cast_possible_truncation)]
809 hop: (hop + 1) as u8,
810 };
811 self.log_egress_event(&event).await;
812 }
813 return Err(validated.unwrap_err());
814 }
815 let validated = validated.unwrap();
816 let resolve_result = resolve_and_validate(&validated).await;
817 if let Err(ref _e) = resolve_result {
818 if self.egress_config.enabled && self.egress_config.log_blocked {
819 #[allow(clippy::cast_possible_truncation)]
820 let duration_ms = hop_start.elapsed().as_millis() as u64;
821 let next_host = next_url.host_str().unwrap_or("").to_owned();
822 let event = EgressEvent {
823 timestamp: chrono_now(),
824 kind: "egress",
825 correlation_id: correlation_id.to_owned(),
826 tool: tool.into(),
827 url: redact_url_for_log(next_url.as_str()),
828 host: next_host,
829 method: "GET".to_owned(),
830 status: None,
831 duration_ms,
832 response_bytes: 0,
833 blocked: true,
834 block_reason: Some("ssrf"),
835 caller_id: caller_id.clone(),
836 #[allow(clippy::cast_possible_truncation)]
837 hop: (hop + 1) as u8,
838 };
839 self.log_egress_event(&event).await;
840 }
841 return Err(resolve_result.unwrap_err());
842 }
843 let (next_host, next_addrs) = resolve_result.unwrap();
844
845 current_url = next_url.to_string();
846 current_host = next_host;
847 current_addrs = next_addrs;
848 continue;
849 }
850
851 if !status.is_success() {
852 if self.egress_config.enabled {
853 #[allow(clippy::cast_possible_truncation)]
854 let duration_ms = hop_start.elapsed().as_millis() as u64;
855 let event = EgressEvent {
856 timestamp: chrono_now(),
857 kind: "egress",
858 correlation_id: correlation_id.to_owned(),
859 tool: tool.into(),
860 url: current_url.clone(),
861 host: current_host.clone(),
862 method: "GET".to_owned(),
863 status: Some(status.as_u16()),
864 duration_ms,
865 response_bytes: 0,
866 blocked: false,
867 block_reason: None,
868 caller_id: caller_id.clone(),
869 #[allow(clippy::cast_possible_truncation)]
870 hop: hop as u8,
871 };
872 self.log_egress_event(&event).await;
873 }
874 return Err(ToolError::Http {
875 status: status.as_u16(),
876 message: status.canonical_reason().unwrap_or("unknown").to_owned(),
877 });
878 }
879
880 let bytes = resp
881 .bytes()
882 .await
883 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
884
885 if bytes.len() > self.max_body_bytes {
886 if self.egress_config.enabled {
887 #[allow(clippy::cast_possible_truncation)]
888 let duration_ms = hop_start.elapsed().as_millis() as u64;
889 let event = EgressEvent {
890 timestamp: chrono_now(),
891 kind: "egress",
892 correlation_id: correlation_id.to_owned(),
893 tool: tool.into(),
894 url: current_url.clone(),
895 host: current_host.clone(),
896 method: "GET".to_owned(),
897 status: Some(status.as_u16()),
898 duration_ms,
899 response_bytes: bytes.len(),
900 blocked: false,
901 block_reason: None,
902 caller_id: caller_id.clone(),
903 #[allow(clippy::cast_possible_truncation)]
904 hop: hop as u8,
905 };
906 self.log_egress_event(&event).await;
907 }
908 return Err(ToolError::Execution(std::io::Error::other(format!(
909 "response too large: {} bytes (max: {})",
910 bytes.len(),
911 self.max_body_bytes,
912 ))));
913 }
914
915 if self.egress_config.enabled {
917 #[allow(clippy::cast_possible_truncation)]
918 let duration_ms = hop_start.elapsed().as_millis() as u64;
919 let response_bytes = if self.egress_config.log_response_bytes {
920 bytes.len()
921 } else {
922 0
923 };
924 let event = EgressEvent {
925 timestamp: chrono_now(),
926 kind: "egress",
927 correlation_id: correlation_id.to_owned(),
928 tool: tool.into(),
929 url: current_url.clone(),
930 host: current_host.clone(),
931 method: "GET".to_owned(),
932 status: Some(status.as_u16()),
933 duration_ms,
934 response_bytes,
935 blocked: false,
936 block_reason: None,
937 caller_id: caller_id.clone(),
938 #[allow(clippy::cast_possible_truncation)]
939 hop: hop as u8,
940 };
941 self.log_egress_event(&event).await;
942 }
943
944 return String::from_utf8(bytes.to_vec())
945 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
946 }
947
948 Err(ToolError::Execution(std::io::Error::other(
949 "too many redirects",
950 )))
951 }
952}
953
954fn extract_scrape_blocks(text: &str) -> Vec<&str> {
955 crate::executor::extract_fenced_blocks(text, "scrape")
956}
957
958fn check_domain_policy(
970 host: &str,
971 allowed_domains: &[String],
972 denied_domains: &[String],
973) -> Result<(), ToolError> {
974 if denied_domains.iter().any(|p| domain_matches(p, host)) {
975 return Err(ToolError::Blocked {
976 command: format!("domain blocked by denylist: {host}"),
977 });
978 }
979 if !allowed_domains.is_empty() {
980 let is_ip = host.parse::<std::net::IpAddr>().is_ok()
982 || (host.starts_with('[') && host.ends_with(']'));
983 if is_ip {
984 return Err(ToolError::Blocked {
985 command: format!(
986 "bare IP address not allowed when domain allowlist is active: {host}"
987 ),
988 });
989 }
990 if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
991 return Err(ToolError::Blocked {
992 command: format!("domain not in allowlist: {host}"),
993 });
994 }
995 }
996 Ok(())
997}
998
999fn domain_matches(pattern: &str, host: &str) -> bool {
1005 if pattern.starts_with("*.") {
1006 let suffix = &pattern[1..]; if let Some(remainder) = host.strip_suffix(suffix) {
1009 !remainder.is_empty() && !remainder.contains('.')
1011 } else {
1012 false
1013 }
1014 } else {
1015 pattern == host
1016 }
1017}
1018
1019fn validate_url(raw: &str) -> Result<Url, ToolError> {
1020 let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
1021 command: format!("invalid URL: {raw}"),
1022 })?;
1023
1024 if parsed.scheme() != "https" {
1025 return Err(ToolError::Blocked {
1026 command: format!("scheme not allowed: {}", parsed.scheme()),
1027 });
1028 }
1029
1030 if let Some(host) = parsed.host()
1031 && is_private_host(&host)
1032 {
1033 return Err(ToolError::Blocked {
1034 command: format!(
1035 "private/local host blocked: {}",
1036 parsed.host_str().unwrap_or("")
1037 ),
1038 });
1039 }
1040
1041 Ok(parsed)
1042}
1043
1044fn is_private_host(host: &url::Host<&str>) -> bool {
1045 match host {
1046 url::Host::Domain(d) => {
1047 #[allow(clippy::case_sensitive_file_extension_comparisons)]
1050 {
1051 *d == "localhost"
1052 || d.ends_with(".localhost")
1053 || d.ends_with(".internal")
1054 || d.ends_with(".local")
1055 }
1056 }
1057 url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
1058 url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
1059 }
1060}
1061
1062async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
1068 let Some(host) = url.host_str() else {
1069 return Ok((String::new(), vec![]));
1070 };
1071 let port = url.port_or_known_default().unwrap_or(443);
1072 let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
1073 .await
1074 .map_err(|e| ToolError::Blocked {
1075 command: format!("DNS resolution failed: {e}"),
1076 })?
1077 .collect();
1078 for addr in &addrs {
1079 if is_private_ip(addr.ip()) {
1080 return Err(ToolError::Blocked {
1081 command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
1082 });
1083 }
1084 }
1085 Ok((host.to_owned(), addrs))
1086}
1087
1088fn parse_and_extract(
1089 html: &str,
1090 selector: &str,
1091 extract: &ExtractMode,
1092 limit: usize,
1093) -> Result<String, ToolError> {
1094 let soup = scrape_core::Soup::parse(html);
1095
1096 let tags = soup.find_all(selector).map_err(|e| {
1097 ToolError::Execution(std::io::Error::new(
1098 std::io::ErrorKind::InvalidData,
1099 format!("invalid selector: {e}"),
1100 ))
1101 })?;
1102
1103 let mut results = Vec::new();
1104
1105 for tag in tags.into_iter().take(limit) {
1106 let value = match extract {
1107 ExtractMode::Text => tag.text(),
1108 ExtractMode::Html => tag.inner_html(),
1109 ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
1110 };
1111 if !value.trim().is_empty() {
1112 results.push(value.trim().to_owned());
1113 }
1114 }
1115
1116 if results.is_empty() {
1117 Ok(format!("No results for selector: {selector}"))
1118 } else {
1119 Ok(results.join("\n"))
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use super::*;
1126
1127 #[test]
1130 fn extract_single_block() {
1131 let text =
1132 "Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
1133 let blocks = extract_scrape_blocks(text);
1134 assert_eq!(blocks.len(), 1);
1135 assert!(blocks[0].contains("example.com"));
1136 }
1137
1138 #[test]
1139 fn extract_multiple_blocks() {
1140 let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
1141 let blocks = extract_scrape_blocks(text);
1142 assert_eq!(blocks.len(), 2);
1143 }
1144
1145 #[test]
1146 fn no_blocks_returns_empty() {
1147 let blocks = extract_scrape_blocks("plain text, no code blocks");
1148 assert!(blocks.is_empty());
1149 }
1150
1151 #[test]
1152 fn unclosed_block_ignored() {
1153 let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
1154 assert!(blocks.is_empty());
1155 }
1156
1157 #[test]
1158 fn non_scrape_block_ignored() {
1159 let text =
1160 "```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
1161 let blocks = extract_scrape_blocks(text);
1162 assert_eq!(blocks.len(), 1);
1163 assert!(blocks[0].contains("x.com"));
1164 }
1165
1166 #[test]
1167 fn multiline_json_block() {
1168 let text =
1169 "```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
1170 let blocks = extract_scrape_blocks(text);
1171 assert_eq!(blocks.len(), 1);
1172 let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
1173 assert_eq!(instr.url, "https://example.com");
1174 }
1175
1176 #[test]
1179 fn parse_valid_instruction() {
1180 let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
1181 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1182 assert_eq!(instr.url, "https://example.com");
1183 assert_eq!(instr.select, "h1");
1184 assert_eq!(instr.extract, "text");
1185 assert_eq!(instr.limit, Some(5));
1186 }
1187
1188 #[test]
1189 fn parse_minimal_instruction() {
1190 let json = r#"{"url":"https://example.com","select":"p"}"#;
1191 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1192 assert_eq!(instr.extract, "text");
1193 assert!(instr.limit.is_none());
1194 }
1195
1196 #[test]
1197 fn parse_attr_extract() {
1198 let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
1199 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1200 assert_eq!(instr.extract, "attr:href");
1201 }
1202
1203 #[test]
1204 fn parse_invalid_json_errors() {
1205 let result = serde_json::from_str::<ScrapeInstruction>("not json");
1206 assert!(result.is_err());
1207 }
1208
1209 #[test]
1212 fn extract_mode_text() {
1213 assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
1214 }
1215
1216 #[test]
1217 fn extract_mode_html() {
1218 assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
1219 }
1220
1221 #[test]
1222 fn extract_mode_attr() {
1223 let mode = ExtractMode::parse("attr:href");
1224 assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
1225 }
1226
1227 #[test]
1228 fn extract_mode_unknown_defaults_to_text() {
1229 assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
1230 }
1231
1232 #[test]
1235 fn valid_https_url() {
1236 assert!(validate_url("https://example.com").is_ok());
1237 }
1238
1239 #[test]
1240 fn http_rejected() {
1241 let err = validate_url("http://example.com").unwrap_err();
1242 assert!(matches!(err, ToolError::Blocked { .. }));
1243 }
1244
1245 #[test]
1246 fn ftp_rejected() {
1247 let err = validate_url("ftp://files.example.com").unwrap_err();
1248 assert!(matches!(err, ToolError::Blocked { .. }));
1249 }
1250
1251 #[test]
1252 fn file_rejected() {
1253 let err = validate_url("file:///etc/passwd").unwrap_err();
1254 assert!(matches!(err, ToolError::Blocked { .. }));
1255 }
1256
1257 #[test]
1258 fn invalid_url_rejected() {
1259 let err = validate_url("not a url").unwrap_err();
1260 assert!(matches!(err, ToolError::Blocked { .. }));
1261 }
1262
1263 #[test]
1264 fn localhost_blocked() {
1265 let err = validate_url("https://localhost/path").unwrap_err();
1266 assert!(matches!(err, ToolError::Blocked { .. }));
1267 }
1268
1269 #[test]
1270 fn loopback_ip_blocked() {
1271 let err = validate_url("https://127.0.0.1/path").unwrap_err();
1272 assert!(matches!(err, ToolError::Blocked { .. }));
1273 }
1274
1275 #[test]
1276 fn private_10_blocked() {
1277 let err = validate_url("https://10.0.0.1/api").unwrap_err();
1278 assert!(matches!(err, ToolError::Blocked { .. }));
1279 }
1280
1281 #[test]
1282 fn private_172_blocked() {
1283 let err = validate_url("https://172.16.0.1/api").unwrap_err();
1284 assert!(matches!(err, ToolError::Blocked { .. }));
1285 }
1286
1287 #[test]
1288 fn private_192_blocked() {
1289 let err = validate_url("https://192.168.1.1/api").unwrap_err();
1290 assert!(matches!(err, ToolError::Blocked { .. }));
1291 }
1292
1293 #[test]
1294 fn ipv6_loopback_blocked() {
1295 let err = validate_url("https://[::1]/path").unwrap_err();
1296 assert!(matches!(err, ToolError::Blocked { .. }));
1297 }
1298
1299 #[test]
1300 fn public_ip_allowed() {
1301 assert!(validate_url("https://93.184.216.34/page").is_ok());
1302 }
1303
1304 #[test]
1307 fn extract_text_from_html() {
1308 let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
1309 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1310 assert_eq!(result, "Hello World");
1311 }
1312
1313 #[test]
1314 fn extract_multiple_elements() {
1315 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1316 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1317 assert_eq!(result, "A\nB\nC");
1318 }
1319
1320 #[test]
1321 fn extract_with_limit() {
1322 let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
1323 let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
1324 assert_eq!(result, "A\nB");
1325 }
1326
1327 #[test]
1328 fn extract_attr_href() {
1329 let html = r#"<a href="https://example.com">Link</a>"#;
1330 let result =
1331 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1332 assert_eq!(result, "https://example.com");
1333 }
1334
1335 #[test]
1336 fn extract_inner_html() {
1337 let html = "<div><span>inner</span></div>";
1338 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1339 assert!(result.contains("<span>inner</span>"));
1340 }
1341
1342 #[test]
1343 fn no_matches_returns_message() {
1344 let html = "<html><body><p>text</p></body></html>";
1345 let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
1346 assert!(result.starts_with("No results for selector:"));
1347 }
1348
1349 #[test]
1350 fn empty_text_skipped() {
1351 let html = "<ul><li> </li><li>A</li></ul>";
1352 let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
1353 assert_eq!(result, "A");
1354 }
1355
1356 #[test]
1357 fn invalid_selector_errors() {
1358 let html = "<html><body></body></html>";
1359 let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
1360 assert!(result.is_err());
1361 }
1362
1363 #[test]
1364 fn empty_html_returns_no_results() {
1365 let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
1366 assert!(result.starts_with("No results for selector:"));
1367 }
1368
1369 #[test]
1370 fn nested_selector() {
1371 let html = "<div><span>inner</span></div><span>outer</span>";
1372 let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
1373 assert_eq!(result, "inner");
1374 }
1375
1376 #[test]
1377 fn attr_missing_returns_empty() {
1378 let html = r"<a>No href</a>";
1379 let result =
1380 parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
1381 assert!(result.starts_with("No results for selector:"));
1382 }
1383
1384 #[test]
1385 fn extract_html_mode() {
1386 let html = "<div><b>bold</b> text</div>";
1387 let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
1388 assert!(result.contains("<b>bold</b>"));
1389 }
1390
1391 #[test]
1392 fn limit_zero_returns_no_results() {
1393 let html = "<ul><li>A</li><li>B</li></ul>";
1394 let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
1395 assert!(result.starts_with("No results for selector:"));
1396 }
1397
1398 #[test]
1401 fn url_with_port_allowed() {
1402 assert!(validate_url("https://example.com:8443/path").is_ok());
1403 }
1404
1405 #[test]
1406 fn link_local_ip_blocked() {
1407 let err = validate_url("https://169.254.1.1/path").unwrap_err();
1408 assert!(matches!(err, ToolError::Blocked { .. }));
1409 }
1410
1411 #[test]
1412 fn url_no_scheme_rejected() {
1413 let err = validate_url("example.com/path").unwrap_err();
1414 assert!(matches!(err, ToolError::Blocked { .. }));
1415 }
1416
1417 #[test]
1418 fn unspecified_ipv4_blocked() {
1419 let err = validate_url("https://0.0.0.0/path").unwrap_err();
1420 assert!(matches!(err, ToolError::Blocked { .. }));
1421 }
1422
1423 #[test]
1424 fn broadcast_ipv4_blocked() {
1425 let err = validate_url("https://255.255.255.255/path").unwrap_err();
1426 assert!(matches!(err, ToolError::Blocked { .. }));
1427 }
1428
1429 #[test]
1430 fn ipv6_link_local_blocked() {
1431 let err = validate_url("https://[fe80::1]/path").unwrap_err();
1432 assert!(matches!(err, ToolError::Blocked { .. }));
1433 }
1434
1435 #[test]
1436 fn ipv6_unique_local_blocked() {
1437 let err = validate_url("https://[fd12::1]/path").unwrap_err();
1438 assert!(matches!(err, ToolError::Blocked { .. }));
1439 }
1440
1441 #[test]
1442 fn ipv4_mapped_ipv6_loopback_blocked() {
1443 let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
1444 assert!(matches!(err, ToolError::Blocked { .. }));
1445 }
1446
1447 #[test]
1448 fn ipv4_mapped_ipv6_private_blocked() {
1449 let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
1450 assert!(matches!(err, ToolError::Blocked { .. }));
1451 }
1452
1453 #[tokio::test]
1456 async fn executor_no_blocks_returns_none() {
1457 let config = ScrapeConfig::default();
1458 let executor = WebScrapeExecutor::new(&config);
1459 let result = executor.execute("plain text").await;
1460 assert!(result.unwrap().is_none());
1461 }
1462
1463 #[tokio::test]
1464 async fn executor_invalid_json_errors() {
1465 let config = ScrapeConfig::default();
1466 let executor = WebScrapeExecutor::new(&config);
1467 let response = "```scrape\nnot json\n```";
1468 let result = executor.execute(response).await;
1469 assert!(matches!(result, Err(ToolError::Execution(_))));
1470 }
1471
1472 #[tokio::test]
1473 async fn executor_blocked_url_errors() {
1474 let config = ScrapeConfig::default();
1475 let executor = WebScrapeExecutor::new(&config);
1476 let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
1477 let result = executor.execute(response).await;
1478 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1479 }
1480
1481 #[tokio::test]
1482 async fn executor_private_ip_blocked() {
1483 let config = ScrapeConfig::default();
1484 let executor = WebScrapeExecutor::new(&config);
1485 let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
1486 let result = executor.execute(response).await;
1487 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1488 }
1489
1490 #[tokio::test]
1491 async fn executor_unreachable_host_returns_error() {
1492 let config = ScrapeConfig {
1493 timeout: 1,
1494 max_body_bytes: 1_048_576,
1495 ..Default::default()
1496 };
1497 let executor = WebScrapeExecutor::new(&config);
1498 let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
1499 let result = executor.execute(response).await;
1500 assert!(matches!(result, Err(ToolError::Execution(_))));
1501 }
1502
1503 #[tokio::test]
1504 async fn executor_localhost_url_blocked() {
1505 let config = ScrapeConfig::default();
1506 let executor = WebScrapeExecutor::new(&config);
1507 let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
1508 let result = executor.execute(response).await;
1509 assert!(matches!(result, Err(ToolError::Blocked { .. })));
1510 }
1511
1512 #[tokio::test]
1513 async fn executor_empty_text_returns_none() {
1514 let config = ScrapeConfig::default();
1515 let executor = WebScrapeExecutor::new(&config);
1516 let result = executor.execute("").await;
1517 assert!(result.unwrap().is_none());
1518 }
1519
1520 #[tokio::test]
1521 async fn executor_multiple_blocks_first_blocked() {
1522 let config = ScrapeConfig::default();
1523 let executor = WebScrapeExecutor::new(&config);
1524 let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
1525 ```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
1526 let result = executor.execute(response).await;
1527 assert!(result.is_err());
1528 }
1529
1530 #[test]
1531 fn validate_url_empty_string() {
1532 let err = validate_url("").unwrap_err();
1533 assert!(matches!(err, ToolError::Blocked { .. }));
1534 }
1535
1536 #[test]
1537 fn validate_url_javascript_scheme_blocked() {
1538 let err = validate_url("javascript:alert(1)").unwrap_err();
1539 assert!(matches!(err, ToolError::Blocked { .. }));
1540 }
1541
1542 #[test]
1543 fn validate_url_data_scheme_blocked() {
1544 let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
1545 assert!(matches!(err, ToolError::Blocked { .. }));
1546 }
1547
1548 #[test]
1549 fn is_private_host_public_domain_is_false() {
1550 let host: url::Host<&str> = url::Host::Domain("example.com");
1551 assert!(!is_private_host(&host));
1552 }
1553
1554 #[test]
1555 fn is_private_host_localhost_is_true() {
1556 let host: url::Host<&str> = url::Host::Domain("localhost");
1557 assert!(is_private_host(&host));
1558 }
1559
1560 #[test]
1561 fn is_private_host_ipv6_unspecified_is_true() {
1562 let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
1563 assert!(is_private_host(&host));
1564 }
1565
1566 #[test]
1567 fn is_private_host_public_ipv6_is_false() {
1568 let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
1569 assert!(!is_private_host(&host));
1570 }
1571
1572 async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
1583 let server = wiremock::MockServer::start().await;
1584 let executor = WebScrapeExecutor {
1585 timeout: Duration::from_secs(5),
1586 max_body_bytes: 1_048_576,
1587 allowed_domains: vec![],
1588 denied_domains: vec![],
1589 audit_logger: None,
1590 egress_config: EgressConfig::default(),
1591 egress_tx: None,
1592 egress_dropped: Arc::new(AtomicU64::new(0)),
1593 };
1594 (executor, server)
1595 }
1596
1597 fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
1599 let uri = server.uri();
1600 let url = Url::parse(&uri).unwrap();
1601 let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
1602 let port = url.port().unwrap_or(80);
1603 let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
1604 (host, vec![addr])
1605 }
1606
1607 async fn follow_redirects_raw(
1611 executor: &WebScrapeExecutor,
1612 start_url: &str,
1613 host: &str,
1614 addrs: &[std::net::SocketAddr],
1615 ) -> Result<String, ToolError> {
1616 const MAX_REDIRECTS: usize = 3;
1617 let mut current_url = start_url.to_owned();
1618 let mut current_host = host.to_owned();
1619 let mut current_addrs = addrs.to_vec();
1620
1621 for hop in 0..=MAX_REDIRECTS {
1622 let client = executor.build_client(¤t_host, ¤t_addrs);
1623 let resp = client
1624 .get(¤t_url)
1625 .send()
1626 .await
1627 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1628
1629 let status = resp.status();
1630
1631 if status.is_redirection() {
1632 if hop == MAX_REDIRECTS {
1633 return Err(ToolError::Execution(std::io::Error::other(
1634 "too many redirects",
1635 )));
1636 }
1637
1638 let location = resp
1639 .headers()
1640 .get(reqwest::header::LOCATION)
1641 .and_then(|v| v.to_str().ok())
1642 .ok_or_else(|| {
1643 ToolError::Execution(std::io::Error::other("redirect with no Location"))
1644 })?;
1645
1646 let base = Url::parse(¤t_url)
1647 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1648 let next_url = base
1649 .join(location)
1650 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1651
1652 current_url = next_url.to_string();
1654 let _ = &mut current_host;
1656 let _ = &mut current_addrs;
1657 continue;
1658 }
1659
1660 if !status.is_success() {
1661 return Err(ToolError::Execution(std::io::Error::other(format!(
1662 "HTTP {status}",
1663 ))));
1664 }
1665
1666 let bytes = resp
1667 .bytes()
1668 .await
1669 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
1670
1671 if bytes.len() > executor.max_body_bytes {
1672 return Err(ToolError::Execution(std::io::Error::other(format!(
1673 "response too large: {} bytes (max: {})",
1674 bytes.len(),
1675 executor.max_body_bytes,
1676 ))));
1677 }
1678
1679 return String::from_utf8(bytes.to_vec())
1680 .map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
1681 }
1682
1683 Err(ToolError::Execution(std::io::Error::other(
1684 "too many redirects",
1685 )))
1686 }
1687
1688 #[tokio::test]
1689 async fn fetch_html_success_returns_body() {
1690 use wiremock::matchers::{method, path};
1691 use wiremock::{Mock, ResponseTemplate};
1692
1693 let (executor, server) = mock_server_executor().await;
1694 Mock::given(method("GET"))
1695 .and(path("/page"))
1696 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
1697 .mount(&server)
1698 .await;
1699
1700 let (host, addrs) = server_host_and_addr(&server);
1701 let url = format!("{}/page", server.uri());
1702 let result = executor
1703 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1704 .await;
1705 assert!(result.is_ok(), "expected Ok, got: {result:?}");
1706 assert_eq!(result.unwrap(), "<h1>OK</h1>");
1707 }
1708
1709 #[tokio::test]
1710 async fn fetch_html_non_2xx_returns_error() {
1711 use wiremock::matchers::{method, path};
1712 use wiremock::{Mock, ResponseTemplate};
1713
1714 let (executor, server) = mock_server_executor().await;
1715 Mock::given(method("GET"))
1716 .and(path("/forbidden"))
1717 .respond_with(ResponseTemplate::new(403))
1718 .mount(&server)
1719 .await;
1720
1721 let (host, addrs) = server_host_and_addr(&server);
1722 let url = format!("{}/forbidden", server.uri());
1723 let result = executor
1724 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1725 .await;
1726 assert!(result.is_err());
1727 let msg = result.unwrap_err().to_string();
1728 assert!(msg.contains("403"), "expected 403 in error: {msg}");
1729 }
1730
1731 #[tokio::test]
1732 async fn fetch_html_404_returns_error() {
1733 use wiremock::matchers::{method, path};
1734 use wiremock::{Mock, ResponseTemplate};
1735
1736 let (executor, server) = mock_server_executor().await;
1737 Mock::given(method("GET"))
1738 .and(path("/missing"))
1739 .respond_with(ResponseTemplate::new(404))
1740 .mount(&server)
1741 .await;
1742
1743 let (host, addrs) = server_host_and_addr(&server);
1744 let url = format!("{}/missing", server.uri());
1745 let result = executor
1746 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1747 .await;
1748 assert!(result.is_err());
1749 let msg = result.unwrap_err().to_string();
1750 assert!(msg.contains("404"), "expected 404 in error: {msg}");
1751 }
1752
1753 #[tokio::test]
1754 async fn fetch_html_redirect_no_location_returns_error() {
1755 use wiremock::matchers::{method, path};
1756 use wiremock::{Mock, ResponseTemplate};
1757
1758 let (executor, server) = mock_server_executor().await;
1759 Mock::given(method("GET"))
1761 .and(path("/redirect-no-loc"))
1762 .respond_with(ResponseTemplate::new(302))
1763 .mount(&server)
1764 .await;
1765
1766 let (host, addrs) = server_host_and_addr(&server);
1767 let url = format!("{}/redirect-no-loc", server.uri());
1768 let result = executor
1769 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1770 .await;
1771 assert!(result.is_err());
1772 let msg = result.unwrap_err().to_string();
1773 assert!(
1774 msg.contains("Location") || msg.contains("location"),
1775 "expected Location-related error: {msg}"
1776 );
1777 }
1778
1779 #[tokio::test]
1780 async fn fetch_html_single_redirect_followed() {
1781 use wiremock::matchers::{method, path};
1782 use wiremock::{Mock, ResponseTemplate};
1783
1784 let (executor, server) = mock_server_executor().await;
1785 let final_url = format!("{}/final", server.uri());
1786
1787 Mock::given(method("GET"))
1788 .and(path("/start"))
1789 .respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
1790 .mount(&server)
1791 .await;
1792
1793 Mock::given(method("GET"))
1794 .and(path("/final"))
1795 .respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
1796 .mount(&server)
1797 .await;
1798
1799 let (host, addrs) = server_host_and_addr(&server);
1800 let url = format!("{}/start", server.uri());
1801 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1802 assert!(result.is_ok(), "single redirect should succeed: {result:?}");
1803 assert_eq!(result.unwrap(), "<p>final</p>");
1804 }
1805
1806 #[tokio::test]
1807 async fn fetch_html_three_redirects_allowed() {
1808 use wiremock::matchers::{method, path};
1809 use wiremock::{Mock, ResponseTemplate};
1810
1811 let (executor, server) = mock_server_executor().await;
1812 let hop2 = format!("{}/hop2", server.uri());
1813 let hop3 = format!("{}/hop3", server.uri());
1814 let final_dest = format!("{}/done", server.uri());
1815
1816 Mock::given(method("GET"))
1817 .and(path("/hop1"))
1818 .respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
1819 .mount(&server)
1820 .await;
1821 Mock::given(method("GET"))
1822 .and(path("/hop2"))
1823 .respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
1824 .mount(&server)
1825 .await;
1826 Mock::given(method("GET"))
1827 .and(path("/hop3"))
1828 .respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
1829 .mount(&server)
1830 .await;
1831 Mock::given(method("GET"))
1832 .and(path("/done"))
1833 .respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
1834 .mount(&server)
1835 .await;
1836
1837 let (host, addrs) = server_host_and_addr(&server);
1838 let url = format!("{}/hop1", server.uri());
1839 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1840 assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
1841 assert_eq!(result.unwrap(), "<p>done</p>");
1842 }
1843
1844 #[tokio::test]
1845 async fn fetch_html_four_redirects_rejected() {
1846 use wiremock::matchers::{method, path};
1847 use wiremock::{Mock, ResponseTemplate};
1848
1849 let (executor, server) = mock_server_executor().await;
1850 let hop2 = format!("{}/r2", server.uri());
1851 let hop3 = format!("{}/r3", server.uri());
1852 let hop4 = format!("{}/r4", server.uri());
1853 let hop5 = format!("{}/r5", server.uri());
1854
1855 for (from, to) in [
1856 ("/r1", &hop2),
1857 ("/r2", &hop3),
1858 ("/r3", &hop4),
1859 ("/r4", &hop5),
1860 ] {
1861 Mock::given(method("GET"))
1862 .and(path(from))
1863 .respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
1864 .mount(&server)
1865 .await;
1866 }
1867
1868 let (host, addrs) = server_host_and_addr(&server);
1869 let url = format!("{}/r1", server.uri());
1870 let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
1871 assert!(result.is_err(), "4 redirects should be rejected");
1872 let msg = result.unwrap_err().to_string();
1873 assert!(
1874 msg.contains("redirect"),
1875 "expected redirect-related error: {msg}"
1876 );
1877 }
1878
1879 #[tokio::test]
1880 async fn fetch_html_body_too_large_returns_error() {
1881 use wiremock::matchers::{method, path};
1882 use wiremock::{Mock, ResponseTemplate};
1883
1884 let small_limit_executor = WebScrapeExecutor {
1885 timeout: Duration::from_secs(5),
1886 max_body_bytes: 10,
1887 allowed_domains: vec![],
1888 denied_domains: vec![],
1889 audit_logger: None,
1890 egress_config: EgressConfig::default(),
1891 egress_tx: None,
1892 egress_dropped: Arc::new(AtomicU64::new(0)),
1893 };
1894 let server = wiremock::MockServer::start().await;
1895 Mock::given(method("GET"))
1896 .and(path("/big"))
1897 .respond_with(
1898 ResponseTemplate::new(200)
1899 .set_body_string("this body is definitely longer than ten bytes"),
1900 )
1901 .mount(&server)
1902 .await;
1903
1904 let (host, addrs) = server_host_and_addr(&server);
1905 let url = format!("{}/big", server.uri());
1906 let result = small_limit_executor
1907 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
1908 .await;
1909 assert!(result.is_err());
1910 let msg = result.unwrap_err().to_string();
1911 assert!(msg.contains("too large"), "expected too-large error: {msg}");
1912 }
1913
1914 #[test]
1915 fn extract_scrape_blocks_empty_block_content() {
1916 let text = "```scrape\n\n```";
1917 let blocks = extract_scrape_blocks(text);
1918 assert_eq!(blocks.len(), 1);
1919 assert!(blocks[0].is_empty());
1920 }
1921
1922 #[test]
1923 fn extract_scrape_blocks_whitespace_only() {
1924 let text = "```scrape\n \n```";
1925 let blocks = extract_scrape_blocks(text);
1926 assert_eq!(blocks.len(), 1);
1927 }
1928
1929 #[test]
1930 fn parse_and_extract_multiple_selectors() {
1931 let html = "<div><h1>Title</h1><p>Para</p></div>";
1932 let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
1933 assert!(result.contains("Title"));
1934 assert!(result.contains("Para"));
1935 }
1936
1937 #[test]
1938 fn webscrape_executor_new_with_custom_config() {
1939 let config = ScrapeConfig {
1940 timeout: 60,
1941 max_body_bytes: 512,
1942 ..Default::default()
1943 };
1944 let executor = WebScrapeExecutor::new(&config);
1945 assert_eq!(executor.max_body_bytes, 512);
1946 }
1947
1948 #[test]
1949 fn webscrape_executor_debug() {
1950 let config = ScrapeConfig::default();
1951 let executor = WebScrapeExecutor::new(&config);
1952 let dbg = format!("{executor:?}");
1953 assert!(dbg.contains("WebScrapeExecutor"));
1954 }
1955
1956 #[test]
1957 fn extract_mode_attr_empty_name() {
1958 let mode = ExtractMode::parse("attr:");
1959 assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
1960 }
1961
1962 #[test]
1963 fn default_extract_returns_text() {
1964 assert_eq!(default_extract(), "text");
1965 }
1966
1967 #[test]
1968 fn scrape_instruction_debug() {
1969 let json = r#"{"url":"https://example.com","select":"h1"}"#;
1970 let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
1971 let dbg = format!("{instr:?}");
1972 assert!(dbg.contains("ScrapeInstruction"));
1973 }
1974
1975 #[test]
1976 fn extract_mode_debug() {
1977 let mode = ExtractMode::Text;
1978 let dbg = format!("{mode:?}");
1979 assert!(dbg.contains("Text"));
1980 }
1981
1982 #[test]
1987 fn max_redirects_constant_is_three() {
1988 const MAX_REDIRECTS: usize = 3;
1992 assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
1993 }
1994
1995 #[test]
1998 fn redirect_no_location_error_message() {
1999 let err = std::io::Error::other("redirect with no Location");
2000 assert!(err.to_string().contains("redirect with no Location"));
2001 }
2002
2003 #[test]
2005 fn too_many_redirects_error_message() {
2006 let err = std::io::Error::other("too many redirects");
2007 assert!(err.to_string().contains("too many redirects"));
2008 }
2009
2010 #[test]
2012 fn non_2xx_status_error_format() {
2013 let status = reqwest::StatusCode::FORBIDDEN;
2014 let msg = format!("HTTP {status}");
2015 assert!(msg.contains("403"));
2016 }
2017
2018 #[test]
2020 fn not_found_status_error_format() {
2021 let status = reqwest::StatusCode::NOT_FOUND;
2022 let msg = format!("HTTP {status}");
2023 assert!(msg.contains("404"));
2024 }
2025
2026 #[test]
2028 fn relative_redirect_same_host_path() {
2029 let base = Url::parse("https://example.com/current").unwrap();
2030 let resolved = base.join("/other").unwrap();
2031 assert_eq!(resolved.as_str(), "https://example.com/other");
2032 }
2033
2034 #[test]
2036 fn relative_redirect_relative_path() {
2037 let base = Url::parse("https://example.com/a/b").unwrap();
2038 let resolved = base.join("c").unwrap();
2039 assert_eq!(resolved.as_str(), "https://example.com/a/c");
2040 }
2041
2042 #[test]
2044 fn absolute_redirect_overrides_base() {
2045 let base = Url::parse("https://example.com/page").unwrap();
2046 let resolved = base.join("https://other.com/target").unwrap();
2047 assert_eq!(resolved.as_str(), "https://other.com/target");
2048 }
2049
2050 #[test]
2052 fn redirect_http_downgrade_rejected() {
2053 let location = "http://example.com/page";
2054 let base = Url::parse("https://example.com/start").unwrap();
2055 let next = base.join(location).unwrap();
2056 let err = validate_url(next.as_str()).unwrap_err();
2057 assert!(matches!(err, ToolError::Blocked { .. }));
2058 }
2059
2060 #[test]
2062 fn redirect_location_private_ip_blocked() {
2063 let location = "https://192.168.100.1/admin";
2064 let base = Url::parse("https://example.com/start").unwrap();
2065 let next = base.join(location).unwrap();
2066 let err = validate_url(next.as_str()).unwrap_err();
2067 assert!(matches!(err, ToolError::Blocked { .. }));
2068 let ToolError::Blocked { command: cmd } = err else {
2069 panic!("expected Blocked");
2070 };
2071 assert!(
2072 cmd.contains("private") || cmd.contains("scheme"),
2073 "error message should describe the block reason: {cmd}"
2074 );
2075 }
2076
2077 #[test]
2079 fn redirect_location_internal_domain_blocked() {
2080 let location = "https://metadata.internal/latest/meta-data/";
2081 let base = Url::parse("https://example.com/start").unwrap();
2082 let next = base.join(location).unwrap();
2083 let err = validate_url(next.as_str()).unwrap_err();
2084 assert!(matches!(err, ToolError::Blocked { .. }));
2085 }
2086
2087 #[test]
2089 fn redirect_chain_three_hops_all_public() {
2090 let hops = [
2091 "https://redirect1.example.com/hop1",
2092 "https://redirect2.example.com/hop2",
2093 "https://destination.example.com/final",
2094 ];
2095 for hop in hops {
2096 assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
2097 }
2098 }
2099
2100 #[test]
2105 fn redirect_to_private_ip_rejected_by_validate_url() {
2106 let private_targets = [
2108 "https://127.0.0.1/secret",
2109 "https://10.0.0.1/internal",
2110 "https://192.168.1.1/admin",
2111 "https://172.16.0.1/data",
2112 "https://[::1]/path",
2113 "https://[fe80::1]/path",
2114 "https://localhost/path",
2115 "https://service.internal/api",
2116 ];
2117 for target in private_targets {
2118 let result = validate_url(target);
2119 assert!(result.is_err(), "expected error for {target}");
2120 assert!(
2121 matches!(result.unwrap_err(), ToolError::Blocked { .. }),
2122 "expected Blocked for {target}"
2123 );
2124 }
2125 }
2126
2127 #[test]
2129 fn redirect_relative_url_resolves_correctly() {
2130 let base = Url::parse("https://example.com/page").unwrap();
2131 let relative = "/other";
2132 let resolved = base.join(relative).unwrap();
2133 assert_eq!(resolved.as_str(), "https://example.com/other");
2134 }
2135
2136 #[test]
2138 fn redirect_to_http_rejected() {
2139 let err = validate_url("http://example.com/page").unwrap_err();
2140 assert!(matches!(err, ToolError::Blocked { .. }));
2141 }
2142
2143 #[test]
2144 fn ipv4_mapped_ipv6_link_local_blocked() {
2145 let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
2146 assert!(matches!(err, ToolError::Blocked { .. }));
2147 }
2148
2149 #[test]
2150 fn ipv4_mapped_ipv6_public_allowed() {
2151 assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
2152 }
2153
2154 #[tokio::test]
2157 async fn fetch_http_scheme_blocked() {
2158 let config = ScrapeConfig::default();
2159 let executor = WebScrapeExecutor::new(&config);
2160 let call = crate::executor::ToolCall {
2161 tool_id: ToolName::new("fetch"),
2162 params: {
2163 let mut m = serde_json::Map::new();
2164 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2165 m
2166 },
2167 caller_id: None,
2168 };
2169 let result = executor.execute_tool_call(&call).await;
2170 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2171 }
2172
2173 #[tokio::test]
2174 async fn fetch_private_ip_blocked() {
2175 let config = ScrapeConfig::default();
2176 let executor = WebScrapeExecutor::new(&config);
2177 let call = crate::executor::ToolCall {
2178 tool_id: ToolName::new("fetch"),
2179 params: {
2180 let mut m = serde_json::Map::new();
2181 m.insert(
2182 "url".to_owned(),
2183 serde_json::json!("https://192.168.1.1/secret"),
2184 );
2185 m
2186 },
2187 caller_id: None,
2188 };
2189 let result = executor.execute_tool_call(&call).await;
2190 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2191 }
2192
2193 #[tokio::test]
2194 async fn fetch_localhost_blocked() {
2195 let config = ScrapeConfig::default();
2196 let executor = WebScrapeExecutor::new(&config);
2197 let call = crate::executor::ToolCall {
2198 tool_id: ToolName::new("fetch"),
2199 params: {
2200 let mut m = serde_json::Map::new();
2201 m.insert(
2202 "url".to_owned(),
2203 serde_json::json!("https://localhost/page"),
2204 );
2205 m
2206 },
2207 caller_id: None,
2208 };
2209 let result = executor.execute_tool_call(&call).await;
2210 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2211 }
2212
2213 #[tokio::test]
2214 async fn fetch_unknown_tool_returns_none() {
2215 let config = ScrapeConfig::default();
2216 let executor = WebScrapeExecutor::new(&config);
2217 let call = crate::executor::ToolCall {
2218 tool_id: ToolName::new("unknown_tool"),
2219 params: serde_json::Map::new(),
2220 caller_id: None,
2221 };
2222 let result = executor.execute_tool_call(&call).await;
2223 assert!(result.unwrap().is_none());
2224 }
2225
2226 #[tokio::test]
2227 async fn fetch_returns_body_via_mock() {
2228 use wiremock::matchers::{method, path};
2229 use wiremock::{Mock, ResponseTemplate};
2230
2231 let (executor, server) = mock_server_executor().await;
2232 Mock::given(method("GET"))
2233 .and(path("/content"))
2234 .respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
2235 .mount(&server)
2236 .await;
2237
2238 let (host, addrs) = server_host_and_addr(&server);
2239 let url = format!("{}/content", server.uri());
2240 let result = executor
2241 .fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
2242 .await;
2243 assert!(result.is_ok());
2244 assert_eq!(result.unwrap(), "plain text content");
2245 }
2246
2247 #[test]
2248 fn tool_definitions_returns_web_scrape_and_fetch() {
2249 let config = ScrapeConfig::default();
2250 let executor = WebScrapeExecutor::new(&config);
2251 let defs = executor.tool_definitions();
2252 assert_eq!(defs.len(), 2);
2253 assert_eq!(defs[0].id, "web_scrape");
2254 assert_eq!(
2255 defs[0].invocation,
2256 crate::registry::InvocationHint::FencedBlock("scrape")
2257 );
2258 assert_eq!(defs[1].id, "fetch");
2259 assert_eq!(
2260 defs[1].invocation,
2261 crate::registry::InvocationHint::ToolCall
2262 );
2263 }
2264
2265 #[test]
2266 fn tool_definitions_schema_has_all_params() {
2267 let config = ScrapeConfig::default();
2268 let executor = WebScrapeExecutor::new(&config);
2269 let defs = executor.tool_definitions();
2270 let obj = defs[0].schema.as_object().unwrap();
2271 let props = obj["properties"].as_object().unwrap();
2272 assert!(props.contains_key("url"));
2273 assert!(props.contains_key("select"));
2274 assert!(props.contains_key("extract"));
2275 assert!(props.contains_key("limit"));
2276 let req = obj["required"].as_array().unwrap();
2277 assert!(req.iter().any(|v| v.as_str() == Some("url")));
2278 assert!(req.iter().any(|v| v.as_str() == Some("select")));
2279 assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
2280 }
2281
2282 #[test]
2285 fn subdomain_localhost_blocked() {
2286 let host: url::Host<&str> = url::Host::Domain("foo.localhost");
2287 assert!(is_private_host(&host));
2288 }
2289
2290 #[test]
2291 fn internal_tld_blocked() {
2292 let host: url::Host<&str> = url::Host::Domain("service.internal");
2293 assert!(is_private_host(&host));
2294 }
2295
2296 #[test]
2297 fn local_tld_blocked() {
2298 let host: url::Host<&str> = url::Host::Domain("printer.local");
2299 assert!(is_private_host(&host));
2300 }
2301
2302 #[test]
2303 fn public_domain_not_blocked() {
2304 let host: url::Host<&str> = url::Host::Domain("example.com");
2305 assert!(!is_private_host(&host));
2306 }
2307
2308 #[tokio::test]
2311 async fn resolve_loopback_rejected() {
2312 let url = url::Url::parse("https://127.0.0.1/path").unwrap();
2314 let result = resolve_and_validate(&url).await;
2316 assert!(
2317 result.is_err(),
2318 "loopback IP must be rejected by resolve_and_validate"
2319 );
2320 let err = result.unwrap_err();
2321 assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
2322 }
2323
2324 #[tokio::test]
2325 async fn resolve_private_10_rejected() {
2326 let url = url::Url::parse("https://10.0.0.1/path").unwrap();
2327 let result = resolve_and_validate(&url).await;
2328 assert!(result.is_err());
2329 assert!(matches!(
2330 result.unwrap_err(),
2331 crate::executor::ToolError::Blocked { .. }
2332 ));
2333 }
2334
2335 #[tokio::test]
2336 async fn resolve_private_192_rejected() {
2337 let url = url::Url::parse("https://192.168.1.1/path").unwrap();
2338 let result = resolve_and_validate(&url).await;
2339 assert!(result.is_err());
2340 assert!(matches!(
2341 result.unwrap_err(),
2342 crate::executor::ToolError::Blocked { .. }
2343 ));
2344 }
2345
2346 #[tokio::test]
2347 async fn resolve_ipv6_loopback_rejected() {
2348 let url = url::Url::parse("https://[::1]/path").unwrap();
2349 let result = resolve_and_validate(&url).await;
2350 assert!(result.is_err());
2351 assert!(matches!(
2352 result.unwrap_err(),
2353 crate::executor::ToolError::Blocked { .. }
2354 ));
2355 }
2356
2357 #[tokio::test]
2358 async fn resolve_no_host_returns_ok() {
2359 let url = url::Url::parse("https://example.com/path").unwrap();
2361 let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
2363 let result = resolve_and_validate(&url_no_host).await;
2365 assert!(result.is_ok());
2366 let (host, addrs) = result.unwrap();
2367 assert!(host.is_empty());
2368 assert!(addrs.is_empty());
2369 drop(url);
2370 drop(url_no_host);
2371 }
2372
2373 async fn make_file_audit_logger(
2377 dir: &tempfile::TempDir,
2378 ) -> (
2379 std::sync::Arc<crate::audit::AuditLogger>,
2380 std::path::PathBuf,
2381 ) {
2382 use crate::audit::AuditLogger;
2383 use crate::config::AuditConfig;
2384 let path = dir.path().join("audit.log");
2385 let config = AuditConfig {
2386 enabled: true,
2387 destination: path.display().to_string(),
2388 ..Default::default()
2389 };
2390 let logger = std::sync::Arc::new(AuditLogger::from_config(&config, false).await.unwrap());
2391 (logger, path)
2392 }
2393
2394 #[tokio::test]
2395 async fn with_audit_sets_logger() {
2396 let config = ScrapeConfig::default();
2397 let executor = WebScrapeExecutor::new(&config);
2398 assert!(executor.audit_logger.is_none());
2399
2400 let dir = tempfile::tempdir().unwrap();
2401 let (logger, _path) = make_file_audit_logger(&dir).await;
2402 let executor = executor.with_audit(logger);
2403 assert!(executor.audit_logger.is_some());
2404 }
2405
2406 #[test]
2407 fn tool_error_to_audit_result_blocked_maps_correctly() {
2408 let err = ToolError::Blocked {
2409 command: "scheme not allowed: http".into(),
2410 };
2411 let result = tool_error_to_audit_result(&err);
2412 assert!(
2413 matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
2414 );
2415 }
2416
2417 #[test]
2418 fn tool_error_to_audit_result_timeout_maps_correctly() {
2419 let err = ToolError::Timeout { timeout_secs: 15 };
2420 let result = tool_error_to_audit_result(&err);
2421 assert!(matches!(result, AuditResult::Timeout));
2422 }
2423
2424 #[test]
2425 fn tool_error_to_audit_result_execution_error_maps_correctly() {
2426 let err = ToolError::Execution(std::io::Error::other("connection refused"));
2427 let result = tool_error_to_audit_result(&err);
2428 assert!(
2429 matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
2430 );
2431 }
2432
2433 #[tokio::test]
2434 async fn fetch_audit_blocked_url_logged() {
2435 let dir = tempfile::tempdir().unwrap();
2436 let (logger, log_path) = make_file_audit_logger(&dir).await;
2437
2438 let config = ScrapeConfig::default();
2439 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2440
2441 let call = crate::executor::ToolCall {
2442 tool_id: ToolName::new("fetch"),
2443 params: {
2444 let mut m = serde_json::Map::new();
2445 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2446 m
2447 },
2448 caller_id: None,
2449 };
2450 let result = executor.execute_tool_call(&call).await;
2451 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2452
2453 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2454 assert!(
2455 content.contains("\"tool\":\"fetch\""),
2456 "expected tool=fetch in audit: {content}"
2457 );
2458 assert!(
2459 content.contains("\"type\":\"blocked\""),
2460 "expected type=blocked in audit: {content}"
2461 );
2462 assert!(
2463 content.contains("http://example.com"),
2464 "expected URL in audit command field: {content}"
2465 );
2466 }
2467
2468 #[tokio::test]
2469 async fn log_audit_success_writes_to_file() {
2470 let dir = tempfile::tempdir().unwrap();
2471 let (logger, log_path) = make_file_audit_logger(&dir).await;
2472
2473 let config = ScrapeConfig::default();
2474 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2475
2476 executor
2477 .log_audit(
2478 "fetch",
2479 "https://example.com/page",
2480 AuditResult::Success,
2481 42,
2482 None,
2483 None,
2484 None,
2485 )
2486 .await;
2487
2488 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2489 assert!(
2490 content.contains("\"tool\":\"fetch\""),
2491 "expected tool=fetch in audit: {content}"
2492 );
2493 assert!(
2494 content.contains("\"type\":\"success\""),
2495 "expected type=success in audit: {content}"
2496 );
2497 assert!(
2498 content.contains("\"command\":\"https://example.com/page\""),
2499 "expected command URL in audit: {content}"
2500 );
2501 assert!(
2502 content.contains("\"duration_ms\":42"),
2503 "expected duration_ms in audit: {content}"
2504 );
2505 }
2506
2507 #[tokio::test]
2508 async fn log_audit_blocked_writes_to_file() {
2509 let dir = tempfile::tempdir().unwrap();
2510 let (logger, log_path) = make_file_audit_logger(&dir).await;
2511
2512 let config = ScrapeConfig::default();
2513 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2514
2515 executor
2516 .log_audit(
2517 "web_scrape",
2518 "http://evil.com/page",
2519 AuditResult::Blocked {
2520 reason: "scheme not allowed: http".into(),
2521 },
2522 0,
2523 None,
2524 None,
2525 None,
2526 )
2527 .await;
2528
2529 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2530 assert!(
2531 content.contains("\"tool\":\"web_scrape\""),
2532 "expected tool=web_scrape in audit: {content}"
2533 );
2534 assert!(
2535 content.contains("\"type\":\"blocked\""),
2536 "expected type=blocked in audit: {content}"
2537 );
2538 assert!(
2539 content.contains("scheme not allowed"),
2540 "expected block reason in audit: {content}"
2541 );
2542 }
2543
2544 #[tokio::test]
2545 async fn web_scrape_audit_blocked_url_logged() {
2546 let dir = tempfile::tempdir().unwrap();
2547 let (logger, log_path) = make_file_audit_logger(&dir).await;
2548
2549 let config = ScrapeConfig::default();
2550 let executor = WebScrapeExecutor::new(&config).with_audit(logger);
2551
2552 let call = crate::executor::ToolCall {
2553 tool_id: ToolName::new("web_scrape"),
2554 params: {
2555 let mut m = serde_json::Map::new();
2556 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2557 m.insert("select".to_owned(), serde_json::json!("h1"));
2558 m
2559 },
2560 caller_id: None,
2561 };
2562 let result = executor.execute_tool_call(&call).await;
2563 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2564
2565 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
2566 assert!(
2567 content.contains("\"tool\":\"web_scrape\""),
2568 "expected tool=web_scrape in audit: {content}"
2569 );
2570 assert!(
2571 content.contains("\"type\":\"blocked\""),
2572 "expected type=blocked in audit: {content}"
2573 );
2574 }
2575
2576 #[tokio::test]
2577 async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
2578 let config = ScrapeConfig::default();
2579 let executor = WebScrapeExecutor::new(&config);
2580 assert!(executor.audit_logger.is_none());
2581
2582 let call = crate::executor::ToolCall {
2583 tool_id: ToolName::new("fetch"),
2584 params: {
2585 let mut m = serde_json::Map::new();
2586 m.insert("url".to_owned(), serde_json::json!("http://example.com"));
2587 m
2588 },
2589 caller_id: None,
2590 };
2591 let result = executor.execute_tool_call(&call).await;
2593 assert!(matches!(result, Err(ToolError::Blocked { .. })));
2594 }
2595
2596 #[tokio::test]
2598 async fn fetch_execute_tool_call_end_to_end() {
2599 use wiremock::matchers::{method, path};
2600 use wiremock::{Mock, ResponseTemplate};
2601
2602 let (executor, server) = mock_server_executor().await;
2603 Mock::given(method("GET"))
2604 .and(path("/e2e"))
2605 .respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
2606 .mount(&server)
2607 .await;
2608
2609 let (host, addrs) = server_host_and_addr(&server);
2610 let result = executor
2612 .fetch_html(
2613 &format!("{}/e2e", server.uri()),
2614 &host,
2615 &addrs,
2616 "fetch",
2617 "test-cid",
2618 None,
2619 )
2620 .await;
2621 assert!(result.is_ok());
2622 assert!(result.unwrap().contains("end-to-end"));
2623 }
2624
2625 #[test]
2628 fn domain_matches_exact() {
2629 assert!(domain_matches("example.com", "example.com"));
2630 assert!(!domain_matches("example.com", "other.com"));
2631 assert!(!domain_matches("example.com", "sub.example.com"));
2632 }
2633
2634 #[test]
2635 fn domain_matches_wildcard_single_subdomain() {
2636 assert!(domain_matches("*.example.com", "sub.example.com"));
2637 assert!(!domain_matches("*.example.com", "example.com"));
2638 assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
2639 }
2640
2641 #[test]
2642 fn domain_matches_wildcard_does_not_match_empty_label() {
2643 assert!(!domain_matches("*.example.com", ".example.com"));
2645 }
2646
2647 #[test]
2648 fn domain_matches_multi_wildcard_treated_as_exact() {
2649 assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
2651 }
2652
2653 #[test]
2656 fn check_domain_policy_empty_lists_allow_all() {
2657 assert!(check_domain_policy("example.com", &[], &[]).is_ok());
2658 assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
2659 }
2660
2661 #[test]
2662 fn check_domain_policy_denylist_blocks() {
2663 let denied = vec!["evil.com".to_string()];
2664 let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
2665 assert!(matches!(err, ToolError::Blocked { .. }));
2666 }
2667
2668 #[test]
2669 fn check_domain_policy_denylist_does_not_block_other_domains() {
2670 let denied = vec!["evil.com".to_string()];
2671 assert!(check_domain_policy("good.com", &[], &denied).is_ok());
2672 }
2673
2674 #[test]
2675 fn check_domain_policy_allowlist_permits_matching() {
2676 let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
2677 assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
2678 assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
2679 }
2680
2681 #[test]
2682 fn check_domain_policy_allowlist_blocks_unknown() {
2683 let allowed = vec!["docs.rs".to_string()];
2684 let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
2685 assert!(matches!(err, ToolError::Blocked { .. }));
2686 }
2687
2688 #[test]
2689 fn check_domain_policy_deny_overrides_allow() {
2690 let allowed = vec!["example.com".to_string()];
2691 let denied = vec!["example.com".to_string()];
2692 let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
2693 assert!(matches!(err, ToolError::Blocked { .. }));
2694 }
2695
2696 #[test]
2697 fn check_domain_policy_wildcard_in_denylist() {
2698 let denied = vec!["*.evil.com".to_string()];
2699 let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
2700 assert!(matches!(err, ToolError::Blocked { .. }));
2701 assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
2703 }
2704}