// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
// SPDX-License-Identifier: MIT OR Apache-2.0
//! Web scraping executor with SSRF protection and domain policy enforcement.
//!
//! Exposes two tools to the LLM:
//!
//! - **`web_scrape`** — fetches a URL and extracts elements matching a CSS selector.
//! - **`fetch`** — fetches a URL and returns the raw response body as UTF-8 text.
//!
//! Both tools enforce:
//!
//! - HTTPS-only URLs (HTTP and other schemes are rejected).
//! - DNS resolution followed by a private-IP check to prevent SSRF.
//! - Optional domain allowlist and denylist from [`ScrapeConfig`].
//! - Configurable timeout and maximum response body size.
//! - Redirect following is disabled to prevent open-redirect SSRF bypasses.
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use schemars::JsonSchema;
use serde::Deserialize;
use url::Url;
use zeph_common::ToolName;
use crate::audit::{AuditEntry, AuditLogger, AuditResult, EgressEvent, chrono_now};
use crate::config::{EgressConfig, ScrapeConfig};
use crate::executor::{
ClaimSource, ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params,
};
use crate::net::is_private_ip;
/// Strips userinfo (`user:pass@`) and sensitive query params from a URL for safe logging.
///
/// Returns the sanitized URL string; falls back to the original if parsing fails.
fn redact_url_for_log(url: &str) -> String {
let Ok(mut parsed) = Url::parse(url) else {
return url.to_owned();
};
// Remove userinfo.
let _ = parsed.set_username("");
let _ = parsed.set_password(None);
// Strip query params whose names suggest secrets (token, key, secret, password, auth, sig).
let sensitive = [
"token", "key", "secret", "password", "auth", "sig", "api_key", "apikey",
];
let filtered: Vec<(String, String)> = parsed
.query_pairs()
.filter(|(k, _)| {
let lower = k.to_lowercase();
!sensitive.iter().any(|s| lower.contains(s))
})
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect();
if filtered.is_empty() {
parsed.set_query(None);
} else {
let q: String = filtered
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
parsed.set_query(Some(&q));
}
parsed.to_string()
}
#[derive(Debug, Deserialize, JsonSchema)]
struct FetchParams {
/// HTTPS URL to fetch
url: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
struct ScrapeInstruction {
/// HTTPS URL to scrape
url: String,
/// CSS selector
select: String,
/// Extract mode: text, html, or attr:<name>
#[serde(default = "default_extract")]
extract: String,
/// Max results to return
limit: Option<usize>,
}
fn default_extract() -> String {
"text".into()
}
#[derive(Debug)]
enum ExtractMode {
Text,
Html,
Attr(String),
}
impl ExtractMode {
fn parse(s: &str) -> Self {
match s {
"text" => Self::Text,
"html" => Self::Html,
attr if attr.starts_with("attr:") => {
Self::Attr(attr.strip_prefix("attr:").unwrap_or(attr).to_owned())
}
_ => Self::Text,
}
}
}
/// Extracts data from web pages via CSS selectors.
///
/// Handles two invocation paths:
///
/// 1. **Legacy fenced blocks** — detects ` ```scrape ` blocks in the LLM response, each
/// containing a JSON scrape instruction object. Dispatched via [`ToolExecutor::execute`].
/// 2. **Structured tool calls** — dispatched via [`ToolExecutor::execute_tool_call`] for
/// tool IDs `"web_scrape"` and `"fetch"`.
///
/// # Security
///
/// - Only HTTPS URLs are accepted. HTTP and other schemes return [`ToolError::InvalidParams`].
/// - DNS is resolved synchronously and each resolved address is checked against
/// [`is_private_ip`]. Private addresses are rejected to prevent SSRF.
/// - HTTP redirects are disabled (`Policy::none()`) to prevent open-redirect bypasses.
/// - Domain allowlists and denylists from config are enforced before DNS resolution.
///
/// # Example
///
/// ```rust,no_run
/// use zeph_tools::{WebScrapeExecutor, ToolExecutor, ToolCall, config::ScrapeConfig};
/// use zeph_common::ToolName;
///
/// # async fn example() {
/// let executor = WebScrapeExecutor::new(&ScrapeConfig::default());
///
/// let call = ToolCall {
/// tool_id: ToolName::new("fetch"),
/// params: {
/// let mut m = serde_json::Map::new();
/// m.insert("url".to_owned(), serde_json::json!("https://example.com"));
/// m
/// },
/// caller_id: None,
/// };
/// let _ = executor.execute_tool_call(&call).await;
/// # }
/// ```
#[derive(Debug)]
pub struct WebScrapeExecutor {
timeout: Duration,
max_body_bytes: usize,
allowed_domains: Vec<String>,
denied_domains: Vec<String>,
audit_logger: Option<Arc<AuditLogger>>,
egress_config: EgressConfig,
egress_tx: Option<tokio::sync::mpsc::Sender<EgressEvent>>,
egress_dropped: Arc<AtomicU64>,
}
impl WebScrapeExecutor {
/// Create a new `WebScrapeExecutor` from configuration.
///
/// No network connections are made at construction time.
#[must_use]
pub fn new(config: &ScrapeConfig) -> Self {
Self {
timeout: Duration::from_secs(config.timeout),
max_body_bytes: config.max_body_bytes,
allowed_domains: config.allowed_domains.clone(),
denied_domains: config.denied_domains.clone(),
audit_logger: None,
egress_config: EgressConfig::default(),
egress_tx: None,
egress_dropped: Arc::new(AtomicU64::new(0)),
}
}
/// Attach an audit logger. Each tool invocation will emit an [`AuditEntry`].
#[must_use]
pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
self.audit_logger = Some(logger);
self
}
/// Configure egress event logging.
#[must_use]
pub fn with_egress_config(mut self, config: EgressConfig) -> Self {
self.egress_config = config;
self
}
/// Attach the egress telemetry channel sender and drop counter.
///
/// Events are sent via [`tokio::sync::mpsc::Sender::try_send`] — the executor
/// never blocks waiting for capacity.
#[must_use]
pub fn with_egress_tx(
mut self,
tx: tokio::sync::mpsc::Sender<EgressEvent>,
dropped: Arc<AtomicU64>,
) -> Self {
self.egress_tx = Some(tx);
self.egress_dropped = dropped;
self
}
/// Returns a clone of the egress drop counter, for use in the drain task.
#[must_use]
pub fn egress_dropped(&self) -> Arc<AtomicU64> {
Arc::clone(&self.egress_dropped)
}
fn build_client(&self, host: &str, addrs: &[SocketAddr]) -> reqwest::Client {
let mut builder = reqwest::Client::builder()
.timeout(self.timeout)
.redirect(reqwest::redirect::Policy::none());
builder = builder.resolve_to_addrs(host, addrs);
builder.build().unwrap_or_default()
}
}
impl ToolExecutor for WebScrapeExecutor {
fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
use crate::registry::{InvocationHint, ToolDef};
vec![
ToolDef {
id: "web_scrape".into(),
description: "Extract structured data from a web page using CSS selectors.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns.\n\nParameters: url (string, required) - HTTPS URL; select (string, required) - CSS selector; extract (string, optional) - \"text\", \"html\", or \"attr:<name>\"; limit (integer, optional) - max results\nReturns: extracted text/HTML/attribute values, one per line\nErrors: InvalidParams if URL is not HTTPS or selector is empty; Timeout after configured seconds; connection/DNS failures".into(),
schema: schemars::schema_for!(ScrapeInstruction),
invocation: InvocationHint::FencedBlock("scrape"),
output_schema: None,
},
ToolDef {
id: "fetch".into(),
description: "Fetch a URL and return the response body as plain text.\n\nONLY call this tool when the user has explicitly provided a URL in their message, or when a prior tool call returned a URL to retrieve. NEVER construct, guess, or infer a URL from entity names, brand knowledge, or domain patterns. If no URL is present in the conversation, do not call this tool.\n\nParameters: url (string, required) - HTTPS URL to fetch\nReturns: response body as UTF-8 text, truncated if exceeding max body size\nErrors: InvalidParams if URL is not HTTPS; Timeout; SSRF-blocked private IPs; connection failures".into(),
schema: schemars::schema_for!(FetchParams),
invocation: InvocationHint::ToolCall,
output_schema: None,
},
]
}
async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
let blocks = extract_scrape_blocks(response);
if blocks.is_empty() {
return Ok(None);
}
let mut outputs = Vec::with_capacity(blocks.len());
#[allow(clippy::cast_possible_truncation)]
let blocks_executed = blocks.len() as u32;
for block in &blocks {
let instruction: ScrapeInstruction = serde_json::from_str(block).map_err(|e| {
ToolError::Execution(std::io::Error::new(
std::io::ErrorKind::InvalidData,
e.to_string(),
))
})?;
let correlation_id = EgressEvent::new_correlation_id();
let start = Instant::now();
let scrape_result = self
.scrape_instruction(&instruction, &correlation_id, None)
.await;
#[allow(clippy::cast_possible_truncation)]
let duration_ms = start.elapsed().as_millis() as u64;
match scrape_result {
Ok(output) => {
self.log_audit(
"web_scrape",
&instruction.url,
AuditResult::Success,
duration_ms,
None,
None,
Some(correlation_id),
)
.await;
outputs.push(output);
}
Err(e) => {
let audit_result = tool_error_to_audit_result(&e);
self.log_audit(
"web_scrape",
&instruction.url,
audit_result,
duration_ms,
Some(&e),
None,
Some(correlation_id),
)
.await;
return Err(e);
}
}
}
Ok(Some(ToolOutput {
tool_name: ToolName::new("web-scrape"),
summary: outputs.join("\n\n"),
blocks_executed,
filter_stats: None,
diff: None,
streamed: false,
terminal_id: None,
locations: None,
raw_response: None,
claim_source: Some(ClaimSource::WebScrape),
}))
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "tool.web_scrape", skip_all)
)]
#[allow(clippy::too_many_lines)]
async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
match call.tool_id.as_str() {
"web_scrape" => {
let instruction: ScrapeInstruction = deserialize_params(&call.params)?;
let correlation_id = EgressEvent::new_correlation_id();
let start = Instant::now();
let result = self
.scrape_instruction(&instruction, &correlation_id, call.caller_id.clone())
.await;
#[allow(clippy::cast_possible_truncation)]
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(output) => {
self.log_audit(
"web_scrape",
&instruction.url,
AuditResult::Success,
duration_ms,
None,
call.caller_id.clone(),
Some(correlation_id),
)
.await;
Ok(Some(ToolOutput {
tool_name: ToolName::new("web-scrape"),
summary: output,
blocks_executed: 1,
filter_stats: None,
diff: None,
streamed: false,
terminal_id: None,
locations: None,
raw_response: None,
claim_source: Some(ClaimSource::WebScrape),
}))
}
Err(e) => {
let audit_result = tool_error_to_audit_result(&e);
self.log_audit(
"web_scrape",
&instruction.url,
audit_result,
duration_ms,
Some(&e),
call.caller_id.clone(),
Some(correlation_id),
)
.await;
Err(e)
}
}
}
"fetch" => {
let p: FetchParams = deserialize_params(&call.params)?;
let correlation_id = EgressEvent::new_correlation_id();
let start = Instant::now();
let result = self
.handle_fetch(&p, &correlation_id, call.caller_id.clone())
.await;
#[allow(clippy::cast_possible_truncation)]
let duration_ms = start.elapsed().as_millis() as u64;
match result {
Ok(output) => {
self.log_audit(
"fetch",
&p.url,
AuditResult::Success,
duration_ms,
None,
call.caller_id.clone(),
Some(correlation_id),
)
.await;
Ok(Some(ToolOutput {
tool_name: ToolName::new("fetch"),
summary: output,
blocks_executed: 1,
filter_stats: None,
diff: None,
streamed: false,
terminal_id: None,
locations: None,
raw_response: None,
claim_source: Some(ClaimSource::WebScrape),
}))
}
Err(e) => {
let audit_result = tool_error_to_audit_result(&e);
self.log_audit(
"fetch",
&p.url,
audit_result,
duration_ms,
Some(&e),
call.caller_id.clone(),
Some(correlation_id),
)
.await;
Err(e)
}
}
}
_ => Ok(None),
}
}
fn is_tool_retryable(&self, tool_id: &str) -> bool {
matches!(tool_id, "web_scrape" | "fetch")
}
}
fn tool_error_to_audit_result(e: &ToolError) -> AuditResult {
match e {
ToolError::Blocked { command } => AuditResult::Blocked {
reason: command.clone(),
},
ToolError::Timeout { .. } => AuditResult::Timeout,
_ => AuditResult::Error {
message: e.to_string(),
},
}
}
impl WebScrapeExecutor {
#[allow(clippy::too_many_arguments)]
async fn log_audit(
&self,
tool: &str,
command: &str,
result: AuditResult,
duration_ms: u64,
error: Option<&ToolError>,
caller_id: Option<String>,
correlation_id: Option<String>,
) {
if let Some(ref logger) = self.audit_logger {
let (error_category, error_domain, error_phase) =
error.map_or((None, None, None), |e| {
let cat = e.category();
(
Some(cat.label().to_owned()),
Some(cat.domain().label().to_owned()),
Some(cat.phase().label().to_owned()),
)
});
let entry = AuditEntry {
timestamp: chrono_now(),
tool: tool.into(),
command: command.into(),
result,
duration_ms,
error_category,
error_domain,
error_phase,
claim_source: Some(ClaimSource::WebScrape),
mcp_server_id: None,
injection_flagged: false,
embedding_anomalous: false,
cross_boundary_mcp_to_acp: false,
adversarial_policy_decision: None,
exit_code: None,
truncated: false,
caller_id,
policy_match: None,
correlation_id,
vigil_risk: None,
};
logger.log(&entry).await;
}
}
fn send_egress_event(&self, event: EgressEvent) {
if let Some(ref tx) = self.egress_tx {
match tx.try_send(event) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
self.egress_dropped.fetch_add(1, Ordering::Relaxed);
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("egress channel closed; executor continuing without telemetry");
}
}
}
}
async fn log_egress_event(&self, event: &EgressEvent) {
if let Some(ref logger) = self.audit_logger {
logger.log_egress(event).await;
}
self.send_egress_event(event.clone());
}
async fn handle_fetch(
&self,
params: &FetchParams,
correlation_id: &str,
caller_id: Option<String>,
) -> Result<String, ToolError> {
let parsed = validate_url(¶ms.url);
let host_str = parsed
.as_ref()
.map(|u| u.host_str().unwrap_or("").to_owned())
.unwrap_or_default();
if let Err(ref _e) = parsed {
if self.egress_config.enabled && self.egress_config.log_blocked {
let event = Self::make_blocked_event(
"fetch",
¶ms.url,
&host_str,
correlation_id,
caller_id.clone(),
"scheme",
);
self.log_egress_event(&event).await;
}
return Err(parsed.unwrap_err());
}
let parsed = parsed.unwrap();
if let Err(e) = check_domain_policy(
parsed.host_str().unwrap_or(""),
&self.allowed_domains,
&self.denied_domains,
) {
if self.egress_config.enabled && self.egress_config.log_blocked {
let event = Self::make_blocked_event(
"fetch",
¶ms.url,
parsed.host_str().unwrap_or(""),
correlation_id,
caller_id.clone(),
"blocklist",
);
self.log_egress_event(&event).await;
}
return Err(e);
}
let (host, addrs) = match resolve_and_validate(&parsed).await {
Ok(v) => v,
Err(e) => {
if self.egress_config.enabled && self.egress_config.log_blocked {
let event = Self::make_blocked_event(
"fetch",
¶ms.url,
parsed.host_str().unwrap_or(""),
correlation_id,
caller_id.clone(),
"ssrf",
);
self.log_egress_event(&event).await;
}
return Err(e);
}
};
self.fetch_html(
¶ms.url,
&host,
&addrs,
"fetch",
correlation_id,
caller_id,
)
.await
}
async fn scrape_instruction(
&self,
instruction: &ScrapeInstruction,
correlation_id: &str,
caller_id: Option<String>,
) -> Result<String, ToolError> {
let parsed = validate_url(&instruction.url);
let host_str = parsed
.as_ref()
.map(|u| u.host_str().unwrap_or("").to_owned())
.unwrap_or_default();
if let Err(ref _e) = parsed {
if self.egress_config.enabled && self.egress_config.log_blocked {
let event = Self::make_blocked_event(
"web_scrape",
&instruction.url,
&host_str,
correlation_id,
caller_id.clone(),
"scheme",
);
self.log_egress_event(&event).await;
}
return Err(parsed.unwrap_err());
}
let parsed = parsed.unwrap();
if let Err(e) = check_domain_policy(
parsed.host_str().unwrap_or(""),
&self.allowed_domains,
&self.denied_domains,
) {
if self.egress_config.enabled && self.egress_config.log_blocked {
let event = Self::make_blocked_event(
"web_scrape",
&instruction.url,
parsed.host_str().unwrap_or(""),
correlation_id,
caller_id.clone(),
"blocklist",
);
self.log_egress_event(&event).await;
}
return Err(e);
}
let (host, addrs) = match resolve_and_validate(&parsed).await {
Ok(v) => v,
Err(e) => {
if self.egress_config.enabled && self.egress_config.log_blocked {
let event = Self::make_blocked_event(
"web_scrape",
&instruction.url,
parsed.host_str().unwrap_or(""),
correlation_id,
caller_id.clone(),
"ssrf",
);
self.log_egress_event(&event).await;
}
return Err(e);
}
};
let html = self
.fetch_html(
&instruction.url,
&host,
&addrs,
"web_scrape",
correlation_id,
caller_id,
)
.await?;
let selector = instruction.select.clone();
let extract = ExtractMode::parse(&instruction.extract);
let limit = instruction.limit.unwrap_or(10);
tokio::task::spawn_blocking(move || parse_and_extract(&html, &selector, &extract, limit))
.await
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?
}
fn make_blocked_event(
tool: &str,
url: &str,
host: &str,
correlation_id: &str,
caller_id: Option<String>,
block_reason: &'static str,
) -> EgressEvent {
EgressEvent {
timestamp: chrono_now(),
kind: "egress",
correlation_id: correlation_id.to_owned(),
tool: tool.into(),
url: redact_url_for_log(url),
host: host.to_owned(),
method: "GET".to_owned(),
status: None,
duration_ms: 0,
response_bytes: 0,
blocked: true,
block_reason: Some(block_reason),
caller_id,
hop: 0,
}
}
/// Fetches the HTML at `url`, manually following up to 3 redirects.
///
/// Each redirect target is validated with `validate_url` and `resolve_and_validate`
/// before following, preventing SSRF via redirect chains. When egress logging is
/// enabled, one [`EgressEvent`] is emitted per hop.
///
/// # Errors
///
/// Returns `ToolError::Blocked` if any redirect target resolves to a private IP.
/// Returns `ToolError::Execution` on HTTP errors, too-large bodies, or too many redirects.
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
async fn fetch_html(
&self,
url: &str,
host: &str,
addrs: &[SocketAddr],
tool: &str,
correlation_id: &str,
caller_id: Option<String>,
) -> Result<String, ToolError> {
const MAX_REDIRECTS: usize = 3;
let mut current_url = url.to_owned();
let mut current_host = host.to_owned();
let mut current_addrs = addrs.to_vec();
for hop in 0..=MAX_REDIRECTS {
let hop_start = Instant::now();
// Build a per-hop client pinned to the current hop's validated addresses.
let client = self.build_client(¤t_host, ¤t_addrs);
let resp = client
.get(¤t_url)
.send()
.await
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
let resp = match resp {
Ok(r) => r,
Err(e) => {
if self.egress_config.enabled {
#[allow(clippy::cast_possible_truncation)]
let duration_ms = hop_start.elapsed().as_millis() as u64;
let event = EgressEvent {
timestamp: chrono_now(),
kind: "egress",
correlation_id: correlation_id.to_owned(),
tool: tool.into(),
url: redact_url_for_log(¤t_url),
host: current_host.clone(),
method: "GET".to_owned(),
status: None,
duration_ms,
response_bytes: 0,
blocked: false,
block_reason: None,
caller_id: caller_id.clone(),
#[allow(clippy::cast_possible_truncation)]
hop: hop as u8,
};
self.log_egress_event(&event).await;
}
return Err(e);
}
};
let status = resp.status();
if status.is_redirection() {
if hop == MAX_REDIRECTS {
return Err(ToolError::Execution(std::io::Error::other(
"too many redirects",
)));
}
let location = resp
.headers()
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
ToolError::Execution(std::io::Error::other("redirect with no Location"))
})?;
// Resolve relative redirect URLs against the current URL.
let base = Url::parse(¤t_url)
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
let next_url = base
.join(location)
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
let validated = validate_url(next_url.as_str());
if let Err(ref _e) = validated {
if self.egress_config.enabled && self.egress_config.log_blocked {
#[allow(clippy::cast_possible_truncation)]
let duration_ms = hop_start.elapsed().as_millis() as u64;
let next_host = next_url.host_str().unwrap_or("").to_owned();
let event = EgressEvent {
timestamp: chrono_now(),
kind: "egress",
correlation_id: correlation_id.to_owned(),
tool: tool.into(),
url: redact_url_for_log(next_url.as_str()),
host: next_host,
method: "GET".to_owned(),
status: None,
duration_ms,
response_bytes: 0,
blocked: true,
block_reason: Some("ssrf"),
caller_id: caller_id.clone(),
#[allow(clippy::cast_possible_truncation)]
hop: (hop + 1) as u8,
};
self.log_egress_event(&event).await;
}
return Err(validated.unwrap_err());
}
let validated = validated.unwrap();
let resolve_result = resolve_and_validate(&validated).await;
if let Err(ref _e) = resolve_result {
if self.egress_config.enabled && self.egress_config.log_blocked {
#[allow(clippy::cast_possible_truncation)]
let duration_ms = hop_start.elapsed().as_millis() as u64;
let next_host = next_url.host_str().unwrap_or("").to_owned();
let event = EgressEvent {
timestamp: chrono_now(),
kind: "egress",
correlation_id: correlation_id.to_owned(),
tool: tool.into(),
url: redact_url_for_log(next_url.as_str()),
host: next_host,
method: "GET".to_owned(),
status: None,
duration_ms,
response_bytes: 0,
blocked: true,
block_reason: Some("ssrf"),
caller_id: caller_id.clone(),
#[allow(clippy::cast_possible_truncation)]
hop: (hop + 1) as u8,
};
self.log_egress_event(&event).await;
}
return Err(resolve_result.unwrap_err());
}
let (next_host, next_addrs) = resolve_result.unwrap();
current_url = next_url.to_string();
current_host = next_host;
current_addrs = next_addrs;
continue;
}
if !status.is_success() {
if self.egress_config.enabled {
#[allow(clippy::cast_possible_truncation)]
let duration_ms = hop_start.elapsed().as_millis() as u64;
let event = EgressEvent {
timestamp: chrono_now(),
kind: "egress",
correlation_id: correlation_id.to_owned(),
tool: tool.into(),
url: current_url.clone(),
host: current_host.clone(),
method: "GET".to_owned(),
status: Some(status.as_u16()),
duration_ms,
response_bytes: 0,
blocked: false,
block_reason: None,
caller_id: caller_id.clone(),
#[allow(clippy::cast_possible_truncation)]
hop: hop as u8,
};
self.log_egress_event(&event).await;
}
return Err(ToolError::Http {
status: status.as_u16(),
message: status.canonical_reason().unwrap_or("unknown").to_owned(),
});
}
let bytes = resp
.bytes()
.await
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
if bytes.len() > self.max_body_bytes {
if self.egress_config.enabled {
#[allow(clippy::cast_possible_truncation)]
let duration_ms = hop_start.elapsed().as_millis() as u64;
let event = EgressEvent {
timestamp: chrono_now(),
kind: "egress",
correlation_id: correlation_id.to_owned(),
tool: tool.into(),
url: current_url.clone(),
host: current_host.clone(),
method: "GET".to_owned(),
status: Some(status.as_u16()),
duration_ms,
response_bytes: bytes.len(),
blocked: false,
block_reason: None,
caller_id: caller_id.clone(),
#[allow(clippy::cast_possible_truncation)]
hop: hop as u8,
};
self.log_egress_event(&event).await;
}
return Err(ToolError::Execution(std::io::Error::other(format!(
"response too large: {} bytes (max: {})",
bytes.len(),
self.max_body_bytes,
))));
}
// Success — emit egress event.
if self.egress_config.enabled {
#[allow(clippy::cast_possible_truncation)]
let duration_ms = hop_start.elapsed().as_millis() as u64;
let response_bytes = if self.egress_config.log_response_bytes {
bytes.len()
} else {
0
};
let event = EgressEvent {
timestamp: chrono_now(),
kind: "egress",
correlation_id: correlation_id.to_owned(),
tool: tool.into(),
url: current_url.clone(),
host: current_host.clone(),
method: "GET".to_owned(),
status: Some(status.as_u16()),
duration_ms,
response_bytes,
blocked: false,
block_reason: None,
caller_id: caller_id.clone(),
#[allow(clippy::cast_possible_truncation)]
hop: hop as u8,
};
self.log_egress_event(&event).await;
}
return String::from_utf8(bytes.to_vec())
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
}
Err(ToolError::Execution(std::io::Error::other(
"too many redirects",
)))
}
}
fn extract_scrape_blocks(text: &str) -> Vec<&str> {
crate::executor::extract_fenced_blocks(text, "scrape")
}
/// Check host against the domain allowlist/denylist from `ScrapeConfig`.
///
/// Logic:
/// 1. If `denied_domains` matches the host → block.
/// 2. If `allowed_domains` is non-empty:
/// a. IP address hosts are always rejected (no pattern can match a bare IP).
/// b. Hosts not matching any entry → block.
/// 3. Otherwise → allow.
///
/// Wildcard prefix matching: `*.example.com` matches `sub.example.com` but NOT `example.com`.
/// Multiple wildcards are not supported; patterns with more than one `*` are treated as exact.
fn check_domain_policy(
host: &str,
allowed_domains: &[String],
denied_domains: &[String],
) -> Result<(), ToolError> {
if denied_domains.iter().any(|p| domain_matches(p, host)) {
return Err(ToolError::Blocked {
command: format!("domain blocked by denylist: {host}"),
});
}
if !allowed_domains.is_empty() {
// Bare IP addresses cannot match any domain pattern — reject when allowlist is active.
let is_ip = host.parse::<std::net::IpAddr>().is_ok()
|| (host.starts_with('[') && host.ends_with(']'));
if is_ip {
return Err(ToolError::Blocked {
command: format!(
"bare IP address not allowed when domain allowlist is active: {host}"
),
});
}
if !allowed_domains.iter().any(|p| domain_matches(p, host)) {
return Err(ToolError::Blocked {
command: format!("domain not in allowlist: {host}"),
});
}
}
Ok(())
}
/// Match a domain pattern against a hostname.
///
/// Pattern `*.example.com` matches `sub.example.com` but not `example.com` or
/// `sub.sub.example.com`. Exact patterns match only themselves.
/// Patterns with multiple `*` characters are not supported and are treated as exact strings.
fn domain_matches(pattern: &str, host: &str) -> bool {
if pattern.starts_with("*.") {
// Allow only a single subdomain level: `*.example.com` → `<label>.example.com`
let suffix = &pattern[1..]; // ".example.com"
if let Some(remainder) = host.strip_suffix(suffix) {
// remainder must be a single DNS label (no dots)
!remainder.is_empty() && !remainder.contains('.')
} else {
false
}
} else {
pattern == host
}
}
fn validate_url(raw: &str) -> Result<Url, ToolError> {
let parsed = Url::parse(raw).map_err(|_| ToolError::Blocked {
command: format!("invalid URL: {raw}"),
})?;
if parsed.scheme() != "https" {
return Err(ToolError::Blocked {
command: format!("scheme not allowed: {}", parsed.scheme()),
});
}
if let Some(host) = parsed.host()
&& is_private_host(&host)
{
return Err(ToolError::Blocked {
command: format!(
"private/local host blocked: {}",
parsed.host_str().unwrap_or("")
),
});
}
Ok(parsed)
}
fn is_private_host(host: &url::Host<&str>) -> bool {
match host {
url::Host::Domain(d) => {
// Exact match or subdomain of localhost (e.g. foo.localhost)
// and .internal/.local TLDs used in cloud/k8s environments.
#[allow(clippy::case_sensitive_file_extension_comparisons)]
{
*d == "localhost"
|| d.ends_with(".localhost")
|| d.ends_with(".internal")
|| d.ends_with(".local")
}
}
url::Host::Ipv4(v4) => is_private_ip(IpAddr::V4(*v4)),
url::Host::Ipv6(v6) => is_private_ip(IpAddr::V6(*v6)),
}
}
/// Resolves DNS for the URL host, validates all resolved IPs against private ranges,
/// and returns the hostname and validated socket addresses.
///
/// Returning the addresses allows the caller to pin the HTTP client to these exact
/// addresses, eliminating TOCTOU between DNS validation and the actual connection.
async fn resolve_and_validate(url: &Url) -> Result<(String, Vec<SocketAddr>), ToolError> {
let Some(host) = url.host_str() else {
return Ok((String::new(), vec![]));
};
let port = url.port_or_known_default().unwrap_or(443);
let addrs: Vec<SocketAddr> = tokio::net::lookup_host(format!("{host}:{port}"))
.await
.map_err(|e| ToolError::Blocked {
command: format!("DNS resolution failed: {e}"),
})?
.collect();
for addr in &addrs {
if is_private_ip(addr.ip()) {
return Err(ToolError::Blocked {
command: format!("SSRF protection: private IP {} for host {host}", addr.ip()),
});
}
}
Ok((host.to_owned(), addrs))
}
fn parse_and_extract(
html: &str,
selector: &str,
extract: &ExtractMode,
limit: usize,
) -> Result<String, ToolError> {
let soup = scrape_core::Soup::parse(html);
let tags = soup.find_all(selector).map_err(|e| {
ToolError::Execution(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("invalid selector: {e}"),
))
})?;
let mut results = Vec::new();
for tag in tags.into_iter().take(limit) {
let value = match extract {
ExtractMode::Text => tag.text(),
ExtractMode::Html => tag.inner_html(),
ExtractMode::Attr(name) => tag.get(name).unwrap_or_default().to_owned(),
};
if !value.trim().is_empty() {
results.push(value.trim().to_owned());
}
}
if results.is_empty() {
Ok(format!("No results for selector: {selector}"))
} else {
Ok(results.join("\n"))
}
}
#[cfg(test)]
mod tests {
use super::*;
// --- extract_scrape_blocks ---
#[test]
fn extract_single_block() {
let text =
"Here:\n```scrape\n{\"url\":\"https://example.com\",\"select\":\"h1\"}\n```\nDone.";
let blocks = extract_scrape_blocks(text);
assert_eq!(blocks.len(), 1);
assert!(blocks[0].contains("example.com"));
}
#[test]
fn extract_multiple_blocks() {
let text = "```scrape\n{\"url\":\"https://a.com\",\"select\":\"h1\"}\n```\ntext\n```scrape\n{\"url\":\"https://b.com\",\"select\":\"p\"}\n```";
let blocks = extract_scrape_blocks(text);
assert_eq!(blocks.len(), 2);
}
#[test]
fn no_blocks_returns_empty() {
let blocks = extract_scrape_blocks("plain text, no code blocks");
assert!(blocks.is_empty());
}
#[test]
fn unclosed_block_ignored() {
let blocks = extract_scrape_blocks("```scrape\n{\"url\":\"https://x.com\"}");
assert!(blocks.is_empty());
}
#[test]
fn non_scrape_block_ignored() {
let text =
"```bash\necho hi\n```\n```scrape\n{\"url\":\"https://x.com\",\"select\":\"h1\"}\n```";
let blocks = extract_scrape_blocks(text);
assert_eq!(blocks.len(), 1);
assert!(blocks[0].contains("x.com"));
}
#[test]
fn multiline_json_block() {
let text =
"```scrape\n{\n \"url\": \"https://example.com\",\n \"select\": \"h1\"\n}\n```";
let blocks = extract_scrape_blocks(text);
assert_eq!(blocks.len(), 1);
let instr: ScrapeInstruction = serde_json::from_str(blocks[0]).unwrap();
assert_eq!(instr.url, "https://example.com");
}
// --- ScrapeInstruction parsing ---
#[test]
fn parse_valid_instruction() {
let json = r#"{"url":"https://example.com","select":"h1","extract":"text","limit":5}"#;
let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
assert_eq!(instr.url, "https://example.com");
assert_eq!(instr.select, "h1");
assert_eq!(instr.extract, "text");
assert_eq!(instr.limit, Some(5));
}
#[test]
fn parse_minimal_instruction() {
let json = r#"{"url":"https://example.com","select":"p"}"#;
let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
assert_eq!(instr.extract, "text");
assert!(instr.limit.is_none());
}
#[test]
fn parse_attr_extract() {
let json = r#"{"url":"https://example.com","select":"a","extract":"attr:href"}"#;
let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
assert_eq!(instr.extract, "attr:href");
}
#[test]
fn parse_invalid_json_errors() {
let result = serde_json::from_str::<ScrapeInstruction>("not json");
assert!(result.is_err());
}
// --- ExtractMode ---
#[test]
fn extract_mode_text() {
assert!(matches!(ExtractMode::parse("text"), ExtractMode::Text));
}
#[test]
fn extract_mode_html() {
assert!(matches!(ExtractMode::parse("html"), ExtractMode::Html));
}
#[test]
fn extract_mode_attr() {
let mode = ExtractMode::parse("attr:href");
assert!(matches!(mode, ExtractMode::Attr(ref s) if s == "href"));
}
#[test]
fn extract_mode_unknown_defaults_to_text() {
assert!(matches!(ExtractMode::parse("unknown"), ExtractMode::Text));
}
// --- validate_url ---
#[test]
fn valid_https_url() {
assert!(validate_url("https://example.com").is_ok());
}
#[test]
fn http_rejected() {
let err = validate_url("http://example.com").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ftp_rejected() {
let err = validate_url("ftp://files.example.com").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn file_rejected() {
let err = validate_url("file:///etc/passwd").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn invalid_url_rejected() {
let err = validate_url("not a url").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn localhost_blocked() {
let err = validate_url("https://localhost/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn loopback_ip_blocked() {
let err = validate_url("https://127.0.0.1/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn private_10_blocked() {
let err = validate_url("https://10.0.0.1/api").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn private_172_blocked() {
let err = validate_url("https://172.16.0.1/api").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn private_192_blocked() {
let err = validate_url("https://192.168.1.1/api").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ipv6_loopback_blocked() {
let err = validate_url("https://[::1]/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn public_ip_allowed() {
assert!(validate_url("https://93.184.216.34/page").is_ok());
}
// --- parse_and_extract ---
#[test]
fn extract_text_from_html() {
let html = "<html><body><h1>Hello World</h1><p>Content</p></body></html>";
let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn extract_multiple_elements() {
let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
assert_eq!(result, "A\nB\nC");
}
#[test]
fn extract_with_limit() {
let html = "<ul><li>A</li><li>B</li><li>C</li></ul>";
let result = parse_and_extract(html, "li", &ExtractMode::Text, 2).unwrap();
assert_eq!(result, "A\nB");
}
#[test]
fn extract_attr_href() {
let html = r#"<a href="https://example.com">Link</a>"#;
let result =
parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
assert_eq!(result, "https://example.com");
}
#[test]
fn extract_inner_html() {
let html = "<div><span>inner</span></div>";
let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
assert!(result.contains("<span>inner</span>"));
}
#[test]
fn no_matches_returns_message() {
let html = "<html><body><p>text</p></body></html>";
let result = parse_and_extract(html, "h1", &ExtractMode::Text, 10).unwrap();
assert!(result.starts_with("No results for selector:"));
}
#[test]
fn empty_text_skipped() {
let html = "<ul><li> </li><li>A</li></ul>";
let result = parse_and_extract(html, "li", &ExtractMode::Text, 10).unwrap();
assert_eq!(result, "A");
}
#[test]
fn invalid_selector_errors() {
let html = "<html><body></body></html>";
let result = parse_and_extract(html, "[[[invalid", &ExtractMode::Text, 10);
assert!(result.is_err());
}
#[test]
fn empty_html_returns_no_results() {
let result = parse_and_extract("", "h1", &ExtractMode::Text, 10).unwrap();
assert!(result.starts_with("No results for selector:"));
}
#[test]
fn nested_selector() {
let html = "<div><span>inner</span></div><span>outer</span>";
let result = parse_and_extract(html, "div > span", &ExtractMode::Text, 10).unwrap();
assert_eq!(result, "inner");
}
#[test]
fn attr_missing_returns_empty() {
let html = r"<a>No href</a>";
let result =
parse_and_extract(html, "a", &ExtractMode::Attr("href".to_owned()), 10).unwrap();
assert!(result.starts_with("No results for selector:"));
}
#[test]
fn extract_html_mode() {
let html = "<div><b>bold</b> text</div>";
let result = parse_and_extract(html, "div", &ExtractMode::Html, 10).unwrap();
assert!(result.contains("<b>bold</b>"));
}
#[test]
fn limit_zero_returns_no_results() {
let html = "<ul><li>A</li><li>B</li></ul>";
let result = parse_and_extract(html, "li", &ExtractMode::Text, 0).unwrap();
assert!(result.starts_with("No results for selector:"));
}
// --- validate_url edge cases ---
#[test]
fn url_with_port_allowed() {
assert!(validate_url("https://example.com:8443/path").is_ok());
}
#[test]
fn link_local_ip_blocked() {
let err = validate_url("https://169.254.1.1/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn url_no_scheme_rejected() {
let err = validate_url("example.com/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn unspecified_ipv4_blocked() {
let err = validate_url("https://0.0.0.0/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn broadcast_ipv4_blocked() {
let err = validate_url("https://255.255.255.255/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ipv6_link_local_blocked() {
let err = validate_url("https://[fe80::1]/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ipv6_unique_local_blocked() {
let err = validate_url("https://[fd12::1]/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ipv4_mapped_ipv6_loopback_blocked() {
let err = validate_url("https://[::ffff:127.0.0.1]/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ipv4_mapped_ipv6_private_blocked() {
let err = validate_url("https://[::ffff:10.0.0.1]/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
// --- WebScrapeExecutor (no-network) ---
#[tokio::test]
async fn executor_no_blocks_returns_none() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let result = executor.execute("plain text").await;
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn executor_invalid_json_errors() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let response = "```scrape\nnot json\n```";
let result = executor.execute(response).await;
assert!(matches!(result, Err(ToolError::Execution(_))));
}
#[tokio::test]
async fn executor_blocked_url_errors() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let response = "```scrape\n{\"url\":\"http://example.com\",\"select\":\"h1\"}\n```";
let result = executor.execute(response).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
}
#[tokio::test]
async fn executor_private_ip_blocked() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let response = "```scrape\n{\"url\":\"https://192.168.1.1/api\",\"select\":\"h1\"}\n```";
let result = executor.execute(response).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
}
#[tokio::test]
async fn executor_unreachable_host_returns_error() {
let config = ScrapeConfig {
timeout: 1,
max_body_bytes: 1_048_576,
..Default::default()
};
let executor = WebScrapeExecutor::new(&config);
let response = "```scrape\n{\"url\":\"https://192.0.2.1:1/page\",\"select\":\"h1\"}\n```";
let result = executor.execute(response).await;
assert!(matches!(result, Err(ToolError::Execution(_))));
}
#[tokio::test]
async fn executor_localhost_url_blocked() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let response = "```scrape\n{\"url\":\"https://localhost:9999/api\",\"select\":\"h1\"}\n```";
let result = executor.execute(response).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
}
#[tokio::test]
async fn executor_empty_text_returns_none() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let result = executor.execute("").await;
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn executor_multiple_blocks_first_blocked() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let response = "```scrape\n{\"url\":\"http://evil.com\",\"select\":\"h1\"}\n```\n\
```scrape\n{\"url\":\"https://ok.com\",\"select\":\"h1\"}\n```";
let result = executor.execute(response).await;
assert!(result.is_err());
}
#[test]
fn validate_url_empty_string() {
let err = validate_url("").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn validate_url_javascript_scheme_blocked() {
let err = validate_url("javascript:alert(1)").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn validate_url_data_scheme_blocked() {
let err = validate_url("data:text/html,<h1>hi</h1>").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn is_private_host_public_domain_is_false() {
let host: url::Host<&str> = url::Host::Domain("example.com");
assert!(!is_private_host(&host));
}
#[test]
fn is_private_host_localhost_is_true() {
let host: url::Host<&str> = url::Host::Domain("localhost");
assert!(is_private_host(&host));
}
#[test]
fn is_private_host_ipv6_unspecified_is_true() {
let host = url::Host::Ipv6(std::net::Ipv6Addr::UNSPECIFIED);
assert!(is_private_host(&host));
}
#[test]
fn is_private_host_public_ipv6_is_false() {
let host = url::Host::Ipv6("2001:db8::1".parse().unwrap());
assert!(!is_private_host(&host));
}
// --- fetch_html redirect logic: wiremock HTTP server tests ---
//
// These tests use a local wiremock server to exercise the redirect-following logic
// in `fetch_html` without requiring an external HTTPS connection. The server binds to
// 127.0.0.1, and tests call `fetch_html` directly (bypassing `validate_url`) to avoid
// the SSRF guard that would otherwise block loopback connections.
/// Helper: returns executor + (`server_url`, `server_addr`) from a running wiremock mock server.
/// The server address is passed to `fetch_html` via `resolve_to_addrs` so the client
/// connects to the mock instead of doing a real DNS lookup.
async fn mock_server_executor() -> (WebScrapeExecutor, wiremock::MockServer) {
let server = wiremock::MockServer::start().await;
let executor = WebScrapeExecutor {
timeout: Duration::from_secs(5),
max_body_bytes: 1_048_576,
allowed_domains: vec![],
denied_domains: vec![],
audit_logger: None,
egress_config: EgressConfig::default(),
egress_tx: None,
egress_dropped: Arc::new(AtomicU64::new(0)),
};
(executor, server)
}
/// Parses the mock server's URI into (`host_str`, `socket_addr`) for use with `build_client`.
fn server_host_and_addr(server: &wiremock::MockServer) -> (String, Vec<std::net::SocketAddr>) {
let uri = server.uri();
let url = Url::parse(&uri).unwrap();
let host = url.host_str().unwrap_or("127.0.0.1").to_owned();
let port = url.port().unwrap_or(80);
let addr: std::net::SocketAddr = format!("{host}:{port}").parse().unwrap();
(host, vec![addr])
}
/// Test-only redirect follower that mimics `fetch_html`'s loop but skips `validate_url` /
/// `resolve_and_validate`. This lets us exercise the redirect-counting and
/// missing-Location logic against a plain HTTP wiremock server.
async fn follow_redirects_raw(
executor: &WebScrapeExecutor,
start_url: &str,
host: &str,
addrs: &[std::net::SocketAddr],
) -> Result<String, ToolError> {
const MAX_REDIRECTS: usize = 3;
let mut current_url = start_url.to_owned();
let mut current_host = host.to_owned();
let mut current_addrs = addrs.to_vec();
for hop in 0..=MAX_REDIRECTS {
let client = executor.build_client(¤t_host, ¤t_addrs);
let resp = client
.get(¤t_url)
.send()
.await
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
let status = resp.status();
if status.is_redirection() {
if hop == MAX_REDIRECTS {
return Err(ToolError::Execution(std::io::Error::other(
"too many redirects",
)));
}
let location = resp
.headers()
.get(reqwest::header::LOCATION)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
ToolError::Execution(std::io::Error::other("redirect with no Location"))
})?;
let base = Url::parse(¤t_url)
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
let next_url = base
.join(location)
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
// Re-use same host/addrs (mock server is always the same endpoint).
current_url = next_url.to_string();
// Preserve host/addrs as-is since the mock server doesn't change.
let _ = &mut current_host;
let _ = &mut current_addrs;
continue;
}
if !status.is_success() {
return Err(ToolError::Execution(std::io::Error::other(format!(
"HTTP {status}",
))));
}
let bytes = resp
.bytes()
.await
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())))?;
if bytes.len() > executor.max_body_bytes {
return Err(ToolError::Execution(std::io::Error::other(format!(
"response too large: {} bytes (max: {})",
bytes.len(),
executor.max_body_bytes,
))));
}
return String::from_utf8(bytes.to_vec())
.map_err(|e| ToolError::Execution(std::io::Error::other(e.to_string())));
}
Err(ToolError::Execution(std::io::Error::other(
"too many redirects",
)))
}
#[tokio::test]
async fn fetch_html_success_returns_body() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
Mock::given(method("GET"))
.and(path("/page"))
.respond_with(ResponseTemplate::new(200).set_body_string("<h1>OK</h1>"))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/page", server.uri());
let result = executor
.fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
.await;
assert!(result.is_ok(), "expected Ok, got: {result:?}");
assert_eq!(result.unwrap(), "<h1>OK</h1>");
}
#[tokio::test]
async fn fetch_html_non_2xx_returns_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
Mock::given(method("GET"))
.and(path("/forbidden"))
.respond_with(ResponseTemplate::new(403))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/forbidden", server.uri());
let result = executor
.fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
.await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("403"), "expected 403 in error: {msg}");
}
#[tokio::test]
async fn fetch_html_404_returns_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
Mock::given(method("GET"))
.and(path("/missing"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/missing", server.uri());
let result = executor
.fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
.await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("404"), "expected 404 in error: {msg}");
}
#[tokio::test]
async fn fetch_html_redirect_no_location_returns_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
// 302 with no Location header
Mock::given(method("GET"))
.and(path("/redirect-no-loc"))
.respond_with(ResponseTemplate::new(302))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/redirect-no-loc", server.uri());
let result = executor
.fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
.await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Location") || msg.contains("location"),
"expected Location-related error: {msg}"
);
}
#[tokio::test]
async fn fetch_html_single_redirect_followed() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
let final_url = format!("{}/final", server.uri());
Mock::given(method("GET"))
.and(path("/start"))
.respond_with(ResponseTemplate::new(302).insert_header("location", final_url.as_str()))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/final"))
.respond_with(ResponseTemplate::new(200).set_body_string("<p>final</p>"))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/start", server.uri());
let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
assert!(result.is_ok(), "single redirect should succeed: {result:?}");
assert_eq!(result.unwrap(), "<p>final</p>");
}
#[tokio::test]
async fn fetch_html_three_redirects_allowed() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
let hop2 = format!("{}/hop2", server.uri());
let hop3 = format!("{}/hop3", server.uri());
let final_dest = format!("{}/done", server.uri());
Mock::given(method("GET"))
.and(path("/hop1"))
.respond_with(ResponseTemplate::new(301).insert_header("location", hop2.as_str()))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/hop2"))
.respond_with(ResponseTemplate::new(301).insert_header("location", hop3.as_str()))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/hop3"))
.respond_with(ResponseTemplate::new(301).insert_header("location", final_dest.as_str()))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/done"))
.respond_with(ResponseTemplate::new(200).set_body_string("<p>done</p>"))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/hop1", server.uri());
let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
assert!(result.is_ok(), "3 redirects should succeed: {result:?}");
assert_eq!(result.unwrap(), "<p>done</p>");
}
#[tokio::test]
async fn fetch_html_four_redirects_rejected() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
let hop2 = format!("{}/r2", server.uri());
let hop3 = format!("{}/r3", server.uri());
let hop4 = format!("{}/r4", server.uri());
let hop5 = format!("{}/r5", server.uri());
for (from, to) in [
("/r1", &hop2),
("/r2", &hop3),
("/r3", &hop4),
("/r4", &hop5),
] {
Mock::given(method("GET"))
.and(path(from))
.respond_with(ResponseTemplate::new(301).insert_header("location", to.as_str()))
.mount(&server)
.await;
}
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/r1", server.uri());
let result = follow_redirects_raw(&executor, &url, &host, &addrs).await;
assert!(result.is_err(), "4 redirects should be rejected");
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("redirect"),
"expected redirect-related error: {msg}"
);
}
#[tokio::test]
async fn fetch_html_body_too_large_returns_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let small_limit_executor = WebScrapeExecutor {
timeout: Duration::from_secs(5),
max_body_bytes: 10,
allowed_domains: vec![],
denied_domains: vec![],
audit_logger: None,
egress_config: EgressConfig::default(),
egress_tx: None,
egress_dropped: Arc::new(AtomicU64::new(0)),
};
let server = wiremock::MockServer::start().await;
Mock::given(method("GET"))
.and(path("/big"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("this body is definitely longer than ten bytes"),
)
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/big", server.uri());
let result = small_limit_executor
.fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
.await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("too large"), "expected too-large error: {msg}");
}
#[test]
fn extract_scrape_blocks_empty_block_content() {
let text = "```scrape\n\n```";
let blocks = extract_scrape_blocks(text);
assert_eq!(blocks.len(), 1);
assert!(blocks[0].is_empty());
}
#[test]
fn extract_scrape_blocks_whitespace_only() {
let text = "```scrape\n \n```";
let blocks = extract_scrape_blocks(text);
assert_eq!(blocks.len(), 1);
}
#[test]
fn parse_and_extract_multiple_selectors() {
let html = "<div><h1>Title</h1><p>Para</p></div>";
let result = parse_and_extract(html, "h1, p", &ExtractMode::Text, 10).unwrap();
assert!(result.contains("Title"));
assert!(result.contains("Para"));
}
#[test]
fn webscrape_executor_new_with_custom_config() {
let config = ScrapeConfig {
timeout: 60,
max_body_bytes: 512,
..Default::default()
};
let executor = WebScrapeExecutor::new(&config);
assert_eq!(executor.max_body_bytes, 512);
}
#[test]
fn webscrape_executor_debug() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let dbg = format!("{executor:?}");
assert!(dbg.contains("WebScrapeExecutor"));
}
#[test]
fn extract_mode_attr_empty_name() {
let mode = ExtractMode::parse("attr:");
assert!(matches!(mode, ExtractMode::Attr(ref s) if s.is_empty()));
}
#[test]
fn default_extract_returns_text() {
assert_eq!(default_extract(), "text");
}
#[test]
fn scrape_instruction_debug() {
let json = r#"{"url":"https://example.com","select":"h1"}"#;
let instr: ScrapeInstruction = serde_json::from_str(json).unwrap();
let dbg = format!("{instr:?}");
assert!(dbg.contains("ScrapeInstruction"));
}
#[test]
fn extract_mode_debug() {
let mode = ExtractMode::Text;
let dbg = format!("{mode:?}");
assert!(dbg.contains("Text"));
}
// --- fetch_html redirect logic: constant and validation unit tests ---
/// `MAX_REDIRECTS` is 3; the 4th redirect attempt must be rejected.
/// Verify the boundary is correct by inspecting the constant value.
#[test]
fn max_redirects_constant_is_three() {
// fetch_html uses `for hop in 0..=MAX_REDIRECTS` and returns error when hop == MAX_REDIRECTS
// while still in a redirect. That means hops 0,1,2 can redirect; hop 3 triggers the error.
// This test documents the expected limit.
const MAX_REDIRECTS: usize = 3;
assert_eq!(MAX_REDIRECTS, 3, "fetch_html allows exactly 3 redirects");
}
/// Verifies that a Location-less redirect would produce an error string containing the
/// expected message, matching the error path in `fetch_html`.
#[test]
fn redirect_no_location_error_message() {
let err = std::io::Error::other("redirect with no Location");
assert!(err.to_string().contains("redirect with no Location"));
}
/// Verifies that a too-many-redirects condition produces the expected error string.
#[test]
fn too_many_redirects_error_message() {
let err = std::io::Error::other("too many redirects");
assert!(err.to_string().contains("too many redirects"));
}
/// Verifies that a non-2xx HTTP status produces an error message with the status code.
#[test]
fn non_2xx_status_error_format() {
let status = reqwest::StatusCode::FORBIDDEN;
let msg = format!("HTTP {status}");
assert!(msg.contains("403"));
}
/// Verifies that a 404 response status code formats into the expected error message.
#[test]
fn not_found_status_error_format() {
let status = reqwest::StatusCode::NOT_FOUND;
let msg = format!("HTTP {status}");
assert!(msg.contains("404"));
}
/// Verifies relative redirect resolution for same-host paths (simulates Location: /other).
#[test]
fn relative_redirect_same_host_path() {
let base = Url::parse("https://example.com/current").unwrap();
let resolved = base.join("/other").unwrap();
assert_eq!(resolved.as_str(), "https://example.com/other");
}
/// Verifies relative redirect resolution preserves scheme and host.
#[test]
fn relative_redirect_relative_path() {
let base = Url::parse("https://example.com/a/b").unwrap();
let resolved = base.join("c").unwrap();
assert_eq!(resolved.as_str(), "https://example.com/a/c");
}
/// Verifies that an absolute redirect URL overrides base URL completely.
#[test]
fn absolute_redirect_overrides_base() {
let base = Url::parse("https://example.com/page").unwrap();
let resolved = base.join("https://other.com/target").unwrap();
assert_eq!(resolved.as_str(), "https://other.com/target");
}
/// Verifies that a redirect Location of http:// (downgrade) is rejected.
#[test]
fn redirect_http_downgrade_rejected() {
let location = "http://example.com/page";
let base = Url::parse("https://example.com/start").unwrap();
let next = base.join(location).unwrap();
let err = validate_url(next.as_str()).unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
/// Verifies that a redirect to a private IP literal is blocked.
#[test]
fn redirect_location_private_ip_blocked() {
let location = "https://192.168.100.1/admin";
let base = Url::parse("https://example.com/start").unwrap();
let next = base.join(location).unwrap();
let err = validate_url(next.as_str()).unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
let ToolError::Blocked { command: cmd } = err else {
panic!("expected Blocked");
};
assert!(
cmd.contains("private") || cmd.contains("scheme"),
"error message should describe the block reason: {cmd}"
);
}
/// Verifies that a redirect to a .internal domain is blocked.
#[test]
fn redirect_location_internal_domain_blocked() {
let location = "https://metadata.internal/latest/meta-data/";
let base = Url::parse("https://example.com/start").unwrap();
let next = base.join(location).unwrap();
let err = validate_url(next.as_str()).unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
/// Verifies that a chain of 3 valid public redirects passes `validate_url` at every hop.
#[test]
fn redirect_chain_three_hops_all_public() {
let hops = [
"https://redirect1.example.com/hop1",
"https://redirect2.example.com/hop2",
"https://destination.example.com/final",
];
for hop in hops {
assert!(validate_url(hop).is_ok(), "expected ok for {hop}");
}
}
// --- SSRF redirect chain defense ---
/// Verifies that a redirect Location pointing to a private IP is rejected by `validate_url`
/// before any connection attempt — simulating the validation step inside `fetch_html`.
#[test]
fn redirect_to_private_ip_rejected_by_validate_url() {
// These would appear as Location headers in a redirect response.
let private_targets = [
"https://127.0.0.1/secret",
"https://10.0.0.1/internal",
"https://192.168.1.1/admin",
"https://172.16.0.1/data",
"https://[::1]/path",
"https://[fe80::1]/path",
"https://localhost/path",
"https://service.internal/api",
];
for target in private_targets {
let result = validate_url(target);
assert!(result.is_err(), "expected error for {target}");
assert!(
matches!(result.unwrap_err(), ToolError::Blocked { .. }),
"expected Blocked for {target}"
);
}
}
/// Verifies that relative redirect URLs are resolved correctly before validation.
#[test]
fn redirect_relative_url_resolves_correctly() {
let base = Url::parse("https://example.com/page").unwrap();
let relative = "/other";
let resolved = base.join(relative).unwrap();
assert_eq!(resolved.as_str(), "https://example.com/other");
}
/// Verifies that a protocol-relative redirect to http:// is rejected (scheme check).
#[test]
fn redirect_to_http_rejected() {
let err = validate_url("http://example.com/page").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ipv4_mapped_ipv6_link_local_blocked() {
let err = validate_url("https://[::ffff:169.254.0.1]/path").unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn ipv4_mapped_ipv6_public_allowed() {
assert!(validate_url("https://[::ffff:93.184.216.34]/path").is_ok());
}
// --- fetch tool ---
#[tokio::test]
async fn fetch_http_scheme_blocked() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let call = crate::executor::ToolCall {
tool_id: ToolName::new("fetch"),
params: {
let mut m = serde_json::Map::new();
m.insert("url".to_owned(), serde_json::json!("http://example.com"));
m
},
caller_id: None,
};
let result = executor.execute_tool_call(&call).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
}
#[tokio::test]
async fn fetch_private_ip_blocked() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let call = crate::executor::ToolCall {
tool_id: ToolName::new("fetch"),
params: {
let mut m = serde_json::Map::new();
m.insert(
"url".to_owned(),
serde_json::json!("https://192.168.1.1/secret"),
);
m
},
caller_id: None,
};
let result = executor.execute_tool_call(&call).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
}
#[tokio::test]
async fn fetch_localhost_blocked() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let call = crate::executor::ToolCall {
tool_id: ToolName::new("fetch"),
params: {
let mut m = serde_json::Map::new();
m.insert(
"url".to_owned(),
serde_json::json!("https://localhost/page"),
);
m
},
caller_id: None,
};
let result = executor.execute_tool_call(&call).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
}
#[tokio::test]
async fn fetch_unknown_tool_returns_none() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let call = crate::executor::ToolCall {
tool_id: ToolName::new("unknown_tool"),
params: serde_json::Map::new(),
caller_id: None,
};
let result = executor.execute_tool_call(&call).await;
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn fetch_returns_body_via_mock() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
Mock::given(method("GET"))
.and(path("/content"))
.respond_with(ResponseTemplate::new(200).set_body_string("plain text content"))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
let url = format!("{}/content", server.uri());
let result = executor
.fetch_html(&url, &host, &addrs, "fetch", "test-cid", None)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "plain text content");
}
#[test]
fn tool_definitions_returns_web_scrape_and_fetch() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let defs = executor.tool_definitions();
assert_eq!(defs.len(), 2);
assert_eq!(defs[0].id, "web_scrape");
assert_eq!(
defs[0].invocation,
crate::registry::InvocationHint::FencedBlock("scrape")
);
assert_eq!(defs[1].id, "fetch");
assert_eq!(
defs[1].invocation,
crate::registry::InvocationHint::ToolCall
);
}
#[test]
fn tool_definitions_schema_has_all_params() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
let defs = executor.tool_definitions();
let obj = defs[0].schema.as_object().unwrap();
let props = obj["properties"].as_object().unwrap();
assert!(props.contains_key("url"));
assert!(props.contains_key("select"));
assert!(props.contains_key("extract"));
assert!(props.contains_key("limit"));
let req = obj["required"].as_array().unwrap();
assert!(req.iter().any(|v| v.as_str() == Some("url")));
assert!(req.iter().any(|v| v.as_str() == Some("select")));
assert!(!req.iter().any(|v| v.as_str() == Some("extract")));
}
// --- is_private_host: new domain checks (AUD-02) ---
#[test]
fn subdomain_localhost_blocked() {
let host: url::Host<&str> = url::Host::Domain("foo.localhost");
assert!(is_private_host(&host));
}
#[test]
fn internal_tld_blocked() {
let host: url::Host<&str> = url::Host::Domain("service.internal");
assert!(is_private_host(&host));
}
#[test]
fn local_tld_blocked() {
let host: url::Host<&str> = url::Host::Domain("printer.local");
assert!(is_private_host(&host));
}
#[test]
fn public_domain_not_blocked() {
let host: url::Host<&str> = url::Host::Domain("example.com");
assert!(!is_private_host(&host));
}
// --- resolve_and_validate: private IP rejection ---
#[tokio::test]
async fn resolve_loopback_rejected() {
// 127.0.0.1 resolves directly (literal IP in DNS query)
let url = url::Url::parse("https://127.0.0.1/path").unwrap();
// validate_url catches this before resolve_and_validate, but test directly
let result = resolve_and_validate(&url).await;
assert!(
result.is_err(),
"loopback IP must be rejected by resolve_and_validate"
);
let err = result.unwrap_err();
assert!(matches!(err, crate::executor::ToolError::Blocked { .. }));
}
#[tokio::test]
async fn resolve_private_10_rejected() {
let url = url::Url::parse("https://10.0.0.1/path").unwrap();
let result = resolve_and_validate(&url).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::executor::ToolError::Blocked { .. }
));
}
#[tokio::test]
async fn resolve_private_192_rejected() {
let url = url::Url::parse("https://192.168.1.1/path").unwrap();
let result = resolve_and_validate(&url).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::executor::ToolError::Blocked { .. }
));
}
#[tokio::test]
async fn resolve_ipv6_loopback_rejected() {
let url = url::Url::parse("https://[::1]/path").unwrap();
let result = resolve_and_validate(&url).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
crate::executor::ToolError::Blocked { .. }
));
}
#[tokio::test]
async fn resolve_no_host_returns_ok() {
// URL without a resolvable host — should pass through
let url = url::Url::parse("https://example.com/path").unwrap();
// We can't do a live DNS test, but we can verify a URL with no host
let url_no_host = url::Url::parse("data:text/plain,hello").unwrap();
// data: URLs have no host; resolve_and_validate should return Ok with empty addrs
let result = resolve_and_validate(&url_no_host).await;
assert!(result.is_ok());
let (host, addrs) = result.unwrap();
assert!(host.is_empty());
assert!(addrs.is_empty());
drop(url);
drop(url_no_host);
}
// --- audit logging ---
/// Helper: build an `AuditLogger` writing to a temp file, and return the logger + path.
async fn make_file_audit_logger(
dir: &tempfile::TempDir,
) -> (
std::sync::Arc<crate::audit::AuditLogger>,
std::path::PathBuf,
) {
use crate::audit::AuditLogger;
use crate::config::AuditConfig;
let path = dir.path().join("audit.log");
let config = AuditConfig {
enabled: true,
destination: path.display().to_string(),
..Default::default()
};
let logger = std::sync::Arc::new(AuditLogger::from_config(&config, false).await.unwrap());
(logger, path)
}
#[tokio::test]
async fn with_audit_sets_logger() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
assert!(executor.audit_logger.is_none());
let dir = tempfile::tempdir().unwrap();
let (logger, _path) = make_file_audit_logger(&dir).await;
let executor = executor.with_audit(logger);
assert!(executor.audit_logger.is_some());
}
#[test]
fn tool_error_to_audit_result_blocked_maps_correctly() {
let err = ToolError::Blocked {
command: "scheme not allowed: http".into(),
};
let result = tool_error_to_audit_result(&err);
assert!(
matches!(result, AuditResult::Blocked { reason } if reason == "scheme not allowed: http")
);
}
#[test]
fn tool_error_to_audit_result_timeout_maps_correctly() {
let err = ToolError::Timeout { timeout_secs: 15 };
let result = tool_error_to_audit_result(&err);
assert!(matches!(result, AuditResult::Timeout));
}
#[test]
fn tool_error_to_audit_result_execution_error_maps_correctly() {
let err = ToolError::Execution(std::io::Error::other("connection refused"));
let result = tool_error_to_audit_result(&err);
assert!(
matches!(result, AuditResult::Error { message } if message.contains("connection refused"))
);
}
#[tokio::test]
async fn fetch_audit_blocked_url_logged() {
let dir = tempfile::tempdir().unwrap();
let (logger, log_path) = make_file_audit_logger(&dir).await;
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config).with_audit(logger);
let call = crate::executor::ToolCall {
tool_id: ToolName::new("fetch"),
params: {
let mut m = serde_json::Map::new();
m.insert("url".to_owned(), serde_json::json!("http://example.com"));
m
},
caller_id: None,
};
let result = executor.execute_tool_call(&call).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
let content = tokio::fs::read_to_string(&log_path).await.unwrap();
assert!(
content.contains("\"tool\":\"fetch\""),
"expected tool=fetch in audit: {content}"
);
assert!(
content.contains("\"type\":\"blocked\""),
"expected type=blocked in audit: {content}"
);
assert!(
content.contains("http://example.com"),
"expected URL in audit command field: {content}"
);
}
#[tokio::test]
async fn log_audit_success_writes_to_file() {
let dir = tempfile::tempdir().unwrap();
let (logger, log_path) = make_file_audit_logger(&dir).await;
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config).with_audit(logger);
executor
.log_audit(
"fetch",
"https://example.com/page",
AuditResult::Success,
42,
None,
None,
None,
)
.await;
let content = tokio::fs::read_to_string(&log_path).await.unwrap();
assert!(
content.contains("\"tool\":\"fetch\""),
"expected tool=fetch in audit: {content}"
);
assert!(
content.contains("\"type\":\"success\""),
"expected type=success in audit: {content}"
);
assert!(
content.contains("\"command\":\"https://example.com/page\""),
"expected command URL in audit: {content}"
);
assert!(
content.contains("\"duration_ms\":42"),
"expected duration_ms in audit: {content}"
);
}
#[tokio::test]
async fn log_audit_blocked_writes_to_file() {
let dir = tempfile::tempdir().unwrap();
let (logger, log_path) = make_file_audit_logger(&dir).await;
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config).with_audit(logger);
executor
.log_audit(
"web_scrape",
"http://evil.com/page",
AuditResult::Blocked {
reason: "scheme not allowed: http".into(),
},
0,
None,
None,
None,
)
.await;
let content = tokio::fs::read_to_string(&log_path).await.unwrap();
assert!(
content.contains("\"tool\":\"web_scrape\""),
"expected tool=web_scrape in audit: {content}"
);
assert!(
content.contains("\"type\":\"blocked\""),
"expected type=blocked in audit: {content}"
);
assert!(
content.contains("scheme not allowed"),
"expected block reason in audit: {content}"
);
}
#[tokio::test]
async fn web_scrape_audit_blocked_url_logged() {
let dir = tempfile::tempdir().unwrap();
let (logger, log_path) = make_file_audit_logger(&dir).await;
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config).with_audit(logger);
let call = crate::executor::ToolCall {
tool_id: ToolName::new("web_scrape"),
params: {
let mut m = serde_json::Map::new();
m.insert("url".to_owned(), serde_json::json!("http://example.com"));
m.insert("select".to_owned(), serde_json::json!("h1"));
m
},
caller_id: None,
};
let result = executor.execute_tool_call(&call).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
let content = tokio::fs::read_to_string(&log_path).await.unwrap();
assert!(
content.contains("\"tool\":\"web_scrape\""),
"expected tool=web_scrape in audit: {content}"
);
assert!(
content.contains("\"type\":\"blocked\""),
"expected type=blocked in audit: {content}"
);
}
#[tokio::test]
async fn no_audit_logger_does_not_panic_on_blocked_fetch() {
let config = ScrapeConfig::default();
let executor = WebScrapeExecutor::new(&config);
assert!(executor.audit_logger.is_none());
let call = crate::executor::ToolCall {
tool_id: ToolName::new("fetch"),
params: {
let mut m = serde_json::Map::new();
m.insert("url".to_owned(), serde_json::json!("http://example.com"));
m
},
caller_id: None,
};
// Must not panic even without an audit logger
let result = executor.execute_tool_call(&call).await;
assert!(matches!(result, Err(ToolError::Blocked { .. })));
}
// CR-10: fetch end-to-end via execute_tool_call -> handle_fetch -> fetch_html
#[tokio::test]
async fn fetch_execute_tool_call_end_to_end() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
let (executor, server) = mock_server_executor().await;
Mock::given(method("GET"))
.and(path("/e2e"))
.respond_with(ResponseTemplate::new(200).set_body_string("<h1>end-to-end</h1>"))
.mount(&server)
.await;
let (host, addrs) = server_host_and_addr(&server);
// Call fetch_html directly (bypassing SSRF guard for loopback mock server)
let result = executor
.fetch_html(
&format!("{}/e2e", server.uri()),
&host,
&addrs,
"fetch",
"test-cid",
None,
)
.await;
assert!(result.is_ok());
assert!(result.unwrap().contains("end-to-end"));
}
// --- domain_matches ---
#[test]
fn domain_matches_exact() {
assert!(domain_matches("example.com", "example.com"));
assert!(!domain_matches("example.com", "other.com"));
assert!(!domain_matches("example.com", "sub.example.com"));
}
#[test]
fn domain_matches_wildcard_single_subdomain() {
assert!(domain_matches("*.example.com", "sub.example.com"));
assert!(!domain_matches("*.example.com", "example.com"));
assert!(!domain_matches("*.example.com", "sub.sub.example.com"));
}
#[test]
fn domain_matches_wildcard_does_not_match_empty_label() {
// Pattern "*.example.com" requires a non-empty label before ".example.com"
assert!(!domain_matches("*.example.com", ".example.com"));
}
#[test]
fn domain_matches_multi_wildcard_treated_as_exact() {
// Multiple wildcards are unsupported — treated as literal pattern
assert!(!domain_matches("*.*.example.com", "a.b.example.com"));
}
// --- check_domain_policy ---
#[test]
fn check_domain_policy_empty_lists_allow_all() {
assert!(check_domain_policy("example.com", &[], &[]).is_ok());
assert!(check_domain_policy("evil.com", &[], &[]).is_ok());
}
#[test]
fn check_domain_policy_denylist_blocks() {
let denied = vec!["evil.com".to_string()];
let err = check_domain_policy("evil.com", &[], &denied).unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn check_domain_policy_denylist_does_not_block_other_domains() {
let denied = vec!["evil.com".to_string()];
assert!(check_domain_policy("good.com", &[], &denied).is_ok());
}
#[test]
fn check_domain_policy_allowlist_permits_matching() {
let allowed = vec!["docs.rs".to_string(), "*.rust-lang.org".to_string()];
assert!(check_domain_policy("docs.rs", &allowed, &[]).is_ok());
assert!(check_domain_policy("blog.rust-lang.org", &allowed, &[]).is_ok());
}
#[test]
fn check_domain_policy_allowlist_blocks_unknown() {
let allowed = vec!["docs.rs".to_string()];
let err = check_domain_policy("other.com", &allowed, &[]).unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn check_domain_policy_deny_overrides_allow() {
let allowed = vec!["example.com".to_string()];
let denied = vec!["example.com".to_string()];
let err = check_domain_policy("example.com", &allowed, &denied).unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
}
#[test]
fn check_domain_policy_wildcard_in_denylist() {
let denied = vec!["*.evil.com".to_string()];
let err = check_domain_policy("sub.evil.com", &[], &denied).unwrap_err();
assert!(matches!(err, ToolError::Blocked { .. }));
// parent domain not blocked
assert!(check_domain_policy("evil.com", &[], &denied).is_ok());
}
}